mirror of
https://github.com/karpathy/nanochat.git
synced 2026-05-08 00:39:50 +00:00
optim: fix bf16 mixed-precision crashes on MPS
When NANOCHAT_DTYPE=bfloat16 is set on Apple Silicon (MPS), the
optimizer step crashes with:
'mps.multiply' op requires the same element type for all operands
CUDA implicitly promotes mixed-dtype operands; MPS hard-fails.
Two minimal fixes in optim.py:
1. adamw_step_fused: cast scalar hyperparams (lr, wd, betas, eps)
to p.dtype at the top of the function, then use the cast versions.
nanochat stores some params (wte, value_embeds) at COMPUTE_DTYPE
to save memory but the optimizer's shared scalar tensors are fp32.
First crash site was `p.mul_(1 - lr_t * wd_t)` on a bf16 wte.
2. muon_step_fused: cast g back to stacked_params.dtype after the
polar express loop. The loop intentionally runs in bf16 for speed
(`X = g.bfloat16()`); without the cast back, the subsequent
variance reduction mixes bf16 g with fp32 second_momentum_buffer.
Verified end-to-end on torch 2.9.1 (the upstream pin) on M2 MPS:
forward + backward + optimizer.step all pass under bf16. Fp32 path
is unchanged (identical loss values to master to 6 decimals).
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
parent
0aaca56805
commit
bd5232cafc
|
|
@ -36,16 +36,23 @@ def adamw_step_fused(
|
|||
All in one compiled graph to eliminate Python overhead between ops.
|
||||
The 0-D CPU tensors avoid recompilation when hyperparameter values change.
|
||||
"""
|
||||
# Cast scalar hyperparams to p.dtype. nanochat stores some params (wte,
|
||||
# value_embeds) at COMPUTE_DTYPE (bf16/fp16) to save embedding memory; the
|
||||
# shared scalar tensors here are fp32. CUDA implicitly promotes mixed-dtype
|
||||
# operands but MPS hard-fails ("mps.multiply requires same element type"),
|
||||
# so we cast once up front. No-op when p is already fp32.
|
||||
dtype = p.dtype
|
||||
lr_d, wd_d, beta1_d, beta2_d, eps_d = lr_t.to(dtype), wd_t.to(dtype), beta1_t.to(dtype), beta2_t.to(dtype), eps_t.to(dtype)
|
||||
# Weight decay (decoupled, applied before the update)
|
||||
p.mul_(1 - lr_t * wd_t)
|
||||
p.mul_(1 - lr_d * wd_d)
|
||||
# Update running averages (lerp_ is cleaner and fuses well)
|
||||
exp_avg.lerp_(grad, 1 - beta1_t)
|
||||
exp_avg_sq.lerp_(grad.square(), 1 - beta2_t)
|
||||
# Bias corrections
|
||||
exp_avg.lerp_(grad, 1 - beta1_d)
|
||||
exp_avg_sq.lerp_(grad.square(), 1 - beta2_d)
|
||||
# Bias corrections (in scalar fp32, then cast back to dtype below)
|
||||
bias1 = 1 - beta1_t ** step_t
|
||||
bias2 = 1 - beta2_t ** step_t
|
||||
# Compute update and apply
|
||||
denom = (exp_avg_sq / bias2).sqrt() + eps_t
|
||||
denom = (exp_avg_sq / bias2.to(dtype)).sqrt() + eps_d
|
||||
step_size = lr_t / bias1
|
||||
p.add_(exp_avg / denom, alpha=-step_size)
|
||||
|
||||
|
|
@ -126,7 +133,13 @@ def muon_step_fused(
|
|||
A = X @ X.mT
|
||||
B = b * A + c * (A @ A)
|
||||
X = a * X + B @ X
|
||||
g = X
|
||||
# Cast g back to the parameter dtype: the polar express loop above
|
||||
# intentionally runs in bf16 for speed (X = g.bfloat16()), but the rest
|
||||
# of the function (variance reduction, cautious update) needs g to match
|
||||
# stacked_params and second_momentum_buffer dtypes. CUDA implicitly
|
||||
# promotes mixed-dtype operands; MPS hard-fails. No-op when X.dtype
|
||||
# already matches stacked_params.dtype.
|
||||
g = X.to(stacked_params.dtype)
|
||||
|
||||
# Variance reduction
|
||||
beta2 = beta2_t.to(g.dtype)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user