diff --git a/nanochat/engine.py b/nanochat/engine.py index 7f05eb4..a1ba24c 100644 --- a/nanochat/engine.py +++ b/nanochat/engine.py @@ -306,8 +306,8 @@ if __name__ == "__main__": """ import time # init compute - ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init() device_type = autodetect_device_type() + ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type) autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type == "cuda" else nullcontext() # load the model and tokenizer