Resolves all crashes and silent errors when running on Apple Silicon (MPS)
or CPU after merging upstream/master. Tested on torch 2.2.2 and 2.9.1.
nanochat/engine.py
- Replace signal.SIGALRM timeout with concurrent.futures.ThreadPoolExecutor
so use_calculator() works from FastAPI worker threads (SIGALRM is
Unix main-thread only, silently broken in any threaded web server)
- Guard torch.cuda.synchronize() behind device_type == 'cuda'
nanochat/gpt.py
- Extract init_rotary_embeddings() from init_weights() so checkpoint
loading can restore non-persistent cos/sin buffers without
re-randomising all weights
- Cast rotary cos/sin to bfloat16 on CUDA only (MPS bfloat16 requires
torch>=2.4; float32 used on MPS/CPU)
- Update forward() dtype assertion to match device
- Add F.rms_norm fallback for torch<2.4 (rms_norm added in 2.4)
nanochat/optim.py
- _cuda_compile(): skip torch.compile(fullgraph=True) on MPS/CPU;
return function unchanged so eager execution is used
- adamw_step_fused / muon_step_fused: move 0-D CPU scalar tensors to
parameter device at start of function (cross-device ops crash in
eager mode on MPS)
- muon_step_fused: use bfloat16 in polar express on CUDA only;
fall back to float32 on MPS/CPU
- _step_muon: replace torch._foreach_copy_() with p.copy_(s) loop on
non-CUDA (_foreach_copy_ not implemented on MPS in torch<2.4)
nanochat/flash_attention.py
- Probe SDPA for enable_gqa support at import time (added in torch 2.5;
inspect.signature raises on C builtins in older Python/torch)
- Fall back to manual KV head repetition via repeat_interleave when
enable_gqa is unavailable
nanochat/checkpoint_manager.py
- Call model.init_rotary_embeddings() instead of model.init_weights()
after load_state_dict() to restore non-persistent rotary buffers
without clobbering loaded weights
scripts/base_train.py
- Guard torch.compile(model) behind device_type == 'cuda'
- Set mfu = None on non-CUDA instead of computing 0/inf = 0.00%
- Handle mfu is None in end-of-run report
tests/test_mps_compat.py (new)
- 16 tests covering every fix; all pass on MPS (torch 2.2.2 and 2.9.1)
Store quantized input/weight and their inverse scales in _Float8Matmul ctx to avoid re-quantization in backward and reduce saved-activation memory without changing numerics.