mirror of
https://github.com/karpathy/nanochat.git
synced 2026-04-03 22:25:27 +00:00
Resolves all crashes and silent errors when running on Apple Silicon (MPS) or CPU after merging upstream/master. Tested on torch 2.2.2 and 2.9.1. nanochat/engine.py - Replace signal.SIGALRM timeout with concurrent.futures.ThreadPoolExecutor so use_calculator() works from FastAPI worker threads (SIGALRM is Unix main-thread only, silently broken in any threaded web server) - Guard torch.cuda.synchronize() behind device_type == 'cuda' nanochat/gpt.py - Extract init_rotary_embeddings() from init_weights() so checkpoint loading can restore non-persistent cos/sin buffers without re-randomising all weights - Cast rotary cos/sin to bfloat16 on CUDA only (MPS bfloat16 requires torch>=2.4; float32 used on MPS/CPU) - Update forward() dtype assertion to match device - Add F.rms_norm fallback for torch<2.4 (rms_norm added in 2.4) nanochat/optim.py - _cuda_compile(): skip torch.compile(fullgraph=True) on MPS/CPU; return function unchanged so eager execution is used - adamw_step_fused / muon_step_fused: move 0-D CPU scalar tensors to parameter device at start of function (cross-device ops crash in eager mode on MPS) - muon_step_fused: use bfloat16 in polar express on CUDA only; fall back to float32 on MPS/CPU - _step_muon: replace torch._foreach_copy_() with p.copy_(s) loop on non-CUDA (_foreach_copy_ not implemented on MPS in torch<2.4) nanochat/flash_attention.py - Probe SDPA for enable_gqa support at import time (added in torch 2.5; inspect.signature raises on C builtins in older Python/torch) - Fall back to manual KV head repetition via repeat_interleave when enable_gqa is unavailable nanochat/checkpoint_manager.py - Call model.init_rotary_embeddings() instead of model.init_weights() after load_state_dict() to restore non-persistent rotary buffers without clobbering loaded weights scripts/base_train.py - Guard torch.compile(model) behind device_type == 'cuda' - Set mfu = None on non-CUDA instead of computing 0/inf = 0.00% - Handle mfu is None in end-of-run report tests/test_mps_compat.py (new) - 16 tests covering every fix; all pass on MPS (torch 2.2.2 and 2.9.1) |
||
|---|---|---|
| .. | ||
| base_eval.py | ||
| base_train.py | ||
| chat_cli.py | ||
| chat_eval.py | ||
| chat_rl.py | ||
| chat_sft.py | ||
| chat_web.py | ||
| tok_eval.py | ||
| tok_train.py | ||