Resolves all crashes and silent errors when running on Apple Silicon (MPS)
or CPU after merging upstream/master. Tested on torch 2.2.2 and 2.9.1.
nanochat/engine.py
- Replace signal.SIGALRM timeout with concurrent.futures.ThreadPoolExecutor
so use_calculator() works from FastAPI worker threads (SIGALRM is
Unix main-thread only, silently broken in any threaded web server)
- Guard torch.cuda.synchronize() behind device_type == 'cuda'
nanochat/gpt.py
- Extract init_rotary_embeddings() from init_weights() so checkpoint
loading can restore non-persistent cos/sin buffers without
re-randomising all weights
- Cast rotary cos/sin to bfloat16 on CUDA only (MPS bfloat16 requires
torch>=2.4; float32 used on MPS/CPU)
- Update forward() dtype assertion to match device
- Add F.rms_norm fallback for torch<2.4 (rms_norm added in 2.4)
nanochat/optim.py
- _cuda_compile(): skip torch.compile(fullgraph=True) on MPS/CPU;
return function unchanged so eager execution is used
- adamw_step_fused / muon_step_fused: move 0-D CPU scalar tensors to
parameter device at start of function (cross-device ops crash in
eager mode on MPS)
- muon_step_fused: use bfloat16 in polar express on CUDA only;
fall back to float32 on MPS/CPU
- _step_muon: replace torch._foreach_copy_() with p.copy_(s) loop on
non-CUDA (_foreach_copy_ not implemented on MPS in torch<2.4)
nanochat/flash_attention.py
- Probe SDPA for enable_gqa support at import time (added in torch 2.5;
inspect.signature raises on C builtins in older Python/torch)
- Fall back to manual KV head repetition via repeat_interleave when
enable_gqa is unavailable
nanochat/checkpoint_manager.py
- Call model.init_rotary_embeddings() instead of model.init_weights()
after load_state_dict() to restore non-persistent rotary buffers
without clobbering loaded weights
scripts/base_train.py
- Guard torch.compile(model) behind device_type == 'cuda'
- Set mfu = None on non-CUDA instead of computing 0/inf = 0.00%
- Handle mfu is None in end-of-run report
tests/test_mps_compat.py (new)
- 16 tests covering every fix; all pass on MPS (torch 2.2.2 and 2.9.1)
SDPA fallback now respects sliding window during single-token KV-cache
decode by slicing K/V to the last (window + 1) tokens.
Also simplifies the mask building for chunk inference to properly apply
sliding window in that path as well.
Fixes#452
Co-Authored-By: Kartik Vashishta <kartikv776@gmail.com>
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
* test: add engine generation tests for expected invariants
- test_seed_reproducibility
- test_temperature_zero_determinism
- test_max_tokens_respected
- test_num_samples_count
🤖 Generated with [Claude Code](https://claude.com/claude-code)
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
* Fix temperature test
* add test for seed variation in sampling
Add test for seed variation in sampling with temperature > 0.
* Rename test for clarity
* Shorten assert msg
---------
Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
Co-authored-by: Sofie Van Landeghem <svlandeg@users.noreply.github.com>
Previously, when generating multiple samples (num_samples > 1), the first
token after prefill was sampled once and broadcast to all rows, causing
all samples to start identically. Now the prefill logits are expanded to
num_samples and sampled independently for each row.
Also simplified the generation loop by moving the forward pass to the end
of the loop, eliminating the first_iteration flag and if/else branching.
🤖 Generated with [Claude Code](https://claude.com/claude-code)
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Performance varies by machine and load, making hard assertions flaky.
🤖 Generated with [Claude Code](https://claude.com/claude-code)
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>