Covers the MPS Metal Graph compiler crash that motivated the fix:
adamw_step_fused crashed when p was bf16 (the standard nanochat
config for wte/value_embeds) but the optimizer's shared scalar
hyperparameters were fp32. Two tests:
- test_adamw_step_fused_bf16_param_with_fp32_scalars: smoke test
on bf16 path, verifies no crash and a finite weight update.
- test_adamw_step_fused_fp32_param_unchanged: confirms fp32 path
still produces a sensible update (the dtype-cast patch is a
no-op when source dtype matches target).
Both tests run on CPU (default) or MPS (when available). Muon's
mixed-dtype path is gated on the COMPUTE_DTYPE module constant
(set from NANOCHAT_DTYPE env var at import time), which is
awkward to exercise in a unit test without subprocess; the muon
fix is covered by manual end-to-end testing on M2 + bf16 instead.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
When NANOCHAT_DTYPE=bfloat16 is set on Apple Silicon (MPS), the
optimizer step crashes with:
'mps.multiply' op requires the same element type for all operands
CUDA implicitly promotes mixed-dtype operands; MPS hard-fails.
Two minimal fixes in optim.py:
1. adamw_step_fused: cast scalar hyperparams (lr, wd, betas, eps)
to p.dtype at the top of the function, then use the cast versions.
nanochat stores some params (wte, value_embeds) at COMPUTE_DTYPE
to save memory but the optimizer's shared scalar tensors are fp32.
First crash site was `p.mul_(1 - lr_t * wd_t)` on a bf16 wte.
2. muon_step_fused: cast g back to stacked_params.dtype after the
polar express loop. The loop intentionally runs in bf16 for speed
(`X = g.bfloat16()`); without the cast back, the subsequent
variance reduction mixes bf16 g with fp32 second_momentum_buffer.
Verified end-to-end on torch 2.9.1 (the upstream pin) on M2 MPS:
forward + backward + optimizer.step all pass under bf16. Fp32 path
is unchanged (identical loss values to master to 6 decimals).
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
When swapping Float8Linear to Linear in disable_fp8 context manager,
using device=fp8_module.weight.device directly allocates new tensors
on GPU, causing unnecessary VRAM spike (~1GB for large models).
This fix uses device='meta' to avoid physical memory allocation,
then swaps in the weight tensor reference. This eliminates the
unnecessary VRAM spike during evaluation phase.
Fixes issue #592
Co-authored-by: RoomWithOutRoof <roomwithoutroof@sparklab.ai>
The bf16 cast is intentional for speed on Hopper+ GPUs, but should be
skipped on other platforms rather than blindly applied. fp16 is unstable
here due to its limited exponent range, and fp32 platforms don't benefit
from the cast. Now: bf16 when COMPUTE_DTYPE is bf16, no cast otherwise.
Inspired by PR #667.
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
New architectural features:
- Smear: mix previous token embedding into current position via learned
gate, providing cheap bigram-like info (works in training + KV cache)
- Backout: subtract learned fraction of mid-layer residual before logit
projection to remove low-level features
Hyperparameter tuning:
- Muon momentum warmdown 0.97→0.90 during LR warmdown phase
- Non-uniform per-layer init: resid_lambdas 1.15→1.05, x0_lambdas 0.20→0.05
- c_fc init scale 0.4x, QK norm scale 1.2, sliding window seq_len/4
- Speedrun data:params ratio reduced to 8
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
* printing steps count
* adding reply only loss for chat
* using the mask by render_conversation function of tokeniser
* undoing some changes
* putting back the comment which got removed accidently, no functionality change