adding muP considerations for muon

This commit is contained in:
Amrit Bulusu 2026-03-12 00:45:48 -04:00
parent c7ba252142
commit f07ca1b268
4 changed files with 1442 additions and 11 deletions

View File

@ -37,6 +37,9 @@ class GPTConfig:
# Characters: L=long (full context), S=short (half context)
# Examples: "L"=all full context, "SL"=alternating, "SSL"=two short then one long
window_pattern: str = "SSSL"
# muP (Maximal Update Parametrization): set > 0 to enable. Value is the base/proxy width.
# Enables: non-zero c_proj init scaled as 1/sqrt(m_d), output logit scaling by base_width/n_embd.
mup_base_width: int = 0
def norm(x):
@ -203,7 +206,12 @@ class GPT(nn.Module):
# Embedding and unembedding
torch.nn.init.normal_(self.transformer.wte.weight, mean=0.0, std=1.0)
torch.nn.init.normal_(self.lm_head.weight, mean=0.0, std=0.001)
lm_head_std = 0.001
if self.config.mup_base_width > 0:
# muP: scale lm_head init by 1/sqrt(m_d) so raw logit magnitude is O(1) across widths.
# Without this, |logit| ~ 0.001 * sqrt(n_embd) grows as sqrt(width).
lm_head_std *= (self.config.mup_base_width / self.config.n_embd) ** 0.5
torch.nn.init.normal_(self.lm_head.weight, mean=0.0, std=lm_head_std)
# Transformer blocks: uniform init with bound = sqrt(3) * std (same standard deviation as normal)
n_embd = self.config.n_embd
@ -212,9 +220,15 @@ class GPT(nn.Module):
torch.nn.init.uniform_(block.attn.c_q.weight, -s, s) # weights use Uniform to avoid outliers
torch.nn.init.uniform_(block.attn.c_k.weight, -s, s)
torch.nn.init.uniform_(block.attn.c_v.weight, -s, s)
torch.nn.init.zeros_(block.attn.c_proj.weight) # projections are zero
torch.nn.init.uniform_(block.mlp.c_fc.weight, -s, s)
torch.nn.init.zeros_(block.mlp.c_proj.weight)
if self.config.mup_base_width > 0:
# muP: output projections use same scale as hidden weights (std = sigma_base/sqrt(m_d))
# Zero init causes attn/FFN outputs to vanish as width increases with muP LR scaling
torch.nn.init.uniform_(block.attn.c_proj.weight, -s, s)
torch.nn.init.uniform_(block.mlp.c_proj.weight, -s, s)
else:
torch.nn.init.zeros_(block.attn.c_proj.weight) # SP: projections are zero
torch.nn.init.zeros_(block.mlp.c_proj.weight)
# Per-layer scalars
self.resid_lambdas.fill_(1.0) # 1.0 => typical residual connections at init
@ -345,7 +359,7 @@ class GPT(nn.Module):
'total': total,
}
def setup_optimizer(self, unembedding_lr=0.004, embedding_lr=0.2, matrix_lr=0.02, weight_decay=0.0, adam_betas=(0.8, 0.95), scalar_lr=0.5):
def setup_optimizer(self, unembedding_lr=0.004, embedding_lr=0.2, matrix_lr=0.02, weight_decay=0.0, adam_betas=(0.8, 0.95), scalar_lr=0.5, use_mup=False, base_width=256, muon_lr_exponent=0.0):
model_dim = self.config.n_embd
ddp, rank, local_rank, world_size = get_dist_info()
@ -358,16 +372,41 @@ class GPT(nn.Module):
x0_params = [self.x0_lambdas]
assert len(list(self.parameters())) == len(matrix_params) + len(embedding_params) + len(lm_head_params) + len(value_embeds_params) + len(resid_params) + len(x0_params)
# Scale the LR for the AdamW parameters by ∝1/√dmodel (tuned for 768 dim model)
dmodel_lr_scale = (model_dim / 768) ** -0.5
print0(f"Scaling the LR for the AdamW parameters ∝1/√({model_dim}/768) = {dmodel_lr_scale:.6f}")
# Compute LR scaling factors based on mode
if use_mup:
# muP for mixed Muon+AdamW optimizer:
#
# Key insight: the output logit scaling (logits *= base/width in forward()) already
# propagates a base/width factor into the lm_head gradient. Adam normalizes gradient
# magnitude via its second moment, but applying ANOTHER base/width to the LR compounds
# the scaling, making lm_head updates O(base²/width²) instead of O(1). So we do NOT
# apply width-dependent LR scaling to the output layer — the logit scaling alone suffices.
#
# For Muon (hidden weights): Polar Express orthogonalization normalizes ||update||_F ≈ 1
# regardless of width, making the update already O(1). No width-dependent LR scaling
# is needed (empirically confirmed: exponent 0 and 1 give identical transfer behavior).
#
# The muon_lr_exponent parameter is kept for experimentation but defaults to 0.
width_ratio = base_width / model_dim # e.g., 128/1024 = 0.125
emb_lr_scale = 1.0 # Embeddings: NO width scaling (standard muP)
hidden_lr_scale = width_ratio ** muon_lr_exponent # Hidden (Muon): default 0 = no scaling
output_lr_scale = 1.0 # Output (AdamW): NO LR scaling (logit scaling in forward suffices)
self.config.mup_base_width = base_width # enables output logit scaling in forward()
print0(f"muP scaling: base_width={base_width}, model_dim={model_dim}, width_ratio={width_ratio:.6f}, muon_lr_exp={muon_lr_exponent}")
else:
# Standard (SP): scale AdamW params by 1/√dmodel (tuned for 768 dim model)
dmodel_lr_scale = (model_dim / 768) ** -0.5
emb_lr_scale = dmodel_lr_scale
hidden_lr_scale = 1.0 # Muon params: no scaling in SP mode
output_lr_scale = dmodel_lr_scale
print0(f"Standard scaling: dmodel_lr_scale={dmodel_lr_scale:.6f}")
# Build param_groups with all required fields explicit
param_groups = [
# AdamW groups (embeddings, lm_head, scalars)
dict(kind='adamw', params=lm_head_params, lr=unembedding_lr * dmodel_lr_scale, betas=adam_betas, eps=1e-10, weight_decay=0.0),
dict(kind='adamw', params=embedding_params, lr=embedding_lr * dmodel_lr_scale, betas=adam_betas, eps=1e-10, weight_decay=0.0),
dict(kind='adamw', params=value_embeds_params, lr=embedding_lr * dmodel_lr_scale, betas=adam_betas, eps=1e-10, weight_decay=0.0),
dict(kind='adamw', params=lm_head_params, lr=unembedding_lr * output_lr_scale, betas=adam_betas, eps=1e-10, weight_decay=0.0),
dict(kind='adamw', params=embedding_params, lr=embedding_lr * emb_lr_scale, betas=adam_betas, eps=1e-10, weight_decay=0.0),
dict(kind='adamw', params=value_embeds_params, lr=embedding_lr * emb_lr_scale, betas=adam_betas, eps=1e-10, weight_decay=0.0),
dict(kind='adamw', params=resid_params, lr=scalar_lr * 0.01, betas=adam_betas, eps=1e-10, weight_decay=0.0),
dict(kind='adamw', params=x0_params, lr=scalar_lr, betas=(0.96, 0.95), eps=1e-10, weight_decay=0.0), # higher beta1 for x0
]
@ -375,7 +414,7 @@ class GPT(nn.Module):
for shape in sorted({p.shape for p in matrix_params}):
group_params = [p for p in matrix_params if p.shape == shape]
param_groups.append(dict(
kind='muon', params=group_params, lr=matrix_lr,
kind='muon', params=group_params, lr=matrix_lr * hidden_lr_scale,
momentum=0.95, ns_steps=5, beta2=0.95, weight_decay=weight_decay,
))
@ -409,6 +448,11 @@ class GPT(nn.Module):
# Forward the lm_head (compute logits)
softcap = 15 # smoothly cap the logits to the range [-softcap, softcap]
logits = self.lm_head(x) # (B, T, padded_vocab_size) <- very big tensor, large amount of memory
if self.config.mup_base_width > 0:
# muP: scale output logits by base_width/n_embd (= 1/m_d).
# Without this, logits grow with width because the lm_head dot product sums over n_embd terms.
# 1/sqrt(m_d) only corrects at init; 1/m_d is required for all training steps (see Eleuther blog Fig 8-9).
logits = logits * (self.config.mup_base_width / self.config.n_embd)
logits = logits[..., :self.config.vocab_size] # slice to remove padding
logits = logits.float() # switch to fp32 for logit softcap and loss computation
logits = softcap * torch.tanh(logits / softcap) # squash the logits

View File

@ -68,6 +68,8 @@ parser.add_argument("--adam-beta2", type=float, default=0.95, help="Adam beta2 f
parser.add_argument("--warmup-ratio", type=float, default=0.0, help="ratio of iterations for LR warmup")
parser.add_argument("--warmdown-ratio", type=float, default=0.5, help="ratio of iterations for LR warmdown")
parser.add_argument("--final-lr-frac", type=float, default=0.0, help="final LR as fraction of initial LR")
parser.add_argument("--use-mup", action="store_true", help="use muP (Maximal Update Parameterization) LR scaling")
parser.add_argument("--base-width", type=int, default=256, help="base width for muP LR scaling (LRs tuned at this width)")
parser.add_argument("--resume-from-step", type=int, default=-1, help="resume training from this step (-1 = disable)")
# Evaluation
parser.add_argument("--eval-every", type=int, default=250, help="evaluate val bpb every N steps (-1 = disable)")
@ -309,6 +311,9 @@ optimizer = model.setup_optimizer(
# Muon hyperparameters
matrix_lr=args.matrix_lr * batch_lr_scale,
weight_decay=weight_decay_scaled,
# muP scaling
use_mup=args.use_mup,
base_width=args.base_width,
)
if resuming:

710
scripts/mup_coord_check.py Normal file
View File

@ -0,0 +1,710 @@
"""
muP Coordinate Check for nanochat
This script validates muP implementation by checking that activation magnitudes
are independent of model width. Based on EleutherAI's nanoGPT-mup and Microsoft's
mup library.
Reference: https://blog.eleuther.ai/mutransfer/
Reference: Yang et al., "Tensor Programs V: Tuning Large Neural Networks via Zero-Shot
Hyperparameter Transfer" (arXiv:2203.03466), Sections B.1 and F.
Usage:
python -m scripts.mup_coord_check --widths 128,256,512,1024 --steps 10
python -m scripts.mup_coord_check --use-mup --widths 128,256,512,1024
python -m scripts.mup_coord_check --compare --detailed
python -m scripts.mup_coord_check --compare --muon-lr-exponent 0.5
"""
import argparse
import torch
import torch._dynamo
torch._dynamo.config.disable = True
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np
from collections import defaultdict
from dataclasses import dataclass
from typing import Dict, List, Tuple, Optional
import os
from nanochat.gpt import GPT, GPTConfig
def load_batch(batch_size: int, seq_len: int, device: torch.device):
"""Load a single batch from the nanochat training pipeline.
Falls back to random data if the tokenizer/dataset isn't available."""
try:
from nanochat.tokenizer import get_tokenizer
from nanochat.dataloader import tokenizing_distributed_data_loader_bos_bestfit
tokenizer = get_tokenizer()
vocab_size = tokenizer.get_vocab_size()
loader = tokenizing_distributed_data_loader_bos_bestfit(
tokenizer, batch_size, seq_len, split="train", device=device,
)
x, y = next(loader)
print(f"Loaded real training data (vocab_size={vocab_size})")
return x, y, vocab_size
except Exception as e:
print(f"Could not load training data ({e}), using random tokens")
vocab_size = 32768
rng = torch.Generator(device=device)
rng.manual_seed(42)
x = torch.randint(0, vocab_size, (batch_size, seq_len), device=device, generator=rng)
y = torch.roll(x, -1, dims=1)
y[:, -1] = -1
return x, y, vocab_size
@dataclass
class CoordCheckConfig:
widths: List[int]
num_steps: int = 10
batch_size: int = 4
seq_len: int = 128
vocab_size: int = 32768
n_layer: int = 2
seed: int = 42
use_mup: bool = False
base_width: int = 128
# Learning rates (tuned at base_width)
matrix_lr: float = 0.02
embedding_lr: float = 0.2
unembedding_lr: float = 0.004
# Detailed diagnostics
detailed: bool = False
# Muon LR exponent: 1.0 = base/width (standard muP), 0.5 = sqrt(base/width)
# Paper Section C.1: Frobenius-normalizing optimizers may need exponent 0.5
muon_lr_exponent: float = 0.0
class ActivationRecorder:
"""Records activation statistics during forward pass using hooks."""
def __init__(self, detailed: bool = False):
self.stats: Dict[str, List[float]] = defaultdict(list)
self.hooks = []
self.detailed = detailed
def _get_stat(self, tensor: torch.Tensor) -> float:
"""Compute mean absolute value (l1 norm per element)."""
if tensor is None:
return 0.0
if tensor.dtype == torch.bool:
return tensor.float().abs().mean().item()
return tensor.float().abs().mean().item()
def _make_hook(self, name: str):
"""Create a forward hook that records output statistics."""
def hook(module, input, output):
if isinstance(output, tuple):
output = output[0]
if output is not None and isinstance(output, torch.Tensor):
self.stats[name].append(self._get_stat(output))
return hook
def _make_attn_logit_hook(self, name: str, n_head: int, n_kv_head: int, head_dim: int):
"""Create a hook on c_k that computes pre-softmax attention logit magnitudes.
We hook onto c_k's forward, then use the most recent c_q output to compute
q @ k^T / sqrt(d) for a single batch element to measure attention logit scale.
"""
# We'll store q output and compute logits when k is available
self._last_q = None
def q_hook(module, input, output):
self._last_q = output.detach()
def k_hook(module, input, output):
if self._last_q is None:
return
q = self._last_q
k = output.detach()
B, T, _ = q.shape
q = q[0:1].view(1, T, n_head, head_dim)
k = k[0:1].view(1, T, n_kv_head, head_dim)
# Apply QK norm (same as model)
q = F.rms_norm(q, (q.size(-1),))
k = F.rms_norm(k, (k.size(-1),))
# Expand k for GQA
if n_head != n_kv_head:
k = k.repeat_interleave(n_head // n_kv_head, dim=2)
# Compute logits: q @ k^T / sqrt(d) — just for first few positions
T_sub = min(T, 32)
q_sub = q[:, :T_sub].transpose(1, 2) # (1, H, T_sub, D)
k_sub = k[:, :T_sub].transpose(1, 2) # (1, H, T_sub, D)
logits = torch.matmul(q_sub, k_sub.transpose(-2, -1)) / (head_dim ** 0.5)
self.stats[name].append(logits.float().abs().mean().item())
self._last_q = None
return q_hook, k_hook
def register_hooks(self, model: GPT) -> None:
"""Register forward hooks on key layers."""
# Embedding
h = model.transformer.wte.register_forward_hook(self._make_hook('word embedding'))
self.hooks.append(h)
# Each transformer block
for i, block in enumerate(model.transformer.h):
# Attention output
h = block.attn.c_proj.register_forward_hook(self._make_hook(f'attn output.{i}'))
self.hooks.append(h)
# MLP output
h = block.mlp.c_proj.register_forward_hook(self._make_hook(f'FFN output.{i}'))
self.hooks.append(h)
# Detailed: attention logit magnitudes
if self.detailed:
n_head = block.attn.n_head
n_kv_head = block.attn.n_kv_head
head_dim = block.attn.head_dim
q_hook, k_hook = self._make_attn_logit_hook(
f'attn logits.{i}', n_head, n_kv_head, head_dim)
h1 = block.attn.c_q.register_forward_hook(q_hook)
h2 = block.attn.c_k.register_forward_hook(k_hook)
self.hooks.extend([h1, h2])
# LM head
h = model.lm_head.register_forward_hook(self._make_hook('output logits'))
self.hooks.append(h)
def remove_hooks(self) -> None:
"""Remove all registered hooks."""
for h in self.hooks:
h.remove()
self.hooks = []
def get_step_stats(self) -> Dict[str, float]:
"""Get mean stats for the current step and reset."""
step_stats = {}
for name, values in self.stats.items():
if values:
step_stats[name] = np.mean(values)
self.stats = defaultdict(list)
return step_stats
def create_model(width: int, config: CoordCheckConfig, device: torch.device, mup_base_width: int = 0) -> Tuple[GPT, GPTConfig]:
"""Create a model with the specified width."""
head_dim = 64
n_head = max(1, width // head_dim)
actual_width = n_head * head_dim
gpt_config = GPTConfig(
sequence_len=config.seq_len,
vocab_size=config.vocab_size,
n_layer=config.n_layer,
n_head=n_head,
n_kv_head=n_head,
n_embd=actual_width,
window_pattern="L",
mup_base_width=mup_base_width,
)
with torch.device('meta'):
model = GPT(gpt_config)
model.to_empty(device=device)
model.init_weights()
return model, gpt_config
def setup_optimizer_mup(model: GPT, config: CoordCheckConfig, width: int):
"""Set up optimizer with muP scaling using the native use_mup flag."""
optimizer = model.setup_optimizer(
unembedding_lr=config.unembedding_lr,
embedding_lr=config.embedding_lr,
matrix_lr=config.matrix_lr,
weight_decay=0.0,
use_mup=True,
base_width=config.base_width,
muon_lr_exponent=config.muon_lr_exponent,
)
return optimizer
def setup_optimizer_sp(model: GPT, config: CoordCheckConfig, width: int):
"""Set up optimizer with standard parameterization (current nanochat)."""
optimizer = model.setup_optimizer(
unembedding_lr=config.unembedding_lr,
embedding_lr=config.embedding_lr,
matrix_lr=config.matrix_lr,
weight_decay=0.0,
use_mup=False,
)
return optimizer
def record_detailed_stats(model: GPT, results: Dict, width: int, step: int):
"""Record weight update norms and gradient norms per parameter group."""
for name, p in model.named_parameters():
if p.grad is None:
continue
# Simplify name for display
short_name = name.replace('transformer.', '').replace('.weight', '')
# Gradient norm
grad_norm = p.grad.float().norm().item()
results['detailed_stats'][width][f'grad norm: {short_name}'].append(grad_norm)
def record_weight_update_norms(model: GPT, params_before: Dict[str, torch.Tensor],
results: Dict, width: int):
"""Record ||delta_W|| for each parameter after optimizer step."""
for name, p in model.named_parameters():
if name not in params_before:
continue
short_name = name.replace('transformer.', '').replace('.weight', '')
delta = (p.data.float() - params_before[name]).norm().item()
results['detailed_stats'][width][f'update norm: {short_name}'].append(delta)
def run_coord_check(config: CoordCheckConfig, device: torch.device,
x: torch.Tensor, y: torch.Tensor) -> Dict:
"""Run coordinate check across all widths."""
results = {
'widths': [],
'steps': list(range(config.num_steps)),
'stats': defaultdict(lambda: defaultdict(list)),
'losses': defaultdict(list),
'detailed_stats': defaultdict(lambda: defaultdict(list)),
}
for width in config.widths:
print(f"\nTraining width={width}...")
torch.manual_seed(config.seed)
mup_base_width = config.base_width if config.use_mup else 0
model, gpt_config = create_model(width, config, device, mup_base_width=mup_base_width)
actual_width = gpt_config.n_embd
results['widths'].append(actual_width)
if config.use_mup:
optimizer = setup_optimizer_mup(model, config, actual_width)
else:
optimizer = setup_optimizer_sp(model, config, actual_width)
recorder = ActivationRecorder(detailed=config.detailed)
recorder.register_hooks(model)
model.train()
for step in range(config.num_steps):
with torch.amp.autocast(device_type='cuda', dtype=torch.bfloat16, enabled=(device.type == 'cuda')):
loss = model(x, y)
results['losses'][actual_width].append(loss.item())
step_stats = recorder.get_step_stats()
for layer, value in step_stats.items():
results['stats'][actual_width][layer].append(value)
if step == 0:
print(f" Step {step}: loss={loss.item():.4f}, layers={list(step_stats.keys())}")
# Record gradient norms before step (detailed mode)
loss.backward()
if config.detailed:
record_detailed_stats(model, results, actual_width, step)
# Snapshot params before optimizer step to compute update norms
params_before = {name: p.data.float().clone()
for name, p in model.named_parameters()
if p.grad is not None}
optimizer.step()
if config.detailed:
record_weight_update_norms(model, params_before, results, actual_width)
optimizer.zero_grad(set_to_none=True)
print(f" Final loss: {loss.item():.4f}")
recorder.remove_hooks()
del model, optimizer
torch.cuda.empty_cache() if torch.cuda.is_available() else None
return results
def plot_coord_check(results: Dict, config: CoordCheckConfig, save_path: Optional[str] = None):
"""Plot coordinate check: one subplot per layer, x=width (log2), y=mean |activation|, lines=steps."""
widths = results['widths']
steps = results['steps']
stats = results['stats']
layer_names = list(stats[widths[0]].keys())
n_layers = len(layer_names)
n_cols = 4
n_rows = (n_layers + n_cols - 1) // n_cols
param_type = "muP" if config.use_mup else "SP"
fig, axes = plt.subplots(n_rows, n_cols, figsize=(4 * n_cols, 3 * n_rows))
axes = np.array(axes).flatten()
step_colors = plt.cm.plasma(np.linspace(0, 1, len(steps)))
for i, layer in enumerate(layer_names):
ax = axes[i]
for s, step in enumerate(steps):
values = [stats[w][layer][s] for w in widths]
ax.plot(widths, values, 'o-', color=step_colors[s], linewidth=1.5,
label=f'step {step}' if i == 0 else None)
ax.set_xscale('log', base=2)
ax.set_xticks(widths)
ax.set_xticklabels(widths, fontsize=7)
ax.set_title(layer, fontsize=9)
ax.set_xlabel('Width')
ax.set_ylabel('Mean |activation|')
ax.grid(True, alpha=0.3)
axes[0].legend(fontsize=7, loc='best')
for i in range(n_layers, len(axes)):
axes[i].set_visible(False)
fig.suptitle(f'Coordinate Check ({param_type}): Activation Magnitude vs Width', fontsize=14)
plt.tight_layout()
if save_path:
plt.savefig(save_path, dpi=150, bbox_inches='tight')
print(f"Saved plot to {save_path}")
plt.show()
def plot_loss_curves(results: Dict, config: CoordCheckConfig, title: str = "", save_path: Optional[str] = None):
"""Plot loss curves across widths to verify HP transfer."""
widths = results['widths']
steps = results['steps']
losses = results['losses']
fig, ax = plt.subplots(figsize=(5 * 2, 4))
colors = plt.cm.viridis(np.linspace(0, 1, len(widths)))
for i, w in enumerate(widths):
ax.plot(steps, losses[w], label=f'width={w}', color=colors[i], linewidth=2)
ax.set_xlabel('Step')
ax.set_ylabel('Loss')
ax.set_title(f'Loss Curves Across Widths{" - " + title if title else ""}')
ax.legend()
ax.grid(True, alpha=0.3)
# Add annotation for final loss spread
final_losses = [losses[w][-1] for w in widths]
spread = max(final_losses) - min(final_losses)
ax.annotate(f'Final loss spread: {spread:.4f}', xy=(0.7, 0.95), xycoords='axes fraction', fontsize=10)
plt.tight_layout()
if save_path:
plt.savefig(save_path, dpi=150, bbox_inches='tight')
print(f"Saved loss curves to {save_path}")
plt.show()
def plot_comparison(results_sp: Dict, results_mup: Dict, config: CoordCheckConfig, save_path: Optional[str] = None):
"""Plot SP vs muP: one subplot per layer (left=SP, right=muP), x=width (log2), y=mean |activation|, lines=steps."""
widths = results_sp['widths']
steps = results_sp['steps']
layer_names = list(results_sp['stats'][widths[0]].keys())
n_layers = len(layer_names)
# n_layers activation rows + 1 loss row, 2 cols (SP | muP)
n_rows, n_cols = n_layers + 1, 2
fig, axes = plt.subplots(n_rows, n_cols, figsize=(5 * n_cols, 3 * n_rows))
step_colors = plt.cm.plasma(np.linspace(0, 1, len(steps)))
width_colors = plt.cm.viridis(np.linspace(0, 1, len(widths)))
for row, layer in enumerate(layer_names):
# Shared y-axis range across SP and muP for this layer
all_vals = [results_sp['stats'][w][layer][s] for w in widths for s in range(len(steps))] + \
[results_mup['stats'][w][layer][s] for w in widths for s in range(len(steps))]
y_min, y_max = min(all_vals) * 0.9, max(all_vals) * 1.1
for col, (results, label) in enumerate([(results_sp, 'SP'), (results_mup, 'muP')]):
ax = axes[row, col]
for s, step in enumerate(steps):
values = [results['stats'][w][layer][s] for w in widths]
ax.plot(widths, values, 'o-', color=step_colors[s], linewidth=1.5,
label=f'step {step}' if (row == 0 and col == 0) else None)
ax.set_xscale('log', base=2)
ax.set_xticks(widths)
ax.set_xticklabels(widths, fontsize=7)
ax.set_ylim(y_min, y_max)
ax.set_title(f'{label}: {layer}', fontsize=9)
ax.set_xlabel('Width')
ax.set_ylabel('Mean |activation|')
ax.grid(True, alpha=0.3)
axes[0, 0].legend(fontsize=7, loc='best')
# Loss curves row
all_losses = [v for r in (results_sp, results_mup) for w in widths for v in r['losses'][w]]
loss_min, loss_max = min(all_losses) * 0.95, max(all_losses) * 1.05
for col, (results, label) in enumerate([(results_sp, 'SP'), (results_mup, 'muP')]):
ax = axes[n_layers, col]
for j, w in enumerate(widths):
ax.plot(steps, results['losses'][w], label=f'w={w}', color=width_colors[j], linewidth=2)
ax.set_ylim(loss_min, loss_max)
ax.set_xlabel('Step')
ax.set_ylabel('Loss')
ax.set_title(f'{label}: Loss Curves')
ax.legend(fontsize=7)
ax.grid(True, alpha=0.3)
final_losses = [results['losses'][w][-1] for w in widths]
spread = max(final_losses) - min(final_losses)
ax.annotate(f'Spread: {spread:.4f}', xy=(0.65, 0.95), xycoords='axes fraction', fontsize=9)
fig.suptitle('Coordinate Check: SP vs muP', fontsize=14)
plt.tight_layout()
if save_path:
plt.savefig(save_path, dpi=150, bbox_inches='tight')
print(f"Saved comparison plot to {save_path}")
plt.show()
def plot_detailed(results: Dict, config: CoordCheckConfig, save_path: Optional[str] = None):
"""Plot detailed diagnostics: gradient norms, weight update norms, attention logits."""
widths = results['widths']
detailed = results['detailed_stats']
if not detailed or not detailed[widths[0]]:
print("No detailed stats recorded. Use --detailed flag.")
return
# Collect all detailed metric names
metric_names = sorted(detailed[widths[0]].keys())
# Group by category
categories = defaultdict(list)
for name in metric_names:
if name.startswith('grad norm:'):
categories['Gradient Norms'].append(name)
elif name.startswith('update norm:'):
categories['Weight Update Norms'].append(name)
elif name.startswith('attn logits'):
categories['Attention Logit Magnitudes'].append(name)
for cat_name, names in categories.items():
n = len(names)
n_cols = min(4, n)
n_rows = (n + n_cols - 1) // n_cols
fig, axes = plt.subplots(n_rows, n_cols, figsize=(4 * n_cols, 3 * n_rows))
if n == 1:
axes = np.array([axes])
axes = np.array(axes).flatten()
steps = results['steps']
width_colors = plt.cm.viridis(np.linspace(0, 1, len(widths)))
for i, name in enumerate(names):
ax = axes[i]
for j, w in enumerate(widths):
values = detailed[w].get(name, [])
if values:
ax.plot(range(len(values)), values, color=width_colors[j],
linewidth=1.5, label=f'w={w}' if i == 0 else None)
ax.set_title(name.split(': ', 1)[-1] if ': ' in name else name, fontsize=8)
ax.set_xlabel('Step')
ax.set_ylabel('Norm')
ax.grid(True, alpha=0.3)
ax.set_yscale('log')
for i in range(n, len(axes)):
axes[i].set_visible(False)
axes[0].legend(fontsize=7, loc='best')
param_type = "muP" if config.use_mup else "SP"
fig.suptitle(f'{cat_name} ({param_type})', fontsize=14)
plt.tight_layout()
if save_path:
cat_slug = cat_name.lower().replace(' ', '_')
path = save_path.replace('.png', f'_{cat_slug}.png')
plt.savefig(path, dpi=150, bbox_inches='tight')
print(f"Saved {cat_name} plot to {path}")
plt.show()
def compute_width_dependence(results: Dict) -> Dict[str, float]:
"""Compute how much activations scale with width (slope on log-log plot)."""
widths = np.array(results['widths'])
log_widths = np.log2(widths)
final_step = len(results['steps']) - 1
slopes = {}
for layer in results['stats'][widths[0]].keys():
values = [results['stats'][w][layer][final_step] for w in widths]
log_values = np.log2(np.array(values) + 1e-10)
slope, _ = np.polyfit(log_widths, log_values, 1)
slopes[layer] = slope
return slopes
def main():
parser = argparse.ArgumentParser(description='muP Coordinate Check')
parser.add_argument('--widths', type=str, default='128,256,512,1024',
help='Comma-separated list of widths to test')
parser.add_argument('--steps', type=int, default=10,
help='Number of training steps')
parser.add_argument('--batch-size', type=int, default=4,
help='Batch size')
parser.add_argument('--seq-len', type=int, default=128,
help='Sequence length')
parser.add_argument('--n-layer', type=int, default=2,
help='Number of transformer layers')
parser.add_argument('--use-mup', action='store_true',
help='Use muP learning rate scaling')
parser.add_argument('--base-width', type=int, default=128,
help='Base width for muP scaling')
parser.add_argument('--compare', action='store_true',
help='Run both SP and muP and compare')
parser.add_argument('--save-dir', type=str, default=None,
help='Directory to save plots')
parser.add_argument('--seed', type=int, default=42,
help='Random seed')
parser.add_argument('--detailed', action='store_true',
help='Record detailed diagnostics: gradient norms, weight update norms, '
'attention logit magnitudes')
parser.add_argument('--muon-lr-exponent', type=float, default=0.0,
help='Muon LR exponent for muP: 1.0 = (base/width)^1 (standard muP), '
'0.5 = (base/width)^0.5 (for Frobenius-normalizing optimizers, '
'see Yang et al. Section C.1)')
args = parser.parse_args()
# Parse widths
widths = [int(w) for w in args.widths.split(',')]
# Setup device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
# Load a single batch of real training data (reused every step)
x, y, vocab_size = load_batch(args.batch_size, args.seq_len, device)
# Create config
config = CoordCheckConfig(
widths=widths,
num_steps=args.steps,
batch_size=args.batch_size,
seq_len=args.seq_len,
vocab_size=vocab_size,
n_layer=args.n_layer,
seed=args.seed,
use_mup=args.use_mup,
base_width=args.base_width,
detailed=args.detailed,
muon_lr_exponent=args.muon_lr_exponent,
)
if args.compare:
# Run both SP and muP
print("\n" + "="*60)
print("Running Standard Parameterization (SP)")
print("="*60)
config.use_mup = False
results_sp = run_coord_check(config, device, x, y)
print("\n" + "="*60)
print("Running muP")
if config.muon_lr_exponent != 1.0:
print(f" (Muon LR exponent: {config.muon_lr_exponent})")
print("="*60)
config.use_mup = True
results_mup = run_coord_check(config, device, x, y)
# Compute slopes
print("\n" + "="*60)
print("Width Dependence (slope on log-log plot)")
print("Expected: ~0 for width-independent, positive = grows with width")
print("="*60)
slopes_sp = compute_width_dependence(results_sp)
slopes_mup = compute_width_dependence(results_mup)
print(f"\n{'Layer':<20} {'SP Slope':>12} {'muP Slope':>12}")
print("-"*46)
for layer in slopes_sp:
print(f"{layer:<20} {slopes_sp[layer]:>12.4f} {slopes_mup[layer]:>12.4f}")
# Plot comparison
save_path = None
if args.save_dir:
os.makedirs(args.save_dir, exist_ok=True)
save_path = os.path.join(args.save_dir, 'coord_check_comparison.png')
plot_comparison(results_sp, results_mup, config, save_path)
# Plot detailed diagnostics if requested
if config.detailed:
for results, label in [(results_sp, 'SP'), (results_mup, 'muP')]:
config.use_mup = (label == 'muP')
detail_save = None
if args.save_dir:
detail_save = os.path.join(args.save_dir, f'detailed_{label.lower()}.png')
plot_detailed(results, config, detail_save)
else:
# Run single mode
param_type = "muP" if config.use_mup else "SP"
print(f"\n{'='*60}")
print(f"Running Coordinate Check ({param_type})")
print(f"{'='*60}")
print(f"Widths: {widths}")
print(f"Steps: {config.num_steps}")
print(f"Base width: {config.base_width}")
if config.use_mup and config.muon_lr_exponent != 1.0:
print(f"Muon LR exponent: {config.muon_lr_exponent}")
results = run_coord_check(config, device, x, y)
# Compute slopes
slopes = compute_width_dependence(results)
print("\n" + "="*60)
print("Width Dependence (slope on log-log plot)")
print("Expected for muP: ~0 (width-independent)")
print("="*60)
for layer, slope in slopes.items():
status = "OK" if abs(slope) < 0.1 else "WARN"
print(f" {layer}: {slope:+.4f} [{status}]")
# Loss curve analysis
final_losses = [results['losses'][w][-1] for w in results['widths']]
loss_spread = max(final_losses) - min(final_losses)
print(f"\nFinal loss spread across widths: {loss_spread:.4f}")
print(f"Expected for muP: low spread (similar losses across widths)")
# Plot activations
save_path = None
if args.save_dir:
os.makedirs(args.save_dir, exist_ok=True)
save_path = os.path.join(args.save_dir, f'coord_check_{param_type.lower()}.png')
plot_coord_check(results, config, save_path)
# Plot loss curves
loss_save_path = None
if args.save_dir:
loss_save_path = os.path.join(args.save_dir, f'loss_curves_{param_type.lower()}.png')
plot_loss_curves(results, config, title=param_type, save_path=loss_save_path)
# Plot detailed diagnostics if requested
if config.detailed:
detail_save = None
if args.save_dir:
detail_save = os.path.join(args.save_dir, f'detailed_{param_type.lower()}.png')
plot_detailed(results, config, detail_save)
if __name__ == '__main__':
main()

View File

@ -0,0 +1,672 @@
"""
muP Transfer Check for nanochat
Validates that optimal learning rates transfer across model widths under muP.
For each width, sweeps over LR multipliers and records final loss. Under correct
muP, the optimal LR multiplier should be ~1.0 at all widths (i.e., the same LR
works everywhere). Under SP, the optimal LR typically shifts with width.
Reference: https://blog.eleuther.ai/mutransfer/
Reference: Yang et al., "Tensor Programs V: Tuning Large Neural Networks via Zero-Shot
Hyperparameter Transfer" (arXiv:2203.03466), Section F.
Usage:
# Quick check (~2 min on GPU)
python -m scripts.mup_transfer_check
# Compare SP vs muP side-by-side
python -m scripts.mup_transfer_check --compare
# Wide LR sweep (paper-style, ~1000x range)
python -m scripts.mup_transfer_check --compare --widths 128,256,512,1024 --steps 200
# Random log-uniform LR trials (paper-style methodology)
python -m scripts.mup_transfer_check --compare --num-random-trials 20
# Multi-HP sweep (init scale + output multiplier)
python -m scripts.mup_transfer_check --compare --sweep-init-scale --sweep-output-mult
# Save plots
python -m scripts.mup_transfer_check --compare --save-dir plots/
"""
import argparse
import torch
import torch._dynamo
torch._dynamo.config.disable = True
import matplotlib.pyplot as plt
import numpy as np
from collections import defaultdict
from dataclasses import dataclass, field
from typing import Dict, List, Optional
import os
from nanochat.gpt import GPT, GPTConfig
@dataclass
class TransferCheckConfig:
widths: List[int]
lr_multipliers: List[float]
num_steps: int = 200
batch_size: int = 8
seq_len: int = 128
vocab_size: int = 32768
n_layer: int = 2
seed: int = 42
use_mup: bool = False
base_width: int = 128
# Base learning rates (tuned at base_width)
matrix_lr: float = 0.02
embedding_lr: float = 0.2
unembedding_lr: float = 0.004
# Multi-HP sweeps
sweep_init_scale: bool = False
sweep_output_mult: bool = False
# Data diversity
num_batches: int = 1
# Muon LR exponent for muP (1.0=standard, 0.5=Frobenius-norm optimizers)
muon_lr_exponent: float = 0.0
# Sweep mode: which optimizer groups the LR multiplier applies to
# "all" = multiply all LRs (default), "muon-only" = only matrix_lr,
# "adamw-only" = only embedding_lr/unembedding_lr
sweep_mode: str = "all"
def load_batches(num_batches: int, batch_size: int, seq_len: int, device: torch.device):
"""Load multiple batches from the nanochat training pipeline.
Falls back to random data if the tokenizer/dataset isn't available."""
try:
from nanochat.tokenizer import get_tokenizer
from nanochat.dataloader import tokenizing_distributed_data_loader_bos_bestfit
tokenizer = get_tokenizer()
vocab_size = tokenizer.get_vocab_size()
loader = tokenizing_distributed_data_loader_bos_bestfit(
tokenizer, batch_size, seq_len, split="train", device=device,
)
batches = []
for i, (x, y) in enumerate(loader):
batches.append((x, y))
if len(batches) >= num_batches:
break
print(f"Loaded {len(batches)} real training batch(es) (vocab_size={vocab_size})")
return batches, vocab_size
except Exception as e:
print(f"Could not load training data ({e}), using random tokens")
vocab_size = 32768
batches = []
for i in range(num_batches):
rng = torch.Generator(device=device)
rng.manual_seed(42 + i)
x = torch.randint(0, vocab_size, (batch_size, seq_len), device=device, generator=rng)
y = torch.roll(x, -1, dims=1)
y[:, -1] = -1
batches.append((x, y))
return batches, vocab_size
def create_model(width: int, config: TransferCheckConfig, device: torch.device,
mup_base_width: int = 0, init_scale: float = 1.0,
output_mult: float = 1.0):
"""Create a model with the specified width and optional HP overrides."""
head_dim = 64
n_head = max(1, width // head_dim)
actual_width = n_head * head_dim
gpt_config = GPTConfig(
sequence_len=config.seq_len,
vocab_size=config.vocab_size,
n_layer=config.n_layer,
n_head=n_head,
n_kv_head=n_head,
n_embd=actual_width,
window_pattern="L",
mup_base_width=mup_base_width,
)
with torch.device('meta'):
model = GPT(gpt_config)
model.to_empty(device=device)
model.init_weights()
# Apply init_scale: multiply all parameter inits by scalar
if init_scale != 1.0:
with torch.no_grad():
for p in model.parameters():
p.mul_(init_scale)
# Apply output_mult: scale the output logit multiplier
# We store it as an attribute that forward() checks
model._transfer_check_output_mult = output_mult
return model, gpt_config
def train_model(width: int, lr_mult: float, config: TransferCheckConfig,
device: torch.device, batches: List,
init_scale: float = 1.0, output_mult: float = 1.0):
"""Train a model at given width and LR multiplier, return loss history."""
torch.manual_seed(config.seed)
mup_base_width = config.base_width if config.use_mup else 0
model, gpt_config = create_model(width, config, device, mup_base_width=mup_base_width,
init_scale=init_scale, output_mult=output_mult)
actual_width = gpt_config.n_embd
# Scale the learning rates by the multiplier, respecting sweep_mode
if config.sweep_mode == "muon-only":
muon_mult, adamw_mult = lr_mult, 1.0
elif config.sweep_mode == "adamw-only":
muon_mult, adamw_mult = 1.0, lr_mult
else: # "all"
muon_mult, adamw_mult = lr_mult, lr_mult
optimizer = model.setup_optimizer(
unembedding_lr=config.unembedding_lr * adamw_mult,
embedding_lr=config.embedding_lr * adamw_mult,
matrix_lr=config.matrix_lr * muon_mult,
weight_decay=0.0,
use_mup=config.use_mup,
base_width=config.base_width,
muon_lr_exponent=config.muon_lr_exponent,
)
model.train()
losses = []
num_batches = len(batches)
for step in range(config.num_steps):
x, y = batches[step % num_batches]
with torch.amp.autocast(device_type='cuda', dtype=torch.bfloat16, enabled=(device.type == 'cuda')):
loss = model(x, y)
losses.append(loss.item())
loss.backward()
optimizer.step()
optimizer.zero_grad(set_to_none=True)
del model, optimizer
if torch.cuda.is_available():
torch.cuda.empty_cache()
return losses, actual_width
def run_transfer_check(config: TransferCheckConfig, device: torch.device,
batches: List) -> Dict:
"""Run LR sweep across all widths."""
results = {
'widths': [],
'lr_multipliers': config.lr_multipliers,
'losses': {}, # losses[(width, lr_mult)] = [loss_step0, ...]
'final_losses': defaultdict(dict), # final_losses[width][lr_mult] = final_loss
}
for width in config.widths:
actual_width = None
for lr_mult in config.lr_multipliers:
print(f" width={width}, lr_mult={lr_mult:.4f}...", end=" ", flush=True)
losses, actual_width = train_model(width, lr_mult, config, device, batches)
results['losses'][(actual_width, lr_mult)] = losses
results['final_losses'][actual_width][lr_mult] = losses[-1]
print(f"final_loss={losses[-1]:.4f}")
if actual_width not in results['widths']:
results['widths'].append(actual_width)
return results
def run_hp_sweep(config: TransferCheckConfig, device: torch.device,
batches: List, hp_name: str, hp_values: List[float]) -> Dict:
"""Run a sweep over a single HP (init_scale or output_mult) at fixed LR."""
results = {
'widths': [],
'hp_values': hp_values,
'hp_name': hp_name,
'final_losses': defaultdict(dict),
}
for width in config.widths:
actual_width = None
for hp_val in hp_values:
init_scale = hp_val if hp_name == 'init_scale' else 1.0
output_mult = hp_val if hp_name == 'output_mult' else 1.0
print(f" width={width}, {hp_name}={hp_val:.4f}...", end=" ", flush=True)
losses, actual_width = train_model(
width, 1.0, config, device, batches,
init_scale=init_scale, output_mult=output_mult)
results['final_losses'][actual_width][hp_val] = losses[-1]
print(f"final_loss={losses[-1]:.4f}")
if actual_width not in results['widths']:
results['widths'].append(actual_width)
return results
def find_optimal_lr(final_losses: Dict[float, float]) -> float:
"""Find the LR multiplier with the lowest final loss."""
return min(final_losses, key=final_losses.get)
def plot_lr_sweep(results: Dict, config: TransferCheckConfig, title: str = "", save_path: Optional[str] = None):
"""Plot LR sweep: final loss vs LR multiplier for each width."""
widths = results['widths']
lr_mults = results['lr_multipliers']
final_losses = results['final_losses']
n_cols = 2
fig, axes = plt.subplots(1, n_cols, figsize=(5 * n_cols, 4))
colors = plt.cm.viridis(np.linspace(0, 0.85, len(widths)))
# Left: final loss vs LR multiplier
ax = axes[0]
for i, w in enumerate(widths):
losses = [final_losses[w][m] for m in lr_mults]
ax.plot(lr_mults, losses, 'o-', color=colors[i], linewidth=2, label=f'width={w}')
opt_mult = find_optimal_lr(final_losses[w])
opt_loss = final_losses[w][opt_mult]
ax.plot(opt_mult, opt_loss, '*', color=colors[i], markersize=15, zorder=5)
ax.set_xscale('log', base=2)
ax.set_xlabel('LR Multiplier')
ax.set_ylabel('Final Loss')
ax.set_title(f'LR Sweep{" - " + title if title else ""}')
ax.legend()
ax.grid(True, alpha=0.3)
# Right: optimal LR multiplier vs width
ax = axes[1]
opt_mults = [find_optimal_lr(final_losses[w]) for w in widths]
ax.plot(widths, opt_mults, 'o-', linewidth=2, markersize=8, color='tab:blue')
ax.axhline(y=1.0, color='gray', linestyle='--', alpha=0.5, label='target (1.0)')
ax.set_xscale('log', base=2)
ax.set_yscale('log', base=2)
ax.set_xticks(widths)
ax.set_xticklabels(widths)
ax.set_xlabel('Width')
ax.set_ylabel('Optimal LR Multiplier')
ax.set_title(f'Optimal LR vs Width{" - " + title if title else ""}')
ax.legend()
ax.grid(True, alpha=0.3)
plt.tight_layout()
if save_path:
plt.savefig(save_path, dpi=150, bbox_inches='tight')
print(f"Saved plot to {save_path}")
plt.show()
def plot_comparison(results_sp: Dict, results_mup: Dict, config: TransferCheckConfig, save_path: Optional[str] = None):
"""Plot SP vs muP comparison side by side."""
widths = results_sp['widths']
lr_mults = results_sp['lr_multipliers']
n_rows, n_cols = 2, 2
fig, axes = plt.subplots(n_rows, n_cols, figsize=(5 * n_cols, 4 * n_rows))
colors = plt.cm.viridis(np.linspace(0, 0.85, len(widths)))
# Top row: LR sweep curves
for col, (results, label) in enumerate([(results_sp, 'SP'), (results_mup, 'muP')]):
ax = axes[0, col]
for i, w in enumerate(widths):
losses = [results['final_losses'][w][m] for m in lr_mults]
ax.plot(lr_mults, losses, 'o-', color=colors[i], linewidth=2, label=f'w={w}')
opt_mult = find_optimal_lr(results['final_losses'][w])
opt_loss = results['final_losses'][w][opt_mult]
ax.plot(opt_mult, opt_loss, '*', color=colors[i], markersize=15, zorder=5)
ax.set_xscale('log', base=2)
ax.set_xlabel('LR Multiplier')
ax.set_ylabel('Final Loss')
ax.set_title(f'{label}: Final Loss vs LR Multiplier')
ax.legend(fontsize=8)
ax.grid(True, alpha=0.3)
# Shared y-axis for top row
y_min = min(
min(results_sp['final_losses'][w][m] for m in lr_mults for w in widths),
min(results_mup['final_losses'][w][m] for m in lr_mults for w in widths),
) * 0.98
y_max = max(
max(results_sp['final_losses'][w][m] for m in lr_mults for w in widths),
max(results_mup['final_losses'][w][m] for m in lr_mults for w in widths),
) * 1.02
axes[0, 0].set_ylim(y_min, y_max)
axes[0, 1].set_ylim(y_min, y_max)
# Bottom left: optimal LR vs width for both
ax = axes[1, 0]
for results, label, color in [(results_sp, 'SP', 'tab:red'), (results_mup, 'muP', 'tab:blue')]:
opt_mults = [find_optimal_lr(results['final_losses'][w]) for w in widths]
ax.plot(widths, opt_mults, 'o-', linewidth=2, markersize=8, color=color, label=label)
ax.axhline(y=1.0, color='gray', linestyle='--', alpha=0.5, label='target')
ax.set_xscale('log', base=2)
ax.set_yscale('log', base=2)
ax.set_xticks(widths)
ax.set_xticklabels(widths)
ax.set_xlabel('Width')
ax.set_ylabel('Optimal LR Multiplier')
ax.set_title('Optimal LR Multiplier vs Width')
ax.legend()
ax.grid(True, alpha=0.3)
# Bottom right: loss at optimal LR vs width
ax = axes[1, 1]
for results, label, color in [(results_sp, 'SP', 'tab:red'), (results_mup, 'muP', 'tab:blue')]:
opt_losses = [results['final_losses'][w][find_optimal_lr(results['final_losses'][w])] for w in widths]
ax.plot(widths, opt_losses, 'o-', linewidth=2, markersize=8, color=color, label=label)
ax.set_xscale('log', base=2)
ax.set_xticks(widths)
ax.set_xticklabels(widths)
ax.set_xlabel('Width')
ax.set_ylabel('Best Final Loss')
ax.set_title('Best Achievable Loss vs Width')
ax.legend()
ax.grid(True, alpha=0.3)
fig.suptitle('muP Transfer Check: SP vs muP', fontsize=14)
plt.tight_layout()
if save_path:
plt.savefig(save_path, dpi=150, bbox_inches='tight')
print(f"Saved comparison plot to {save_path}")
plt.show()
def plot_hp_sweep(results: Dict, config: TransferCheckConfig, title: str = "", save_path: Optional[str] = None):
"""Plot HP sweep: final loss vs HP value for each width."""
widths = results['widths']
hp_values = results['hp_values']
hp_name = results['hp_name']
final_losses = results['final_losses']
fig, axes = plt.subplots(1, 2, figsize=(10, 4))
colors = plt.cm.viridis(np.linspace(0, 0.85, len(widths)))
ax = axes[0]
for i, w in enumerate(widths):
losses = [final_losses[w][v] for v in hp_values]
ax.plot(hp_values, losses, 'o-', color=colors[i], linewidth=2, label=f'w={w}')
opt_v = min(final_losses[w], key=final_losses[w].get)
ax.plot(opt_v, final_losses[w][opt_v], '*', color=colors[i], markersize=15, zorder=5)
ax.set_xscale('log', base=2)
ax.set_xlabel(hp_name)
ax.set_ylabel('Final Loss')
ax.set_title(f'{hp_name} Sweep{" - " + title if title else ""}')
ax.legend(fontsize=8)
ax.grid(True, alpha=0.3)
ax = axes[1]
opt_vals = [min(final_losses[w], key=final_losses[w].get) for w in widths]
ax.plot(widths, opt_vals, 'o-', linewidth=2, markersize=8, color='tab:blue')
ax.axhline(y=1.0, color='gray', linestyle='--', alpha=0.5, label='target (1.0)')
ax.set_xscale('log', base=2)
ax.set_yscale('log', base=2)
ax.set_xlabel('Width')
ax.set_ylabel(f'Optimal {hp_name}')
ax.set_title(f'Optimal {hp_name} vs Width')
ax.legend()
ax.grid(True, alpha=0.3)
plt.tight_layout()
if save_path:
plt.savefig(save_path, dpi=150, bbox_inches='tight')
print(f"Saved plot to {save_path}")
plt.show()
def plot_loss_curves_at_optimal(results: Dict, config: TransferCheckConfig, title: str = "", save_path: Optional[str] = None):
"""Plot full loss curves at the optimal LR for each width."""
widths = results['widths']
fig, ax = plt.subplots(figsize=(5 * 2, 4))
colors = plt.cm.viridis(np.linspace(0, 0.85, len(widths)))
for i, w in enumerate(widths):
opt_mult = find_optimal_lr(results['final_losses'][w])
losses = results['losses'][(w, opt_mult)]
ax.plot(losses, color=colors[i], linewidth=2, label=f'w={w} (lr_mult={opt_mult})')
ax.set_xlabel('Step')
ax.set_ylabel('Loss')
ax.set_title(f'Loss Curves at Optimal LR{" - " + title if title else ""}')
ax.legend()
ax.grid(True, alpha=0.3)
plt.tight_layout()
if save_path:
plt.savefig(save_path, dpi=150, bbox_inches='tight')
print(f"Saved plot to {save_path}")
plt.show()
def print_summary(results: Dict, label: str):
"""Print a summary table of the LR sweep results."""
widths = results['widths']
lr_mults = results['lr_multipliers']
final_losses = results['final_losses']
print(f"\n{'='*70}")
print(f" {label}: LR Sweep Results")
print(f"{'='*70}")
# Header
header = f"{'Width':>8}"
for m in lr_mults:
header += f" | {m:>7.3f}"
header += f" | {'Best':>7} | {'Opt LR':>7}"
print(header)
print("-" * len(header))
opt_mults = []
for w in widths:
row = f"{w:>8}"
for m in lr_mults:
loss = final_losses[w][m]
row += f" | {loss:>7.4f}"
opt_m = find_optimal_lr(final_losses[w])
opt_mults.append(opt_m)
opt_loss = final_losses[w][opt_m]
row += f" | {opt_loss:>7.4f} | {opt_m:>7.3f}"
print(row)
# Transfer quality metric: how much does the optimal LR shift?
opt_mults_arr = np.array(opt_mults)
log_opt = np.log2(opt_mults_arr)
spread = log_opt.max() - log_opt.min() # spread in log2 space
print(f"\nOptimal LR spread (log2): {spread:.3f}")
print(f" (0 = perfect transfer, >1 = poor transfer)")
def main():
parser = argparse.ArgumentParser(description='muP Transfer Check')
parser.add_argument('--widths', type=str, default='128,256,512,1024',
help='Comma-separated list of widths to test')
# Paper-style default: ~1000x range, 11 log-spaced points
parser.add_argument('--lr-mults', type=str,
default='0.03125,0.0625,0.125,0.25,0.5,1.0,2.0,4.0,8.0,16.0,32.0',
help='Comma-separated LR multipliers to sweep (default: 1024x range, 11 points)')
parser.add_argument('--num-random-trials', type=int, default=0,
help='If >0, use N log-uniform random LR multipliers from 10^Uniform(-1.5,1.5) '
'instead of the grid. Paper-style methodology (Section F).')
parser.add_argument('--steps', type=int, default=200,
help='Number of training steps per run')
parser.add_argument('--batch-size', type=int, default=8,
help='Batch size')
parser.add_argument('--seq-len', type=int, default=128,
help='Sequence length')
parser.add_argument('--n-layer', type=int, default=2,
help='Number of transformer layers')
parser.add_argument('--use-mup', action='store_true',
help='Use muP learning rate scaling')
parser.add_argument('--base-width', type=int, default=128,
help='Base width for muP scaling')
parser.add_argument('--compare', action='store_true',
help='Run both SP and muP and compare')
parser.add_argument('--save-dir', type=str, default=None,
help='Directory to save plots')
parser.add_argument('--seed', type=int, default=42,
help='Random seed')
parser.add_argument('--num-batches', type=int, default=1,
help='Number of data batches to cycle through (default 1 for backward compat, '
'recommend 8 for thorough checks)')
# Multi-HP sweeps
parser.add_argument('--sweep-init-scale', action='store_true',
help='Also sweep init scale multiplier (sampled from 10^Uniform(-1,1))')
parser.add_argument('--sweep-output-mult', action='store_true',
help='Also sweep output logit multiplier (sampled from 4^Uniform(-1,1))')
parser.add_argument('--muon-lr-exponent', type=float, default=0.0,
help='Muon LR exponent for muP: 1.0 = (base/width)^1 (standard), '
'0.5 = (base/width)^0.5 (for Frobenius-normalizing optimizers like Muon)')
parser.add_argument('--sweep-mode', type=str, default='all',
choices=['all', 'muon-only', 'adamw-only'],
help='Which optimizer groups the LR multiplier applies to: '
'"all" = all LRs (default), "muon-only" = only Muon/matrix LR, '
'"adamw-only" = only AdamW/embedding/output LR')
args = parser.parse_args()
widths = [int(w) for w in args.widths.split(',')]
# Generate LR multipliers
if args.num_random_trials > 0:
# Log-uniform random sampling: 10^Uniform(-1.5, 1.5)
rng = np.random.RandomState(args.seed)
lr_mults = sorted(10 ** rng.uniform(-1.5, 1.5, args.num_random_trials))
lr_mults = [round(float(m), 6) for m in lr_mults]
print(f"Using {args.num_random_trials} random log-uniform LR multipliers: "
f"[{lr_mults[0]:.4f}, ..., {lr_mults[-1]:.4f}]")
else:
lr_mults = sorted(float(m) for m in args.lr_mults.split(','))
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
# Load batches
batches, vocab_size = load_batches(args.num_batches, args.batch_size, args.seq_len, device)
config = TransferCheckConfig(
widths=widths,
lr_multipliers=lr_mults,
num_steps=args.steps,
batch_size=args.batch_size,
seq_len=args.seq_len,
vocab_size=vocab_size,
n_layer=args.n_layer,
seed=args.seed,
use_mup=args.use_mup,
base_width=args.base_width,
sweep_init_scale=args.sweep_init_scale,
sweep_output_mult=args.sweep_output_mult,
num_batches=args.num_batches,
muon_lr_exponent=args.muon_lr_exponent,
sweep_mode=args.sweep_mode,
)
if args.compare:
# Run SP
print("\n" + "=" * 60)
print("Running Standard Parameterization (SP)")
print("=" * 60)
config.use_mup = False
results_sp = run_transfer_check(config, device, batches)
print_summary(results_sp, "SP")
# Run muP
print("\n" + "=" * 60)
print("Running muP")
print("=" * 60)
config.use_mup = True
results_mup = run_transfer_check(config, device, batches)
print_summary(results_mup, "muP")
# Compare
print("\n" + "=" * 60)
print("COMPARISON")
print("=" * 60)
sp_opts = [find_optimal_lr(results_sp['final_losses'][w]) for w in results_sp['widths']]
mup_opts = [find_optimal_lr(results_mup['final_losses'][w]) for w in results_mup['widths']]
sp_spread = np.log2(max(sp_opts)) - np.log2(min(sp_opts))
mup_spread = np.log2(max(mup_opts)) - np.log2(min(mup_opts))
print(f"SP optimal LR spread (log2): {sp_spread:.3f}")
print(f"muP optimal LR spread (log2): {mup_spread:.3f}")
if mup_spread < sp_spread:
print(f"muP shows {sp_spread/max(mup_spread, 0.001):.1f}x better LR transfer!")
else:
print("muP does NOT show better LR transfer (check implementation)")
# Plot
save_path = None
if args.save_dir:
os.makedirs(args.save_dir, exist_ok=True)
save_path = os.path.join(args.save_dir, 'transfer_check_comparison.png')
plot_comparison(results_sp, results_mup, config, save_path)
# Also plot loss curves at optimal LR
for results, label in [(results_sp, 'SP'), (results_mup, 'muP')]:
lc_save = None
if args.save_dir:
lc_save = os.path.join(args.save_dir, f'optimal_loss_curves_{label.lower()}.png')
plot_loss_curves_at_optimal(results, config, title=label, save_path=lc_save)
# Multi-HP sweeps (only for muP, to demonstrate transfer)
if args.sweep_init_scale or args.sweep_output_mult:
config.use_mup = True
if args.sweep_init_scale:
print("\n" + "=" * 60)
print("muP: Init Scale Sweep")
print("=" * 60)
# 10^Uniform(-1, 1) => range [0.1, 10]
init_scales = [0.1, 0.2, 0.5, 1.0, 2.0, 5.0, 10.0]
init_results = run_hp_sweep(config, device, batches, 'init_scale', init_scales)
save_hp = None
if args.save_dir:
save_hp = os.path.join(args.save_dir, 'init_scale_sweep.png')
plot_hp_sweep(init_results, config, title="muP", save_path=save_hp)
if args.sweep_output_mult:
print("\n" + "=" * 60)
print("muP: Output Multiplier Sweep")
print("=" * 60)
# 4^Uniform(-1, 1) => range [0.25, 4]
output_mults = [0.25, 0.5, 1.0, 2.0, 4.0]
output_results = run_hp_sweep(config, device, batches, 'output_mult', output_mults)
save_hp = None
if args.save_dir:
save_hp = os.path.join(args.save_dir, 'output_mult_sweep.png')
plot_hp_sweep(output_results, config, title="muP", save_path=save_hp)
else:
param_type = "muP" if config.use_mup else "SP"
print(f"\n{'='*60}")
print(f"Running Transfer Check ({param_type})")
print(f"{'='*60}")
print(f"Widths: {widths}")
print(f"LR multipliers: {lr_mults}")
print(f"Steps: {config.num_steps}")
if config.sweep_mode != "all":
print(f"Sweep mode: {config.sweep_mode}")
results = run_transfer_check(config, device, batches)
print_summary(results, param_type)
# Plot LR sweep
save_path = None
if args.save_dir:
os.makedirs(args.save_dir, exist_ok=True)
save_path = os.path.join(args.save_dir, f'transfer_check_{param_type.lower()}.png')
plot_lr_sweep(results, config, title=param_type, save_path=save_path)
# Plot loss curves at optimal LR
lc_save = None
if args.save_dir:
lc_save = os.path.join(args.save_dir, f'optimal_loss_curves_{param_type.lower()}.png')
plot_loss_curves_at_optimal(results, config, title=param_type, save_path=lc_save)
if __name__ == '__main__':
main()