use COMPUTE_DTYPE-aware cast in Muon polar express step

The bf16 cast is intentional for speed on Hopper+ GPUs, but should be
skipped on other platforms rather than blindly applied. fp16 is unstable
here due to its limited exponent range, and fp32 platforms don't benefit
from the cast. Now: bf16 when COMPUTE_DTYPE is bf16, no cast otherwise.

Inspired by PR #667.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
Andrej Karpathy 2026-03-25 20:19:14 +00:00
parent 4e1694cc95
commit c0dbf1f3ff

View File

@ -10,6 +10,7 @@ Further contributions from @karpathy and @chrisjmccormick.
import torch
import torch.distributed as dist
from torch import Tensor
from nanochat.common import COMPUTE_DTYPE
# -----------------------------------------------------------------------------
"""
@ -112,7 +113,8 @@ def muon_step_fused(
g = stacked_grads.lerp_(momentum_buffer, momentum)
# Polar express
X = g.bfloat16()
# Cast to bf16 for speed when available; skip cast otherwise (fp16 is unstable here due to limited exponent range)
X = g.bfloat16() if COMPUTE_DTYPE == torch.bfloat16 else g
X = X / (X.norm(dim=(-2, -1), keepdim=True) * 1.01 + 1e-6)
if g.size(-2) > g.size(-1): # Tall matrix
for a, b, c in polar_express_coeffs[:ns_steps]: