nanochat/scripts
Jason Kneen 16c37b7d1d fix: MPS/CPU compatibility for training and inference on Mac
Resolves all crashes and silent errors when running on Apple Silicon (MPS)
or CPU after merging upstream/master. Tested on torch 2.2.2 and 2.9.1.

nanochat/engine.py
- Replace signal.SIGALRM timeout with concurrent.futures.ThreadPoolExecutor
  so use_calculator() works from FastAPI worker threads (SIGALRM is
  Unix main-thread only, silently broken in any threaded web server)
- Guard torch.cuda.synchronize() behind device_type == 'cuda'

nanochat/gpt.py
- Extract init_rotary_embeddings() from init_weights() so checkpoint
  loading can restore non-persistent cos/sin buffers without
  re-randomising all weights
- Cast rotary cos/sin to bfloat16 on CUDA only (MPS bfloat16 requires
  torch>=2.4; float32 used on MPS/CPU)
- Update forward() dtype assertion to match device
- Add F.rms_norm fallback for torch<2.4 (rms_norm added in 2.4)

nanochat/optim.py
- _cuda_compile(): skip torch.compile(fullgraph=True) on MPS/CPU;
  return function unchanged so eager execution is used
- adamw_step_fused / muon_step_fused: move 0-D CPU scalar tensors to
  parameter device at start of function (cross-device ops crash in
  eager mode on MPS)
- muon_step_fused: use bfloat16 in polar express on CUDA only;
  fall back to float32 on MPS/CPU
- _step_muon: replace torch._foreach_copy_() with p.copy_(s) loop on
  non-CUDA (_foreach_copy_ not implemented on MPS in torch<2.4)

nanochat/flash_attention.py
- Probe SDPA for enable_gqa support at import time (added in torch 2.5;
  inspect.signature raises on C builtins in older Python/torch)
- Fall back to manual KV head repetition via repeat_interleave when
  enable_gqa is unavailable

nanochat/checkpoint_manager.py
- Call model.init_rotary_embeddings() instead of model.init_weights()
  after load_state_dict() to restore non-persistent rotary buffers
  without clobbering loaded weights

scripts/base_train.py
- Guard torch.compile(model) behind device_type == 'cuda'
- Set mfu = None on non-CUDA instead of computing 0/inf = 0.00%
- Handle mfu is None in end-of-run report

tests/test_mps_compat.py (new)
- 16 tests covering every fix; all pass on MPS (torch 2.2.2 and 2.9.1)
2026-02-22 15:50:11 +00:00
..
base_eval.py small touchups to the eval script, re-order items etc, cosmetic 2026-02-03 21:03:42 +00:00
base_train.py fix: MPS/CPU compatibility for training and inference on Mac 2026-02-22 15:50:11 +00:00
chat_cli.py remove leftover mid references (#491) 2026-02-02 08:33:46 -08:00
chat_eval.py remove leftover mid references (#491) 2026-02-02 08:33:46 -08:00
chat_rl.py remove leftover mid references (#491) 2026-02-02 08:33:46 -08:00
chat_sft.py tune the data mixture a bit, load optimizer by default when SFT. These were confirmed to be best settings from sweeps of sft 2026-02-18 15:49:18 +00:00
chat_web.py remove leftover mid references (#491) 2026-02-02 08:33:46 -08:00
tok_eval.py initial commit 2025-10-13 06:49:24 -07:00
tok_train.py quick fix to not OOM main speedrun script 2026-01-26 22:31:42 +00:00