mirror of
https://github.com/karpathy/nanochat.git
synced 2026-04-01 21:25:21 +00:00
adding muP considerations for muon
This commit is contained in:
parent
c7ba252142
commit
f07ca1b268
|
|
@ -37,6 +37,9 @@ class GPTConfig:
|
|||
# Characters: L=long (full context), S=short (half context)
|
||||
# Examples: "L"=all full context, "SL"=alternating, "SSL"=two short then one long
|
||||
window_pattern: str = "SSSL"
|
||||
# muP (Maximal Update Parametrization): set > 0 to enable. Value is the base/proxy width.
|
||||
# Enables: non-zero c_proj init scaled as 1/sqrt(m_d), output logit scaling by base_width/n_embd.
|
||||
mup_base_width: int = 0
|
||||
|
||||
|
||||
def norm(x):
|
||||
|
|
@ -203,7 +206,12 @@ class GPT(nn.Module):
|
|||
|
||||
# Embedding and unembedding
|
||||
torch.nn.init.normal_(self.transformer.wte.weight, mean=0.0, std=1.0)
|
||||
torch.nn.init.normal_(self.lm_head.weight, mean=0.0, std=0.001)
|
||||
lm_head_std = 0.001
|
||||
if self.config.mup_base_width > 0:
|
||||
# muP: scale lm_head init by 1/sqrt(m_d) so raw logit magnitude is O(1) across widths.
|
||||
# Without this, |logit| ~ 0.001 * sqrt(n_embd) grows as sqrt(width).
|
||||
lm_head_std *= (self.config.mup_base_width / self.config.n_embd) ** 0.5
|
||||
torch.nn.init.normal_(self.lm_head.weight, mean=0.0, std=lm_head_std)
|
||||
|
||||
# Transformer blocks: uniform init with bound = sqrt(3) * std (same standard deviation as normal)
|
||||
n_embd = self.config.n_embd
|
||||
|
|
@ -212,9 +220,15 @@ class GPT(nn.Module):
|
|||
torch.nn.init.uniform_(block.attn.c_q.weight, -s, s) # weights use Uniform to avoid outliers
|
||||
torch.nn.init.uniform_(block.attn.c_k.weight, -s, s)
|
||||
torch.nn.init.uniform_(block.attn.c_v.weight, -s, s)
|
||||
torch.nn.init.zeros_(block.attn.c_proj.weight) # projections are zero
|
||||
torch.nn.init.uniform_(block.mlp.c_fc.weight, -s, s)
|
||||
torch.nn.init.zeros_(block.mlp.c_proj.weight)
|
||||
if self.config.mup_base_width > 0:
|
||||
# muP: output projections use same scale as hidden weights (std = sigma_base/sqrt(m_d))
|
||||
# Zero init causes attn/FFN outputs to vanish as width increases with muP LR scaling
|
||||
torch.nn.init.uniform_(block.attn.c_proj.weight, -s, s)
|
||||
torch.nn.init.uniform_(block.mlp.c_proj.weight, -s, s)
|
||||
else:
|
||||
torch.nn.init.zeros_(block.attn.c_proj.weight) # SP: projections are zero
|
||||
torch.nn.init.zeros_(block.mlp.c_proj.weight)
|
||||
|
||||
# Per-layer scalars
|
||||
self.resid_lambdas.fill_(1.0) # 1.0 => typical residual connections at init
|
||||
|
|
@ -345,7 +359,7 @@ class GPT(nn.Module):
|
|||
'total': total,
|
||||
}
|
||||
|
||||
def setup_optimizer(self, unembedding_lr=0.004, embedding_lr=0.2, matrix_lr=0.02, weight_decay=0.0, adam_betas=(0.8, 0.95), scalar_lr=0.5):
|
||||
def setup_optimizer(self, unembedding_lr=0.004, embedding_lr=0.2, matrix_lr=0.02, weight_decay=0.0, adam_betas=(0.8, 0.95), scalar_lr=0.5, use_mup=False, base_width=256, muon_lr_exponent=0.0):
|
||||
model_dim = self.config.n_embd
|
||||
ddp, rank, local_rank, world_size = get_dist_info()
|
||||
|
||||
|
|
@ -358,16 +372,41 @@ class GPT(nn.Module):
|
|||
x0_params = [self.x0_lambdas]
|
||||
assert len(list(self.parameters())) == len(matrix_params) + len(embedding_params) + len(lm_head_params) + len(value_embeds_params) + len(resid_params) + len(x0_params)
|
||||
|
||||
# Scale the LR for the AdamW parameters by ∝1/√dmodel (tuned for 768 dim model)
|
||||
dmodel_lr_scale = (model_dim / 768) ** -0.5
|
||||
print0(f"Scaling the LR for the AdamW parameters ∝1/√({model_dim}/768) = {dmodel_lr_scale:.6f}")
|
||||
# Compute LR scaling factors based on mode
|
||||
if use_mup:
|
||||
# muP for mixed Muon+AdamW optimizer:
|
||||
#
|
||||
# Key insight: the output logit scaling (logits *= base/width in forward()) already
|
||||
# propagates a base/width factor into the lm_head gradient. Adam normalizes gradient
|
||||
# magnitude via its second moment, but applying ANOTHER base/width to the LR compounds
|
||||
# the scaling, making lm_head updates O(base²/width²) instead of O(1). So we do NOT
|
||||
# apply width-dependent LR scaling to the output layer — the logit scaling alone suffices.
|
||||
#
|
||||
# For Muon (hidden weights): Polar Express orthogonalization normalizes ||update||_F ≈ 1
|
||||
# regardless of width, making the update already O(1). No width-dependent LR scaling
|
||||
# is needed (empirically confirmed: exponent 0 and 1 give identical transfer behavior).
|
||||
#
|
||||
# The muon_lr_exponent parameter is kept for experimentation but defaults to 0.
|
||||
width_ratio = base_width / model_dim # e.g., 128/1024 = 0.125
|
||||
emb_lr_scale = 1.0 # Embeddings: NO width scaling (standard muP)
|
||||
hidden_lr_scale = width_ratio ** muon_lr_exponent # Hidden (Muon): default 0 = no scaling
|
||||
output_lr_scale = 1.0 # Output (AdamW): NO LR scaling (logit scaling in forward suffices)
|
||||
self.config.mup_base_width = base_width # enables output logit scaling in forward()
|
||||
print0(f"muP scaling: base_width={base_width}, model_dim={model_dim}, width_ratio={width_ratio:.6f}, muon_lr_exp={muon_lr_exponent}")
|
||||
else:
|
||||
# Standard (SP): scale AdamW params by 1/√dmodel (tuned for 768 dim model)
|
||||
dmodel_lr_scale = (model_dim / 768) ** -0.5
|
||||
emb_lr_scale = dmodel_lr_scale
|
||||
hidden_lr_scale = 1.0 # Muon params: no scaling in SP mode
|
||||
output_lr_scale = dmodel_lr_scale
|
||||
print0(f"Standard scaling: dmodel_lr_scale={dmodel_lr_scale:.6f}")
|
||||
|
||||
# Build param_groups with all required fields explicit
|
||||
param_groups = [
|
||||
# AdamW groups (embeddings, lm_head, scalars)
|
||||
dict(kind='adamw', params=lm_head_params, lr=unembedding_lr * dmodel_lr_scale, betas=adam_betas, eps=1e-10, weight_decay=0.0),
|
||||
dict(kind='adamw', params=embedding_params, lr=embedding_lr * dmodel_lr_scale, betas=adam_betas, eps=1e-10, weight_decay=0.0),
|
||||
dict(kind='adamw', params=value_embeds_params, lr=embedding_lr * dmodel_lr_scale, betas=adam_betas, eps=1e-10, weight_decay=0.0),
|
||||
dict(kind='adamw', params=lm_head_params, lr=unembedding_lr * output_lr_scale, betas=adam_betas, eps=1e-10, weight_decay=0.0),
|
||||
dict(kind='adamw', params=embedding_params, lr=embedding_lr * emb_lr_scale, betas=adam_betas, eps=1e-10, weight_decay=0.0),
|
||||
dict(kind='adamw', params=value_embeds_params, lr=embedding_lr * emb_lr_scale, betas=adam_betas, eps=1e-10, weight_decay=0.0),
|
||||
dict(kind='adamw', params=resid_params, lr=scalar_lr * 0.01, betas=adam_betas, eps=1e-10, weight_decay=0.0),
|
||||
dict(kind='adamw', params=x0_params, lr=scalar_lr, betas=(0.96, 0.95), eps=1e-10, weight_decay=0.0), # higher beta1 for x0
|
||||
]
|
||||
|
|
@ -375,7 +414,7 @@ class GPT(nn.Module):
|
|||
for shape in sorted({p.shape for p in matrix_params}):
|
||||
group_params = [p for p in matrix_params if p.shape == shape]
|
||||
param_groups.append(dict(
|
||||
kind='muon', params=group_params, lr=matrix_lr,
|
||||
kind='muon', params=group_params, lr=matrix_lr * hidden_lr_scale,
|
||||
momentum=0.95, ns_steps=5, beta2=0.95, weight_decay=weight_decay,
|
||||
))
|
||||
|
||||
|
|
@ -409,6 +448,11 @@ class GPT(nn.Module):
|
|||
# Forward the lm_head (compute logits)
|
||||
softcap = 15 # smoothly cap the logits to the range [-softcap, softcap]
|
||||
logits = self.lm_head(x) # (B, T, padded_vocab_size) <- very big tensor, large amount of memory
|
||||
if self.config.mup_base_width > 0:
|
||||
# muP: scale output logits by base_width/n_embd (= 1/m_d).
|
||||
# Without this, logits grow with width because the lm_head dot product sums over n_embd terms.
|
||||
# 1/sqrt(m_d) only corrects at init; 1/m_d is required for all training steps (see Eleuther blog Fig 8-9).
|
||||
logits = logits * (self.config.mup_base_width / self.config.n_embd)
|
||||
logits = logits[..., :self.config.vocab_size] # slice to remove padding
|
||||
logits = logits.float() # switch to fp32 for logit softcap and loss computation
|
||||
logits = softcap * torch.tanh(logits / softcap) # squash the logits
|
||||
|
|
|
|||
|
|
@ -68,6 +68,8 @@ parser.add_argument("--adam-beta2", type=float, default=0.95, help="Adam beta2 f
|
|||
parser.add_argument("--warmup-ratio", type=float, default=0.0, help="ratio of iterations for LR warmup")
|
||||
parser.add_argument("--warmdown-ratio", type=float, default=0.5, help="ratio of iterations for LR warmdown")
|
||||
parser.add_argument("--final-lr-frac", type=float, default=0.0, help="final LR as fraction of initial LR")
|
||||
parser.add_argument("--use-mup", action="store_true", help="use muP (Maximal Update Parameterization) LR scaling")
|
||||
parser.add_argument("--base-width", type=int, default=256, help="base width for muP LR scaling (LRs tuned at this width)")
|
||||
parser.add_argument("--resume-from-step", type=int, default=-1, help="resume training from this step (-1 = disable)")
|
||||
# Evaluation
|
||||
parser.add_argument("--eval-every", type=int, default=250, help="evaluate val bpb every N steps (-1 = disable)")
|
||||
|
|
@ -309,6 +311,9 @@ optimizer = model.setup_optimizer(
|
|||
# Muon hyperparameters
|
||||
matrix_lr=args.matrix_lr * batch_lr_scale,
|
||||
weight_decay=weight_decay_scaled,
|
||||
# muP scaling
|
||||
use_mup=args.use_mup,
|
||||
base_width=args.base_width,
|
||||
)
|
||||
|
||||
if resuming:
|
||||
|
|
|
|||
710
scripts/mup_coord_check.py
Normal file
710
scripts/mup_coord_check.py
Normal file
|
|
@ -0,0 +1,710 @@
|
|||
"""
|
||||
muP Coordinate Check for nanochat
|
||||
|
||||
This script validates muP implementation by checking that activation magnitudes
|
||||
are independent of model width. Based on EleutherAI's nanoGPT-mup and Microsoft's
|
||||
mup library.
|
||||
|
||||
Reference: https://blog.eleuther.ai/mutransfer/
|
||||
Reference: Yang et al., "Tensor Programs V: Tuning Large Neural Networks via Zero-Shot
|
||||
Hyperparameter Transfer" (arXiv:2203.03466), Sections B.1 and F.
|
||||
|
||||
Usage:
|
||||
python -m scripts.mup_coord_check --widths 128,256,512,1024 --steps 10
|
||||
python -m scripts.mup_coord_check --use-mup --widths 128,256,512,1024
|
||||
python -m scripts.mup_coord_check --compare --detailed
|
||||
python -m scripts.mup_coord_check --compare --muon-lr-exponent 0.5
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import torch
|
||||
import torch._dynamo
|
||||
torch._dynamo.config.disable = True
|
||||
import torch.nn.functional as F
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Tuple, Optional
|
||||
import os
|
||||
|
||||
from nanochat.gpt import GPT, GPTConfig
|
||||
|
||||
|
||||
def load_batch(batch_size: int, seq_len: int, device: torch.device):
|
||||
"""Load a single batch from the nanochat training pipeline.
|
||||
Falls back to random data if the tokenizer/dataset isn't available."""
|
||||
try:
|
||||
from nanochat.tokenizer import get_tokenizer
|
||||
from nanochat.dataloader import tokenizing_distributed_data_loader_bos_bestfit
|
||||
tokenizer = get_tokenizer()
|
||||
vocab_size = tokenizer.get_vocab_size()
|
||||
loader = tokenizing_distributed_data_loader_bos_bestfit(
|
||||
tokenizer, batch_size, seq_len, split="train", device=device,
|
||||
)
|
||||
x, y = next(loader)
|
||||
print(f"Loaded real training data (vocab_size={vocab_size})")
|
||||
return x, y, vocab_size
|
||||
except Exception as e:
|
||||
print(f"Could not load training data ({e}), using random tokens")
|
||||
vocab_size = 32768
|
||||
rng = torch.Generator(device=device)
|
||||
rng.manual_seed(42)
|
||||
x = torch.randint(0, vocab_size, (batch_size, seq_len), device=device, generator=rng)
|
||||
y = torch.roll(x, -1, dims=1)
|
||||
y[:, -1] = -1
|
||||
return x, y, vocab_size
|
||||
|
||||
|
||||
@dataclass
|
||||
class CoordCheckConfig:
|
||||
widths: List[int]
|
||||
num_steps: int = 10
|
||||
batch_size: int = 4
|
||||
seq_len: int = 128
|
||||
vocab_size: int = 32768
|
||||
n_layer: int = 2
|
||||
seed: int = 42
|
||||
use_mup: bool = False
|
||||
base_width: int = 128
|
||||
# Learning rates (tuned at base_width)
|
||||
matrix_lr: float = 0.02
|
||||
embedding_lr: float = 0.2
|
||||
unembedding_lr: float = 0.004
|
||||
# Detailed diagnostics
|
||||
detailed: bool = False
|
||||
# Muon LR exponent: 1.0 = base/width (standard muP), 0.5 = sqrt(base/width)
|
||||
# Paper Section C.1: Frobenius-normalizing optimizers may need exponent 0.5
|
||||
muon_lr_exponent: float = 0.0
|
||||
|
||||
|
||||
class ActivationRecorder:
|
||||
"""Records activation statistics during forward pass using hooks."""
|
||||
|
||||
def __init__(self, detailed: bool = False):
|
||||
self.stats: Dict[str, List[float]] = defaultdict(list)
|
||||
self.hooks = []
|
||||
self.detailed = detailed
|
||||
|
||||
def _get_stat(self, tensor: torch.Tensor) -> float:
|
||||
"""Compute mean absolute value (l1 norm per element)."""
|
||||
if tensor is None:
|
||||
return 0.0
|
||||
if tensor.dtype == torch.bool:
|
||||
return tensor.float().abs().mean().item()
|
||||
return tensor.float().abs().mean().item()
|
||||
|
||||
def _make_hook(self, name: str):
|
||||
"""Create a forward hook that records output statistics."""
|
||||
def hook(module, input, output):
|
||||
if isinstance(output, tuple):
|
||||
output = output[0]
|
||||
if output is not None and isinstance(output, torch.Tensor):
|
||||
self.stats[name].append(self._get_stat(output))
|
||||
return hook
|
||||
|
||||
def _make_attn_logit_hook(self, name: str, n_head: int, n_kv_head: int, head_dim: int):
|
||||
"""Create a hook on c_k that computes pre-softmax attention logit magnitudes.
|
||||
|
||||
We hook onto c_k's forward, then use the most recent c_q output to compute
|
||||
q @ k^T / sqrt(d) for a single batch element to measure attention logit scale.
|
||||
"""
|
||||
# We'll store q output and compute logits when k is available
|
||||
self._last_q = None
|
||||
|
||||
def q_hook(module, input, output):
|
||||
self._last_q = output.detach()
|
||||
|
||||
def k_hook(module, input, output):
|
||||
if self._last_q is None:
|
||||
return
|
||||
q = self._last_q
|
||||
k = output.detach()
|
||||
B, T, _ = q.shape
|
||||
q = q[0:1].view(1, T, n_head, head_dim)
|
||||
k = k[0:1].view(1, T, n_kv_head, head_dim)
|
||||
# Apply QK norm (same as model)
|
||||
q = F.rms_norm(q, (q.size(-1),))
|
||||
k = F.rms_norm(k, (k.size(-1),))
|
||||
# Expand k for GQA
|
||||
if n_head != n_kv_head:
|
||||
k = k.repeat_interleave(n_head // n_kv_head, dim=2)
|
||||
# Compute logits: q @ k^T / sqrt(d) — just for first few positions
|
||||
T_sub = min(T, 32)
|
||||
q_sub = q[:, :T_sub].transpose(1, 2) # (1, H, T_sub, D)
|
||||
k_sub = k[:, :T_sub].transpose(1, 2) # (1, H, T_sub, D)
|
||||
logits = torch.matmul(q_sub, k_sub.transpose(-2, -1)) / (head_dim ** 0.5)
|
||||
self.stats[name].append(logits.float().abs().mean().item())
|
||||
self._last_q = None
|
||||
|
||||
return q_hook, k_hook
|
||||
|
||||
def register_hooks(self, model: GPT) -> None:
|
||||
"""Register forward hooks on key layers."""
|
||||
# Embedding
|
||||
h = model.transformer.wte.register_forward_hook(self._make_hook('word embedding'))
|
||||
self.hooks.append(h)
|
||||
|
||||
# Each transformer block
|
||||
for i, block in enumerate(model.transformer.h):
|
||||
# Attention output
|
||||
h = block.attn.c_proj.register_forward_hook(self._make_hook(f'attn output.{i}'))
|
||||
self.hooks.append(h)
|
||||
# MLP output
|
||||
h = block.mlp.c_proj.register_forward_hook(self._make_hook(f'FFN output.{i}'))
|
||||
self.hooks.append(h)
|
||||
|
||||
# Detailed: attention logit magnitudes
|
||||
if self.detailed:
|
||||
n_head = block.attn.n_head
|
||||
n_kv_head = block.attn.n_kv_head
|
||||
head_dim = block.attn.head_dim
|
||||
q_hook, k_hook = self._make_attn_logit_hook(
|
||||
f'attn logits.{i}', n_head, n_kv_head, head_dim)
|
||||
h1 = block.attn.c_q.register_forward_hook(q_hook)
|
||||
h2 = block.attn.c_k.register_forward_hook(k_hook)
|
||||
self.hooks.extend([h1, h2])
|
||||
|
||||
# LM head
|
||||
h = model.lm_head.register_forward_hook(self._make_hook('output logits'))
|
||||
self.hooks.append(h)
|
||||
|
||||
def remove_hooks(self) -> None:
|
||||
"""Remove all registered hooks."""
|
||||
for h in self.hooks:
|
||||
h.remove()
|
||||
self.hooks = []
|
||||
|
||||
def get_step_stats(self) -> Dict[str, float]:
|
||||
"""Get mean stats for the current step and reset."""
|
||||
step_stats = {}
|
||||
for name, values in self.stats.items():
|
||||
if values:
|
||||
step_stats[name] = np.mean(values)
|
||||
self.stats = defaultdict(list)
|
||||
return step_stats
|
||||
|
||||
|
||||
def create_model(width: int, config: CoordCheckConfig, device: torch.device, mup_base_width: int = 0) -> Tuple[GPT, GPTConfig]:
|
||||
"""Create a model with the specified width."""
|
||||
head_dim = 64
|
||||
n_head = max(1, width // head_dim)
|
||||
actual_width = n_head * head_dim
|
||||
|
||||
gpt_config = GPTConfig(
|
||||
sequence_len=config.seq_len,
|
||||
vocab_size=config.vocab_size,
|
||||
n_layer=config.n_layer,
|
||||
n_head=n_head,
|
||||
n_kv_head=n_head,
|
||||
n_embd=actual_width,
|
||||
window_pattern="L",
|
||||
mup_base_width=mup_base_width,
|
||||
)
|
||||
|
||||
with torch.device('meta'):
|
||||
model = GPT(gpt_config)
|
||||
model.to_empty(device=device)
|
||||
model.init_weights()
|
||||
|
||||
return model, gpt_config
|
||||
|
||||
|
||||
def setup_optimizer_mup(model: GPT, config: CoordCheckConfig, width: int):
|
||||
"""Set up optimizer with muP scaling using the native use_mup flag."""
|
||||
optimizer = model.setup_optimizer(
|
||||
unembedding_lr=config.unembedding_lr,
|
||||
embedding_lr=config.embedding_lr,
|
||||
matrix_lr=config.matrix_lr,
|
||||
weight_decay=0.0,
|
||||
use_mup=True,
|
||||
base_width=config.base_width,
|
||||
muon_lr_exponent=config.muon_lr_exponent,
|
||||
)
|
||||
return optimizer
|
||||
|
||||
|
||||
def setup_optimizer_sp(model: GPT, config: CoordCheckConfig, width: int):
|
||||
"""Set up optimizer with standard parameterization (current nanochat)."""
|
||||
optimizer = model.setup_optimizer(
|
||||
unembedding_lr=config.unembedding_lr,
|
||||
embedding_lr=config.embedding_lr,
|
||||
matrix_lr=config.matrix_lr,
|
||||
weight_decay=0.0,
|
||||
use_mup=False,
|
||||
)
|
||||
return optimizer
|
||||
|
||||
|
||||
def record_detailed_stats(model: GPT, results: Dict, width: int, step: int):
|
||||
"""Record weight update norms and gradient norms per parameter group."""
|
||||
for name, p in model.named_parameters():
|
||||
if p.grad is None:
|
||||
continue
|
||||
# Simplify name for display
|
||||
short_name = name.replace('transformer.', '').replace('.weight', '')
|
||||
# Gradient norm
|
||||
grad_norm = p.grad.float().norm().item()
|
||||
results['detailed_stats'][width][f'grad norm: {short_name}'].append(grad_norm)
|
||||
|
||||
|
||||
def record_weight_update_norms(model: GPT, params_before: Dict[str, torch.Tensor],
|
||||
results: Dict, width: int):
|
||||
"""Record ||delta_W|| for each parameter after optimizer step."""
|
||||
for name, p in model.named_parameters():
|
||||
if name not in params_before:
|
||||
continue
|
||||
short_name = name.replace('transformer.', '').replace('.weight', '')
|
||||
delta = (p.data.float() - params_before[name]).norm().item()
|
||||
results['detailed_stats'][width][f'update norm: {short_name}'].append(delta)
|
||||
|
||||
|
||||
def run_coord_check(config: CoordCheckConfig, device: torch.device,
|
||||
x: torch.Tensor, y: torch.Tensor) -> Dict:
|
||||
"""Run coordinate check across all widths."""
|
||||
results = {
|
||||
'widths': [],
|
||||
'steps': list(range(config.num_steps)),
|
||||
'stats': defaultdict(lambda: defaultdict(list)),
|
||||
'losses': defaultdict(list),
|
||||
'detailed_stats': defaultdict(lambda: defaultdict(list)),
|
||||
}
|
||||
|
||||
for width in config.widths:
|
||||
print(f"\nTraining width={width}...")
|
||||
|
||||
torch.manual_seed(config.seed)
|
||||
|
||||
mup_base_width = config.base_width if config.use_mup else 0
|
||||
model, gpt_config = create_model(width, config, device, mup_base_width=mup_base_width)
|
||||
actual_width = gpt_config.n_embd
|
||||
results['widths'].append(actual_width)
|
||||
|
||||
if config.use_mup:
|
||||
optimizer = setup_optimizer_mup(model, config, actual_width)
|
||||
else:
|
||||
optimizer = setup_optimizer_sp(model, config, actual_width)
|
||||
|
||||
recorder = ActivationRecorder(detailed=config.detailed)
|
||||
recorder.register_hooks(model)
|
||||
|
||||
model.train()
|
||||
|
||||
for step in range(config.num_steps):
|
||||
with torch.amp.autocast(device_type='cuda', dtype=torch.bfloat16, enabled=(device.type == 'cuda')):
|
||||
loss = model(x, y)
|
||||
|
||||
results['losses'][actual_width].append(loss.item())
|
||||
|
||||
step_stats = recorder.get_step_stats()
|
||||
for layer, value in step_stats.items():
|
||||
results['stats'][actual_width][layer].append(value)
|
||||
|
||||
if step == 0:
|
||||
print(f" Step {step}: loss={loss.item():.4f}, layers={list(step_stats.keys())}")
|
||||
|
||||
# Record gradient norms before step (detailed mode)
|
||||
loss.backward()
|
||||
|
||||
if config.detailed:
|
||||
record_detailed_stats(model, results, actual_width, step)
|
||||
# Snapshot params before optimizer step to compute update norms
|
||||
params_before = {name: p.data.float().clone()
|
||||
for name, p in model.named_parameters()
|
||||
if p.grad is not None}
|
||||
|
||||
optimizer.step()
|
||||
|
||||
if config.detailed:
|
||||
record_weight_update_norms(model, params_before, results, actual_width)
|
||||
|
||||
optimizer.zero_grad(set_to_none=True)
|
||||
|
||||
print(f" Final loss: {loss.item():.4f}")
|
||||
|
||||
recorder.remove_hooks()
|
||||
del model, optimizer
|
||||
torch.cuda.empty_cache() if torch.cuda.is_available() else None
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def plot_coord_check(results: Dict, config: CoordCheckConfig, save_path: Optional[str] = None):
|
||||
"""Plot coordinate check: one subplot per layer, x=width (log2), y=mean |activation|, lines=steps."""
|
||||
widths = results['widths']
|
||||
steps = results['steps']
|
||||
stats = results['stats']
|
||||
|
||||
layer_names = list(stats[widths[0]].keys())
|
||||
n_layers = len(layer_names)
|
||||
n_cols = 4
|
||||
n_rows = (n_layers + n_cols - 1) // n_cols
|
||||
|
||||
param_type = "muP" if config.use_mup else "SP"
|
||||
fig, axes = plt.subplots(n_rows, n_cols, figsize=(4 * n_cols, 3 * n_rows))
|
||||
axes = np.array(axes).flatten()
|
||||
|
||||
step_colors = plt.cm.plasma(np.linspace(0, 1, len(steps)))
|
||||
|
||||
for i, layer in enumerate(layer_names):
|
||||
ax = axes[i]
|
||||
for s, step in enumerate(steps):
|
||||
values = [stats[w][layer][s] for w in widths]
|
||||
ax.plot(widths, values, 'o-', color=step_colors[s], linewidth=1.5,
|
||||
label=f'step {step}' if i == 0 else None)
|
||||
ax.set_xscale('log', base=2)
|
||||
ax.set_xticks(widths)
|
||||
ax.set_xticklabels(widths, fontsize=7)
|
||||
ax.set_title(layer, fontsize=9)
|
||||
ax.set_xlabel('Width')
|
||||
ax.set_ylabel('Mean |activation|')
|
||||
ax.grid(True, alpha=0.3)
|
||||
|
||||
axes[0].legend(fontsize=7, loc='best')
|
||||
|
||||
for i in range(n_layers, len(axes)):
|
||||
axes[i].set_visible(False)
|
||||
|
||||
fig.suptitle(f'Coordinate Check ({param_type}): Activation Magnitude vs Width', fontsize=14)
|
||||
plt.tight_layout()
|
||||
|
||||
if save_path:
|
||||
plt.savefig(save_path, dpi=150, bbox_inches='tight')
|
||||
print(f"Saved plot to {save_path}")
|
||||
|
||||
plt.show()
|
||||
|
||||
|
||||
def plot_loss_curves(results: Dict, config: CoordCheckConfig, title: str = "", save_path: Optional[str] = None):
|
||||
"""Plot loss curves across widths to verify HP transfer."""
|
||||
widths = results['widths']
|
||||
steps = results['steps']
|
||||
losses = results['losses']
|
||||
|
||||
fig, ax = plt.subplots(figsize=(5 * 2, 4))
|
||||
colors = plt.cm.viridis(np.linspace(0, 1, len(widths)))
|
||||
|
||||
for i, w in enumerate(widths):
|
||||
ax.plot(steps, losses[w], label=f'width={w}', color=colors[i], linewidth=2)
|
||||
|
||||
ax.set_xlabel('Step')
|
||||
ax.set_ylabel('Loss')
|
||||
ax.set_title(f'Loss Curves Across Widths{" - " + title if title else ""}')
|
||||
ax.legend()
|
||||
ax.grid(True, alpha=0.3)
|
||||
|
||||
# Add annotation for final loss spread
|
||||
final_losses = [losses[w][-1] for w in widths]
|
||||
spread = max(final_losses) - min(final_losses)
|
||||
ax.annotate(f'Final loss spread: {spread:.4f}', xy=(0.7, 0.95), xycoords='axes fraction', fontsize=10)
|
||||
|
||||
plt.tight_layout()
|
||||
|
||||
if save_path:
|
||||
plt.savefig(save_path, dpi=150, bbox_inches='tight')
|
||||
print(f"Saved loss curves to {save_path}")
|
||||
|
||||
plt.show()
|
||||
|
||||
|
||||
def plot_comparison(results_sp: Dict, results_mup: Dict, config: CoordCheckConfig, save_path: Optional[str] = None):
|
||||
"""Plot SP vs muP: one subplot per layer (left=SP, right=muP), x=width (log2), y=mean |activation|, lines=steps."""
|
||||
widths = results_sp['widths']
|
||||
steps = results_sp['steps']
|
||||
|
||||
layer_names = list(results_sp['stats'][widths[0]].keys())
|
||||
n_layers = len(layer_names)
|
||||
|
||||
# n_layers activation rows + 1 loss row, 2 cols (SP | muP)
|
||||
n_rows, n_cols = n_layers + 1, 2
|
||||
fig, axes = plt.subplots(n_rows, n_cols, figsize=(5 * n_cols, 3 * n_rows))
|
||||
|
||||
step_colors = plt.cm.plasma(np.linspace(0, 1, len(steps)))
|
||||
width_colors = plt.cm.viridis(np.linspace(0, 1, len(widths)))
|
||||
|
||||
for row, layer in enumerate(layer_names):
|
||||
# Shared y-axis range across SP and muP for this layer
|
||||
all_vals = [results_sp['stats'][w][layer][s] for w in widths for s in range(len(steps))] + \
|
||||
[results_mup['stats'][w][layer][s] for w in widths for s in range(len(steps))]
|
||||
y_min, y_max = min(all_vals) * 0.9, max(all_vals) * 1.1
|
||||
|
||||
for col, (results, label) in enumerate([(results_sp, 'SP'), (results_mup, 'muP')]):
|
||||
ax = axes[row, col]
|
||||
for s, step in enumerate(steps):
|
||||
values = [results['stats'][w][layer][s] for w in widths]
|
||||
ax.plot(widths, values, 'o-', color=step_colors[s], linewidth=1.5,
|
||||
label=f'step {step}' if (row == 0 and col == 0) else None)
|
||||
ax.set_xscale('log', base=2)
|
||||
ax.set_xticks(widths)
|
||||
ax.set_xticklabels(widths, fontsize=7)
|
||||
ax.set_ylim(y_min, y_max)
|
||||
ax.set_title(f'{label}: {layer}', fontsize=9)
|
||||
ax.set_xlabel('Width')
|
||||
ax.set_ylabel('Mean |activation|')
|
||||
ax.grid(True, alpha=0.3)
|
||||
|
||||
axes[0, 0].legend(fontsize=7, loc='best')
|
||||
|
||||
# Loss curves row
|
||||
all_losses = [v for r in (results_sp, results_mup) for w in widths for v in r['losses'][w]]
|
||||
loss_min, loss_max = min(all_losses) * 0.95, max(all_losses) * 1.05
|
||||
|
||||
for col, (results, label) in enumerate([(results_sp, 'SP'), (results_mup, 'muP')]):
|
||||
ax = axes[n_layers, col]
|
||||
for j, w in enumerate(widths):
|
||||
ax.plot(steps, results['losses'][w], label=f'w={w}', color=width_colors[j], linewidth=2)
|
||||
ax.set_ylim(loss_min, loss_max)
|
||||
ax.set_xlabel('Step')
|
||||
ax.set_ylabel('Loss')
|
||||
ax.set_title(f'{label}: Loss Curves')
|
||||
ax.legend(fontsize=7)
|
||||
ax.grid(True, alpha=0.3)
|
||||
final_losses = [results['losses'][w][-1] for w in widths]
|
||||
spread = max(final_losses) - min(final_losses)
|
||||
ax.annotate(f'Spread: {spread:.4f}', xy=(0.65, 0.95), xycoords='axes fraction', fontsize=9)
|
||||
|
||||
fig.suptitle('Coordinate Check: SP vs muP', fontsize=14)
|
||||
plt.tight_layout()
|
||||
|
||||
if save_path:
|
||||
plt.savefig(save_path, dpi=150, bbox_inches='tight')
|
||||
print(f"Saved comparison plot to {save_path}")
|
||||
|
||||
plt.show()
|
||||
|
||||
|
||||
def plot_detailed(results: Dict, config: CoordCheckConfig, save_path: Optional[str] = None):
|
||||
"""Plot detailed diagnostics: gradient norms, weight update norms, attention logits."""
|
||||
widths = results['widths']
|
||||
detailed = results['detailed_stats']
|
||||
if not detailed or not detailed[widths[0]]:
|
||||
print("No detailed stats recorded. Use --detailed flag.")
|
||||
return
|
||||
|
||||
# Collect all detailed metric names
|
||||
metric_names = sorted(detailed[widths[0]].keys())
|
||||
|
||||
# Group by category
|
||||
categories = defaultdict(list)
|
||||
for name in metric_names:
|
||||
if name.startswith('grad norm:'):
|
||||
categories['Gradient Norms'].append(name)
|
||||
elif name.startswith('update norm:'):
|
||||
categories['Weight Update Norms'].append(name)
|
||||
elif name.startswith('attn logits'):
|
||||
categories['Attention Logit Magnitudes'].append(name)
|
||||
|
||||
for cat_name, names in categories.items():
|
||||
n = len(names)
|
||||
n_cols = min(4, n)
|
||||
n_rows = (n + n_cols - 1) // n_cols
|
||||
fig, axes = plt.subplots(n_rows, n_cols, figsize=(4 * n_cols, 3 * n_rows))
|
||||
if n == 1:
|
||||
axes = np.array([axes])
|
||||
axes = np.array(axes).flatten()
|
||||
|
||||
steps = results['steps']
|
||||
width_colors = plt.cm.viridis(np.linspace(0, 1, len(widths)))
|
||||
|
||||
for i, name in enumerate(names):
|
||||
ax = axes[i]
|
||||
for j, w in enumerate(widths):
|
||||
values = detailed[w].get(name, [])
|
||||
if values:
|
||||
ax.plot(range(len(values)), values, color=width_colors[j],
|
||||
linewidth=1.5, label=f'w={w}' if i == 0 else None)
|
||||
ax.set_title(name.split(': ', 1)[-1] if ': ' in name else name, fontsize=8)
|
||||
ax.set_xlabel('Step')
|
||||
ax.set_ylabel('Norm')
|
||||
ax.grid(True, alpha=0.3)
|
||||
ax.set_yscale('log')
|
||||
|
||||
for i in range(n, len(axes)):
|
||||
axes[i].set_visible(False)
|
||||
|
||||
axes[0].legend(fontsize=7, loc='best')
|
||||
param_type = "muP" if config.use_mup else "SP"
|
||||
fig.suptitle(f'{cat_name} ({param_type})', fontsize=14)
|
||||
plt.tight_layout()
|
||||
|
||||
if save_path:
|
||||
cat_slug = cat_name.lower().replace(' ', '_')
|
||||
path = save_path.replace('.png', f'_{cat_slug}.png')
|
||||
plt.savefig(path, dpi=150, bbox_inches='tight')
|
||||
print(f"Saved {cat_name} plot to {path}")
|
||||
|
||||
plt.show()
|
||||
|
||||
|
||||
def compute_width_dependence(results: Dict) -> Dict[str, float]:
|
||||
"""Compute how much activations scale with width (slope on log-log plot)."""
|
||||
widths = np.array(results['widths'])
|
||||
log_widths = np.log2(widths)
|
||||
final_step = len(results['steps']) - 1
|
||||
|
||||
slopes = {}
|
||||
for layer in results['stats'][widths[0]].keys():
|
||||
values = [results['stats'][w][layer][final_step] for w in widths]
|
||||
log_values = np.log2(np.array(values) + 1e-10)
|
||||
slope, _ = np.polyfit(log_widths, log_values, 1)
|
||||
slopes[layer] = slope
|
||||
|
||||
return slopes
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description='muP Coordinate Check')
|
||||
parser.add_argument('--widths', type=str, default='128,256,512,1024',
|
||||
help='Comma-separated list of widths to test')
|
||||
parser.add_argument('--steps', type=int, default=10,
|
||||
help='Number of training steps')
|
||||
parser.add_argument('--batch-size', type=int, default=4,
|
||||
help='Batch size')
|
||||
parser.add_argument('--seq-len', type=int, default=128,
|
||||
help='Sequence length')
|
||||
parser.add_argument('--n-layer', type=int, default=2,
|
||||
help='Number of transformer layers')
|
||||
parser.add_argument('--use-mup', action='store_true',
|
||||
help='Use muP learning rate scaling')
|
||||
parser.add_argument('--base-width', type=int, default=128,
|
||||
help='Base width for muP scaling')
|
||||
parser.add_argument('--compare', action='store_true',
|
||||
help='Run both SP and muP and compare')
|
||||
parser.add_argument('--save-dir', type=str, default=None,
|
||||
help='Directory to save plots')
|
||||
parser.add_argument('--seed', type=int, default=42,
|
||||
help='Random seed')
|
||||
parser.add_argument('--detailed', action='store_true',
|
||||
help='Record detailed diagnostics: gradient norms, weight update norms, '
|
||||
'attention logit magnitudes')
|
||||
parser.add_argument('--muon-lr-exponent', type=float, default=0.0,
|
||||
help='Muon LR exponent for muP: 1.0 = (base/width)^1 (standard muP), '
|
||||
'0.5 = (base/width)^0.5 (for Frobenius-normalizing optimizers, '
|
||||
'see Yang et al. Section C.1)')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Parse widths
|
||||
widths = [int(w) for w in args.widths.split(',')]
|
||||
|
||||
# Setup device
|
||||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
print(f"Using device: {device}")
|
||||
|
||||
# Load a single batch of real training data (reused every step)
|
||||
x, y, vocab_size = load_batch(args.batch_size, args.seq_len, device)
|
||||
|
||||
# Create config
|
||||
config = CoordCheckConfig(
|
||||
widths=widths,
|
||||
num_steps=args.steps,
|
||||
batch_size=args.batch_size,
|
||||
seq_len=args.seq_len,
|
||||
vocab_size=vocab_size,
|
||||
n_layer=args.n_layer,
|
||||
seed=args.seed,
|
||||
use_mup=args.use_mup,
|
||||
base_width=args.base_width,
|
||||
detailed=args.detailed,
|
||||
muon_lr_exponent=args.muon_lr_exponent,
|
||||
)
|
||||
|
||||
if args.compare:
|
||||
# Run both SP and muP
|
||||
print("\n" + "="*60)
|
||||
print("Running Standard Parameterization (SP)")
|
||||
print("="*60)
|
||||
config.use_mup = False
|
||||
results_sp = run_coord_check(config, device, x, y)
|
||||
|
||||
print("\n" + "="*60)
|
||||
print("Running muP")
|
||||
if config.muon_lr_exponent != 1.0:
|
||||
print(f" (Muon LR exponent: {config.muon_lr_exponent})")
|
||||
print("="*60)
|
||||
config.use_mup = True
|
||||
results_mup = run_coord_check(config, device, x, y)
|
||||
|
||||
# Compute slopes
|
||||
print("\n" + "="*60)
|
||||
print("Width Dependence (slope on log-log plot)")
|
||||
print("Expected: ~0 for width-independent, positive = grows with width")
|
||||
print("="*60)
|
||||
|
||||
slopes_sp = compute_width_dependence(results_sp)
|
||||
slopes_mup = compute_width_dependence(results_mup)
|
||||
|
||||
print(f"\n{'Layer':<20} {'SP Slope':>12} {'muP Slope':>12}")
|
||||
print("-"*46)
|
||||
for layer in slopes_sp:
|
||||
print(f"{layer:<20} {slopes_sp[layer]:>12.4f} {slopes_mup[layer]:>12.4f}")
|
||||
|
||||
# Plot comparison
|
||||
save_path = None
|
||||
if args.save_dir:
|
||||
os.makedirs(args.save_dir, exist_ok=True)
|
||||
save_path = os.path.join(args.save_dir, 'coord_check_comparison.png')
|
||||
plot_comparison(results_sp, results_mup, config, save_path)
|
||||
|
||||
# Plot detailed diagnostics if requested
|
||||
if config.detailed:
|
||||
for results, label in [(results_sp, 'SP'), (results_mup, 'muP')]:
|
||||
config.use_mup = (label == 'muP')
|
||||
detail_save = None
|
||||
if args.save_dir:
|
||||
detail_save = os.path.join(args.save_dir, f'detailed_{label.lower()}.png')
|
||||
plot_detailed(results, config, detail_save)
|
||||
|
||||
else:
|
||||
# Run single mode
|
||||
param_type = "muP" if config.use_mup else "SP"
|
||||
print(f"\n{'='*60}")
|
||||
print(f"Running Coordinate Check ({param_type})")
|
||||
print(f"{'='*60}")
|
||||
print(f"Widths: {widths}")
|
||||
print(f"Steps: {config.num_steps}")
|
||||
print(f"Base width: {config.base_width}")
|
||||
if config.use_mup and config.muon_lr_exponent != 1.0:
|
||||
print(f"Muon LR exponent: {config.muon_lr_exponent}")
|
||||
|
||||
results = run_coord_check(config, device, x, y)
|
||||
|
||||
# Compute slopes
|
||||
slopes = compute_width_dependence(results)
|
||||
print("\n" + "="*60)
|
||||
print("Width Dependence (slope on log-log plot)")
|
||||
print("Expected for muP: ~0 (width-independent)")
|
||||
print("="*60)
|
||||
for layer, slope in slopes.items():
|
||||
status = "OK" if abs(slope) < 0.1 else "WARN"
|
||||
print(f" {layer}: {slope:+.4f} [{status}]")
|
||||
|
||||
# Loss curve analysis
|
||||
final_losses = [results['losses'][w][-1] for w in results['widths']]
|
||||
loss_spread = max(final_losses) - min(final_losses)
|
||||
print(f"\nFinal loss spread across widths: {loss_spread:.4f}")
|
||||
print(f"Expected for muP: low spread (similar losses across widths)")
|
||||
|
||||
# Plot activations
|
||||
save_path = None
|
||||
if args.save_dir:
|
||||
os.makedirs(args.save_dir, exist_ok=True)
|
||||
save_path = os.path.join(args.save_dir, f'coord_check_{param_type.lower()}.png')
|
||||
plot_coord_check(results, config, save_path)
|
||||
|
||||
# Plot loss curves
|
||||
loss_save_path = None
|
||||
if args.save_dir:
|
||||
loss_save_path = os.path.join(args.save_dir, f'loss_curves_{param_type.lower()}.png')
|
||||
plot_loss_curves(results, config, title=param_type, save_path=loss_save_path)
|
||||
|
||||
# Plot detailed diagnostics if requested
|
||||
if config.detailed:
|
||||
detail_save = None
|
||||
if args.save_dir:
|
||||
detail_save = os.path.join(args.save_dir, f'detailed_{param_type.lower()}.png')
|
||||
plot_detailed(results, config, detail_save)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
672
scripts/mup_transfer_check.py
Normal file
672
scripts/mup_transfer_check.py
Normal file
|
|
@ -0,0 +1,672 @@
|
|||
"""
|
||||
muP Transfer Check for nanochat
|
||||
|
||||
Validates that optimal learning rates transfer across model widths under muP.
|
||||
For each width, sweeps over LR multipliers and records final loss. Under correct
|
||||
muP, the optimal LR multiplier should be ~1.0 at all widths (i.e., the same LR
|
||||
works everywhere). Under SP, the optimal LR typically shifts with width.
|
||||
|
||||
Reference: https://blog.eleuther.ai/mutransfer/
|
||||
Reference: Yang et al., "Tensor Programs V: Tuning Large Neural Networks via Zero-Shot
|
||||
Hyperparameter Transfer" (arXiv:2203.03466), Section F.
|
||||
|
||||
Usage:
|
||||
# Quick check (~2 min on GPU)
|
||||
python -m scripts.mup_transfer_check
|
||||
|
||||
# Compare SP vs muP side-by-side
|
||||
python -m scripts.mup_transfer_check --compare
|
||||
|
||||
# Wide LR sweep (paper-style, ~1000x range)
|
||||
python -m scripts.mup_transfer_check --compare --widths 128,256,512,1024 --steps 200
|
||||
|
||||
# Random log-uniform LR trials (paper-style methodology)
|
||||
python -m scripts.mup_transfer_check --compare --num-random-trials 20
|
||||
|
||||
# Multi-HP sweep (init scale + output multiplier)
|
||||
python -m scripts.mup_transfer_check --compare --sweep-init-scale --sweep-output-mult
|
||||
|
||||
# Save plots
|
||||
python -m scripts.mup_transfer_check --compare --save-dir plots/
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import torch
|
||||
import torch._dynamo
|
||||
torch._dynamo.config.disable = True
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Dict, List, Optional
|
||||
import os
|
||||
|
||||
from nanochat.gpt import GPT, GPTConfig
|
||||
|
||||
|
||||
@dataclass
|
||||
class TransferCheckConfig:
|
||||
widths: List[int]
|
||||
lr_multipliers: List[float]
|
||||
num_steps: int = 200
|
||||
batch_size: int = 8
|
||||
seq_len: int = 128
|
||||
vocab_size: int = 32768
|
||||
n_layer: int = 2
|
||||
seed: int = 42
|
||||
use_mup: bool = False
|
||||
base_width: int = 128
|
||||
# Base learning rates (tuned at base_width)
|
||||
matrix_lr: float = 0.02
|
||||
embedding_lr: float = 0.2
|
||||
unembedding_lr: float = 0.004
|
||||
# Multi-HP sweeps
|
||||
sweep_init_scale: bool = False
|
||||
sweep_output_mult: bool = False
|
||||
# Data diversity
|
||||
num_batches: int = 1
|
||||
# Muon LR exponent for muP (1.0=standard, 0.5=Frobenius-norm optimizers)
|
||||
muon_lr_exponent: float = 0.0
|
||||
# Sweep mode: which optimizer groups the LR multiplier applies to
|
||||
# "all" = multiply all LRs (default), "muon-only" = only matrix_lr,
|
||||
# "adamw-only" = only embedding_lr/unembedding_lr
|
||||
sweep_mode: str = "all"
|
||||
|
||||
|
||||
def load_batches(num_batches: int, batch_size: int, seq_len: int, device: torch.device):
|
||||
"""Load multiple batches from the nanochat training pipeline.
|
||||
Falls back to random data if the tokenizer/dataset isn't available."""
|
||||
try:
|
||||
from nanochat.tokenizer import get_tokenizer
|
||||
from nanochat.dataloader import tokenizing_distributed_data_loader_bos_bestfit
|
||||
tokenizer = get_tokenizer()
|
||||
vocab_size = tokenizer.get_vocab_size()
|
||||
loader = tokenizing_distributed_data_loader_bos_bestfit(
|
||||
tokenizer, batch_size, seq_len, split="train", device=device,
|
||||
)
|
||||
batches = []
|
||||
for i, (x, y) in enumerate(loader):
|
||||
batches.append((x, y))
|
||||
if len(batches) >= num_batches:
|
||||
break
|
||||
print(f"Loaded {len(batches)} real training batch(es) (vocab_size={vocab_size})")
|
||||
return batches, vocab_size
|
||||
except Exception as e:
|
||||
print(f"Could not load training data ({e}), using random tokens")
|
||||
vocab_size = 32768
|
||||
batches = []
|
||||
for i in range(num_batches):
|
||||
rng = torch.Generator(device=device)
|
||||
rng.manual_seed(42 + i)
|
||||
x = torch.randint(0, vocab_size, (batch_size, seq_len), device=device, generator=rng)
|
||||
y = torch.roll(x, -1, dims=1)
|
||||
y[:, -1] = -1
|
||||
batches.append((x, y))
|
||||
return batches, vocab_size
|
||||
|
||||
|
||||
def create_model(width: int, config: TransferCheckConfig, device: torch.device,
|
||||
mup_base_width: int = 0, init_scale: float = 1.0,
|
||||
output_mult: float = 1.0):
|
||||
"""Create a model with the specified width and optional HP overrides."""
|
||||
head_dim = 64
|
||||
n_head = max(1, width // head_dim)
|
||||
actual_width = n_head * head_dim
|
||||
|
||||
gpt_config = GPTConfig(
|
||||
sequence_len=config.seq_len,
|
||||
vocab_size=config.vocab_size,
|
||||
n_layer=config.n_layer,
|
||||
n_head=n_head,
|
||||
n_kv_head=n_head,
|
||||
n_embd=actual_width,
|
||||
window_pattern="L",
|
||||
mup_base_width=mup_base_width,
|
||||
)
|
||||
|
||||
with torch.device('meta'):
|
||||
model = GPT(gpt_config)
|
||||
model.to_empty(device=device)
|
||||
model.init_weights()
|
||||
|
||||
# Apply init_scale: multiply all parameter inits by scalar
|
||||
if init_scale != 1.0:
|
||||
with torch.no_grad():
|
||||
for p in model.parameters():
|
||||
p.mul_(init_scale)
|
||||
|
||||
# Apply output_mult: scale the output logit multiplier
|
||||
# We store it as an attribute that forward() checks
|
||||
model._transfer_check_output_mult = output_mult
|
||||
|
||||
return model, gpt_config
|
||||
|
||||
|
||||
def train_model(width: int, lr_mult: float, config: TransferCheckConfig,
|
||||
device: torch.device, batches: List,
|
||||
init_scale: float = 1.0, output_mult: float = 1.0):
|
||||
"""Train a model at given width and LR multiplier, return loss history."""
|
||||
torch.manual_seed(config.seed)
|
||||
|
||||
mup_base_width = config.base_width if config.use_mup else 0
|
||||
model, gpt_config = create_model(width, config, device, mup_base_width=mup_base_width,
|
||||
init_scale=init_scale, output_mult=output_mult)
|
||||
actual_width = gpt_config.n_embd
|
||||
|
||||
# Scale the learning rates by the multiplier, respecting sweep_mode
|
||||
if config.sweep_mode == "muon-only":
|
||||
muon_mult, adamw_mult = lr_mult, 1.0
|
||||
elif config.sweep_mode == "adamw-only":
|
||||
muon_mult, adamw_mult = 1.0, lr_mult
|
||||
else: # "all"
|
||||
muon_mult, adamw_mult = lr_mult, lr_mult
|
||||
|
||||
optimizer = model.setup_optimizer(
|
||||
unembedding_lr=config.unembedding_lr * adamw_mult,
|
||||
embedding_lr=config.embedding_lr * adamw_mult,
|
||||
matrix_lr=config.matrix_lr * muon_mult,
|
||||
weight_decay=0.0,
|
||||
use_mup=config.use_mup,
|
||||
base_width=config.base_width,
|
||||
muon_lr_exponent=config.muon_lr_exponent,
|
||||
)
|
||||
|
||||
model.train()
|
||||
losses = []
|
||||
num_batches = len(batches)
|
||||
|
||||
for step in range(config.num_steps):
|
||||
x, y = batches[step % num_batches]
|
||||
with torch.amp.autocast(device_type='cuda', dtype=torch.bfloat16, enabled=(device.type == 'cuda')):
|
||||
loss = model(x, y)
|
||||
|
||||
losses.append(loss.item())
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
optimizer.zero_grad(set_to_none=True)
|
||||
|
||||
del model, optimizer
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
return losses, actual_width
|
||||
|
||||
|
||||
def run_transfer_check(config: TransferCheckConfig, device: torch.device,
|
||||
batches: List) -> Dict:
|
||||
"""Run LR sweep across all widths."""
|
||||
results = {
|
||||
'widths': [],
|
||||
'lr_multipliers': config.lr_multipliers,
|
||||
'losses': {}, # losses[(width, lr_mult)] = [loss_step0, ...]
|
||||
'final_losses': defaultdict(dict), # final_losses[width][lr_mult] = final_loss
|
||||
}
|
||||
|
||||
for width in config.widths:
|
||||
actual_width = None
|
||||
for lr_mult in config.lr_multipliers:
|
||||
print(f" width={width}, lr_mult={lr_mult:.4f}...", end=" ", flush=True)
|
||||
|
||||
losses, actual_width = train_model(width, lr_mult, config, device, batches)
|
||||
results['losses'][(actual_width, lr_mult)] = losses
|
||||
results['final_losses'][actual_width][lr_mult] = losses[-1]
|
||||
print(f"final_loss={losses[-1]:.4f}")
|
||||
|
||||
if actual_width not in results['widths']:
|
||||
results['widths'].append(actual_width)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def run_hp_sweep(config: TransferCheckConfig, device: torch.device,
|
||||
batches: List, hp_name: str, hp_values: List[float]) -> Dict:
|
||||
"""Run a sweep over a single HP (init_scale or output_mult) at fixed LR."""
|
||||
results = {
|
||||
'widths': [],
|
||||
'hp_values': hp_values,
|
||||
'hp_name': hp_name,
|
||||
'final_losses': defaultdict(dict),
|
||||
}
|
||||
|
||||
for width in config.widths:
|
||||
actual_width = None
|
||||
for hp_val in hp_values:
|
||||
init_scale = hp_val if hp_name == 'init_scale' else 1.0
|
||||
output_mult = hp_val if hp_name == 'output_mult' else 1.0
|
||||
print(f" width={width}, {hp_name}={hp_val:.4f}...", end=" ", flush=True)
|
||||
|
||||
losses, actual_width = train_model(
|
||||
width, 1.0, config, device, batches,
|
||||
init_scale=init_scale, output_mult=output_mult)
|
||||
results['final_losses'][actual_width][hp_val] = losses[-1]
|
||||
print(f"final_loss={losses[-1]:.4f}")
|
||||
|
||||
if actual_width not in results['widths']:
|
||||
results['widths'].append(actual_width)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def find_optimal_lr(final_losses: Dict[float, float]) -> float:
|
||||
"""Find the LR multiplier with the lowest final loss."""
|
||||
return min(final_losses, key=final_losses.get)
|
||||
|
||||
|
||||
def plot_lr_sweep(results: Dict, config: TransferCheckConfig, title: str = "", save_path: Optional[str] = None):
|
||||
"""Plot LR sweep: final loss vs LR multiplier for each width."""
|
||||
widths = results['widths']
|
||||
lr_mults = results['lr_multipliers']
|
||||
final_losses = results['final_losses']
|
||||
|
||||
n_cols = 2
|
||||
fig, axes = plt.subplots(1, n_cols, figsize=(5 * n_cols, 4))
|
||||
colors = plt.cm.viridis(np.linspace(0, 0.85, len(widths)))
|
||||
|
||||
# Left: final loss vs LR multiplier
|
||||
ax = axes[0]
|
||||
for i, w in enumerate(widths):
|
||||
losses = [final_losses[w][m] for m in lr_mults]
|
||||
ax.plot(lr_mults, losses, 'o-', color=colors[i], linewidth=2, label=f'width={w}')
|
||||
opt_mult = find_optimal_lr(final_losses[w])
|
||||
opt_loss = final_losses[w][opt_mult]
|
||||
ax.plot(opt_mult, opt_loss, '*', color=colors[i], markersize=15, zorder=5)
|
||||
|
||||
ax.set_xscale('log', base=2)
|
||||
ax.set_xlabel('LR Multiplier')
|
||||
ax.set_ylabel('Final Loss')
|
||||
ax.set_title(f'LR Sweep{" - " + title if title else ""}')
|
||||
ax.legend()
|
||||
ax.grid(True, alpha=0.3)
|
||||
|
||||
# Right: optimal LR multiplier vs width
|
||||
ax = axes[1]
|
||||
opt_mults = [find_optimal_lr(final_losses[w]) for w in widths]
|
||||
ax.plot(widths, opt_mults, 'o-', linewidth=2, markersize=8, color='tab:blue')
|
||||
ax.axhline(y=1.0, color='gray', linestyle='--', alpha=0.5, label='target (1.0)')
|
||||
ax.set_xscale('log', base=2)
|
||||
ax.set_yscale('log', base=2)
|
||||
ax.set_xticks(widths)
|
||||
ax.set_xticklabels(widths)
|
||||
ax.set_xlabel('Width')
|
||||
ax.set_ylabel('Optimal LR Multiplier')
|
||||
ax.set_title(f'Optimal LR vs Width{" - " + title if title else ""}')
|
||||
ax.legend()
|
||||
ax.grid(True, alpha=0.3)
|
||||
|
||||
plt.tight_layout()
|
||||
if save_path:
|
||||
plt.savefig(save_path, dpi=150, bbox_inches='tight')
|
||||
print(f"Saved plot to {save_path}")
|
||||
plt.show()
|
||||
|
||||
|
||||
def plot_comparison(results_sp: Dict, results_mup: Dict, config: TransferCheckConfig, save_path: Optional[str] = None):
|
||||
"""Plot SP vs muP comparison side by side."""
|
||||
widths = results_sp['widths']
|
||||
lr_mults = results_sp['lr_multipliers']
|
||||
|
||||
n_rows, n_cols = 2, 2
|
||||
fig, axes = plt.subplots(n_rows, n_cols, figsize=(5 * n_cols, 4 * n_rows))
|
||||
colors = plt.cm.viridis(np.linspace(0, 0.85, len(widths)))
|
||||
|
||||
# Top row: LR sweep curves
|
||||
for col, (results, label) in enumerate([(results_sp, 'SP'), (results_mup, 'muP')]):
|
||||
ax = axes[0, col]
|
||||
for i, w in enumerate(widths):
|
||||
losses = [results['final_losses'][w][m] for m in lr_mults]
|
||||
ax.plot(lr_mults, losses, 'o-', color=colors[i], linewidth=2, label=f'w={w}')
|
||||
opt_mult = find_optimal_lr(results['final_losses'][w])
|
||||
opt_loss = results['final_losses'][w][opt_mult]
|
||||
ax.plot(opt_mult, opt_loss, '*', color=colors[i], markersize=15, zorder=5)
|
||||
ax.set_xscale('log', base=2)
|
||||
ax.set_xlabel('LR Multiplier')
|
||||
ax.set_ylabel('Final Loss')
|
||||
ax.set_title(f'{label}: Final Loss vs LR Multiplier')
|
||||
ax.legend(fontsize=8)
|
||||
ax.grid(True, alpha=0.3)
|
||||
|
||||
# Shared y-axis for top row
|
||||
y_min = min(
|
||||
min(results_sp['final_losses'][w][m] for m in lr_mults for w in widths),
|
||||
min(results_mup['final_losses'][w][m] for m in lr_mults for w in widths),
|
||||
) * 0.98
|
||||
y_max = max(
|
||||
max(results_sp['final_losses'][w][m] for m in lr_mults for w in widths),
|
||||
max(results_mup['final_losses'][w][m] for m in lr_mults for w in widths),
|
||||
) * 1.02
|
||||
axes[0, 0].set_ylim(y_min, y_max)
|
||||
axes[0, 1].set_ylim(y_min, y_max)
|
||||
|
||||
# Bottom left: optimal LR vs width for both
|
||||
ax = axes[1, 0]
|
||||
for results, label, color in [(results_sp, 'SP', 'tab:red'), (results_mup, 'muP', 'tab:blue')]:
|
||||
opt_mults = [find_optimal_lr(results['final_losses'][w]) for w in widths]
|
||||
ax.plot(widths, opt_mults, 'o-', linewidth=2, markersize=8, color=color, label=label)
|
||||
ax.axhline(y=1.0, color='gray', linestyle='--', alpha=0.5, label='target')
|
||||
ax.set_xscale('log', base=2)
|
||||
ax.set_yscale('log', base=2)
|
||||
ax.set_xticks(widths)
|
||||
ax.set_xticklabels(widths)
|
||||
ax.set_xlabel('Width')
|
||||
ax.set_ylabel('Optimal LR Multiplier')
|
||||
ax.set_title('Optimal LR Multiplier vs Width')
|
||||
ax.legend()
|
||||
ax.grid(True, alpha=0.3)
|
||||
|
||||
# Bottom right: loss at optimal LR vs width
|
||||
ax = axes[1, 1]
|
||||
for results, label, color in [(results_sp, 'SP', 'tab:red'), (results_mup, 'muP', 'tab:blue')]:
|
||||
opt_losses = [results['final_losses'][w][find_optimal_lr(results['final_losses'][w])] for w in widths]
|
||||
ax.plot(widths, opt_losses, 'o-', linewidth=2, markersize=8, color=color, label=label)
|
||||
ax.set_xscale('log', base=2)
|
||||
ax.set_xticks(widths)
|
||||
ax.set_xticklabels(widths)
|
||||
ax.set_xlabel('Width')
|
||||
ax.set_ylabel('Best Final Loss')
|
||||
ax.set_title('Best Achievable Loss vs Width')
|
||||
ax.legend()
|
||||
ax.grid(True, alpha=0.3)
|
||||
|
||||
fig.suptitle('muP Transfer Check: SP vs muP', fontsize=14)
|
||||
plt.tight_layout()
|
||||
|
||||
if save_path:
|
||||
plt.savefig(save_path, dpi=150, bbox_inches='tight')
|
||||
print(f"Saved comparison plot to {save_path}")
|
||||
plt.show()
|
||||
|
||||
|
||||
def plot_hp_sweep(results: Dict, config: TransferCheckConfig, title: str = "", save_path: Optional[str] = None):
|
||||
"""Plot HP sweep: final loss vs HP value for each width."""
|
||||
widths = results['widths']
|
||||
hp_values = results['hp_values']
|
||||
hp_name = results['hp_name']
|
||||
final_losses = results['final_losses']
|
||||
|
||||
fig, axes = plt.subplots(1, 2, figsize=(10, 4))
|
||||
colors = plt.cm.viridis(np.linspace(0, 0.85, len(widths)))
|
||||
|
||||
ax = axes[0]
|
||||
for i, w in enumerate(widths):
|
||||
losses = [final_losses[w][v] for v in hp_values]
|
||||
ax.plot(hp_values, losses, 'o-', color=colors[i], linewidth=2, label=f'w={w}')
|
||||
opt_v = min(final_losses[w], key=final_losses[w].get)
|
||||
ax.plot(opt_v, final_losses[w][opt_v], '*', color=colors[i], markersize=15, zorder=5)
|
||||
ax.set_xscale('log', base=2)
|
||||
ax.set_xlabel(hp_name)
|
||||
ax.set_ylabel('Final Loss')
|
||||
ax.set_title(f'{hp_name} Sweep{" - " + title if title else ""}')
|
||||
ax.legend(fontsize=8)
|
||||
ax.grid(True, alpha=0.3)
|
||||
|
||||
ax = axes[1]
|
||||
opt_vals = [min(final_losses[w], key=final_losses[w].get) for w in widths]
|
||||
ax.plot(widths, opt_vals, 'o-', linewidth=2, markersize=8, color='tab:blue')
|
||||
ax.axhline(y=1.0, color='gray', linestyle='--', alpha=0.5, label='target (1.0)')
|
||||
ax.set_xscale('log', base=2)
|
||||
ax.set_yscale('log', base=2)
|
||||
ax.set_xlabel('Width')
|
||||
ax.set_ylabel(f'Optimal {hp_name}')
|
||||
ax.set_title(f'Optimal {hp_name} vs Width')
|
||||
ax.legend()
|
||||
ax.grid(True, alpha=0.3)
|
||||
|
||||
plt.tight_layout()
|
||||
if save_path:
|
||||
plt.savefig(save_path, dpi=150, bbox_inches='tight')
|
||||
print(f"Saved plot to {save_path}")
|
||||
plt.show()
|
||||
|
||||
|
||||
def plot_loss_curves_at_optimal(results: Dict, config: TransferCheckConfig, title: str = "", save_path: Optional[str] = None):
|
||||
"""Plot full loss curves at the optimal LR for each width."""
|
||||
widths = results['widths']
|
||||
|
||||
fig, ax = plt.subplots(figsize=(5 * 2, 4))
|
||||
colors = plt.cm.viridis(np.linspace(0, 0.85, len(widths)))
|
||||
|
||||
for i, w in enumerate(widths):
|
||||
opt_mult = find_optimal_lr(results['final_losses'][w])
|
||||
losses = results['losses'][(w, opt_mult)]
|
||||
ax.plot(losses, color=colors[i], linewidth=2, label=f'w={w} (lr_mult={opt_mult})')
|
||||
|
||||
ax.set_xlabel('Step')
|
||||
ax.set_ylabel('Loss')
|
||||
ax.set_title(f'Loss Curves at Optimal LR{" - " + title if title else ""}')
|
||||
ax.legend()
|
||||
ax.grid(True, alpha=0.3)
|
||||
plt.tight_layout()
|
||||
|
||||
if save_path:
|
||||
plt.savefig(save_path, dpi=150, bbox_inches='tight')
|
||||
print(f"Saved plot to {save_path}")
|
||||
plt.show()
|
||||
|
||||
|
||||
def print_summary(results: Dict, label: str):
|
||||
"""Print a summary table of the LR sweep results."""
|
||||
widths = results['widths']
|
||||
lr_mults = results['lr_multipliers']
|
||||
final_losses = results['final_losses']
|
||||
|
||||
print(f"\n{'='*70}")
|
||||
print(f" {label}: LR Sweep Results")
|
||||
print(f"{'='*70}")
|
||||
|
||||
# Header
|
||||
header = f"{'Width':>8}"
|
||||
for m in lr_mults:
|
||||
header += f" | {m:>7.3f}"
|
||||
header += f" | {'Best':>7} | {'Opt LR':>7}"
|
||||
print(header)
|
||||
print("-" * len(header))
|
||||
|
||||
opt_mults = []
|
||||
for w in widths:
|
||||
row = f"{w:>8}"
|
||||
for m in lr_mults:
|
||||
loss = final_losses[w][m]
|
||||
row += f" | {loss:>7.4f}"
|
||||
opt_m = find_optimal_lr(final_losses[w])
|
||||
opt_mults.append(opt_m)
|
||||
opt_loss = final_losses[w][opt_m]
|
||||
row += f" | {opt_loss:>7.4f} | {opt_m:>7.3f}"
|
||||
print(row)
|
||||
|
||||
# Transfer quality metric: how much does the optimal LR shift?
|
||||
opt_mults_arr = np.array(opt_mults)
|
||||
log_opt = np.log2(opt_mults_arr)
|
||||
spread = log_opt.max() - log_opt.min() # spread in log2 space
|
||||
print(f"\nOptimal LR spread (log2): {spread:.3f}")
|
||||
print(f" (0 = perfect transfer, >1 = poor transfer)")
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description='muP Transfer Check')
|
||||
parser.add_argument('--widths', type=str, default='128,256,512,1024',
|
||||
help='Comma-separated list of widths to test')
|
||||
# Paper-style default: ~1000x range, 11 log-spaced points
|
||||
parser.add_argument('--lr-mults', type=str,
|
||||
default='0.03125,0.0625,0.125,0.25,0.5,1.0,2.0,4.0,8.0,16.0,32.0',
|
||||
help='Comma-separated LR multipliers to sweep (default: 1024x range, 11 points)')
|
||||
parser.add_argument('--num-random-trials', type=int, default=0,
|
||||
help='If >0, use N log-uniform random LR multipliers from 10^Uniform(-1.5,1.5) '
|
||||
'instead of the grid. Paper-style methodology (Section F).')
|
||||
parser.add_argument('--steps', type=int, default=200,
|
||||
help='Number of training steps per run')
|
||||
parser.add_argument('--batch-size', type=int, default=8,
|
||||
help='Batch size')
|
||||
parser.add_argument('--seq-len', type=int, default=128,
|
||||
help='Sequence length')
|
||||
parser.add_argument('--n-layer', type=int, default=2,
|
||||
help='Number of transformer layers')
|
||||
parser.add_argument('--use-mup', action='store_true',
|
||||
help='Use muP learning rate scaling')
|
||||
parser.add_argument('--base-width', type=int, default=128,
|
||||
help='Base width for muP scaling')
|
||||
parser.add_argument('--compare', action='store_true',
|
||||
help='Run both SP and muP and compare')
|
||||
parser.add_argument('--save-dir', type=str, default=None,
|
||||
help='Directory to save plots')
|
||||
parser.add_argument('--seed', type=int, default=42,
|
||||
help='Random seed')
|
||||
parser.add_argument('--num-batches', type=int, default=1,
|
||||
help='Number of data batches to cycle through (default 1 for backward compat, '
|
||||
'recommend 8 for thorough checks)')
|
||||
# Multi-HP sweeps
|
||||
parser.add_argument('--sweep-init-scale', action='store_true',
|
||||
help='Also sweep init scale multiplier (sampled from 10^Uniform(-1,1))')
|
||||
parser.add_argument('--sweep-output-mult', action='store_true',
|
||||
help='Also sweep output logit multiplier (sampled from 4^Uniform(-1,1))')
|
||||
parser.add_argument('--muon-lr-exponent', type=float, default=0.0,
|
||||
help='Muon LR exponent for muP: 1.0 = (base/width)^1 (standard), '
|
||||
'0.5 = (base/width)^0.5 (for Frobenius-normalizing optimizers like Muon)')
|
||||
parser.add_argument('--sweep-mode', type=str, default='all',
|
||||
choices=['all', 'muon-only', 'adamw-only'],
|
||||
help='Which optimizer groups the LR multiplier applies to: '
|
||||
'"all" = all LRs (default), "muon-only" = only Muon/matrix LR, '
|
||||
'"adamw-only" = only AdamW/embedding/output LR')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
widths = [int(w) for w in args.widths.split(',')]
|
||||
|
||||
# Generate LR multipliers
|
||||
if args.num_random_trials > 0:
|
||||
# Log-uniform random sampling: 10^Uniform(-1.5, 1.5)
|
||||
rng = np.random.RandomState(args.seed)
|
||||
lr_mults = sorted(10 ** rng.uniform(-1.5, 1.5, args.num_random_trials))
|
||||
lr_mults = [round(float(m), 6) for m in lr_mults]
|
||||
print(f"Using {args.num_random_trials} random log-uniform LR multipliers: "
|
||||
f"[{lr_mults[0]:.4f}, ..., {lr_mults[-1]:.4f}]")
|
||||
else:
|
||||
lr_mults = sorted(float(m) for m in args.lr_mults.split(','))
|
||||
|
||||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
print(f"Using device: {device}")
|
||||
|
||||
# Load batches
|
||||
batches, vocab_size = load_batches(args.num_batches, args.batch_size, args.seq_len, device)
|
||||
|
||||
config = TransferCheckConfig(
|
||||
widths=widths,
|
||||
lr_multipliers=lr_mults,
|
||||
num_steps=args.steps,
|
||||
batch_size=args.batch_size,
|
||||
seq_len=args.seq_len,
|
||||
vocab_size=vocab_size,
|
||||
n_layer=args.n_layer,
|
||||
seed=args.seed,
|
||||
use_mup=args.use_mup,
|
||||
base_width=args.base_width,
|
||||
sweep_init_scale=args.sweep_init_scale,
|
||||
sweep_output_mult=args.sweep_output_mult,
|
||||
num_batches=args.num_batches,
|
||||
muon_lr_exponent=args.muon_lr_exponent,
|
||||
sweep_mode=args.sweep_mode,
|
||||
)
|
||||
|
||||
if args.compare:
|
||||
# Run SP
|
||||
print("\n" + "=" * 60)
|
||||
print("Running Standard Parameterization (SP)")
|
||||
print("=" * 60)
|
||||
config.use_mup = False
|
||||
results_sp = run_transfer_check(config, device, batches)
|
||||
print_summary(results_sp, "SP")
|
||||
|
||||
# Run muP
|
||||
print("\n" + "=" * 60)
|
||||
print("Running muP")
|
||||
print("=" * 60)
|
||||
config.use_mup = True
|
||||
results_mup = run_transfer_check(config, device, batches)
|
||||
print_summary(results_mup, "muP")
|
||||
|
||||
# Compare
|
||||
print("\n" + "=" * 60)
|
||||
print("COMPARISON")
|
||||
print("=" * 60)
|
||||
sp_opts = [find_optimal_lr(results_sp['final_losses'][w]) for w in results_sp['widths']]
|
||||
mup_opts = [find_optimal_lr(results_mup['final_losses'][w]) for w in results_mup['widths']]
|
||||
sp_spread = np.log2(max(sp_opts)) - np.log2(min(sp_opts))
|
||||
mup_spread = np.log2(max(mup_opts)) - np.log2(min(mup_opts))
|
||||
print(f"SP optimal LR spread (log2): {sp_spread:.3f}")
|
||||
print(f"muP optimal LR spread (log2): {mup_spread:.3f}")
|
||||
if mup_spread < sp_spread:
|
||||
print(f"muP shows {sp_spread/max(mup_spread, 0.001):.1f}x better LR transfer!")
|
||||
else:
|
||||
print("muP does NOT show better LR transfer (check implementation)")
|
||||
|
||||
# Plot
|
||||
save_path = None
|
||||
if args.save_dir:
|
||||
os.makedirs(args.save_dir, exist_ok=True)
|
||||
save_path = os.path.join(args.save_dir, 'transfer_check_comparison.png')
|
||||
plot_comparison(results_sp, results_mup, config, save_path)
|
||||
|
||||
# Also plot loss curves at optimal LR
|
||||
for results, label in [(results_sp, 'SP'), (results_mup, 'muP')]:
|
||||
lc_save = None
|
||||
if args.save_dir:
|
||||
lc_save = os.path.join(args.save_dir, f'optimal_loss_curves_{label.lower()}.png')
|
||||
plot_loss_curves_at_optimal(results, config, title=label, save_path=lc_save)
|
||||
|
||||
# Multi-HP sweeps (only for muP, to demonstrate transfer)
|
||||
if args.sweep_init_scale or args.sweep_output_mult:
|
||||
config.use_mup = True
|
||||
|
||||
if args.sweep_init_scale:
|
||||
print("\n" + "=" * 60)
|
||||
print("muP: Init Scale Sweep")
|
||||
print("=" * 60)
|
||||
# 10^Uniform(-1, 1) => range [0.1, 10]
|
||||
init_scales = [0.1, 0.2, 0.5, 1.0, 2.0, 5.0, 10.0]
|
||||
init_results = run_hp_sweep(config, device, batches, 'init_scale', init_scales)
|
||||
save_hp = None
|
||||
if args.save_dir:
|
||||
save_hp = os.path.join(args.save_dir, 'init_scale_sweep.png')
|
||||
plot_hp_sweep(init_results, config, title="muP", save_path=save_hp)
|
||||
|
||||
if args.sweep_output_mult:
|
||||
print("\n" + "=" * 60)
|
||||
print("muP: Output Multiplier Sweep")
|
||||
print("=" * 60)
|
||||
# 4^Uniform(-1, 1) => range [0.25, 4]
|
||||
output_mults = [0.25, 0.5, 1.0, 2.0, 4.0]
|
||||
output_results = run_hp_sweep(config, device, batches, 'output_mult', output_mults)
|
||||
save_hp = None
|
||||
if args.save_dir:
|
||||
save_hp = os.path.join(args.save_dir, 'output_mult_sweep.png')
|
||||
plot_hp_sweep(output_results, config, title="muP", save_path=save_hp)
|
||||
|
||||
else:
|
||||
param_type = "muP" if config.use_mup else "SP"
|
||||
print(f"\n{'='*60}")
|
||||
print(f"Running Transfer Check ({param_type})")
|
||||
print(f"{'='*60}")
|
||||
print(f"Widths: {widths}")
|
||||
print(f"LR multipliers: {lr_mults}")
|
||||
print(f"Steps: {config.num_steps}")
|
||||
if config.sweep_mode != "all":
|
||||
print(f"Sweep mode: {config.sweep_mode}")
|
||||
|
||||
results = run_transfer_check(config, device, batches)
|
||||
print_summary(results, param_type)
|
||||
|
||||
# Plot LR sweep
|
||||
save_path = None
|
||||
if args.save_dir:
|
||||
os.makedirs(args.save_dir, exist_ok=True)
|
||||
save_path = os.path.join(args.save_dir, f'transfer_check_{param_type.lower()}.png')
|
||||
plot_lr_sweep(results, config, title=param_type, save_path=save_path)
|
||||
|
||||
# Plot loss curves at optimal LR
|
||||
lc_save = None
|
||||
if args.save_dir:
|
||||
lc_save = os.path.join(args.save_dir, f'optimal_loss_curves_{param_type.lower()}.png')
|
||||
plot_loss_curves_at_optimal(results, config, title=param_type, save_path=lc_save)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
Loading…
Reference in New Issue
Block a user