From bd5232cafccf50f9f0c7516670e5208cb2ee11a4 Mon Sep 17 00:00:00 2001 From: Matt Parrett Date: Thu, 30 Apr 2026 20:07:01 -0700 Subject: [PATCH 1/2] optim: fix bf16 mixed-precision crashes on MPS When NANOCHAT_DTYPE=bfloat16 is set on Apple Silicon (MPS), the optimizer step crashes with: 'mps.multiply' op requires the same element type for all operands CUDA implicitly promotes mixed-dtype operands; MPS hard-fails. Two minimal fixes in optim.py: 1. adamw_step_fused: cast scalar hyperparams (lr, wd, betas, eps) to p.dtype at the top of the function, then use the cast versions. nanochat stores some params (wte, value_embeds) at COMPUTE_DTYPE to save memory but the optimizer's shared scalar tensors are fp32. First crash site was `p.mul_(1 - lr_t * wd_t)` on a bf16 wte. 2. muon_step_fused: cast g back to stacked_params.dtype after the polar express loop. The loop intentionally runs in bf16 for speed (`X = g.bfloat16()`); without the cast back, the subsequent variance reduction mixes bf16 g with fp32 second_momentum_buffer. Verified end-to-end on torch 2.9.1 (the upstream pin) on M2 MPS: forward + backward + optimizer.step all pass under bf16. Fp32 path is unchanged (identical loss values to master to 6 decimals). Co-Authored-By: Claude Opus 4.7 (1M context) --- nanochat/optim.py | 25 +++++++++++++++++++------ 1 file changed, 19 insertions(+), 6 deletions(-) diff --git a/nanochat/optim.py b/nanochat/optim.py index 56e85e14..8ff5a9e1 100644 --- a/nanochat/optim.py +++ b/nanochat/optim.py @@ -36,16 +36,23 @@ def adamw_step_fused( All in one compiled graph to eliminate Python overhead between ops. The 0-D CPU tensors avoid recompilation when hyperparameter values change. """ + # Cast scalar hyperparams to p.dtype. nanochat stores some params (wte, + # value_embeds) at COMPUTE_DTYPE (bf16/fp16) to save embedding memory; the + # shared scalar tensors here are fp32. CUDA implicitly promotes mixed-dtype + # operands but MPS hard-fails ("mps.multiply requires same element type"), + # so we cast once up front. No-op when p is already fp32. + dtype = p.dtype + lr_d, wd_d, beta1_d, beta2_d, eps_d = lr_t.to(dtype), wd_t.to(dtype), beta1_t.to(dtype), beta2_t.to(dtype), eps_t.to(dtype) # Weight decay (decoupled, applied before the update) - p.mul_(1 - lr_t * wd_t) + p.mul_(1 - lr_d * wd_d) # Update running averages (lerp_ is cleaner and fuses well) - exp_avg.lerp_(grad, 1 - beta1_t) - exp_avg_sq.lerp_(grad.square(), 1 - beta2_t) - # Bias corrections + exp_avg.lerp_(grad, 1 - beta1_d) + exp_avg_sq.lerp_(grad.square(), 1 - beta2_d) + # Bias corrections (in scalar fp32, then cast back to dtype below) bias1 = 1 - beta1_t ** step_t bias2 = 1 - beta2_t ** step_t # Compute update and apply - denom = (exp_avg_sq / bias2).sqrt() + eps_t + denom = (exp_avg_sq / bias2.to(dtype)).sqrt() + eps_d step_size = lr_t / bias1 p.add_(exp_avg / denom, alpha=-step_size) @@ -126,7 +133,13 @@ def muon_step_fused( A = X @ X.mT B = b * A + c * (A @ A) X = a * X + B @ X - g = X + # Cast g back to the parameter dtype: the polar express loop above + # intentionally runs in bf16 for speed (X = g.bfloat16()), but the rest + # of the function (variance reduction, cautious update) needs g to match + # stacked_params and second_momentum_buffer dtypes. CUDA implicitly + # promotes mixed-dtype operands; MPS hard-fails. No-op when X.dtype + # already matches stacked_params.dtype. + g = X.to(stacked_params.dtype) # Variance reduction beta2 = beta2_t.to(g.dtype) From 45aa6e2de2b506e9f5f412c85aa33ab372bcc9ff Mon Sep 17 00:00:00 2001 From: Matt Parrett Date: Thu, 30 Apr 2026 20:19:36 -0700 Subject: [PATCH 2/2] tests: regression test for adamw_step_fused with bf16 params + fp32 scalars Covers the MPS Metal Graph compiler crash that motivated the fix: adamw_step_fused crashed when p was bf16 (the standard nanochat config for wte/value_embeds) but the optimizer's shared scalar hyperparameters were fp32. Two tests: - test_adamw_step_fused_bf16_param_with_fp32_scalars: smoke test on bf16 path, verifies no crash and a finite weight update. - test_adamw_step_fused_fp32_param_unchanged: confirms fp32 path still produces a sensible update (the dtype-cast patch is a no-op when source dtype matches target). Both tests run on CPU (default) or MPS (when available). Muon's mixed-dtype path is gated on the COMPUTE_DTYPE module constant (set from NANOCHAT_DTYPE env var at import time), which is awkward to exercise in a unit test without subprocess; the muon fix is covered by manual end-to-end testing on M2 + bf16 instead. Co-Authored-By: Claude Opus 4.7 (1M context) --- tests/test_optim_bf16.py | 71 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 71 insertions(+) create mode 100644 tests/test_optim_bf16.py diff --git a/tests/test_optim_bf16.py b/tests/test_optim_bf16.py new file mode 100644 index 00000000..dd63fdf8 --- /dev/null +++ b/tests/test_optim_bf16.py @@ -0,0 +1,71 @@ +""" +Regression tests for mixed-dtype scalar / parameter handling in optim.py. + +These cover the MPS Metal Graph compiler crashes seen with +NANOCHAT_DTYPE=bfloat16: scalar hyperparams (fp32) being multiplied with +bf16 params (wte, value_embeds) failed with "mps.multiply requires same +element type". CUDA implicitly promotes mixed-dtype operands; MPS doesn't. +""" +import pytest +import torch + +from nanochat.optim import adamw_step_fused + + +def _device(): + if torch.backends.mps.is_available(): + return torch.device("mps") + return torch.device("cpu") + + +def _scalars(): + """0-D fp32 scalar tensors matching what MuonAdamW.__init__ creates.""" + return [ + torch.tensor(1.0, dtype=torch.float32), # step + torch.tensor(0.01, dtype=torch.float32), # lr + torch.tensor(0.9, dtype=torch.float32), # beta1 + torch.tensor(0.999, dtype=torch.float32), # beta2 + torch.tensor(1e-8, dtype=torch.float32), # eps + torch.tensor(0.01, dtype=torch.float32), # wd + ] + + +def _sync(device): + if device.type == "mps": + torch.mps.synchronize() + + +def _run_adamw(p, grad): + exp_avg = torch.zeros_like(p) + exp_avg_sq = torch.zeros_like(p) + p_before = p.clone() + adamw_step_fused(p, grad, exp_avg, exp_avg_sq, *_scalars()) + _sync(p.device) + return p_before + + +def test_adamw_step_fused_bf16_param_with_fp32_scalars(): + """Regression: adamw_step_fused must not crash when p is bf16 but the + scalar hyperparams are fp32. This is the standard nanochat config — + wte and value_embeds are cast to COMPUTE_DTYPE (bf16) to save memory, + while MuonAdamW's shared scalar tensors remain fp32.""" + device = _device() + torch.manual_seed(0) + p = torch.randn(64, 32, dtype=torch.bfloat16, device=device) + grad = torch.randn_like(p) + p_before = _run_adamw(p, grad) + assert torch.isfinite(p).all(), "bf16 update produced non-finite values" + assert not torch.equal(p, p_before), "weight did not change after step" + + +def test_adamw_step_fused_fp32_param_unchanged(): + """The fp32 path must still work and produce a sensible update — + the dtype-cast patch should be a no-op when p is already fp32.""" + device = _device() + torch.manual_seed(0) + p = torch.randn(64, 32, dtype=torch.float32, device=device) + grad = torch.randn_like(p) + p_before = _run_adamw(p, grad) + assert torch.isfinite(p).all() + delta = (p - p_before).norm().item() + assert 0 < delta < 10, f"unreasonable update magnitude: {delta}"