diff --git a/nanochat/optim.py b/nanochat/optim.py index 56e85e14..8ff5a9e1 100644 --- a/nanochat/optim.py +++ b/nanochat/optim.py @@ -36,16 +36,23 @@ def adamw_step_fused( All in one compiled graph to eliminate Python overhead between ops. The 0-D CPU tensors avoid recompilation when hyperparameter values change. """ + # Cast scalar hyperparams to p.dtype. nanochat stores some params (wte, + # value_embeds) at COMPUTE_DTYPE (bf16/fp16) to save embedding memory; the + # shared scalar tensors here are fp32. CUDA implicitly promotes mixed-dtype + # operands but MPS hard-fails ("mps.multiply requires same element type"), + # so we cast once up front. No-op when p is already fp32. + dtype = p.dtype + lr_d, wd_d, beta1_d, beta2_d, eps_d = lr_t.to(dtype), wd_t.to(dtype), beta1_t.to(dtype), beta2_t.to(dtype), eps_t.to(dtype) # Weight decay (decoupled, applied before the update) - p.mul_(1 - lr_t * wd_t) + p.mul_(1 - lr_d * wd_d) # Update running averages (lerp_ is cleaner and fuses well) - exp_avg.lerp_(grad, 1 - beta1_t) - exp_avg_sq.lerp_(grad.square(), 1 - beta2_t) - # Bias corrections + exp_avg.lerp_(grad, 1 - beta1_d) + exp_avg_sq.lerp_(grad.square(), 1 - beta2_d) + # Bias corrections (in scalar fp32, then cast back to dtype below) bias1 = 1 - beta1_t ** step_t bias2 = 1 - beta2_t ** step_t # Compute update and apply - denom = (exp_avg_sq / bias2).sqrt() + eps_t + denom = (exp_avg_sq / bias2.to(dtype)).sqrt() + eps_d step_size = lr_t / bias1 p.add_(exp_avg / denom, alpha=-step_size) @@ -126,7 +133,13 @@ def muon_step_fused( A = X @ X.mT B = b * A + c * (A @ A) X = a * X + B @ X - g = X + # Cast g back to the parameter dtype: the polar express loop above + # intentionally runs in bf16 for speed (X = g.bfloat16()), but the rest + # of the function (variance reduction, cautious update) needs g to match + # stacked_params and second_momentum_buffer dtypes. CUDA implicitly + # promotes mixed-dtype operands; MPS hard-fails. No-op when X.dtype + # already matches stacked_params.dtype. + g = X.to(stacked_params.dtype) # Variance reduction beta2 = beta2_t.to(g.dtype)