mirror of
https://github.com/karpathy/nanochat.git
synced 2026-05-12 10:50:16 +00:00
Merge 45aa6e2de2 into dc54a1a307
This commit is contained in:
commit
1d232b3ec4
|
|
@ -36,16 +36,23 @@ def adamw_step_fused(
|
|||
All in one compiled graph to eliminate Python overhead between ops.
|
||||
The 0-D CPU tensors avoid recompilation when hyperparameter values change.
|
||||
"""
|
||||
# Cast scalar hyperparams to p.dtype. nanochat stores some params (wte,
|
||||
# value_embeds) at COMPUTE_DTYPE (bf16/fp16) to save embedding memory; the
|
||||
# shared scalar tensors here are fp32. CUDA implicitly promotes mixed-dtype
|
||||
# operands but MPS hard-fails ("mps.multiply requires same element type"),
|
||||
# so we cast once up front. No-op when p is already fp32.
|
||||
dtype = p.dtype
|
||||
lr_d, wd_d, beta1_d, beta2_d, eps_d = lr_t.to(dtype), wd_t.to(dtype), beta1_t.to(dtype), beta2_t.to(dtype), eps_t.to(dtype)
|
||||
# Weight decay (decoupled, applied before the update)
|
||||
p.mul_(1 - lr_t * wd_t)
|
||||
p.mul_(1 - lr_d * wd_d)
|
||||
# Update running averages (lerp_ is cleaner and fuses well)
|
||||
exp_avg.lerp_(grad, 1 - beta1_t)
|
||||
exp_avg_sq.lerp_(grad.square(), 1 - beta2_t)
|
||||
# Bias corrections
|
||||
exp_avg.lerp_(grad, 1 - beta1_d)
|
||||
exp_avg_sq.lerp_(grad.square(), 1 - beta2_d)
|
||||
# Bias corrections (in scalar fp32, then cast back to dtype below)
|
||||
bias1 = 1 - beta1_t ** step_t
|
||||
bias2 = 1 - beta2_t ** step_t
|
||||
# Compute update and apply
|
||||
denom = (exp_avg_sq / bias2).sqrt() + eps_t
|
||||
denom = (exp_avg_sq / bias2.to(dtype)).sqrt() + eps_d
|
||||
step_size = lr_t / bias1
|
||||
p.add_(exp_avg / denom, alpha=-step_size)
|
||||
|
||||
|
|
@ -126,7 +133,13 @@ def muon_step_fused(
|
|||
A = X @ X.mT
|
||||
B = b * A + c * (A @ A)
|
||||
X = a * X + B @ X
|
||||
g = X
|
||||
# Cast g back to the parameter dtype: the polar express loop above
|
||||
# intentionally runs in bf16 for speed (X = g.bfloat16()), but the rest
|
||||
# of the function (variance reduction, cautious update) needs g to match
|
||||
# stacked_params and second_momentum_buffer dtypes. CUDA implicitly
|
||||
# promotes mixed-dtype operands; MPS hard-fails. No-op when X.dtype
|
||||
# already matches stacked_params.dtype.
|
||||
g = X.to(stacked_params.dtype)
|
||||
|
||||
# Variance reduction
|
||||
beta2 = beta2_t.to(g.dtype)
|
||||
|
|
|
|||
71
tests/test_optim_bf16.py
Normal file
71
tests/test_optim_bf16.py
Normal file
|
|
@ -0,0 +1,71 @@
|
|||
"""
|
||||
Regression tests for mixed-dtype scalar / parameter handling in optim.py.
|
||||
|
||||
These cover the MPS Metal Graph compiler crashes seen with
|
||||
NANOCHAT_DTYPE=bfloat16: scalar hyperparams (fp32) being multiplied with
|
||||
bf16 params (wte, value_embeds) failed with "mps.multiply requires same
|
||||
element type". CUDA implicitly promotes mixed-dtype operands; MPS doesn't.
|
||||
"""
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from nanochat.optim import adamw_step_fused
|
||||
|
||||
|
||||
def _device():
|
||||
if torch.backends.mps.is_available():
|
||||
return torch.device("mps")
|
||||
return torch.device("cpu")
|
||||
|
||||
|
||||
def _scalars():
|
||||
"""0-D fp32 scalar tensors matching what MuonAdamW.__init__ creates."""
|
||||
return [
|
||||
torch.tensor(1.0, dtype=torch.float32), # step
|
||||
torch.tensor(0.01, dtype=torch.float32), # lr
|
||||
torch.tensor(0.9, dtype=torch.float32), # beta1
|
||||
torch.tensor(0.999, dtype=torch.float32), # beta2
|
||||
torch.tensor(1e-8, dtype=torch.float32), # eps
|
||||
torch.tensor(0.01, dtype=torch.float32), # wd
|
||||
]
|
||||
|
||||
|
||||
def _sync(device):
|
||||
if device.type == "mps":
|
||||
torch.mps.synchronize()
|
||||
|
||||
|
||||
def _run_adamw(p, grad):
|
||||
exp_avg = torch.zeros_like(p)
|
||||
exp_avg_sq = torch.zeros_like(p)
|
||||
p_before = p.clone()
|
||||
adamw_step_fused(p, grad, exp_avg, exp_avg_sq, *_scalars())
|
||||
_sync(p.device)
|
||||
return p_before
|
||||
|
||||
|
||||
def test_adamw_step_fused_bf16_param_with_fp32_scalars():
|
||||
"""Regression: adamw_step_fused must not crash when p is bf16 but the
|
||||
scalar hyperparams are fp32. This is the standard nanochat config —
|
||||
wte and value_embeds are cast to COMPUTE_DTYPE (bf16) to save memory,
|
||||
while MuonAdamW's shared scalar tensors remain fp32."""
|
||||
device = _device()
|
||||
torch.manual_seed(0)
|
||||
p = torch.randn(64, 32, dtype=torch.bfloat16, device=device)
|
||||
grad = torch.randn_like(p)
|
||||
p_before = _run_adamw(p, grad)
|
||||
assert torch.isfinite(p).all(), "bf16 update produced non-finite values"
|
||||
assert not torch.equal(p, p_before), "weight did not change after step"
|
||||
|
||||
|
||||
def test_adamw_step_fused_fp32_param_unchanged():
|
||||
"""The fp32 path must still work and produce a sensible update —
|
||||
the dtype-cast patch should be a no-op when p is already fp32."""
|
||||
device = _device()
|
||||
torch.manual_seed(0)
|
||||
p = torch.randn(64, 32, dtype=torch.float32, device=device)
|
||||
grad = torch.randn_like(p)
|
||||
p_before = _run_adamw(p, grad)
|
||||
assert torch.isfinite(p).all()
|
||||
delta = (p - p_before).norm().item()
|
||||
assert 0 < delta < 10, f"unreasonable update magnitude: {delta}"
|
||||
Loading…
Reference in New Issue
Block a user