mirror of
https://github.com/karpathy/nanochat.git
synced 2026-01-21 10:54:15 +00:00
fix: pass device_type to compute_init in engine.__main__ (#451)
When running engine.py directly on non-GPU devices (CPU, MPS), compute_init() needs the device_type parameter to initialize correctly. This fixes failures on machines without CUDA support.
This commit is contained in:
parent
63bb5831e2
commit
6a477eedbd
|
|
@ -306,8 +306,8 @@ if __name__ == "__main__":
|
|||
"""
|
||||
import time
|
||||
# init compute
|
||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init()
|
||||
device_type = autodetect_device_type()
|
||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
|
||||
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type == "cuda" else nullcontext()
|
||||
|
||||
# load the model and tokenizer
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user