fix: pass device_type to compute_init in engine.__main__ (#451)

When running engine.py directly on non-GPU devices (CPU, MPS),
compute_init() needs the device_type parameter to initialize correctly.
This fixes failures on machines without CUDA support.
This commit is contained in:
xiayan0118 2026-01-19 17:19:51 -08:00 committed by GitHub
parent 63bb5831e2
commit 6a477eedbd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -306,8 +306,8 @@ if __name__ == "__main__":
"""
import time
# init compute
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init()
device_type = autodetect_device_type()
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type == "cuda" else nullcontext()
# load the model and tokenizer