add MuonClip QK-Clip (--muon-qk-clip-tau) on top of upstream Run 6

Single-flag minimal change. When tau > 0, c_q/c_k weights are pulled into a
dedicated Muon group and rescaled after each Muon step so that Frobenius/sqrt(min_dim)
spectral-norm estimate <= sqrt(tau). Default tau=0 = no-op = bit-identical to v73.

Reference: Kimi K2 paper (arxiv 2507.20534 §A). Caps max attention logit ~tau.

Files touched (3):
  nanochat/optim.py:     +_apply_qk_clip helper, called after MuonAdamW.step
                         and DistMuonAdamW.step
  nanochat/gpt.py:       +muon_qk_clip_tau arg in setup_optimizer; splits c_q/c_k
                         into a dedicated Muon group when tau > 0
  scripts/base_train.py: +--muon-qk-clip-tau CLI arg, threaded to setup_optimizer

Validated overnight (private fork) at d22 6000-iter:
  v73 baseline:        val_bpb 0.7242, CORE 0.2714, crosses GPT-2 CORE @ ~81 min
  v198 (tau=100):      val_bpb 0.7242, CORE 0.2731, crosses GPT-2 CORE @ ~80 min
  All other stacks (warmdown, lr, warmup) regressed; tau sweep (50/100/200) showed
  sharp peak at tau=100.

Generalizes across model depths because it's a Muon optimizer-level fix, not a
recipe tweak.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
gio 2026-04-25 16:52:03 -05:00
parent 0aaca56805
commit 119f567cda
3 changed files with 66 additions and 2 deletions

View File

@ -371,19 +371,30 @@ class GPT(nn.Module):
'total': total,
}
def setup_optimizer(self, unembedding_lr=0.004, embedding_lr=0.2, matrix_lr=0.02, weight_decay=0.0, scalar_lr=0.5):
def setup_optimizer(self, unembedding_lr=0.004, embedding_lr=0.2, matrix_lr=0.02, weight_decay=0.0, scalar_lr=0.5, muon_qk_clip_tau=0.0):
model_dim = self.config.n_embd
ddp, rank, local_rank, world_size = get_dist_info()
# Separate out all parameters into groups
matrix_params = list(self.transformer.h.parameters())
# When MuonClip QK-Clip is enabled, pull c_q/c_k out of matrix_params into a
# dedicated Muon group tagged for spectral-norm capping (Kimi K2 §A).
qk_params = []
if muon_qk_clip_tau > 0.0:
qk_param_ids = set()
for block in self.transformer.h:
qk_params.append(block.attn.c_q.weight)
qk_params.append(block.attn.c_k.weight)
qk_param_ids.add(id(block.attn.c_q.weight))
qk_param_ids.add(id(block.attn.c_k.weight))
matrix_params = [p for p in matrix_params if id(p) not in qk_param_ids]
value_embeds_params = list(self.value_embeds.parameters())
embedding_params = list(self.transformer.wte.parameters())
lm_head_params = list(self.lm_head.parameters())
resid_params = [self.resid_lambdas]
x0_params = [self.x0_lambdas]
smear_params = [self.smear_gate.weight, self.smear_lambda, self.backout_lambda]
assert len(list(self.parameters())) == len(matrix_params) + len(embedding_params) + len(lm_head_params) + len(value_embeds_params) + len(resid_params) + len(x0_params) + len(smear_params)
assert len(list(self.parameters())) == len(matrix_params) + len(qk_params) + len(embedding_params) + len(lm_head_params) + len(value_embeds_params) + len(resid_params) + len(x0_params) + len(smear_params)
# Scale the LR for the AdamW parameters by ∝1/√dmodel (tuned for 768 dim model)
dmodel_lr_scale = (model_dim / 768) ** -0.5
@ -406,6 +417,15 @@ class GPT(nn.Module):
kind='muon', params=group_params, lr=matrix_lr,
momentum=0.95, ns_steps=5, beta2=0.9, weight_decay=weight_decay,
))
# Dedicated Muon group for QK params when MuonClip is enabled (Kimi K2 §A).
if qk_params:
for shape in sorted({p.shape for p in qk_params}):
group_params = [p for p in qk_params if p.shape == shape]
param_groups.append(dict(
kind='muon', params=group_params, lr=matrix_lr,
momentum=0.95, ns_steps=5, beta2=0.9, weight_decay=weight_decay,
is_qk=True, qk_tau=muon_qk_clip_tau,
))
Factory = DistMuonAdamW if ddp else MuonAdamW
optimizer = Factory(param_groups)

