diff --git a/tests/test_optim_bf16.py b/tests/test_optim_bf16.py new file mode 100644 index 00000000..dd63fdf8 --- /dev/null +++ b/tests/test_optim_bf16.py @@ -0,0 +1,71 @@ +""" +Regression tests for mixed-dtype scalar / parameter handling in optim.py. + +These cover the MPS Metal Graph compiler crashes seen with +NANOCHAT_DTYPE=bfloat16: scalar hyperparams (fp32) being multiplied with +bf16 params (wte, value_embeds) failed with "mps.multiply requires same +element type". CUDA implicitly promotes mixed-dtype operands; MPS doesn't. +""" +import pytest +import torch + +from nanochat.optim import adamw_step_fused + + +def _device(): + if torch.backends.mps.is_available(): + return torch.device("mps") + return torch.device("cpu") + + +def _scalars(): + """0-D fp32 scalar tensors matching what MuonAdamW.__init__ creates.""" + return [ + torch.tensor(1.0, dtype=torch.float32), # step + torch.tensor(0.01, dtype=torch.float32), # lr + torch.tensor(0.9, dtype=torch.float32), # beta1 + torch.tensor(0.999, dtype=torch.float32), # beta2 + torch.tensor(1e-8, dtype=torch.float32), # eps + torch.tensor(0.01, dtype=torch.float32), # wd + ] + + +def _sync(device): + if device.type == "mps": + torch.mps.synchronize() + + +def _run_adamw(p, grad): + exp_avg = torch.zeros_like(p) + exp_avg_sq = torch.zeros_like(p) + p_before = p.clone() + adamw_step_fused(p, grad, exp_avg, exp_avg_sq, *_scalars()) + _sync(p.device) + return p_before + + +def test_adamw_step_fused_bf16_param_with_fp32_scalars(): + """Regression: adamw_step_fused must not crash when p is bf16 but the + scalar hyperparams are fp32. This is the standard nanochat config — + wte and value_embeds are cast to COMPUTE_DTYPE (bf16) to save memory, + while MuonAdamW's shared scalar tensors remain fp32.""" + device = _device() + torch.manual_seed(0) + p = torch.randn(64, 32, dtype=torch.bfloat16, device=device) + grad = torch.randn_like(p) + p_before = _run_adamw(p, grad) + assert torch.isfinite(p).all(), "bf16 update produced non-finite values" + assert not torch.equal(p, p_before), "weight did not change after step" + + +def test_adamw_step_fused_fp32_param_unchanged(): + """The fp32 path must still work and produce a sensible update — + the dtype-cast patch should be a no-op when p is already fp32.""" + device = _device() + torch.manual_seed(0) + p = torch.randn(64, 32, dtype=torch.float32, device=device) + grad = torch.randn_like(p) + p_before = _run_adamw(p, grad) + assert torch.isfinite(p).all() + delta = (p - p_before).norm().item() + assert 0 < delta < 10, f"unreasonable update magnitude: {delta}"