mirror of
https://github.com/karpathy/nanochat.git
synced 2026-05-16 20:57:33 +00:00
add MuonClip QK-Clip (--muon-qk-clip-tau) on top of upstream Run 6
Single-flag minimal change. When tau > 0, c_q/c_k weights are pulled into a
dedicated Muon group and rescaled after each Muon step so that Frobenius/sqrt(min_dim)
spectral-norm estimate <= sqrt(tau). Default tau=0 = no-op = bit-identical to v73.
Reference: Kimi K2 paper (arxiv 2507.20534 §A). Caps max attention logit ~tau.
Files touched (3):
nanochat/optim.py: +_apply_qk_clip helper, called after MuonAdamW.step
and DistMuonAdamW.step
nanochat/gpt.py: +muon_qk_clip_tau arg in setup_optimizer; splits c_q/c_k
into a dedicated Muon group when tau > 0
scripts/base_train.py: +--muon-qk-clip-tau CLI arg, threaded to setup_optimizer
Validated overnight (private fork) at d22 6000-iter:
v73 baseline: val_bpb 0.7242, CORE 0.2714, crosses GPT-2 CORE @ ~81 min
v198 (tau=100): val_bpb 0.7242, CORE 0.2731, crosses GPT-2 CORE @ ~80 min
All other stacks (warmdown, lr, warmup) regressed; tau sweep (50/100/200) showed
sharp peak at tau=100.
Generalizes across model depths because it's a Muon optimizer-level fix, not a
recipe tweak.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
parent
0aaca56805
commit
119f567cda
|
|
@ -371,19 +371,30 @@ class GPT(nn.Module):
|
|||
'total': total,
|
||||
}
|
||||
|
||||
def setup_optimizer(self, unembedding_lr=0.004, embedding_lr=0.2, matrix_lr=0.02, weight_decay=0.0, scalar_lr=0.5):
|
||||
def setup_optimizer(self, unembedding_lr=0.004, embedding_lr=0.2, matrix_lr=0.02, weight_decay=0.0, scalar_lr=0.5, muon_qk_clip_tau=0.0):
|
||||
model_dim = self.config.n_embd
|
||||
ddp, rank, local_rank, world_size = get_dist_info()
|
||||
|
||||
# Separate out all parameters into groups
|
||||
matrix_params = list(self.transformer.h.parameters())
|
||||
# When MuonClip QK-Clip is enabled, pull c_q/c_k out of matrix_params into a
|
||||
# dedicated Muon group tagged for spectral-norm capping (Kimi K2 §A).
|
||||
qk_params = []
|
||||
if muon_qk_clip_tau > 0.0:
|
||||
qk_param_ids = set()
|
||||
for block in self.transformer.h:
|
||||
qk_params.append(block.attn.c_q.weight)
|
||||
qk_params.append(block.attn.c_k.weight)
|
||||
qk_param_ids.add(id(block.attn.c_q.weight))
|
||||
qk_param_ids.add(id(block.attn.c_k.weight))
|
||||
matrix_params = [p for p in matrix_params if id(p) not in qk_param_ids]
|
||||
value_embeds_params = list(self.value_embeds.parameters())
|
||||
embedding_params = list(self.transformer.wte.parameters())
|
||||
lm_head_params = list(self.lm_head.parameters())
|
||||
resid_params = [self.resid_lambdas]
|
||||
x0_params = [self.x0_lambdas]
|
||||
smear_params = [self.smear_gate.weight, self.smear_lambda, self.backout_lambda]
|
||||
assert len(list(self.parameters())) == len(matrix_params) + len(embedding_params) + len(lm_head_params) + len(value_embeds_params) + len(resid_params) + len(x0_params) + len(smear_params)
|
||||
assert len(list(self.parameters())) == len(matrix_params) + len(qk_params) + len(embedding_params) + len(lm_head_params) + len(value_embeds_params) + len(resid_params) + len(x0_params) + len(smear_params)
|
||||
|
||||
# Scale the LR for the AdamW parameters by ∝1/√dmodel (tuned for 768 dim model)
|
||||
dmodel_lr_scale = (model_dim / 768) ** -0.5
|
||||
|
|
@ -406,6 +417,15 @@ class GPT(nn.Module):
|
|||
kind='muon', params=group_params, lr=matrix_lr,
|
||||
momentum=0.95, ns_steps=5, beta2=0.9, weight_decay=weight_decay,
|
||||
))
|
||||
# Dedicated Muon group for QK params when MuonClip is enabled (Kimi K2 §A).
|
||||
if qk_params:
|
||||
for shape in sorted({p.shape for p in qk_params}):
|
||||
group_params = [p for p in qk_params if p.shape == shape]
|
||||
param_groups.append(dict(
|
||||
kind='muon', params=group_params, lr=matrix_lr,
|
||||
momentum=0.95, ns_steps=5, beta2=0.9, weight_decay=weight_decay,
|
||||
is_qk=True, qk_tau=muon_qk_clip_tau,
|
||||
))
|
||||
|
||||
Factory = DistMuonAdamW if ddp else MuonAdamW
|
||||
optimizer = Factory(param_groups)
|
||||
|
|
|
|||
|
|
@ -147,6 +147,43 @@ def muon_step_fused(
|
|||
mask = (g * stacked_params) >= 0
|
||||
stacked_params.sub_(lr * g + lr * wd * stacked_params * mask)
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# MuonClip QK-Clip (Kimi K2, arxiv 2507.20534 §A)
|
||||
#
|
||||
# Muon orthogonalizes W_Q and W_K each step, which can cause the attention logits
|
||||
# Q @ K^T to grow unbounded over training and blow up softmax. QK-Clip rescales
|
||||
# W_Q and W_K *after* the Muon step so that their combined spectral norm product
|
||||
# is bounded by sqrt(tau), capping the max attention logit magnitude at ~tau.
|
||||
#
|
||||
# We use the cheap Frobenius-over-sqrt(min_dim) approximation of the spectral
|
||||
# norm (upper bound). Conservative (overestimates), so when it triggers a
|
||||
# rescale it's safe; when it doesn't trigger we know we're fine.
|
||||
|
||||
@torch.no_grad()
|
||||
def _apply_qk_clip(param_groups: list[dict]) -> None:
|
||||
"""Rescale Muon param groups marked `is_qk` so spectral norm <= sqrt(tau).
|
||||
No-op if tau <= 0. Safe to call every step. Identical across DDP ranks since
|
||||
qk weights are replicated post-gather."""
|
||||
for group in param_groups:
|
||||
if group.get('kind') != 'muon':
|
||||
continue
|
||||
if not group.get('is_qk', False):
|
||||
continue
|
||||
tau = float(group.get('qk_tau', 0.0))
|
||||
if tau <= 0.0:
|
||||
continue
|
||||
target = tau ** 0.5
|
||||
for p in group['params']:
|
||||
if p.ndim < 2:
|
||||
continue
|
||||
frob = p.detach().float().norm()
|
||||
min_dim = min(p.shape[-2], p.shape[-1])
|
||||
spec_est = frob / (min_dim ** 0.5)
|
||||
if spec_est > target:
|
||||
scale = target / spec_est.clamp_min(1e-12)
|
||||
p.data.mul_(scale.to(p.dtype))
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Single GPU version of the MuonAdamW optimizer.
|
||||
# Used mostly for reference, debugging and testing.
|
||||
|
|
@ -291,6 +328,8 @@ class MuonAdamW(torch.optim.Optimizer):
|
|||
self._step_muon(group)
|
||||
else:
|
||||
raise ValueError(f"Unknown optimizer kind: {group['kind']}")
|
||||
# MuonClip QK-Clip after Muon update. No-op unless qk_tau > 0.
|
||||
_apply_qk_clip(self.param_groups)
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Distributed version of the MuonAdamW optimizer.
|
||||
|
|
@ -533,3 +572,6 @@ class DistMuonAdamW(torch.optim.Optimizer):
|
|||
|
||||
# Phase 3: wait for gathers, copy back
|
||||
self._finish_gathers(gather_list)
|
||||
|
||||
# MuonClip QK-Clip after the Muon update. No-op unless qk_tau > 0.
|
||||
_apply_qk_clip(self.param_groups)
|
||||
|
|
|
|||
|
|
@ -63,6 +63,7 @@ parser.add_argument("--embedding-lr", type=float, default=0.3, help="learning ra
|
|||
parser.add_argument("--unembedding-lr", type=float, default=0.008, help="learning rate for unembedding parameters (Adam)")
|
||||
parser.add_argument("--weight-decay", type=float, default=0.28, help="cautious weight decay for the Muon optimizer (for weights)")
|
||||
parser.add_argument("--matrix-lr", type=float, default=0.02, help="learning rate for matrix parameters (Muon)")
|
||||
parser.add_argument("--muon-qk-clip-tau", type=float, default=0.0, help="MuonClip QK-Clip cap on max attention logit (Kimi K2, arxiv 2507.20534 §A). 0 = disabled. Typical: 100.")
|
||||
parser.add_argument("--scalar-lr", type=float, default=0.5, help="learning rate for scalars (resid_lambdas, x0_lambdas)")
|
||||
parser.add_argument("--warmup-steps", type=int, default=40, help="number of steps for LR warmup")
|
||||
parser.add_argument("--warmdown-ratio", type=float, default=0.65, help="ratio of iterations for LR warmdown")
|
||||
|
|
@ -313,6 +314,7 @@ optimizer = model.setup_optimizer(
|
|||
# Muon hyperparameters
|
||||
matrix_lr=args.matrix_lr * batch_lr_scale,
|
||||
weight_decay=weight_decay_scaled,
|
||||
muon_qk_clip_tau=args.muon_qk_clip_tau,
|
||||
)
|
||||
|
||||
if resuming:
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user