Created a complete educational resource covering the implementation of nanochat from scratch, including: - Mathematical foundations (linear algebra, optimization, attention) - Tokenization with detailed BPE algorithm explanation - Transformer architecture and GPT model implementation - Self-attention mechanism with RoPE and Multi-Query Attention - Training process, data loading, and distributed training - Advanced optimization techniques (Muon + AdamW) - Practical implementation guide with debugging tips - Automated PDF compilation script The guide includes deep code walkthroughs with line-by-line explanations of key components, making it accessible for beginners while covering advanced techniques used in modern LLMs. Total content: ~4,300 lines across 8 chapters plus README and tooling. PDF compilation available via compile_to_pdf.py script.
16 KiB
Advanced Optimization Techniques
nanochat uses a hybrid optimization strategy: combining Muon for matrix parameters and AdamW for embeddings. This is more sophisticated than standard approaches.
Why Different Optimizers?
Different parameter types have different optimization needs:
| Parameter Type | Examples | Characteristics | Best Optimizer |
|---|---|---|---|
| Matrices | Attention, MLP | Dense, high-dimensional | Muon |
| Embeddings | Token embeddings | Sparse updates, embedding-specific | AdamW |
| Vectors | LM head | Output layer, sparse | AdamW |
Traditional approach: Use AdamW for everything nanochat approach: Use Muon for matrices, AdamW for embeddings/head
Result: Faster training, better convergence
1. Muon Optimizer
Muon is a novel optimizer designed specifically for matrix parameters in neural networks.
Core Idea
Standard optimizers (SGD, Adam) treat matrices as flat vectors:
Matrix [3×4] → Flatten to vector [12] → Update
Muon exploits matrix structure:
Matrix [3×4] → Update using matrix operations → Keep matrix shape
Mathematical Formulation
For weight matrix W \in \mathbb{R}^{m \times n}:
Standard momentum:
v_t = \beta v_{t-1} + (1-\beta) g_t
W_t = W_{t-1} - \eta v_t
Muon:
- Compute gradient
G_t = \nabla_W \mathcal{L} - Orthogonalize using Newton-Schulz iteration
- Apply momentum in tangent space
- Update with adaptive step size
Implementation: nanochat/muon.py:53
class Muon(torch.optim.Optimizer):
def __init__(self, params, lr=0.02, momentum=0.95):
defaults = dict(lr=lr, momentum=momentum)
super(Muon, self).__init__(params, defaults)
@torch.no_grad()
def step(self):
for group in self.param_groups:
lr = group['lr']
momentum = group['momentum']
for p in group['params']:
if p.grad is None:
continue
g = p.grad # Gradient
# Get state
state = self.state[p]
if 'momentum_buffer' not in state:
state['momentum_buffer'] = torch.zeros_like(g)
buf = state['momentum_buffer']
# Handle matrix vs non-matrix parameters
if g.ndim == 2 and g.size(0) >= 16 and g.size(1) >= 16:
# Matrix parameter: use Muon update
g = newton_schulz_orthogonalize(g, steps=5)
# Momentum update
buf.mul_(momentum).add_(g)
# Parameter update
p.data.add_(buf, alpha=-lr)
Newton-Schulz Orthogonalization: nanochat/muon.py:16
def newton_schulz_orthogonalize(G, steps=5, eps=1e-7):
"""
Orthogonalize gradient matrix using Newton-Schulz iteration
"""
# Make square by padding or cropping
a, b = G.size()
if a > b:
G = G[:b, :]
elif a < b:
G = G[:, :a]
# Initialize
# Normalization factor
t = G.size(0)
# X_0 = G / ||G||_F
A = G / (G.norm() + eps)
# Newton-Schulz iteration: X_{k+1} = X_k * (3I - X_k^T X_k) / 2
for _ in range(steps):
A_T_A = A.t() @ A
A = A @ (1.5 * torch.eye(t, device=A.device, dtype=A.dtype) - 0.5 * A_T_A)
# Restore original shape
if a > b:
A = torch.cat([A, torch.zeros(a - b, b, device=A.device, dtype=A.dtype)], dim=0)
elif a < b:
A = torch.cat([A, torch.zeros(a, b - a, device=A.device, dtype=A.dtype)], dim=1)
return A
What does this do?
For a matrix G, find orthogonal matrix Q closest to G:
Q = \arg\min_{\tilde{Q}^T\tilde{Q}=I} \|G - \tilde{Q}\|_F
Uses iterative formula:
X_{k+1} = X_k \left(\frac{3I - X_k^TX_k}{2}\right)
Converges to Q = G(G^TG)^{-1/2} (the orthogonal component of $G$).
Why orthogonalize?
- Keeps gradients on Stiefel manifold
- Better geometry for optimization
- Prevents gradient explosion/vanishing
- Faster convergence
Distributed Muon: nanochat/muon.py:155
For multi-GPU training:
class DistMuon(Muon):
def step(self):
# First, average gradients across all GPUs
for group in self.param_groups:
for p in group['params']:
if p.grad is not None:
torch.distributed.all_reduce(p.grad, op=torch.distributed.ReduceOp.AVG)
# Then apply standard Muon update
super().step()
Key: All-reduce gradients before Muon update ensures synchronization.
Muon Learning Rate Scaling
# From scripts/base_train.py:238
dmodel_lr_scale = (model_dim / 768) ** -0.5
lr_scaled = matrix_lr # No scaling for Muon (handles it internally)
Muon is scale-invariant, so no need to scale LR by model dimension!
Momentum Schedule for Muon: scripts/base_train.py:160
def get_muon_momentum(it):
"""Warmup momentum from 0.85 to 0.95"""
frac = min(it / 300, 1)
momentum = (1 - frac) * 0.85 + frac * 0.95
return momentum
Start with lower momentum (more responsive), increase to higher momentum (more stable).
2. AdamW Optimizer
AdamW is used for embedding and language model head parameters.
Standard Adam
Combines momentum and adaptive learning rates:
m_t = \beta_1 m_{t-1} + (1-\beta_1) g_t \quad \text{(first moment)}
v_t = \beta_2 v_{t-1} + (1-\beta_2) g_t^2 \quad \text{(second moment)}
Bias correction:
\hat{m}_t = \frac{m_t}{1-\beta_1^t}, \quad \hat{v}_t = \frac{v_t}{1-\beta_2^t}
Update:
\theta_t = \theta_{t-1} - \eta \frac{\hat{m}_t}{\sqrt{\hat{v}_t} + \epsilon}
AdamW: Decoupled Weight Decay
Adam with L2 regularization:
\mathcal{L}' = \mathcal{L} + \frac{\lambda}{2}\|\theta\|^2
Problem: Weight decay interacts with adaptive learning rate in weird ways.
AdamW solution: Decouple weight decay from gradient:
\theta_t = (1 - \lambda \eta) \theta_{t-1} - \eta \frac{\hat{m}_t}{\sqrt{\hat{v}_t} + \epsilon}
Benefits:
- Cleaner regularization
- Better generalization
- Less hyperparameter interaction
Implementation: nanochat/adamw.py:53
class DistAdamW(torch.optim.Optimizer):
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0.0):
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
super().__init__(params, defaults)
@torch.no_grad()
def step(self):
# First, all-reduce gradients across GPUs
for group in self.param_groups:
for p in group['params']:
if p.grad is not None:
torch.distributed.all_reduce(p.grad, op=torch.distributed.ReduceOp.AVG)
# Then apply AdamW update
for group in self.param_groups:
lr = group['lr']
beta1, beta2 = group['betas']
eps = group['eps']
weight_decay = group['weight_decay']
for p in group['params']:
if p.grad is None:
continue
grad = p.grad
state = self.state[p]
# Initialize state
if len(state) == 0:
state['step'] = 0
state['exp_avg'] = torch.zeros_like(p) # m_t
state['exp_avg_sq'] = torch.zeros_like(p) # v_t
state['step'] += 1
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
step = state['step']
# Update biased first and second moments
exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
# Bias correction
bias_correction1 = 1 - beta1 ** step
bias_correction2 = 1 - beta2 ** step
step_size = lr / bias_correction1
# Compute denominator (with bias correction)
denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(eps)
# Weight decay (decoupled)
if weight_decay != 0:
p.data.mul_(1 - lr * weight_decay)
# Update parameters
p.data.addcdiv_(exp_avg, denom, value=-step_size)
AdamW Hyperparameters in nanochat
# From scripts/base_train.py:228
adam_groups = [
dict(params=lm_head_params, lr=unembedding_lr * dmodel_lr_scale), # 0.004
dict(params=embedding_params, lr=embedding_lr * dmodel_lr_scale), # 0.2
]
adamw_kwargs = dict(
betas=(0.8, 0.95), # Instead of default (0.9, 0.999)
eps=1e-10,
weight_decay=weight_decay # Usually 0.0 for small models
)
Why different betas?
\beta_1 = 0.8: Slightly less momentum (more responsive)\beta_2 = 0.95: Much less variance accumulation (adapts faster)
This is better tuned for LLM training than defaults.
Learning Rate Scaling by Model Dimension
dmodel_lr_scale = (model_dim / 768) ** -0.5
# Example:
# model_dim = 1280 → scale = (1280/768)^{-0.5} ≈ 0.77
# model_dim = 384 → scale = (384/768)^{-0.5} ≈ 1.41
Why \propto 1/\sqrt{d_{model}}?
Larger models have larger gradients (sum over more dimensions). Scaling LR prevents instability.
3. Hybrid Optimizer Setup: nanochat/gpt.py:228
def setup_optimizers(self, unembedding_lr=0.004, embedding_lr=0.2, matrix_lr=0.02, weight_decay=0.0):
model_dim = self.config.n_embd
ddp, rank, _, _ = get_dist_info()
# Separate parameters into groups
matrix_params = list(self.transformer.h.parameters()) # All transformer blocks
embedding_params = list(self.transformer.wte.parameters()) # Token embeddings
lm_head_params = list(self.lm_head.parameters()) # Output layer
# Scale learning rates
dmodel_lr_scale = (model_dim / 768) ** -0.5
# AdamW for embeddings and LM head
adam_groups = [
dict(params=lm_head_params, lr=unembedding_lr * dmodel_lr_scale),
dict(params=embedding_params, lr=embedding_lr * dmodel_lr_scale),
]
AdamWFactory = DistAdamW if ddp else partial(torch.optim.AdamW, fused=True)
adamw_optimizer = AdamWFactory(adam_groups, betas=(0.8, 0.95), eps=1e-10, weight_decay=weight_decay)
# Muon for transformer matrices
MuonFactory = DistMuon if ddp else Muon
muon_optimizer = MuonFactory(matrix_params, lr=matrix_lr, momentum=0.95)
# Return both optimizers
optimizers = [adamw_optimizer, muon_optimizer]
return optimizers
Why different learning rates?
| Parameter | LR | Reasoning |
|---|---|---|
| Embeddings | 0.2 | Sparse updates, can handle high LR |
| LM head | 0.004 | Dense gradients, needs lower LR |
| Matrices | 0.02 | Muon handles geometry, moderate LR |
Stepping Multiple Optimizers: scripts/base_train.py:269
# Update learning rates for all optimizers
lrm = get_lr_multiplier(step)
for opt in optimizers:
for group in opt.param_groups:
group["lr"] = group["initial_lr"] * lrm
# Update Muon momentum
muon_momentum = get_muon_momentum(step)
for group in muon_optimizer.param_groups:
group["momentum"] = muon_momentum
# Step all optimizers
for opt in optimizers:
opt.step()
# Clear gradients
model.zero_grad(set_to_none=True)
Important: set_to_none=True saves memory compared to zeroing.
4. Gradient Clipping
Prevents exploding gradients during training.
Global Norm Clipping: scripts/base_train.py:265
if grad_clip > 0.0:
torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
How it works:
- Compute global gradient norm:
\|\mathbf{g}\|_{global} = \sqrt{\sum_{\theta \in \Theta} \|\nabla_\theta \mathcal{L}\|^2}
- If too large, scale all gradients:
\mathbf{g}_\theta \leftarrow \frac{\text{max\_norm}}{\|\mathbf{g}\|_{global}} \mathbf{g}_\theta
Effect: Limits maximum gradient magnitude without changing direction.
Implementation Details
def clip_grad_norm_(parameters, max_norm, norm_type=2):
parameters = list(filter(lambda p: p.grad is not None, parameters))
# Compute total norm
total_norm = torch.norm(
torch.stack([torch.norm(p.grad.detach(), norm_type) for p in parameters]),
norm_type
)
# Compute clipping coefficient
clip_coef = max_norm / (total_norm + 1e-6)
# Clip if necessary
if clip_coef < 1:
for p in parameters:
p.grad.detach().mul_(clip_coef)
return total_norm
5. Warmup and Decay Schedules
Why Warmup?
At initialization:
- Weights are random
- Gradients can be very large
- Adam's second moment estimate is inaccurate
Solution: Start with low LR, gradually increase.
Why Decay?
Near end of training:
- Model is close to optimum
- Small steps refine solution
- Prevents oscillation
Solution: Gradually decrease LR to 0.
Schedule Implementation: scripts/base_train.py:148
warmup_ratio = 0.0 # Skip warmup for simplicity
warmdown_ratio = 0.2 # Last 20% of training
final_lr_frac = 0.0 # Decay to 0
def get_lr_multiplier(it):
warmup_iters = round(warmup_ratio * num_iterations)
warmdown_iters = round(warmdown_ratio * num_iterations)
if it < warmup_iters:
# Linear warmup
return (it + 1) / warmup_iters
elif it <= num_iterations - warmdown_iters:
# Constant LR
return 1.0
else:
# Linear warmdown
progress = (num_iterations - it) / warmdown_iters
return progress * 1.0 + (1 - progress) * final_lr_frac
Alternative schedules:
- Cosine decay: Smoother than linear
- Exponential decay: Aggressive reduction
- Step decay: Discrete jumps
6. Optimization Best Practices
Learning Rate Tuning
Too high:
- Training unstable
- Loss oscillates or diverges
- NaN loss
Too low:
- Training very slow
- Gets stuck in local minima
- Underfits
Good LR:
- Steady loss decrease
- Occasional small oscillations
- Converges smoothly
Finding Good LR: Learning Rate Range Test
# Start with very low LR, gradually increase
lrs = []
losses = []
lr = 1e-8
for step in range(1000):
loss = train_step(lr)
lrs.append(lr)
losses.append(loss)
lr *= 1.01 # Increase by 1%
# Plot losses vs LRs
# Good LR is where loss decreases fastest
Batch Size Effects
Larger batch size:
- More stable gradients
- Better GPU utilization
- Can use higher LR
- Slower wall-clock time per iteration
- May generalize worse
Smaller batch size:
- Noisier gradients (implicit regularization)
- Less GPU efficient
- Lower LR needed
- Faster iterations
nanochat choice: 524K tokens/batch (very large for stability)
7. Comparison: Different Optimization Strategies
| Strategy | Training Speed | Final Loss | Complexity |
|---|---|---|---|
| SGD | Slow | Good | Simple |
| Adam | Fast | Good | Medium |
| AdamW | Fast | Better | Medium |
| Muon (matrices only) | Very Fast | Best | High |
| Hybrid (AdamW + Muon) | Very Fast | Best | High |
nanochat's hybrid approach is cutting-edge!
8. Memory Optimization
Gradient Checkpointing (Not used in nanochat)
Trade compute for memory:
- Don't store intermediate activations
- Recompute during backward pass
- 2× slower, but 10× less memory
Optimizer State Management
AdamW stores:
- First moment (m): same size as parameters
- Second moment (v): same size as parameters
Memory: ~2× parameter size
For 270M param model:
- Parameters: 270M × 2 bytes (BF16) = 540 MB
- AdamW states: 270M × 8 bytes (FP32) = 2.16 GB
- Total: ~2.7 GB
Fused Optimizers
AdamW(..., fused=True) # Uses fused CUDA kernel
Benefits:
- Faster updates (single kernel launch)
- Less memory traffic
- ~10-20% speedup
Next Steps
We've covered optimization! Next, we'll explore Implementation Details - practical coding techniques used throughout nanochat.