diff --git a/nanochat/gpt.py b/nanochat/gpt.py index 208acd1..33b9346 100644 --- a/nanochat/gpt.py +++ b/nanochat/gpt.py @@ -37,6 +37,9 @@ class GPTConfig: # Characters: L=long (full context), S=short (half context) # Examples: "L"=all full context, "SL"=alternating, "SSL"=two short then one long window_pattern: str = "SSSL" + # muP (Maximal Update Parametrization): set > 0 to enable. Value is the base/proxy width. + # Enables: non-zero c_proj init scaled as 1/sqrt(m_d), output logit scaling by base_width/n_embd. + mup_base_width: int = 0 def norm(x): @@ -203,7 +206,12 @@ class GPT(nn.Module): # Embedding and unembedding torch.nn.init.normal_(self.transformer.wte.weight, mean=0.0, std=1.0) - torch.nn.init.normal_(self.lm_head.weight, mean=0.0, std=0.001) + lm_head_std = 0.001 + if self.config.mup_base_width > 0: + # muP: scale lm_head init by 1/sqrt(m_d) so raw logit magnitude is O(1) across widths. + # Without this, |logit| ~ 0.001 * sqrt(n_embd) grows as sqrt(width). + lm_head_std *= (self.config.mup_base_width / self.config.n_embd) ** 0.5 + torch.nn.init.normal_(self.lm_head.weight, mean=0.0, std=lm_head_std) # Transformer blocks: uniform init with bound = sqrt(3) * std (same standard deviation as normal) n_embd = self.config.n_embd @@ -212,9 +220,15 @@ class GPT(nn.Module): torch.nn.init.uniform_(block.attn.c_q.weight, -s, s) # weights use Uniform to avoid outliers torch.nn.init.uniform_(block.attn.c_k.weight, -s, s) torch.nn.init.uniform_(block.attn.c_v.weight, -s, s) - torch.nn.init.zeros_(block.attn.c_proj.weight) # projections are zero torch.nn.init.uniform_(block.mlp.c_fc.weight, -s, s) - torch.nn.init.zeros_(block.mlp.c_proj.weight) + if self.config.mup_base_width > 0: + # muP: output projections use same scale as hidden weights (std = sigma_base/sqrt(m_d)) + # Zero init causes attn/FFN outputs to vanish as width increases with muP LR scaling + torch.nn.init.uniform_(block.attn.c_proj.weight, -s, s) + torch.nn.init.uniform_(block.mlp.c_proj.weight, -s, s) + else: + torch.nn.init.zeros_(block.attn.c_proj.weight) # SP: projections are zero + torch.nn.init.zeros_(block.mlp.c_proj.weight) # Per-layer scalars self.resid_lambdas.fill_(1.0) # 1.0 => typical residual connections at init @@ -345,7 +359,7 @@ class GPT(nn.Module): 'total': total, } - def setup_optimizer(self, unembedding_lr=0.004, embedding_lr=0.2, matrix_lr=0.02, weight_decay=0.0, adam_betas=(0.8, 0.95), scalar_lr=0.5): + def setup_optimizer(self, unembedding_lr=0.004, embedding_lr=0.2, matrix_lr=0.02, weight_decay=0.0, adam_betas=(0.8, 0.95), scalar_lr=0.5, use_mup=False, base_width=256, muon_lr_exponent=0.0): model_dim = self.config.n_embd ddp, rank, local_rank, world_size = get_dist_info() @@ -358,16 +372,41 @@ class GPT(nn.Module): x0_params = [self.x0_lambdas] assert len(list(self.parameters())) == len(matrix_params) + len(embedding_params) + len(lm_head_params) + len(value_embeds_params) + len(resid_params) + len(x0_params) - # Scale the LR for the AdamW parameters by ∝1/√dmodel (tuned for 768 dim model) - dmodel_lr_scale = (model_dim / 768) ** -0.5 - print0(f"Scaling the LR for the AdamW parameters ∝1/√({model_dim}/768) = {dmodel_lr_scale:.6f}") + # Compute LR scaling factors based on mode + if use_mup: + # muP for mixed Muon+AdamW optimizer: + # + # Key insight: the output logit scaling (logits *= base/width in forward()) already + # propagates a base/width factor into the lm_head gradient. Adam normalizes gradient + # magnitude via its second moment, but applying ANOTHER base/width to the LR compounds + # the scaling, making lm_head updates O(base²/width²) instead of O(1). So we do NOT + # apply width-dependent LR scaling to the output layer — the logit scaling alone suffices. + # + # For Muon (hidden weights): Polar Express orthogonalization normalizes ||update||_F ≈ 1 + # regardless of width, making the update already O(1). No width-dependent LR scaling + # is needed (empirically confirmed: exponent 0 and 1 give identical transfer behavior). + # + # The muon_lr_exponent parameter is kept for experimentation but defaults to 0. + width_ratio = base_width / model_dim # e.g., 128/1024 = 0.125 + emb_lr_scale = 1.0 # Embeddings: NO width scaling (standard muP) + hidden_lr_scale = width_ratio ** muon_lr_exponent # Hidden (Muon): default 0 = no scaling + output_lr_scale = 1.0 # Output (AdamW): NO LR scaling (logit scaling in forward suffices) + self.config.mup_base_width = base_width # enables output logit scaling in forward() + print0(f"muP scaling: base_width={base_width}, model_dim={model_dim}, width_ratio={width_ratio:.6f}, muon_lr_exp={muon_lr_exponent}") + else: + # Standard (SP): scale AdamW params by 1/√dmodel (tuned for 768 dim model) + dmodel_lr_scale = (model_dim / 768) ** -0.5 + emb_lr_scale = dmodel_lr_scale + hidden_lr_scale = 1.0 # Muon params: no scaling in SP mode + output_lr_scale = dmodel_lr_scale + print0(f"Standard scaling: dmodel_lr_scale={dmodel_lr_scale:.6f}") # Build param_groups with all required fields explicit param_groups = [ # AdamW groups (embeddings, lm_head, scalars) - dict(kind='adamw', params=lm_head_params, lr=unembedding_lr * dmodel_lr_scale, betas=adam_betas, eps=1e-10, weight_decay=0.0), - dict(kind='adamw', params=embedding_params, lr=embedding_lr * dmodel_lr_scale, betas=adam_betas, eps=1e-10, weight_decay=0.0), - dict(kind='adamw', params=value_embeds_params, lr=embedding_lr * dmodel_lr_scale, betas=adam_betas, eps=1e-10, weight_decay=0.0), + dict(kind='adamw', params=lm_head_params, lr=unembedding_lr * output_lr_scale, betas=adam_betas, eps=1e-10, weight_decay=0.0), + dict(kind='adamw', params=embedding_params, lr=embedding_lr * emb_lr_scale, betas=adam_betas, eps=1e-10, weight_decay=0.0), + dict(kind='adamw', params=value_embeds_params, lr=embedding_lr * emb_lr_scale, betas=adam_betas, eps=1e-10, weight_decay=0.0), dict(kind='adamw', params=resid_params, lr=scalar_lr * 0.01, betas=adam_betas, eps=1e-10, weight_decay=0.0), dict(kind='adamw', params=x0_params, lr=scalar_lr, betas=(0.96, 0.95), eps=1e-10, weight_decay=0.0), # higher beta1 for x0 ] @@ -375,7 +414,7 @@ class GPT(nn.Module): for shape in sorted({p.shape for p in matrix_params}): group_params = [p for p in matrix_params if p.shape == shape] param_groups.append(dict( - kind='muon', params=group_params, lr=matrix_lr, + kind='muon', params=group_params, lr=matrix_lr * hidden_lr_scale, momentum=0.95, ns_steps=5, beta2=0.95, weight_decay=weight_decay, )) @@ -409,6 +448,11 @@ class GPT(nn.Module): # Forward the lm_head (compute logits) softcap = 15 # smoothly cap the logits to the range [-softcap, softcap] logits = self.lm_head(x) # (B, T, padded_vocab_size) <- very big tensor, large amount of memory + if self.config.mup_base_width > 0: + # muP: scale output logits by base_width/n_embd (= 1/m_d). + # Without this, logits grow with width because the lm_head dot product sums over n_embd terms. + # 1/sqrt(m_d) only corrects at init; 1/m_d is required for all training steps (see Eleuther blog Fig 8-9). + logits = logits * (self.config.mup_base_width / self.config.n_embd) logits = logits[..., :self.config.vocab_size] # slice to remove padding logits = logits.float() # switch to fp32 for logit softcap and loss computation logits = softcap * torch.tanh(logits / softcap) # squash the logits diff --git a/scripts/base_train.py b/scripts/base_train.py index 24091b6..b07109f 100644 --- a/scripts/base_train.py +++ b/scripts/base_train.py @@ -68,6 +68,8 @@ parser.add_argument("--adam-beta2", type=float, default=0.95, help="Adam beta2 f parser.add_argument("--warmup-ratio", type=float, default=0.0, help="ratio of iterations for LR warmup") parser.add_argument("--warmdown-ratio", type=float, default=0.5, help="ratio of iterations for LR warmdown") parser.add_argument("--final-lr-frac", type=float, default=0.0, help="final LR as fraction of initial LR") +parser.add_argument("--use-mup", action="store_true", help="use muP (Maximal Update Parameterization) LR scaling") +parser.add_argument("--base-width", type=int, default=256, help="base width for muP LR scaling (LRs tuned at this width)") parser.add_argument("--resume-from-step", type=int, default=-1, help="resume training from this step (-1 = disable)") # Evaluation parser.add_argument("--eval-every", type=int, default=250, help="evaluate val bpb every N steps (-1 = disable)") @@ -309,6 +311,9 @@ optimizer = model.setup_optimizer( # Muon hyperparameters matrix_lr=args.matrix_lr * batch_lr_scale, weight_decay=weight_decay_scaled, + # muP scaling + use_mup=args.use_mup, + base_width=args.base_width, ) if resuming: diff --git a/scripts/mup_coord_check.py b/scripts/mup_coord_check.py new file mode 100644 index 0000000..ce9edf6 --- /dev/null +++ b/scripts/mup_coord_check.py @@ -0,0 +1,710 @@ +""" +muP Coordinate Check for nanochat + +This script validates muP implementation by checking that activation magnitudes +are independent of model width. Based on EleutherAI's nanoGPT-mup and Microsoft's +mup library. + +Reference: https://blog.eleuther.ai/mutransfer/ +Reference: Yang et al., "Tensor Programs V: Tuning Large Neural Networks via Zero-Shot + Hyperparameter Transfer" (arXiv:2203.03466), Sections B.1 and F. + +Usage: + python -m scripts.mup_coord_check --widths 128,256,512,1024 --steps 10 + python -m scripts.mup_coord_check --use-mup --widths 128,256,512,1024 + python -m scripts.mup_coord_check --compare --detailed + python -m scripts.mup_coord_check --compare --muon-lr-exponent 0.5 +""" + +import argparse +import torch +import torch._dynamo +torch._dynamo.config.disable = True +import torch.nn.functional as F +import matplotlib.pyplot as plt +import numpy as np +from collections import defaultdict +from dataclasses import dataclass +from typing import Dict, List, Tuple, Optional +import os + +from nanochat.gpt import GPT, GPTConfig + + +def load_batch(batch_size: int, seq_len: int, device: torch.device): + """Load a single batch from the nanochat training pipeline. + Falls back to random data if the tokenizer/dataset isn't available.""" + try: + from nanochat.tokenizer import get_tokenizer + from nanochat.dataloader import tokenizing_distributed_data_loader_bos_bestfit + tokenizer = get_tokenizer() + vocab_size = tokenizer.get_vocab_size() + loader = tokenizing_distributed_data_loader_bos_bestfit( + tokenizer, batch_size, seq_len, split="train", device=device, + ) + x, y = next(loader) + print(f"Loaded real training data (vocab_size={vocab_size})") + return x, y, vocab_size + except Exception as e: + print(f"Could not load training data ({e}), using random tokens") + vocab_size = 32768 + rng = torch.Generator(device=device) + rng.manual_seed(42) + x = torch.randint(0, vocab_size, (batch_size, seq_len), device=device, generator=rng) + y = torch.roll(x, -1, dims=1) + y[:, -1] = -1 + return x, y, vocab_size + + +@dataclass +class CoordCheckConfig: + widths: List[int] + num_steps: int = 10 + batch_size: int = 4 + seq_len: int = 128 + vocab_size: int = 32768 + n_layer: int = 2 + seed: int = 42 + use_mup: bool = False + base_width: int = 128 + # Learning rates (tuned at base_width) + matrix_lr: float = 0.02 + embedding_lr: float = 0.2 + unembedding_lr: float = 0.004 + # Detailed diagnostics + detailed: bool = False + # Muon LR exponent: 1.0 = base/width (standard muP), 0.5 = sqrt(base/width) + # Paper Section C.1: Frobenius-normalizing optimizers may need exponent 0.5 + muon_lr_exponent: float = 0.0 + + +class ActivationRecorder: + """Records activation statistics during forward pass using hooks.""" + + def __init__(self, detailed: bool = False): + self.stats: Dict[str, List[float]] = defaultdict(list) + self.hooks = [] + self.detailed = detailed + + def _get_stat(self, tensor: torch.Tensor) -> float: + """Compute mean absolute value (l1 norm per element).""" + if tensor is None: + return 0.0 + if tensor.dtype == torch.bool: + return tensor.float().abs().mean().item() + return tensor.float().abs().mean().item() + + def _make_hook(self, name: str): + """Create a forward hook that records output statistics.""" + def hook(module, input, output): + if isinstance(output, tuple): + output = output[0] + if output is not None and isinstance(output, torch.Tensor): + self.stats[name].append(self._get_stat(output)) + return hook + + def _make_attn_logit_hook(self, name: str, n_head: int, n_kv_head: int, head_dim: int): + """Create a hook on c_k that computes pre-softmax attention logit magnitudes. + + We hook onto c_k's forward, then use the most recent c_q output to compute + q @ k^T / sqrt(d) for a single batch element to measure attention logit scale. + """ + # We'll store q output and compute logits when k is available + self._last_q = None + + def q_hook(module, input, output): + self._last_q = output.detach() + + def k_hook(module, input, output): + if self._last_q is None: + return + q = self._last_q + k = output.detach() + B, T, _ = q.shape + q = q[0:1].view(1, T, n_head, head_dim) + k = k[0:1].view(1, T, n_kv_head, head_dim) + # Apply QK norm (same as model) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + # Expand k for GQA + if n_head != n_kv_head: + k = k.repeat_interleave(n_head // n_kv_head, dim=2) + # Compute logits: q @ k^T / sqrt(d) — just for first few positions + T_sub = min(T, 32) + q_sub = q[:, :T_sub].transpose(1, 2) # (1, H, T_sub, D) + k_sub = k[:, :T_sub].transpose(1, 2) # (1, H, T_sub, D) + logits = torch.matmul(q_sub, k_sub.transpose(-2, -1)) / (head_dim ** 0.5) + self.stats[name].append(logits.float().abs().mean().item()) + self._last_q = None + + return q_hook, k_hook + + def register_hooks(self, model: GPT) -> None: + """Register forward hooks on key layers.""" + # Embedding + h = model.transformer.wte.register_forward_hook(self._make_hook('word embedding')) + self.hooks.append(h) + + # Each transformer block + for i, block in enumerate(model.transformer.h): + # Attention output + h = block.attn.c_proj.register_forward_hook(self._make_hook(f'attn output.{i}')) + self.hooks.append(h) + # MLP output + h = block.mlp.c_proj.register_forward_hook(self._make_hook(f'FFN output.{i}')) + self.hooks.append(h) + + # Detailed: attention logit magnitudes + if self.detailed: + n_head = block.attn.n_head + n_kv_head = block.attn.n_kv_head + head_dim = block.attn.head_dim + q_hook, k_hook = self._make_attn_logit_hook( + f'attn logits.{i}', n_head, n_kv_head, head_dim) + h1 = block.attn.c_q.register_forward_hook(q_hook) + h2 = block.attn.c_k.register_forward_hook(k_hook) + self.hooks.extend([h1, h2]) + + # LM head + h = model.lm_head.register_forward_hook(self._make_hook('output logits')) + self.hooks.append(h) + + def remove_hooks(self) -> None: + """Remove all registered hooks.""" + for h in self.hooks: + h.remove() + self.hooks = [] + + def get_step_stats(self) -> Dict[str, float]: + """Get mean stats for the current step and reset.""" + step_stats = {} + for name, values in self.stats.items(): + if values: + step_stats[name] = np.mean(values) + self.stats = defaultdict(list) + return step_stats + + +def create_model(width: int, config: CoordCheckConfig, device: torch.device, mup_base_width: int = 0) -> Tuple[GPT, GPTConfig]: + """Create a model with the specified width.""" + head_dim = 64 + n_head = max(1, width // head_dim) + actual_width = n_head * head_dim + + gpt_config = GPTConfig( + sequence_len=config.seq_len, + vocab_size=config.vocab_size, + n_layer=config.n_layer, + n_head=n_head, + n_kv_head=n_head, + n_embd=actual_width, + window_pattern="L", + mup_base_width=mup_base_width, + ) + + with torch.device('meta'): + model = GPT(gpt_config) + model.to_empty(device=device) + model.init_weights() + + return model, gpt_config + + +def setup_optimizer_mup(model: GPT, config: CoordCheckConfig, width: int): + """Set up optimizer with muP scaling using the native use_mup flag.""" + optimizer = model.setup_optimizer( + unembedding_lr=config.unembedding_lr, + embedding_lr=config.embedding_lr, + matrix_lr=config.matrix_lr, + weight_decay=0.0, + use_mup=True, + base_width=config.base_width, + muon_lr_exponent=config.muon_lr_exponent, + ) + return optimizer + + +def setup_optimizer_sp(model: GPT, config: CoordCheckConfig, width: int): + """Set up optimizer with standard parameterization (current nanochat).""" + optimizer = model.setup_optimizer( + unembedding_lr=config.unembedding_lr, + embedding_lr=config.embedding_lr, + matrix_lr=config.matrix_lr, + weight_decay=0.0, + use_mup=False, + ) + return optimizer + + +def record_detailed_stats(model: GPT, results: Dict, width: int, step: int): + """Record weight update norms and gradient norms per parameter group.""" + for name, p in model.named_parameters(): + if p.grad is None: + continue + # Simplify name for display + short_name = name.replace('transformer.', '').replace('.weight', '') + # Gradient norm + grad_norm = p.grad.float().norm().item() + results['detailed_stats'][width][f'grad norm: {short_name}'].append(grad_norm) + + +def record_weight_update_norms(model: GPT, params_before: Dict[str, torch.Tensor], + results: Dict, width: int): + """Record ||delta_W|| for each parameter after optimizer step.""" + for name, p in model.named_parameters(): + if name not in params_before: + continue + short_name = name.replace('transformer.', '').replace('.weight', '') + delta = (p.data.float() - params_before[name]).norm().item() + results['detailed_stats'][width][f'update norm: {short_name}'].append(delta) + + +def run_coord_check(config: CoordCheckConfig, device: torch.device, + x: torch.Tensor, y: torch.Tensor) -> Dict: + """Run coordinate check across all widths.""" + results = { + 'widths': [], + 'steps': list(range(config.num_steps)), + 'stats': defaultdict(lambda: defaultdict(list)), + 'losses': defaultdict(list), + 'detailed_stats': defaultdict(lambda: defaultdict(list)), + } + + for width in config.widths: + print(f"\nTraining width={width}...") + + torch.manual_seed(config.seed) + + mup_base_width = config.base_width if config.use_mup else 0 + model, gpt_config = create_model(width, config, device, mup_base_width=mup_base_width) + actual_width = gpt_config.n_embd + results['widths'].append(actual_width) + + if config.use_mup: + optimizer = setup_optimizer_mup(model, config, actual_width) + else: + optimizer = setup_optimizer_sp(model, config, actual_width) + + recorder = ActivationRecorder(detailed=config.detailed) + recorder.register_hooks(model) + + model.train() + + for step in range(config.num_steps): + with torch.amp.autocast(device_type='cuda', dtype=torch.bfloat16, enabled=(device.type == 'cuda')): + loss = model(x, y) + + results['losses'][actual_width].append(loss.item()) + + step_stats = recorder.get_step_stats() + for layer, value in step_stats.items(): + results['stats'][actual_width][layer].append(value) + + if step == 0: + print(f" Step {step}: loss={loss.item():.4f}, layers={list(step_stats.keys())}") + + # Record gradient norms before step (detailed mode) + loss.backward() + + if config.detailed: + record_detailed_stats(model, results, actual_width, step) + # Snapshot params before optimizer step to compute update norms + params_before = {name: p.data.float().clone() + for name, p in model.named_parameters() + if p.grad is not None} + + optimizer.step() + + if config.detailed: + record_weight_update_norms(model, params_before, results, actual_width) + + optimizer.zero_grad(set_to_none=True) + + print(f" Final loss: {loss.item():.4f}") + + recorder.remove_hooks() + del model, optimizer + torch.cuda.empty_cache() if torch.cuda.is_available() else None + + return results + + +def plot_coord_check(results: Dict, config: CoordCheckConfig, save_path: Optional[str] = None): + """Plot coordinate check: one subplot per layer, x=width (log2), y=mean |activation|, lines=steps.""" + widths = results['widths'] + steps = results['steps'] + stats = results['stats'] + + layer_names = list(stats[widths[0]].keys()) + n_layers = len(layer_names) + n_cols = 4 + n_rows = (n_layers + n_cols - 1) // n_cols + + param_type = "muP" if config.use_mup else "SP" + fig, axes = plt.subplots(n_rows, n_cols, figsize=(4 * n_cols, 3 * n_rows)) + axes = np.array(axes).flatten() + + step_colors = plt.cm.plasma(np.linspace(0, 1, len(steps))) + + for i, layer in enumerate(layer_names): + ax = axes[i] + for s, step in enumerate(steps): + values = [stats[w][layer][s] for w in widths] + ax.plot(widths, values, 'o-', color=step_colors[s], linewidth=1.5, + label=f'step {step}' if i == 0 else None) + ax.set_xscale('log', base=2) + ax.set_xticks(widths) + ax.set_xticklabels(widths, fontsize=7) + ax.set_title(layer, fontsize=9) + ax.set_xlabel('Width') + ax.set_ylabel('Mean |activation|') + ax.grid(True, alpha=0.3) + + axes[0].legend(fontsize=7, loc='best') + + for i in range(n_layers, len(axes)): + axes[i].set_visible(False) + + fig.suptitle(f'Coordinate Check ({param_type}): Activation Magnitude vs Width', fontsize=14) + plt.tight_layout() + + if save_path: + plt.savefig(save_path, dpi=150, bbox_inches='tight') + print(f"Saved plot to {save_path}") + + plt.show() + + +def plot_loss_curves(results: Dict, config: CoordCheckConfig, title: str = "", save_path: Optional[str] = None): + """Plot loss curves across widths to verify HP transfer.""" + widths = results['widths'] + steps = results['steps'] + losses = results['losses'] + + fig, ax = plt.subplots(figsize=(5 * 2, 4)) + colors = plt.cm.viridis(np.linspace(0, 1, len(widths))) + + for i, w in enumerate(widths): + ax.plot(steps, losses[w], label=f'width={w}', color=colors[i], linewidth=2) + + ax.set_xlabel('Step') + ax.set_ylabel('Loss') + ax.set_title(f'Loss Curves Across Widths{" - " + title if title else ""}') + ax.legend() + ax.grid(True, alpha=0.3) + + # Add annotation for final loss spread + final_losses = [losses[w][-1] for w in widths] + spread = max(final_losses) - min(final_losses) + ax.annotate(f'Final loss spread: {spread:.4f}', xy=(0.7, 0.95), xycoords='axes fraction', fontsize=10) + + plt.tight_layout() + + if save_path: + plt.savefig(save_path, dpi=150, bbox_inches='tight') + print(f"Saved loss curves to {save_path}") + + plt.show() + + +def plot_comparison(results_sp: Dict, results_mup: Dict, config: CoordCheckConfig, save_path: Optional[str] = None): + """Plot SP vs muP: one subplot per layer (left=SP, right=muP), x=width (log2), y=mean |activation|, lines=steps.""" + widths = results_sp['widths'] + steps = results_sp['steps'] + + layer_names = list(results_sp['stats'][widths[0]].keys()) + n_layers = len(layer_names) + + # n_layers activation rows + 1 loss row, 2 cols (SP | muP) + n_rows, n_cols = n_layers + 1, 2 + fig, axes = plt.subplots(n_rows, n_cols, figsize=(5 * n_cols, 3 * n_rows)) + + step_colors = plt.cm.plasma(np.linspace(0, 1, len(steps))) + width_colors = plt.cm.viridis(np.linspace(0, 1, len(widths))) + + for row, layer in enumerate(layer_names): + # Shared y-axis range across SP and muP for this layer + all_vals = [results_sp['stats'][w][layer][s] for w in widths for s in range(len(steps))] + \ + [results_mup['stats'][w][layer][s] for w in widths for s in range(len(steps))] + y_min, y_max = min(all_vals) * 0.9, max(all_vals) * 1.1 + + for col, (results, label) in enumerate([(results_sp, 'SP'), (results_mup, 'muP')]): + ax = axes[row, col] + for s, step in enumerate(steps): + values = [results['stats'][w][layer][s] for w in widths] + ax.plot(widths, values, 'o-', color=step_colors[s], linewidth=1.5, + label=f'step {step}' if (row == 0 and col == 0) else None) + ax.set_xscale('log', base=2) + ax.set_xticks(widths) + ax.set_xticklabels(widths, fontsize=7) + ax.set_ylim(y_min, y_max) + ax.set_title(f'{label}: {layer}', fontsize=9) + ax.set_xlabel('Width') + ax.set_ylabel('Mean |activation|') + ax.grid(True, alpha=0.3) + + axes[0, 0].legend(fontsize=7, loc='best') + + # Loss curves row + all_losses = [v for r in (results_sp, results_mup) for w in widths for v in r['losses'][w]] + loss_min, loss_max = min(all_losses) * 0.95, max(all_losses) * 1.05 + + for col, (results, label) in enumerate([(results_sp, 'SP'), (results_mup, 'muP')]): + ax = axes[n_layers, col] + for j, w in enumerate(widths): + ax.plot(steps, results['losses'][w], label=f'w={w}', color=width_colors[j], linewidth=2) + ax.set_ylim(loss_min, loss_max) + ax.set_xlabel('Step') + ax.set_ylabel('Loss') + ax.set_title(f'{label}: Loss Curves') + ax.legend(fontsize=7) + ax.grid(True, alpha=0.3) + final_losses = [results['losses'][w][-1] for w in widths] + spread = max(final_losses) - min(final_losses) + ax.annotate(f'Spread: {spread:.4f}', xy=(0.65, 0.95), xycoords='axes fraction', fontsize=9) + + fig.suptitle('Coordinate Check: SP vs muP', fontsize=14) + plt.tight_layout() + + if save_path: + plt.savefig(save_path, dpi=150, bbox_inches='tight') + print(f"Saved comparison plot to {save_path}") + + plt.show() + + +def plot_detailed(results: Dict, config: CoordCheckConfig, save_path: Optional[str] = None): + """Plot detailed diagnostics: gradient norms, weight update norms, attention logits.""" + widths = results['widths'] + detailed = results['detailed_stats'] + if not detailed or not detailed[widths[0]]: + print("No detailed stats recorded. Use --detailed flag.") + return + + # Collect all detailed metric names + metric_names = sorted(detailed[widths[0]].keys()) + + # Group by category + categories = defaultdict(list) + for name in metric_names: + if name.startswith('grad norm:'): + categories['Gradient Norms'].append(name) + elif name.startswith('update norm:'): + categories['Weight Update Norms'].append(name) + elif name.startswith('attn logits'): + categories['Attention Logit Magnitudes'].append(name) + + for cat_name, names in categories.items(): + n = len(names) + n_cols = min(4, n) + n_rows = (n + n_cols - 1) // n_cols + fig, axes = plt.subplots(n_rows, n_cols, figsize=(4 * n_cols, 3 * n_rows)) + if n == 1: + axes = np.array([axes]) + axes = np.array(axes).flatten() + + steps = results['steps'] + width_colors = plt.cm.viridis(np.linspace(0, 1, len(widths))) + + for i, name in enumerate(names): + ax = axes[i] + for j, w in enumerate(widths): + values = detailed[w].get(name, []) + if values: + ax.plot(range(len(values)), values, color=width_colors[j], + linewidth=1.5, label=f'w={w}' if i == 0 else None) + ax.set_title(name.split(': ', 1)[-1] if ': ' in name else name, fontsize=8) + ax.set_xlabel('Step') + ax.set_ylabel('Norm') + ax.grid(True, alpha=0.3) + ax.set_yscale('log') + + for i in range(n, len(axes)): + axes[i].set_visible(False) + + axes[0].legend(fontsize=7, loc='best') + param_type = "muP" if config.use_mup else "SP" + fig.suptitle(f'{cat_name} ({param_type})', fontsize=14) + plt.tight_layout() + + if save_path: + cat_slug = cat_name.lower().replace(' ', '_') + path = save_path.replace('.png', f'_{cat_slug}.png') + plt.savefig(path, dpi=150, bbox_inches='tight') + print(f"Saved {cat_name} plot to {path}") + + plt.show() + + +def compute_width_dependence(results: Dict) -> Dict[str, float]: + """Compute how much activations scale with width (slope on log-log plot).""" + widths = np.array(results['widths']) + log_widths = np.log2(widths) + final_step = len(results['steps']) - 1 + + slopes = {} + for layer in results['stats'][widths[0]].keys(): + values = [results['stats'][w][layer][final_step] for w in widths] + log_values = np.log2(np.array(values) + 1e-10) + slope, _ = np.polyfit(log_widths, log_values, 1) + slopes[layer] = slope + + return slopes + + +def main(): + parser = argparse.ArgumentParser(description='muP Coordinate Check') + parser.add_argument('--widths', type=str, default='128,256,512,1024', + help='Comma-separated list of widths to test') + parser.add_argument('--steps', type=int, default=10, + help='Number of training steps') + parser.add_argument('--batch-size', type=int, default=4, + help='Batch size') + parser.add_argument('--seq-len', type=int, default=128, + help='Sequence length') + parser.add_argument('--n-layer', type=int, default=2, + help='Number of transformer layers') + parser.add_argument('--use-mup', action='store_true', + help='Use muP learning rate scaling') + parser.add_argument('--base-width', type=int, default=128, + help='Base width for muP scaling') + parser.add_argument('--compare', action='store_true', + help='Run both SP and muP and compare') + parser.add_argument('--save-dir', type=str, default=None, + help='Directory to save plots') + parser.add_argument('--seed', type=int, default=42, + help='Random seed') + parser.add_argument('--detailed', action='store_true', + help='Record detailed diagnostics: gradient norms, weight update norms, ' + 'attention logit magnitudes') + parser.add_argument('--muon-lr-exponent', type=float, default=0.0, + help='Muon LR exponent for muP: 1.0 = (base/width)^1 (standard muP), ' + '0.5 = (base/width)^0.5 (for Frobenius-normalizing optimizers, ' + 'see Yang et al. Section C.1)') + + args = parser.parse_args() + + # Parse widths + widths = [int(w) for w in args.widths.split(',')] + + # Setup device + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + print(f"Using device: {device}") + + # Load a single batch of real training data (reused every step) + x, y, vocab_size = load_batch(args.batch_size, args.seq_len, device) + + # Create config + config = CoordCheckConfig( + widths=widths, + num_steps=args.steps, + batch_size=args.batch_size, + seq_len=args.seq_len, + vocab_size=vocab_size, + n_layer=args.n_layer, + seed=args.seed, + use_mup=args.use_mup, + base_width=args.base_width, + detailed=args.detailed, + muon_lr_exponent=args.muon_lr_exponent, + ) + + if args.compare: + # Run both SP and muP + print("\n" + "="*60) + print("Running Standard Parameterization (SP)") + print("="*60) + config.use_mup = False + results_sp = run_coord_check(config, device, x, y) + + print("\n" + "="*60) + print("Running muP") + if config.muon_lr_exponent != 1.0: + print(f" (Muon LR exponent: {config.muon_lr_exponent})") + print("="*60) + config.use_mup = True + results_mup = run_coord_check(config, device, x, y) + + # Compute slopes + print("\n" + "="*60) + print("Width Dependence (slope on log-log plot)") + print("Expected: ~0 for width-independent, positive = grows with width") + print("="*60) + + slopes_sp = compute_width_dependence(results_sp) + slopes_mup = compute_width_dependence(results_mup) + + print(f"\n{'Layer':<20} {'SP Slope':>12} {'muP Slope':>12}") + print("-"*46) + for layer in slopes_sp: + print(f"{layer:<20} {slopes_sp[layer]:>12.4f} {slopes_mup[layer]:>12.4f}") + + # Plot comparison + save_path = None + if args.save_dir: + os.makedirs(args.save_dir, exist_ok=True) + save_path = os.path.join(args.save_dir, 'coord_check_comparison.png') + plot_comparison(results_sp, results_mup, config, save_path) + + # Plot detailed diagnostics if requested + if config.detailed: + for results, label in [(results_sp, 'SP'), (results_mup, 'muP')]: + config.use_mup = (label == 'muP') + detail_save = None + if args.save_dir: + detail_save = os.path.join(args.save_dir, f'detailed_{label.lower()}.png') + plot_detailed(results, config, detail_save) + + else: + # Run single mode + param_type = "muP" if config.use_mup else "SP" + print(f"\n{'='*60}") + print(f"Running Coordinate Check ({param_type})") + print(f"{'='*60}") + print(f"Widths: {widths}") + print(f"Steps: {config.num_steps}") + print(f"Base width: {config.base_width}") + if config.use_mup and config.muon_lr_exponent != 1.0: + print(f"Muon LR exponent: {config.muon_lr_exponent}") + + results = run_coord_check(config, device, x, y) + + # Compute slopes + slopes = compute_width_dependence(results) + print("\n" + "="*60) + print("Width Dependence (slope on log-log plot)") + print("Expected for muP: ~0 (width-independent)") + print("="*60) + for layer, slope in slopes.items(): + status = "OK" if abs(slope) < 0.1 else "WARN" + print(f" {layer}: {slope:+.4f} [{status}]") + + # Loss curve analysis + final_losses = [results['losses'][w][-1] for w in results['widths']] + loss_spread = max(final_losses) - min(final_losses) + print(f"\nFinal loss spread across widths: {loss_spread:.4f}") + print(f"Expected for muP: low spread (similar losses across widths)") + + # Plot activations + save_path = None + if args.save_dir: + os.makedirs(args.save_dir, exist_ok=True) + save_path = os.path.join(args.save_dir, f'coord_check_{param_type.lower()}.png') + plot_coord_check(results, config, save_path) + + # Plot loss curves + loss_save_path = None + if args.save_dir: + loss_save_path = os.path.join(args.save_dir, f'loss_curves_{param_type.lower()}.png') + plot_loss_curves(results, config, title=param_type, save_path=loss_save_path) + + # Plot detailed diagnostics if requested + if config.detailed: + detail_save = None + if args.save_dir: + detail_save = os.path.join(args.save_dir, f'detailed_{param_type.lower()}.png') + plot_detailed(results, config, detail_save) + + +if __name__ == '__main__': + main() diff --git a/scripts/mup_transfer_check.py b/scripts/mup_transfer_check.py new file mode 100644 index 0000000..567f1b4 --- /dev/null +++ b/scripts/mup_transfer_check.py @@ -0,0 +1,672 @@ +""" +muP Transfer Check for nanochat + +Validates that optimal learning rates transfer across model widths under muP. +For each width, sweeps over LR multipliers and records final loss. Under correct +muP, the optimal LR multiplier should be ~1.0 at all widths (i.e., the same LR +works everywhere). Under SP, the optimal LR typically shifts with width. + +Reference: https://blog.eleuther.ai/mutransfer/ +Reference: Yang et al., "Tensor Programs V: Tuning Large Neural Networks via Zero-Shot + Hyperparameter Transfer" (arXiv:2203.03466), Section F. + +Usage: + # Quick check (~2 min on GPU) + python -m scripts.mup_transfer_check + + # Compare SP vs muP side-by-side + python -m scripts.mup_transfer_check --compare + + # Wide LR sweep (paper-style, ~1000x range) + python -m scripts.mup_transfer_check --compare --widths 128,256,512,1024 --steps 200 + + # Random log-uniform LR trials (paper-style methodology) + python -m scripts.mup_transfer_check --compare --num-random-trials 20 + + # Multi-HP sweep (init scale + output multiplier) + python -m scripts.mup_transfer_check --compare --sweep-init-scale --sweep-output-mult + + # Save plots + python -m scripts.mup_transfer_check --compare --save-dir plots/ +""" + +import argparse +import torch +import torch._dynamo +torch._dynamo.config.disable = True +import matplotlib.pyplot as plt +import numpy as np +from collections import defaultdict +from dataclasses import dataclass, field +from typing import Dict, List, Optional +import os + +from nanochat.gpt import GPT, GPTConfig + + +@dataclass +class TransferCheckConfig: + widths: List[int] + lr_multipliers: List[float] + num_steps: int = 200 + batch_size: int = 8 + seq_len: int = 128 + vocab_size: int = 32768 + n_layer: int = 2 + seed: int = 42 + use_mup: bool = False + base_width: int = 128 + # Base learning rates (tuned at base_width) + matrix_lr: float = 0.02 + embedding_lr: float = 0.2 + unembedding_lr: float = 0.004 + # Multi-HP sweeps + sweep_init_scale: bool = False + sweep_output_mult: bool = False + # Data diversity + num_batches: int = 1 + # Muon LR exponent for muP (1.0=standard, 0.5=Frobenius-norm optimizers) + muon_lr_exponent: float = 0.0 + # Sweep mode: which optimizer groups the LR multiplier applies to + # "all" = multiply all LRs (default), "muon-only" = only matrix_lr, + # "adamw-only" = only embedding_lr/unembedding_lr + sweep_mode: str = "all" + + +def load_batches(num_batches: int, batch_size: int, seq_len: int, device: torch.device): + """Load multiple batches from the nanochat training pipeline. + Falls back to random data if the tokenizer/dataset isn't available.""" + try: + from nanochat.tokenizer import get_tokenizer + from nanochat.dataloader import tokenizing_distributed_data_loader_bos_bestfit + tokenizer = get_tokenizer() + vocab_size = tokenizer.get_vocab_size() + loader = tokenizing_distributed_data_loader_bos_bestfit( + tokenizer, batch_size, seq_len, split="train", device=device, + ) + batches = [] + for i, (x, y) in enumerate(loader): + batches.append((x, y)) + if len(batches) >= num_batches: + break + print(f"Loaded {len(batches)} real training batch(es) (vocab_size={vocab_size})") + return batches, vocab_size + except Exception as e: + print(f"Could not load training data ({e}), using random tokens") + vocab_size = 32768 + batches = [] + for i in range(num_batches): + rng = torch.Generator(device=device) + rng.manual_seed(42 + i) + x = torch.randint(0, vocab_size, (batch_size, seq_len), device=device, generator=rng) + y = torch.roll(x, -1, dims=1) + y[:, -1] = -1 + batches.append((x, y)) + return batches, vocab_size + + +def create_model(width: int, config: TransferCheckConfig, device: torch.device, + mup_base_width: int = 0, init_scale: float = 1.0, + output_mult: float = 1.0): + """Create a model with the specified width and optional HP overrides.""" + head_dim = 64 + n_head = max(1, width // head_dim) + actual_width = n_head * head_dim + + gpt_config = GPTConfig( + sequence_len=config.seq_len, + vocab_size=config.vocab_size, + n_layer=config.n_layer, + n_head=n_head, + n_kv_head=n_head, + n_embd=actual_width, + window_pattern="L", + mup_base_width=mup_base_width, + ) + + with torch.device('meta'): + model = GPT(gpt_config) + model.to_empty(device=device) + model.init_weights() + + # Apply init_scale: multiply all parameter inits by scalar + if init_scale != 1.0: + with torch.no_grad(): + for p in model.parameters(): + p.mul_(init_scale) + + # Apply output_mult: scale the output logit multiplier + # We store it as an attribute that forward() checks + model._transfer_check_output_mult = output_mult + + return model, gpt_config + + +def train_model(width: int, lr_mult: float, config: TransferCheckConfig, + device: torch.device, batches: List, + init_scale: float = 1.0, output_mult: float = 1.0): + """Train a model at given width and LR multiplier, return loss history.""" + torch.manual_seed(config.seed) + + mup_base_width = config.base_width if config.use_mup else 0 + model, gpt_config = create_model(width, config, device, mup_base_width=mup_base_width, + init_scale=init_scale, output_mult=output_mult) + actual_width = gpt_config.n_embd + + # Scale the learning rates by the multiplier, respecting sweep_mode + if config.sweep_mode == "muon-only": + muon_mult, adamw_mult = lr_mult, 1.0 + elif config.sweep_mode == "adamw-only": + muon_mult, adamw_mult = 1.0, lr_mult + else: # "all" + muon_mult, adamw_mult = lr_mult, lr_mult + + optimizer = model.setup_optimizer( + unembedding_lr=config.unembedding_lr * adamw_mult, + embedding_lr=config.embedding_lr * adamw_mult, + matrix_lr=config.matrix_lr * muon_mult, + weight_decay=0.0, + use_mup=config.use_mup, + base_width=config.base_width, + muon_lr_exponent=config.muon_lr_exponent, + ) + + model.train() + losses = [] + num_batches = len(batches) + + for step in range(config.num_steps): + x, y = batches[step % num_batches] + with torch.amp.autocast(device_type='cuda', dtype=torch.bfloat16, enabled=(device.type == 'cuda')): + loss = model(x, y) + + losses.append(loss.item()) + loss.backward() + optimizer.step() + optimizer.zero_grad(set_to_none=True) + + del model, optimizer + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + return losses, actual_width + + +def run_transfer_check(config: TransferCheckConfig, device: torch.device, + batches: List) -> Dict: + """Run LR sweep across all widths.""" + results = { + 'widths': [], + 'lr_multipliers': config.lr_multipliers, + 'losses': {}, # losses[(width, lr_mult)] = [loss_step0, ...] + 'final_losses': defaultdict(dict), # final_losses[width][lr_mult] = final_loss + } + + for width in config.widths: + actual_width = None + for lr_mult in config.lr_multipliers: + print(f" width={width}, lr_mult={lr_mult:.4f}...", end=" ", flush=True) + + losses, actual_width = train_model(width, lr_mult, config, device, batches) + results['losses'][(actual_width, lr_mult)] = losses + results['final_losses'][actual_width][lr_mult] = losses[-1] + print(f"final_loss={losses[-1]:.4f}") + + if actual_width not in results['widths']: + results['widths'].append(actual_width) + + return results + + +def run_hp_sweep(config: TransferCheckConfig, device: torch.device, + batches: List, hp_name: str, hp_values: List[float]) -> Dict: + """Run a sweep over a single HP (init_scale or output_mult) at fixed LR.""" + results = { + 'widths': [], + 'hp_values': hp_values, + 'hp_name': hp_name, + 'final_losses': defaultdict(dict), + } + + for width in config.widths: + actual_width = None + for hp_val in hp_values: + init_scale = hp_val if hp_name == 'init_scale' else 1.0 + output_mult = hp_val if hp_name == 'output_mult' else 1.0 + print(f" width={width}, {hp_name}={hp_val:.4f}...", end=" ", flush=True) + + losses, actual_width = train_model( + width, 1.0, config, device, batches, + init_scale=init_scale, output_mult=output_mult) + results['final_losses'][actual_width][hp_val] = losses[-1] + print(f"final_loss={losses[-1]:.4f}") + + if actual_width not in results['widths']: + results['widths'].append(actual_width) + + return results + + +def find_optimal_lr(final_losses: Dict[float, float]) -> float: + """Find the LR multiplier with the lowest final loss.""" + return min(final_losses, key=final_losses.get) + + +def plot_lr_sweep(results: Dict, config: TransferCheckConfig, title: str = "", save_path: Optional[str] = None): + """Plot LR sweep: final loss vs LR multiplier for each width.""" + widths = results['widths'] + lr_mults = results['lr_multipliers'] + final_losses = results['final_losses'] + + n_cols = 2 + fig, axes = plt.subplots(1, n_cols, figsize=(5 * n_cols, 4)) + colors = plt.cm.viridis(np.linspace(0, 0.85, len(widths))) + + # Left: final loss vs LR multiplier + ax = axes[0] + for i, w in enumerate(widths): + losses = [final_losses[w][m] for m in lr_mults] + ax.plot(lr_mults, losses, 'o-', color=colors[i], linewidth=2, label=f'width={w}') + opt_mult = find_optimal_lr(final_losses[w]) + opt_loss = final_losses[w][opt_mult] + ax.plot(opt_mult, opt_loss, '*', color=colors[i], markersize=15, zorder=5) + + ax.set_xscale('log', base=2) + ax.set_xlabel('LR Multiplier') + ax.set_ylabel('Final Loss') + ax.set_title(f'LR Sweep{" - " + title if title else ""}') + ax.legend() + ax.grid(True, alpha=0.3) + + # Right: optimal LR multiplier vs width + ax = axes[1] + opt_mults = [find_optimal_lr(final_losses[w]) for w in widths] + ax.plot(widths, opt_mults, 'o-', linewidth=2, markersize=8, color='tab:blue') + ax.axhline(y=1.0, color='gray', linestyle='--', alpha=0.5, label='target (1.0)') + ax.set_xscale('log', base=2) + ax.set_yscale('log', base=2) + ax.set_xticks(widths) + ax.set_xticklabels(widths) + ax.set_xlabel('Width') + ax.set_ylabel('Optimal LR Multiplier') + ax.set_title(f'Optimal LR vs Width{" - " + title if title else ""}') + ax.legend() + ax.grid(True, alpha=0.3) + + plt.tight_layout() + if save_path: + plt.savefig(save_path, dpi=150, bbox_inches='tight') + print(f"Saved plot to {save_path}") + plt.show() + + +def plot_comparison(results_sp: Dict, results_mup: Dict, config: TransferCheckConfig, save_path: Optional[str] = None): + """Plot SP vs muP comparison side by side.""" + widths = results_sp['widths'] + lr_mults = results_sp['lr_multipliers'] + + n_rows, n_cols = 2, 2 + fig, axes = plt.subplots(n_rows, n_cols, figsize=(5 * n_cols, 4 * n_rows)) + colors = plt.cm.viridis(np.linspace(0, 0.85, len(widths))) + + # Top row: LR sweep curves + for col, (results, label) in enumerate([(results_sp, 'SP'), (results_mup, 'muP')]): + ax = axes[0, col] + for i, w in enumerate(widths): + losses = [results['final_losses'][w][m] for m in lr_mults] + ax.plot(lr_mults, losses, 'o-', color=colors[i], linewidth=2, label=f'w={w}') + opt_mult = find_optimal_lr(results['final_losses'][w]) + opt_loss = results['final_losses'][w][opt_mult] + ax.plot(opt_mult, opt_loss, '*', color=colors[i], markersize=15, zorder=5) + ax.set_xscale('log', base=2) + ax.set_xlabel('LR Multiplier') + ax.set_ylabel('Final Loss') + ax.set_title(f'{label}: Final Loss vs LR Multiplier') + ax.legend(fontsize=8) + ax.grid(True, alpha=0.3) + + # Shared y-axis for top row + y_min = min( + min(results_sp['final_losses'][w][m] for m in lr_mults for w in widths), + min(results_mup['final_losses'][w][m] for m in lr_mults for w in widths), + ) * 0.98 + y_max = max( + max(results_sp['final_losses'][w][m] for m in lr_mults for w in widths), + max(results_mup['final_losses'][w][m] for m in lr_mults for w in widths), + ) * 1.02 + axes[0, 0].set_ylim(y_min, y_max) + axes[0, 1].set_ylim(y_min, y_max) + + # Bottom left: optimal LR vs width for both + ax = axes[1, 0] + for results, label, color in [(results_sp, 'SP', 'tab:red'), (results_mup, 'muP', 'tab:blue')]: + opt_mults = [find_optimal_lr(results['final_losses'][w]) for w in widths] + ax.plot(widths, opt_mults, 'o-', linewidth=2, markersize=8, color=color, label=label) + ax.axhline(y=1.0, color='gray', linestyle='--', alpha=0.5, label='target') + ax.set_xscale('log', base=2) + ax.set_yscale('log', base=2) + ax.set_xticks(widths) + ax.set_xticklabels(widths) + ax.set_xlabel('Width') + ax.set_ylabel('Optimal LR Multiplier') + ax.set_title('Optimal LR Multiplier vs Width') + ax.legend() + ax.grid(True, alpha=0.3) + + # Bottom right: loss at optimal LR vs width + ax = axes[1, 1] + for results, label, color in [(results_sp, 'SP', 'tab:red'), (results_mup, 'muP', 'tab:blue')]: + opt_losses = [results['final_losses'][w][find_optimal_lr(results['final_losses'][w])] for w in widths] + ax.plot(widths, opt_losses, 'o-', linewidth=2, markersize=8, color=color, label=label) + ax.set_xscale('log', base=2) + ax.set_xticks(widths) + ax.set_xticklabels(widths) + ax.set_xlabel('Width') + ax.set_ylabel('Best Final Loss') + ax.set_title('Best Achievable Loss vs Width') + ax.legend() + ax.grid(True, alpha=0.3) + + fig.suptitle('muP Transfer Check: SP vs muP', fontsize=14) + plt.tight_layout() + + if save_path: + plt.savefig(save_path, dpi=150, bbox_inches='tight') + print(f"Saved comparison plot to {save_path}") + plt.show() + + +def plot_hp_sweep(results: Dict, config: TransferCheckConfig, title: str = "", save_path: Optional[str] = None): + """Plot HP sweep: final loss vs HP value for each width.""" + widths = results['widths'] + hp_values = results['hp_values'] + hp_name = results['hp_name'] + final_losses = results['final_losses'] + + fig, axes = plt.subplots(1, 2, figsize=(10, 4)) + colors = plt.cm.viridis(np.linspace(0, 0.85, len(widths))) + + ax = axes[0] + for i, w in enumerate(widths): + losses = [final_losses[w][v] for v in hp_values] + ax.plot(hp_values, losses, 'o-', color=colors[i], linewidth=2, label=f'w={w}') + opt_v = min(final_losses[w], key=final_losses[w].get) + ax.plot(opt_v, final_losses[w][opt_v], '*', color=colors[i], markersize=15, zorder=5) + ax.set_xscale('log', base=2) + ax.set_xlabel(hp_name) + ax.set_ylabel('Final Loss') + ax.set_title(f'{hp_name} Sweep{" - " + title if title else ""}') + ax.legend(fontsize=8) + ax.grid(True, alpha=0.3) + + ax = axes[1] + opt_vals = [min(final_losses[w], key=final_losses[w].get) for w in widths] + ax.plot(widths, opt_vals, 'o-', linewidth=2, markersize=8, color='tab:blue') + ax.axhline(y=1.0, color='gray', linestyle='--', alpha=0.5, label='target (1.0)') + ax.set_xscale('log', base=2) + ax.set_yscale('log', base=2) + ax.set_xlabel('Width') + ax.set_ylabel(f'Optimal {hp_name}') + ax.set_title(f'Optimal {hp_name} vs Width') + ax.legend() + ax.grid(True, alpha=0.3) + + plt.tight_layout() + if save_path: + plt.savefig(save_path, dpi=150, bbox_inches='tight') + print(f"Saved plot to {save_path}") + plt.show() + + +def plot_loss_curves_at_optimal(results: Dict, config: TransferCheckConfig, title: str = "", save_path: Optional[str] = None): + """Plot full loss curves at the optimal LR for each width.""" + widths = results['widths'] + + fig, ax = plt.subplots(figsize=(5 * 2, 4)) + colors = plt.cm.viridis(np.linspace(0, 0.85, len(widths))) + + for i, w in enumerate(widths): + opt_mult = find_optimal_lr(results['final_losses'][w]) + losses = results['losses'][(w, opt_mult)] + ax.plot(losses, color=colors[i], linewidth=2, label=f'w={w} (lr_mult={opt_mult})') + + ax.set_xlabel('Step') + ax.set_ylabel('Loss') + ax.set_title(f'Loss Curves at Optimal LR{" - " + title if title else ""}') + ax.legend() + ax.grid(True, alpha=0.3) + plt.tight_layout() + + if save_path: + plt.savefig(save_path, dpi=150, bbox_inches='tight') + print(f"Saved plot to {save_path}") + plt.show() + + +def print_summary(results: Dict, label: str): + """Print a summary table of the LR sweep results.""" + widths = results['widths'] + lr_mults = results['lr_multipliers'] + final_losses = results['final_losses'] + + print(f"\n{'='*70}") + print(f" {label}: LR Sweep Results") + print(f"{'='*70}") + + # Header + header = f"{'Width':>8}" + for m in lr_mults: + header += f" | {m:>7.3f}" + header += f" | {'Best':>7} | {'Opt LR':>7}" + print(header) + print("-" * len(header)) + + opt_mults = [] + for w in widths: + row = f"{w:>8}" + for m in lr_mults: + loss = final_losses[w][m] + row += f" | {loss:>7.4f}" + opt_m = find_optimal_lr(final_losses[w]) + opt_mults.append(opt_m) + opt_loss = final_losses[w][opt_m] + row += f" | {opt_loss:>7.4f} | {opt_m:>7.3f}" + print(row) + + # Transfer quality metric: how much does the optimal LR shift? + opt_mults_arr = np.array(opt_mults) + log_opt = np.log2(opt_mults_arr) + spread = log_opt.max() - log_opt.min() # spread in log2 space + print(f"\nOptimal LR spread (log2): {spread:.3f}") + print(f" (0 = perfect transfer, >1 = poor transfer)") + + +def main(): + parser = argparse.ArgumentParser(description='muP Transfer Check') + parser.add_argument('--widths', type=str, default='128,256,512,1024', + help='Comma-separated list of widths to test') + # Paper-style default: ~1000x range, 11 log-spaced points + parser.add_argument('--lr-mults', type=str, + default='0.03125,0.0625,0.125,0.25,0.5,1.0,2.0,4.0,8.0,16.0,32.0', + help='Comma-separated LR multipliers to sweep (default: 1024x range, 11 points)') + parser.add_argument('--num-random-trials', type=int, default=0, + help='If >0, use N log-uniform random LR multipliers from 10^Uniform(-1.5,1.5) ' + 'instead of the grid. Paper-style methodology (Section F).') + parser.add_argument('--steps', type=int, default=200, + help='Number of training steps per run') + parser.add_argument('--batch-size', type=int, default=8, + help='Batch size') + parser.add_argument('--seq-len', type=int, default=128, + help='Sequence length') + parser.add_argument('--n-layer', type=int, default=2, + help='Number of transformer layers') + parser.add_argument('--use-mup', action='store_true', + help='Use muP learning rate scaling') + parser.add_argument('--base-width', type=int, default=128, + help='Base width for muP scaling') + parser.add_argument('--compare', action='store_true', + help='Run both SP and muP and compare') + parser.add_argument('--save-dir', type=str, default=None, + help='Directory to save plots') + parser.add_argument('--seed', type=int, default=42, + help='Random seed') + parser.add_argument('--num-batches', type=int, default=1, + help='Number of data batches to cycle through (default 1 for backward compat, ' + 'recommend 8 for thorough checks)') + # Multi-HP sweeps + parser.add_argument('--sweep-init-scale', action='store_true', + help='Also sweep init scale multiplier (sampled from 10^Uniform(-1,1))') + parser.add_argument('--sweep-output-mult', action='store_true', + help='Also sweep output logit multiplier (sampled from 4^Uniform(-1,1))') + parser.add_argument('--muon-lr-exponent', type=float, default=0.0, + help='Muon LR exponent for muP: 1.0 = (base/width)^1 (standard), ' + '0.5 = (base/width)^0.5 (for Frobenius-normalizing optimizers like Muon)') + parser.add_argument('--sweep-mode', type=str, default='all', + choices=['all', 'muon-only', 'adamw-only'], + help='Which optimizer groups the LR multiplier applies to: ' + '"all" = all LRs (default), "muon-only" = only Muon/matrix LR, ' + '"adamw-only" = only AdamW/embedding/output LR') + + args = parser.parse_args() + + widths = [int(w) for w in args.widths.split(',')] + + # Generate LR multipliers + if args.num_random_trials > 0: + # Log-uniform random sampling: 10^Uniform(-1.5, 1.5) + rng = np.random.RandomState(args.seed) + lr_mults = sorted(10 ** rng.uniform(-1.5, 1.5, args.num_random_trials)) + lr_mults = [round(float(m), 6) for m in lr_mults] + print(f"Using {args.num_random_trials} random log-uniform LR multipliers: " + f"[{lr_mults[0]:.4f}, ..., {lr_mults[-1]:.4f}]") + else: + lr_mults = sorted(float(m) for m in args.lr_mults.split(',')) + + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + print(f"Using device: {device}") + + # Load batches + batches, vocab_size = load_batches(args.num_batches, args.batch_size, args.seq_len, device) + + config = TransferCheckConfig( + widths=widths, + lr_multipliers=lr_mults, + num_steps=args.steps, + batch_size=args.batch_size, + seq_len=args.seq_len, + vocab_size=vocab_size, + n_layer=args.n_layer, + seed=args.seed, + use_mup=args.use_mup, + base_width=args.base_width, + sweep_init_scale=args.sweep_init_scale, + sweep_output_mult=args.sweep_output_mult, + num_batches=args.num_batches, + muon_lr_exponent=args.muon_lr_exponent, + sweep_mode=args.sweep_mode, + ) + + if args.compare: + # Run SP + print("\n" + "=" * 60) + print("Running Standard Parameterization (SP)") + print("=" * 60) + config.use_mup = False + results_sp = run_transfer_check(config, device, batches) + print_summary(results_sp, "SP") + + # Run muP + print("\n" + "=" * 60) + print("Running muP") + print("=" * 60) + config.use_mup = True + results_mup = run_transfer_check(config, device, batches) + print_summary(results_mup, "muP") + + # Compare + print("\n" + "=" * 60) + print("COMPARISON") + print("=" * 60) + sp_opts = [find_optimal_lr(results_sp['final_losses'][w]) for w in results_sp['widths']] + mup_opts = [find_optimal_lr(results_mup['final_losses'][w]) for w in results_mup['widths']] + sp_spread = np.log2(max(sp_opts)) - np.log2(min(sp_opts)) + mup_spread = np.log2(max(mup_opts)) - np.log2(min(mup_opts)) + print(f"SP optimal LR spread (log2): {sp_spread:.3f}") + print(f"muP optimal LR spread (log2): {mup_spread:.3f}") + if mup_spread < sp_spread: + print(f"muP shows {sp_spread/max(mup_spread, 0.001):.1f}x better LR transfer!") + else: + print("muP does NOT show better LR transfer (check implementation)") + + # Plot + save_path = None + if args.save_dir: + os.makedirs(args.save_dir, exist_ok=True) + save_path = os.path.join(args.save_dir, 'transfer_check_comparison.png') + plot_comparison(results_sp, results_mup, config, save_path) + + # Also plot loss curves at optimal LR + for results, label in [(results_sp, 'SP'), (results_mup, 'muP')]: + lc_save = None + if args.save_dir: + lc_save = os.path.join(args.save_dir, f'optimal_loss_curves_{label.lower()}.png') + plot_loss_curves_at_optimal(results, config, title=label, save_path=lc_save) + + # Multi-HP sweeps (only for muP, to demonstrate transfer) + if args.sweep_init_scale or args.sweep_output_mult: + config.use_mup = True + + if args.sweep_init_scale: + print("\n" + "=" * 60) + print("muP: Init Scale Sweep") + print("=" * 60) + # 10^Uniform(-1, 1) => range [0.1, 10] + init_scales = [0.1, 0.2, 0.5, 1.0, 2.0, 5.0, 10.0] + init_results = run_hp_sweep(config, device, batches, 'init_scale', init_scales) + save_hp = None + if args.save_dir: + save_hp = os.path.join(args.save_dir, 'init_scale_sweep.png') + plot_hp_sweep(init_results, config, title="muP", save_path=save_hp) + + if args.sweep_output_mult: + print("\n" + "=" * 60) + print("muP: Output Multiplier Sweep") + print("=" * 60) + # 4^Uniform(-1, 1) => range [0.25, 4] + output_mults = [0.25, 0.5, 1.0, 2.0, 4.0] + output_results = run_hp_sweep(config, device, batches, 'output_mult', output_mults) + save_hp = None + if args.save_dir: + save_hp = os.path.join(args.save_dir, 'output_mult_sweep.png') + plot_hp_sweep(output_results, config, title="muP", save_path=save_hp) + + else: + param_type = "muP" if config.use_mup else "SP" + print(f"\n{'='*60}") + print(f"Running Transfer Check ({param_type})") + print(f"{'='*60}") + print(f"Widths: {widths}") + print(f"LR multipliers: {lr_mults}") + print(f"Steps: {config.num_steps}") + if config.sweep_mode != "all": + print(f"Sweep mode: {config.sweep_mode}") + + results = run_transfer_check(config, device, batches) + print_summary(results, param_type) + + # Plot LR sweep + save_path = None + if args.save_dir: + os.makedirs(args.save_dir, exist_ok=True) + save_path = os.path.join(args.save_dir, f'transfer_check_{param_type.lower()}.png') + plot_lr_sweep(results, config, title=param_type, save_path=save_path) + + # Plot loss curves at optimal LR + lc_save = None + if args.save_dir: + lc_save = os.path.join(args.save_dir, f'optimal_loss_curves_{param_type.lower()}.png') + plot_loss_curves_at_optimal(results, config, title=param_type, save_path=lc_save) + + +if __name__ == '__main__': + main()