From c0dbf1f3fff10ef9d1a50e14a6188e04506251b6 Mon Sep 17 00:00:00 2001 From: Andrej Karpathy Date: Wed, 25 Mar 2026 20:19:14 +0000 Subject: [PATCH] use COMPUTE_DTYPE-aware cast in Muon polar express step The bf16 cast is intentional for speed on Hopper+ GPUs, but should be skipped on other platforms rather than blindly applied. fp16 is unstable here due to its limited exponent range, and fp32 platforms don't benefit from the cast. Now: bf16 when COMPUTE_DTYPE is bf16, no cast otherwise. Inspired by PR #667. Co-Authored-By: Claude Opus 4.6 (1M context) --- nanochat/optim.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/nanochat/optim.py b/nanochat/optim.py index 0ee2e27..56e85e1 100644 --- a/nanochat/optim.py +++ b/nanochat/optim.py @@ -10,6 +10,7 @@ Further contributions from @karpathy and @chrisjmccormick. import torch import torch.distributed as dist from torch import Tensor +from nanochat.common import COMPUTE_DTYPE # ----------------------------------------------------------------------------- """ @@ -112,7 +113,8 @@ def muon_step_fused( g = stacked_grads.lerp_(momentum_buffer, momentum) # Polar express - X = g.bfloat16() + # Cast to bf16 for speed when available; skip cast otherwise (fp16 is unstable here due to limited exponent range) + X = g.bfloat16() if COMPUTE_DTYPE == torch.bfloat16 else g X = X / (X.norm(dim=(-2, -1), keepdim=True) * 1.01 + 1e-6) if g.size(-2) > g.size(-1): # Tall matrix for a, b, c in polar_express_coeffs[:ns_steps]: