nanochat/tests
Jason Kneen 16c37b7d1d fix: MPS/CPU compatibility for training and inference on Mac
Resolves all crashes and silent errors when running on Apple Silicon (MPS)
or CPU after merging upstream/master. Tested on torch 2.2.2 and 2.9.1.

nanochat/engine.py
- Replace signal.SIGALRM timeout with concurrent.futures.ThreadPoolExecutor
  so use_calculator() works from FastAPI worker threads (SIGALRM is
  Unix main-thread only, silently broken in any threaded web server)
- Guard torch.cuda.synchronize() behind device_type == 'cuda'

nanochat/gpt.py
- Extract init_rotary_embeddings() from init_weights() so checkpoint
  loading can restore non-persistent cos/sin buffers without
  re-randomising all weights
- Cast rotary cos/sin to bfloat16 on CUDA only (MPS bfloat16 requires
  torch>=2.4; float32 used on MPS/CPU)
- Update forward() dtype assertion to match device
- Add F.rms_norm fallback for torch<2.4 (rms_norm added in 2.4)

nanochat/optim.py
- _cuda_compile(): skip torch.compile(fullgraph=True) on MPS/CPU;
  return function unchanged so eager execution is used
- adamw_step_fused / muon_step_fused: move 0-D CPU scalar tensors to
  parameter device at start of function (cross-device ops crash in
  eager mode on MPS)
- muon_step_fused: use bfloat16 in polar express on CUDA only;
  fall back to float32 on MPS/CPU
- _step_muon: replace torch._foreach_copy_() with p.copy_(s) loop on
  non-CUDA (_foreach_copy_ not implemented on MPS in torch<2.4)

nanochat/flash_attention.py
- Probe SDPA for enable_gqa support at import time (added in torch 2.5;
  inspect.signature raises on C builtins in older Python/torch)
- Fall back to manual KV head repetition via repeat_interleave when
  enable_gqa is unavailable

nanochat/checkpoint_manager.py
- Call model.init_rotary_embeddings() instead of model.init_weights()
  after load_state_dict() to restore non-persistent rotary buffers
  without clobbering loaded weights

scripts/base_train.py
- Guard torch.compile(model) behind device_type == 'cuda'
- Set mfu = None on non-CUDA instead of computing 0/inf = 0.00%
- Handle mfu is None in end-of-run report

tests/test_mps_compat.py (new)
- 16 tests covering every fix; all pass on MPS (torch 2.2.2 and 2.9.1)
2026-02-22 15:50:11 +00:00
..
test_attention_fallback.py Fix SDPA KV-cache decode to respect sliding window (#456) 2026-01-30 17:32:12 +00:00
test_engine.py Fix MockModel's device definition (#535) 2026-02-17 16:03:46 -08:00
test_mps_compat.py fix: MPS/CPU compatibility for training and inference on Mac 2026-02-22 15:50:11 +00:00