mirror of
https://github.com/karpathy/nanochat.git
synced 2026-03-26 22:55:16 +00:00
use COMPUTE_DTYPE-aware cast in Muon polar express step
The bf16 cast is intentional for speed on Hopper+ GPUs, but should be skipped on other platforms rather than blindly applied. fp16 is unstable here due to its limited exponent range, and fp32 platforms don't benefit from the cast. Now: bf16 when COMPUTE_DTYPE is bf16, no cast otherwise. Inspired by PR #667. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
parent
4e1694cc95
commit
c0dbf1f3ff
|
|
@ -10,6 +10,7 @@ Further contributions from @karpathy and @chrisjmccormick.
|
|||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch import Tensor
|
||||
from nanochat.common import COMPUTE_DTYPE
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
"""
|
||||
|
|
@ -112,7 +113,8 @@ def muon_step_fused(
|
|||
g = stacked_grads.lerp_(momentum_buffer, momentum)
|
||||
|
||||
# Polar express
|
||||
X = g.bfloat16()
|
||||
# Cast to bf16 for speed when available; skip cast otherwise (fp16 is unstable here due to limited exponent range)
|
||||
X = g.bfloat16() if COMPUTE_DTYPE == torch.bfloat16 else g
|
||||
X = X / (X.norm(dim=(-2, -1), keepdim=True) * 1.01 + 1e-6)
|
||||
if g.size(-2) > g.size(-1): # Tall matrix
|
||||
for a, b, c in polar_express_coeffs[:ns_steps]:
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user