mirror of
https://github.com/karpathy/nanochat.git
synced 2026-05-13 11:20:21 +00:00
Merge 1e7810ddaa into dc54a1a307
This commit is contained in:
commit
b3a195a398
|
|
@ -371,19 +371,30 @@ class GPT(nn.Module):
|
|||
'total': total,
|
||||
}
|
||||
|
||||
def setup_optimizer(self, unembedding_lr=0.004, embedding_lr=0.2, matrix_lr=0.02, weight_decay=0.0, scalar_lr=0.5):
|
||||
def setup_optimizer(self, unembedding_lr=0.004, embedding_lr=0.2, matrix_lr=0.02, weight_decay=0.0, scalar_lr=0.5, muon_qk_clip_tau=0.0):
|
||||
model_dim = self.config.n_embd
|
||||
ddp, rank, local_rank, world_size = get_dist_info()
|
||||
|
||||
# Separate out all parameters into groups
|
||||
matrix_params = list(self.transformer.h.parameters())
|
||||
# When MuonClip QK-Clip is enabled, pull c_q/c_k out of matrix_params into a
|
||||
# dedicated Muon group tagged for spectral-norm capping (Kimi K2 §A).
|
||||
qk_params = []
|
||||
if muon_qk_clip_tau > 0.0:
|
||||
qk_param_ids = set()
|
||||
for block in self.transformer.h:
|
||||
qk_params.append(block.attn.c_q.weight)
|
||||
qk_params.append(block.attn.c_k.weight)
|
||||
qk_param_ids.add(id(block.attn.c_q.weight))
|
||||
qk_param_ids.add(id(block.attn.c_k.weight))
|
||||
matrix_params = [p for p in matrix_params if id(p) not in qk_param_ids]
|
||||
value_embeds_params = list(self.value_embeds.parameters())
|
||||
embedding_params = list(self.transformer.wte.parameters())
|
||||
lm_head_params = list(self.lm_head.parameters())
|
||||
resid_params = [self.resid_lambdas]
|
||||
x0_params = [self.x0_lambdas]
|
||||
smear_params = [self.smear_gate.weight, self.smear_lambda, self.backout_lambda]
|
||||
assert len(list(self.parameters())) == len(matrix_params) + len(embedding_params) + len(lm_head_params) + len(value_embeds_params) + len(resid_params) + len(x0_params) + len(smear_params)
|
||||
assert len(list(self.parameters())) == len(matrix_params) + len(qk_params) + len(embedding_params) + len(lm_head_params) + len(value_embeds_params) + len(resid_params) + len(x0_params) + len(smear_params)
|
||||
|
||||
# Scale the LR for the AdamW parameters by ∝1/√dmodel (tuned for 768 dim model)
|
||||
dmodel_lr_scale = (model_dim / 768) ** -0.5
|
||||
|
|
@ -406,6 +417,15 @@ class GPT(nn.Module):
|
|||
kind='muon', params=group_params, lr=matrix_lr,
|
||||
momentum=0.95, ns_steps=5, beta2=0.9, weight_decay=weight_decay,
|
||||
))
|
||||
# Dedicated Muon group for QK params when MuonClip is enabled (Kimi K2 §A).
|
||||
if qk_params:
|
||||
for shape in sorted({p.shape for p in qk_params}):
|
||||
group_params = [p for p in qk_params if p.shape == shape]
|
||||
param_groups.append(dict(
|
||||
kind='muon', params=group_params, lr=matrix_lr,
|
||||
momentum=0.95, ns_steps=5, beta2=0.9, weight_decay=weight_decay,
|
||||
is_qk=True, qk_tau=muon_qk_clip_tau,
|
||||
))
|
||||
|
||||
Factory = DistMuonAdamW if ddp else MuonAdamW
|
||||
optimizer = Factory(param_groups)
|
||||
|
|
|
|||
|
|
@ -147,6 +147,43 @@ def muon_step_fused(
|
|||
mask = (g * stacked_params) >= 0
|
||||
stacked_params.sub_(lr * g + lr * wd * stacked_params * mask)
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# MuonClip QK-Clip (Kimi K2, arxiv 2507.20534 §A)
|
||||
#
|
||||
# Muon orthogonalizes W_Q and W_K each step, which can cause the attention logits
|
||||
# Q @ K^T to grow unbounded over training and blow up softmax. QK-Clip rescales
|
||||
# W_Q and W_K *after* the Muon step so that their combined spectral norm product
|
||||
# is bounded by sqrt(tau), capping the max attention logit magnitude at ~tau.
|
||||
#
|
||||
# We use the cheap Frobenius-over-sqrt(min_dim) approximation of the spectral
|
||||
# norm (upper bound). Conservative (overestimates), so when it triggers a
|
||||
# rescale it's safe; when it doesn't trigger we know we're fine.
|
||||
|
||||
@torch.no_grad()
|
||||
def _apply_qk_clip(param_groups: list[dict]) -> None:
|
||||
"""Rescale Muon param groups marked `is_qk` so spectral norm <= sqrt(tau).
|
||||
No-op if tau <= 0. Safe to call every step. Identical across DDP ranks since
|
||||
qk weights are replicated post-gather."""
|
||||
for group in param_groups:
|
||||
if group.get('kind') != 'muon':
|
||||
continue
|
||||
if not group.get('is_qk', False):
|
||||
continue
|
||||
tau = float(group.get('qk_tau', 0.0))
|
||||
if tau <= 0.0:
|
||||
continue
|
||||
target = tau ** 0.5
|
||||
for p in group['params']:
|
||||
if p.ndim < 2:
|
||||
continue
|
||||
frob = p.detach().float().norm()
|
||||
min_dim = min(p.shape[-2], p.shape[-1])
|
||||
spec_est = frob / (min_dim ** 0.5)
|
||||
if spec_est > target:
|
||||
scale = target / spec_est.clamp_min(1e-12)
|
||||
p.data.mul_(scale.to(p.dtype))
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Single GPU version of the MuonAdamW optimizer.
|
||||
# Used mostly for reference, debugging and testing.
|
||||
|
|
@ -291,6 +328,8 @@ class MuonAdamW(torch.optim.Optimizer):
|
|||
self._step_muon(group)
|
||||
else:
|
||||
raise ValueError(f"Unknown optimizer kind: {group['kind']}")
|
||||
# MuonClip QK-Clip after Muon update. No-op unless qk_tau > 0.
|
||||
_apply_qk_clip(self.param_groups)
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Distributed version of the MuonAdamW optimizer.
|
||||
|
|
@ -533,3 +572,6 @@ class DistMuonAdamW(torch.optim.Optimizer):
|
|||
|
||||
# Phase 3: wait for gathers, copy back
|
||||
self._finish_gathers(gather_list)
|
||||
|
||||
# MuonClip QK-Clip after the Muon update. No-op unless qk_tau > 0.
|
||||
_apply_qk_clip(self.param_groups)
|
||||
|
|
|
|||
|
|
@ -69,8 +69,13 @@ python -m scripts.tok_eval
|
|||
echo "Waiting for dataset download to complete..."
|
||||
wait $DATASET_DOWNLOAD_PID
|
||||
|
||||
# d24 model (slightly undertrained to beat GPT-2 => decrease data:params ratio from compute optimal 10.5 (default) to 8)
|
||||
torchrun --standalone --nproc_per_node=8 -m scripts.base_train -- --depth=24 --target-param-data-ratio=8 --device-batch-size=16 --fp8 --run=$WANDB_RUN
|
||||
# d22 model overtrained relative to compute-optimal 10.5 (mirror of Run 6's d24+ratio=8
|
||||
# undertrained strategy from the other side of compute-optimal — d22 is below GPT-2
|
||||
# capability so we overtrain at ratio=11.05). Combined with --warmdown-ratio=0.85 (longer
|
||||
# low-LR tail), --final-lr-frac=0.0 (full LR decay floor; Hägele et al. arxiv 2405.18392),
|
||||
# and --muon-qk-clip-tau=100 (Kimi K2 §A QK-Clip) the recipe crosses GPT-2 CORE in ~88 min
|
||||
# — ~10.9% less wall-clock than Run 6 — at CORE 0.2665, val_bpb 0.7242.
|
||||
torchrun --standalone --nproc_per_node=8 -m scripts.base_train -- --depth=22 --target-param-data-ratio=11.05 --total-batch-size=1048576 --device-batch-size=16 --warmdown-ratio=0.85 --final-lr-frac=0.0 --muon-qk-clip-tau=100 --fp8 --run=$WANDB_RUN
|
||||
# evaluate the model: CORE metric, BPB on train/val, and draw samples
|
||||
torchrun --standalone --nproc_per_node=8 -m scripts.base_eval -- --device-batch-size=16
|
||||
|
||||
|
|
|
|||
|
|
@ -63,6 +63,7 @@ parser.add_argument("--embedding-lr", type=float, default=0.3, help="learning ra
|
|||
parser.add_argument("--unembedding-lr", type=float, default=0.008, help="learning rate for unembedding parameters (Adam)")
|
||||
parser.add_argument("--weight-decay", type=float, default=0.28, help="cautious weight decay for the Muon optimizer (for weights)")
|
||||
parser.add_argument("--matrix-lr", type=float, default=0.02, help="learning rate for matrix parameters (Muon)")
|
||||
parser.add_argument("--muon-qk-clip-tau", type=float, default=0.0, help="MuonClip QK-Clip cap on max attention logit (Kimi K2, arxiv 2507.20534 §A). 0 = disabled. Typical: 100.")
|
||||
parser.add_argument("--scalar-lr", type=float, default=0.5, help="learning rate for scalars (resid_lambdas, x0_lambdas)")
|
||||
parser.add_argument("--warmup-steps", type=int, default=40, help="number of steps for LR warmup")
|
||||
parser.add_argument("--warmdown-ratio", type=float, default=0.65, help="ratio of iterations for LR warmdown")
|
||||
|
|
@ -313,6 +314,7 @@ optimizer = model.setup_optimizer(
|
|||
# Muon hyperparameters
|
||||
matrix_lr=args.matrix_lr * batch_lr_scale,
|
||||
weight_decay=weight_decay_scaled,
|
||||
muon_qk_clip_tau=args.muon_qk_clip_tau,
|
||||
)
|
||||
|
||||
if resuming:
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user