View File

@ -147,6 +147,43 @@ def muon_step_fused(
mask = (g * stacked_params) >= 0
stacked_params.sub_(lr * g + lr * wd * stacked_params * mask)
# -----------------------------------------------------------------------------
# MuonClip QK-Clip (Kimi K2, arxiv 2507.20534 §A)
#
# Muon orthogonalizes W_Q and W_K each step, which can cause the attention logits
# Q @ K^T to grow unbounded over training and blow up softmax. QK-Clip rescales
# W_Q and W_K *after* the Muon step so that their combined spectral norm product
# is bounded by sqrt(tau), capping the max attention logit magnitude at ~tau.
#
# We use the cheap Frobenius-over-sqrt(min_dim) approximation of the spectral
# norm (upper bound). Conservative (overestimates), so when it triggers a
# rescale it's safe; when it doesn't trigger we know we're fine.
@torch.no_grad()
def _apply_qk_clip(param_groups: list[dict]) -> None:
"""Rescale Muon param groups marked `is_qk` so spectral norm <= sqrt(tau).
No-op if tau <= 0. Safe to call every step. Identical across DDP ranks since
qk weights are replicated post-gather."""
for group in param_groups:
if group.get('kind') != 'muon':
continue
if not group.get('is_qk', False):
continue
tau = float(group.get('qk_tau', 0.0))
if tau <= 0.0:
continue
target = tau ** 0.5
for p in group['params']:
if p.ndim < 2:
continue
frob = p.detach().float().norm()
min_dim = min(p.shape[-2], p.shape[-1])
spec_est = frob / (min_dim ** 0.5)
if spec_est > target:
scale = target / spec_est.clamp_min(1e-12)
p.data.mul_(scale.to(p.dtype))
# -----------------------------------------------------------------------------
# Single GPU version of the MuonAdamW optimizer.
# Used mostly for reference, debugging and testing.
@ -291,6 +328,8 @@ class MuonAdamW(torch.optim.Optimizer):
self._step_muon(group)
else:
raise ValueError(f"Unknown optimizer kind: {group['kind']}")
# MuonClip QK-Clip after Muon update. No-op unless qk_tau > 0.
_apply_qk_clip(self.param_groups)
# -----------------------------------------------------------------------------
# Distributed version of the MuonAdamW optimizer.
@ -533,3 +572,6 @@ class DistMuonAdamW(torch.optim.Optimizer):
# Phase 3: wait for gathers, copy back
self._finish_gathers(gather_list)
# MuonClip QK-Clip after the Muon update. No-op unless qk_tau > 0.
_apply_qk_clip(self.param_groups)

View File

@ -63,6 +63,7 @@ parser.add_argument("--embedding-lr", type=float, default=0.3, help="learning ra
parser.add_argument("--unembedding-lr", type=float, default=0.008, help="learning rate for unembedding parameters (Adam)")
parser.add_argument("--weight-decay", type=float, default=0.28, help="cautious weight decay for the Muon optimizer (for weights)")
parser.add_argument("--matrix-lr", type=float, default=0.02, help="learning rate for matrix parameters (Muon)")
parser.add_argument("--muon-qk-clip-tau", type=float, default=0.0, help="MuonClip QK-Clip cap on max attention logit (Kimi K2, arxiv 2507.20534 §A). 0 = disabled. Typical: 100.")
parser.add_argument("--scalar-lr", type=float, default=0.5, help="learning rate for scalars (resid_lambdas, x0_lambdas)")
parser.add_argument("--warmup-steps", type=int, default=40, help="number of steps for LR warmup")
parser.add_argument("--warmdown-ratio", type=float, default=0.65, help="ratio of iterations for LR warmdown")
@ -313,6 +314,7 @@ optimizer = model.setup_optimizer(
# Muon hyperparameters
matrix_lr=args.matrix_lr * batch_lr_scale,
weight_decay=weight_decay_scaled,
muon_qk_clip_tau=args.muon_qk_clip_tau,
)
if resuming: