nanochat/educational/07_optimization.md
Matt Suiche c5ef68cea2 Add comprehensive educational guide for nanochat
Created a complete educational resource covering the implementation of
nanochat from scratch, including:

- Mathematical foundations (linear algebra, optimization, attention)
- Tokenization with detailed BPE algorithm explanation
- Transformer architecture and GPT model implementation
- Self-attention mechanism with RoPE and Multi-Query Attention
- Training process, data loading, and distributed training
- Advanced optimization techniques (Muon + AdamW)
- Practical implementation guide with debugging tips
- Automated PDF compilation script

The guide includes deep code walkthroughs with line-by-line explanations
of key components, making it accessible for beginners while covering
advanced techniques used in modern LLMs.

Total content: ~4,300 lines across 8 chapters plus README and tooling.
PDF compilation available via compile_to_pdf.py script.
2025-10-21 18:36:26 +04:00

567 lines
16 KiB
Markdown
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# Advanced Optimization Techniques
nanochat uses a **hybrid optimization** strategy: combining **Muon** for matrix parameters and **AdamW** for embeddings. This is more sophisticated than standard approaches.
## Why Different Optimizers?
Different parameter types have different optimization needs:
| Parameter Type | Examples | Characteristics | Best Optimizer |
|----------------|----------|-----------------|----------------|
| **Matrices** | Attention, MLP | Dense, high-dimensional | Muon |
| **Embeddings** | Token embeddings | Sparse updates, embedding-specific | AdamW |
| **Vectors** | LM head | Output layer, sparse | AdamW |
**Traditional approach:** Use AdamW for everything
**nanochat approach:** Use Muon for matrices, AdamW for embeddings/head
**Result:** Faster training, better convergence
## 1. Muon Optimizer
Muon is a novel optimizer designed specifically for **matrix parameters** in neural networks.
### Core Idea
Standard optimizers (SGD, Adam) treat matrices as flat vectors:
```
Matrix [3×4] → Flatten to vector [12] → Update
```
Muon exploits **matrix structure**:
```
Matrix [3×4] → Update using matrix operations → Keep matrix shape
```
### Mathematical Formulation
For weight matrix $W \in \mathbb{R}^{m \times n}$:
**Standard momentum:**
$$v_t = \beta v_{t-1} + (1-\beta) g_t$$
$$W_t = W_{t-1} - \eta v_t$$
**Muon:**
1. Compute gradient $G_t = \nabla_W \mathcal{L}$
2. Orthogonalize using Newton-Schulz iteration
3. Apply momentum in tangent space
4. Update with adaptive step size
### Implementation: `nanochat/muon.py:53`
```python
class Muon(torch.optim.Optimizer):
def __init__(self, params, lr=0.02, momentum=0.95):
defaults = dict(lr=lr, momentum=momentum)
super(Muon, self).__init__(params, defaults)
@torch.no_grad()
def step(self):
for group in self.param_groups:
lr = group['lr']
momentum = group['momentum']
for p in group['params']:
if p.grad is None:
continue
g = p.grad # Gradient
# Get state
state = self.state[p]
if 'momentum_buffer' not in state:
state['momentum_buffer'] = torch.zeros_like(g)
buf = state['momentum_buffer']
# Handle matrix vs non-matrix parameters
if g.ndim == 2 and g.size(0) >= 16 and g.size(1) >= 16:
# Matrix parameter: use Muon update
g = newton_schulz_orthogonalize(g, steps=5)
# Momentum update
buf.mul_(momentum).add_(g)
# Parameter update
p.data.add_(buf, alpha=-lr)
```
### Newton-Schulz Orthogonalization: `nanochat/muon.py:16`
```python
def newton_schulz_orthogonalize(G, steps=5, eps=1e-7):
"""
Orthogonalize gradient matrix using Newton-Schulz iteration
"""
# Make square by padding or cropping
a, b = G.size()
if a > b:
G = G[:b, :]
elif a < b:
G = G[:, :a]
# Initialize
# Normalization factor
t = G.size(0)
# X_0 = G / ||G||_F
A = G / (G.norm() + eps)
# Newton-Schulz iteration: X_{k+1} = X_k * (3I - X_k^T X_k) / 2
for _ in range(steps):
A_T_A = A.t() @ A
A = A @ (1.5 * torch.eye(t, device=A.device, dtype=A.dtype) - 0.5 * A_T_A)
# Restore original shape
if a > b:
A = torch.cat([A, torch.zeros(a - b, b, device=A.device, dtype=A.dtype)], dim=0)
elif a < b:
A = torch.cat([A, torch.zeros(a, b - a, device=A.device, dtype=A.dtype)], dim=1)
return A
```
**What does this do?**
For a matrix $G$, find orthogonal matrix $Q$ closest to $G$:
$$Q = \arg\min_{\tilde{Q}^T\tilde{Q}=I} \|G - \tilde{Q}\|_F$$
Uses iterative formula:
$$X_{k+1} = X_k \left(\frac{3I - X_k^TX_k}{2}\right)$$
Converges to $Q = G(G^TG)^{-1/2}$ (the orthogonal component of $G$).
**Why orthogonalize?**
- Keeps gradients on Stiefel manifold
- Better geometry for optimization
- Prevents gradient explosion/vanishing
- Faster convergence
### Distributed Muon: `nanochat/muon.py:155`
For multi-GPU training:
```python
class DistMuon(Muon):
def step(self):
# First, average gradients across all GPUs
for group in self.param_groups:
for p in group['params']:
if p.grad is not None:
torch.distributed.all_reduce(p.grad, op=torch.distributed.ReduceOp.AVG)
# Then apply standard Muon update
super().step()
```
**Key:** All-reduce gradients before Muon update ensures synchronization.
### Muon Learning Rate Scaling
```python
# From scripts/base_train.py:238
dmodel_lr_scale = (model_dim / 768) ** -0.5
lr_scaled = matrix_lr # No scaling for Muon (handles it internally)
```
Muon is **scale-invariant**, so no need to scale LR by model dimension!
### Momentum Schedule for Muon: `scripts/base_train.py:160`
```python
def get_muon_momentum(it):
"""Warmup momentum from 0.85 to 0.95"""
frac = min(it / 300, 1)
momentum = (1 - frac) * 0.85 + frac * 0.95
return momentum
```
Start with lower momentum (more responsive), increase to higher momentum (more stable).
## 2. AdamW Optimizer
AdamW is used for embedding and language model head parameters.
### Standard Adam
Combines **momentum** and **adaptive learning rates**:
$$m_t = \beta_1 m_{t-1} + (1-\beta_1) g_t \quad \text{(first moment)}$$
$$v_t = \beta_2 v_{t-1} + (1-\beta_2) g_t^2 \quad \text{(second moment)}$$
Bias correction:
$$\hat{m}_t = \frac{m_t}{1-\beta_1^t}, \quad \hat{v}_t = \frac{v_t}{1-\beta_2^t}$$
Update:
$$\theta_t = \theta_{t-1} - \eta \frac{\hat{m}_t}{\sqrt{\hat{v}_t} + \epsilon}$$
### AdamW: Decoupled Weight Decay
**Adam with L2 regularization:**
$$\mathcal{L}' = \mathcal{L} + \frac{\lambda}{2}\|\theta\|^2$$
**Problem:** Weight decay interacts with adaptive learning rate in weird ways.
**AdamW solution:** Decouple weight decay from gradient:
$$\theta_t = (1 - \lambda \eta) \theta_{t-1} - \eta \frac{\hat{m}_t}{\sqrt{\hat{v}_t} + \epsilon}$$
**Benefits:**
- Cleaner regularization
- Better generalization
- Less hyperparameter interaction
### Implementation: `nanochat/adamw.py:53`
```python
class DistAdamW(torch.optim.Optimizer):
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0.0):
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
super().__init__(params, defaults)
@torch.no_grad()
def step(self):
# First, all-reduce gradients across GPUs
for group in self.param_groups:
for p in group['params']:
if p.grad is not None:
torch.distributed.all_reduce(p.grad, op=torch.distributed.ReduceOp.AVG)
# Then apply AdamW update
for group in self.param_groups:
lr = group['lr']
beta1, beta2 = group['betas']
eps = group['eps']
weight_decay = group['weight_decay']
for p in group['params']:
if p.grad is None:
continue
grad = p.grad
state = self.state[p]
# Initialize state
if len(state) == 0:
state['step'] = 0
state['exp_avg'] = torch.zeros_like(p) # m_t
state['exp_avg_sq'] = torch.zeros_like(p) # v_t
state['step'] += 1
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
step = state['step']
# Update biased first and second moments
exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
# Bias correction
bias_correction1 = 1 - beta1 ** step
bias_correction2 = 1 - beta2 ** step
step_size = lr / bias_correction1
# Compute denominator (with bias correction)
denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(eps)
# Weight decay (decoupled)
if weight_decay != 0:
p.data.mul_(1 - lr * weight_decay)
# Update parameters
p.data.addcdiv_(exp_avg, denom, value=-step_size)
```
### AdamW Hyperparameters in nanochat
```python
# From scripts/base_train.py:228
adam_groups = [
dict(params=lm_head_params, lr=unembedding_lr * dmodel_lr_scale), # 0.004
dict(params=embedding_params, lr=embedding_lr * dmodel_lr_scale), # 0.2
]
adamw_kwargs = dict(
betas=(0.8, 0.95), # Instead of default (0.9, 0.999)
eps=1e-10,
weight_decay=weight_decay # Usually 0.0 for small models
)
```
**Why different betas?**
- $\beta_1 = 0.8$: Slightly less momentum (more responsive)
- $\beta_2 = 0.95$: Much less variance accumulation (adapts faster)
This is better tuned for LLM training than defaults.
### Learning Rate Scaling by Model Dimension
```python
dmodel_lr_scale = (model_dim / 768) ** -0.5
# Example:
# model_dim = 1280 → scale = (1280/768)^{-0.5} ≈ 0.77
# model_dim = 384 → scale = (384/768)^{-0.5} ≈ 1.41
```
**Why $\propto 1/\sqrt{d_{model}}$?**
Larger models have larger gradients (sum over more dimensions). Scaling LR prevents instability.
## 3. Hybrid Optimizer Setup: `nanochat/gpt.py:228`
```python
def setup_optimizers(self, unembedding_lr=0.004, embedding_lr=0.2, matrix_lr=0.02, weight_decay=0.0):
model_dim = self.config.n_embd
ddp, rank, _, _ = get_dist_info()
# Separate parameters into groups
matrix_params = list(self.transformer.h.parameters()) # All transformer blocks
embedding_params = list(self.transformer.wte.parameters()) # Token embeddings
lm_head_params = list(self.lm_head.parameters()) # Output layer
# Scale learning rates
dmodel_lr_scale = (model_dim / 768) ** -0.5
# AdamW for embeddings and LM head
adam_groups = [
dict(params=lm_head_params, lr=unembedding_lr * dmodel_lr_scale),
dict(params=embedding_params, lr=embedding_lr * dmodel_lr_scale),
]
AdamWFactory = DistAdamW if ddp else partial(torch.optim.AdamW, fused=True)
adamw_optimizer = AdamWFactory(adam_groups, betas=(0.8, 0.95), eps=1e-10, weight_decay=weight_decay)
# Muon for transformer matrices
MuonFactory = DistMuon if ddp else Muon
muon_optimizer = MuonFactory(matrix_params, lr=matrix_lr, momentum=0.95)
# Return both optimizers
optimizers = [adamw_optimizer, muon_optimizer]
return optimizers
```
**Why different learning rates?**
| Parameter | LR | Reasoning |
|-----------|-----|-----------|
| Embeddings | 0.2 | Sparse updates, can handle high LR |
| LM head | 0.004 | Dense gradients, needs lower LR |
| Matrices | 0.02 | Muon handles geometry, moderate LR |
### Stepping Multiple Optimizers: `scripts/base_train.py:269`
```python
# Update learning rates for all optimizers
lrm = get_lr_multiplier(step)
for opt in optimizers:
for group in opt.param_groups:
group["lr"] = group["initial_lr"] * lrm
# Update Muon momentum
muon_momentum = get_muon_momentum(step)
for group in muon_optimizer.param_groups:
group["momentum"] = muon_momentum
# Step all optimizers
for opt in optimizers:
opt.step()
# Clear gradients
model.zero_grad(set_to_none=True)
```
**Important:** `set_to_none=True` saves memory compared to zeroing.
## 4. Gradient Clipping
Prevents exploding gradients during training.
### Global Norm Clipping: `scripts/base_train.py:265`
```python
if grad_clip > 0.0:
torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
```
**How it works:**
1. Compute global gradient norm:
$$\|\mathbf{g}\|_{global} = \sqrt{\sum_{\theta \in \Theta} \|\nabla_\theta \mathcal{L}\|^2}$$
2. If too large, scale all gradients:
$$\mathbf{g}_\theta \leftarrow \frac{\text{max\_norm}}{\|\mathbf{g}\|_{global}} \mathbf{g}_\theta$$
**Effect:** Limits maximum gradient magnitude without changing direction.
### Implementation Details
```python
def clip_grad_norm_(parameters, max_norm, norm_type=2):
parameters = list(filter(lambda p: p.grad is not None, parameters))
# Compute total norm
total_norm = torch.norm(
torch.stack([torch.norm(p.grad.detach(), norm_type) for p in parameters]),
norm_type
)
# Compute clipping coefficient
clip_coef = max_norm / (total_norm + 1e-6)
# Clip if necessary
if clip_coef < 1:
for p in parameters:
p.grad.detach().mul_(clip_coef)
return total_norm
```
## 5. Warmup and Decay Schedules
### Why Warmup?
At initialization:
- Weights are random
- Gradients can be very large
- Adam's second moment estimate is inaccurate
**Solution:** Start with low LR, gradually increase.
### Why Decay?
Near end of training:
- Model is close to optimum
- Small steps refine solution
- Prevents oscillation
**Solution:** Gradually decrease LR to 0.
### Schedule Implementation: `scripts/base_train.py:148`
```python
warmup_ratio = 0.0 # Skip warmup for simplicity
warmdown_ratio = 0.2 # Last 20% of training
final_lr_frac = 0.0 # Decay to 0
def get_lr_multiplier(it):
warmup_iters = round(warmup_ratio * num_iterations)
warmdown_iters = round(warmdown_ratio * num_iterations)
if it < warmup_iters:
# Linear warmup
return (it + 1) / warmup_iters
elif it <= num_iterations - warmdown_iters:
# Constant LR
return 1.0
else:
# Linear warmdown
progress = (num_iterations - it) / warmdown_iters
return progress * 1.0 + (1 - progress) * final_lr_frac
```
**Alternative schedules:**
- Cosine decay: Smoother than linear
- Exponential decay: Aggressive reduction
- Step decay: Discrete jumps
## 6. Optimization Best Practices
### Learning Rate Tuning
**Too high:**
- Training unstable
- Loss oscillates or diverges
- NaN loss
**Too low:**
- Training very slow
- Gets stuck in local minima
- Underfits
**Good LR:**
- Steady loss decrease
- Occasional small oscillations
- Converges smoothly
### Finding Good LR: Learning Rate Range Test
```python
# Start with very low LR, gradually increase
lrs = []
losses = []
lr = 1e-8
for step in range(1000):
loss = train_step(lr)
lrs.append(lr)
losses.append(loss)
lr *= 1.01 # Increase by 1%
# Plot losses vs LRs
# Good LR is where loss decreases fastest
```
### Batch Size Effects
**Larger batch size:**
- More stable gradients
- Better GPU utilization
- Can use higher LR
- Slower wall-clock time per iteration
- May generalize worse
**Smaller batch size:**
- Noisier gradients (implicit regularization)
- Less GPU efficient
- Lower LR needed
- Faster iterations
**nanochat choice:** 524K tokens/batch (very large for stability)
## 7. Comparison: Different Optimization Strategies
| Strategy | Training Speed | Final Loss | Complexity |
|----------|----------------|------------|------------|
| SGD | Slow | Good | Simple |
| Adam | Fast | Good | Medium |
| AdamW | Fast | Better | Medium |
| Muon (matrices only) | Very Fast | Best | High |
| **Hybrid (AdamW + Muon)** | **Very Fast** | **Best** | **High** |
nanochat's hybrid approach is cutting-edge!
## 8. Memory Optimization
### Gradient Checkpointing (Not used in nanochat)
Trade compute for memory:
- Don't store intermediate activations
- Recompute during backward pass
- 2× slower, but 10× less memory
### Optimizer State Management
AdamW stores:
- First moment (m): same size as parameters
- Second moment (v): same size as parameters
**Memory:** ~2× parameter size
For 270M param model:
- Parameters: 270M × 2 bytes (BF16) = 540 MB
- AdamW states: 270M × 8 bytes (FP32) = 2.16 GB
- Total: ~2.7 GB
### Fused Optimizers
```python
AdamW(..., fused=True) # Uses fused CUDA kernel
```
**Benefits:**
- Faster updates (single kernel launch)
- Less memory traffic
- ~10-20% speedup
## Next Steps
We've covered optimization! Next, we'll explore **Implementation Details** - practical coding techniques used throughout nanochat.