diff --git a/nanochat/optim.py b/nanochat/optim.py index 0ee2e27..56e85e1 100644 --- a/nanochat/optim.py +++ b/nanochat/optim.py @@ -10,6 +10,7 @@ Further contributions from @karpathy and @chrisjmccormick. import torch import torch.distributed as dist from torch import Tensor +from nanochat.common import COMPUTE_DTYPE # ----------------------------------------------------------------------------- """ @@ -112,7 +113,8 @@ def muon_step_fused( g = stacked_grads.lerp_(momentum_buffer, momentum) # Polar express - X = g.bfloat16() + # Cast to bf16 for speed when available; skip cast otherwise (fp16 is unstable here due to limited exponent range) + X = g.bfloat16() if COMPUTE_DTYPE == torch.bfloat16 else g X = X / (X.norm(dim=(-2, -1), keepdim=True) * 1.01 + 1e-6) if g.size(-2) > g.size(-1): # Tall matrix for a, b, c in polar_express_coeffs[:ns_steps]: