mirror of
https://github.com/karpathy/nanochat.git
synced 2026-05-08 16:59:59 +00:00
Covers the MPS Metal Graph compiler crash that motivated the fix: adamw_step_fused crashed when p was bf16 (the standard nanochat config for wte/value_embeds) but the optimizer's shared scalar hyperparameters were fp32. Two tests: - test_adamw_step_fused_bf16_param_with_fp32_scalars: smoke test on bf16 path, verifies no crash and a finite weight update. - test_adamw_step_fused_fp32_param_unchanged: confirms fp32 path still produces a sensible update (the dtype-cast patch is a no-op when source dtype matches target). Both tests run on CPU (default) or MPS (when available). Muon's mixed-dtype path is gated on the COMPUTE_DTYPE module constant (set from NANOCHAT_DTYPE env var at import time), which is awkward to exercise in a unit test without subprocess; the muon fix is covered by manual end-to-end testing on M2 + bf16 instead. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
72 lines
2.5 KiB
Python
72 lines
2.5 KiB
Python
"""
|
|
Regression tests for mixed-dtype scalar / parameter handling in optim.py.
|
|
|
|
These cover the MPS Metal Graph compiler crashes seen with
|
|
NANOCHAT_DTYPE=bfloat16: scalar hyperparams (fp32) being multiplied with
|
|
bf16 params (wte, value_embeds) failed with "mps.multiply requires same
|
|
element type". CUDA implicitly promotes mixed-dtype operands; MPS doesn't.
|
|
"""
|
|
import pytest
|
|
import torch
|
|
|
|
from nanochat.optim import adamw_step_fused
|
|
|
|
|
|
def _device():
|
|
if torch.backends.mps.is_available():
|
|
return torch.device("mps")
|
|
return torch.device("cpu")
|
|
|
|
|
|
def _scalars():
|
|
"""0-D fp32 scalar tensors matching what MuonAdamW.__init__ creates."""
|
|
return [
|
|
torch.tensor(1.0, dtype=torch.float32), # step
|
|
torch.tensor(0.01, dtype=torch.float32), # lr
|
|
torch.tensor(0.9, dtype=torch.float32), # beta1
|
|
torch.tensor(0.999, dtype=torch.float32), # beta2
|
|
torch.tensor(1e-8, dtype=torch.float32), # eps
|
|
torch.tensor(0.01, dtype=torch.float32), # wd
|
|
]
|
|
|
|
|
|
def _sync(device):
|
|
if device.type == "mps":
|
|
torch.mps.synchronize()
|
|
|
|
|
|
def _run_adamw(p, grad):
|
|
exp_avg = torch.zeros_like(p)
|
|
exp_avg_sq = torch.zeros_like(p)
|
|
p_before = p.clone()
|
|
adamw_step_fused(p, grad, exp_avg, exp_avg_sq, *_scalars())
|
|
_sync(p.device)
|
|
return p_before
|
|
|
|
|
|
def test_adamw_step_fused_bf16_param_with_fp32_scalars():
|
|
"""Regression: adamw_step_fused must not crash when p is bf16 but the
|
|
scalar hyperparams are fp32. This is the standard nanochat config —
|
|
wte and value_embeds are cast to COMPUTE_DTYPE (bf16) to save memory,
|
|
while MuonAdamW's shared scalar tensors remain fp32."""
|
|
device = _device()
|
|
torch.manual_seed(0)
|
|
p = torch.randn(64, 32, dtype=torch.bfloat16, device=device)
|
|
grad = torch.randn_like(p)
|
|
p_before = _run_adamw(p, grad)
|
|
assert torch.isfinite(p).all(), "bf16 update produced non-finite values"
|
|
assert not torch.equal(p, p_before), "weight did not change after step"
|
|
|
|
|
|
def test_adamw_step_fused_fp32_param_unchanged():
|
|
"""The fp32 path must still work and produce a sensible update —
|
|
the dtype-cast patch should be a no-op when p is already fp32."""
|
|
device = _device()
|
|
torch.manual_seed(0)
|
|
p = torch.randn(64, 32, dtype=torch.float32, device=device)
|
|
grad = torch.randn_like(p)
|
|
p_before = _run_adamw(p, grad)
|
|
assert torch.isfinite(p).all()
|
|
delta = (p - p_before).norm().item()
|
|
assert 0 < delta < 10, f"unreasonable update magnitude: {delta}"
|