mirror of
https://github.com/karpathy/nanochat.git
synced 2026-02-21 19:10:24 +00:00
Created a complete educational resource covering the implementation of nanochat from scratch, including: - Mathematical foundations (linear algebra, optimization, attention) - Tokenization with detailed BPE algorithm explanation - Transformer architecture and GPT model implementation - Self-attention mechanism with RoPE and Multi-Query Attention - Training process, data loading, and distributed training - Advanced optimization techniques (Muon + AdamW) - Practical implementation guide with debugging tips - Automated PDF compilation script The guide includes deep code walkthroughs with line-by-line explanations of key components, making it accessible for beginners while covering advanced techniques used in modern LLMs. Total content: ~4,300 lines across 8 chapters plus README and tooling. PDF compilation available via compile_to_pdf.py script.
567 lines
16 KiB
Markdown
567 lines
16 KiB
Markdown
# Advanced Optimization Techniques
|
||
|
||
nanochat uses a **hybrid optimization** strategy: combining **Muon** for matrix parameters and **AdamW** for embeddings. This is more sophisticated than standard approaches.
|
||
|
||
## Why Different Optimizers?
|
||
|
||
Different parameter types have different optimization needs:
|
||
|
||
| Parameter Type | Examples | Characteristics | Best Optimizer |
|
||
|----------------|----------|-----------------|----------------|
|
||
| **Matrices** | Attention, MLP | Dense, high-dimensional | Muon |
|
||
| **Embeddings** | Token embeddings | Sparse updates, embedding-specific | AdamW |
|
||
| **Vectors** | LM head | Output layer, sparse | AdamW |
|
||
|
||
**Traditional approach:** Use AdamW for everything
|
||
**nanochat approach:** Use Muon for matrices, AdamW for embeddings/head
|
||
|
||
**Result:** Faster training, better convergence
|
||
|
||
## 1. Muon Optimizer
|
||
|
||
Muon is a novel optimizer designed specifically for **matrix parameters** in neural networks.
|
||
|
||
### Core Idea
|
||
|
||
Standard optimizers (SGD, Adam) treat matrices as flat vectors:
|
||
```
|
||
Matrix [3×4] → Flatten to vector [12] → Update
|
||
```
|
||
|
||
Muon exploits **matrix structure**:
|
||
```
|
||
Matrix [3×4] → Update using matrix operations → Keep matrix shape
|
||
```
|
||
|
||
### Mathematical Formulation
|
||
|
||
For weight matrix $W \in \mathbb{R}^{m \times n}$:
|
||
|
||
**Standard momentum:**
|
||
$$v_t = \beta v_{t-1} + (1-\beta) g_t$$
|
||
$$W_t = W_{t-1} - \eta v_t$$
|
||
|
||
**Muon:**
|
||
1. Compute gradient $G_t = \nabla_W \mathcal{L}$
|
||
2. Orthogonalize using Newton-Schulz iteration
|
||
3. Apply momentum in tangent space
|
||
4. Update with adaptive step size
|
||
|
||
### Implementation: `nanochat/muon.py:53`
|
||
|
||
```python
|
||
class Muon(torch.optim.Optimizer):
|
||
def __init__(self, params, lr=0.02, momentum=0.95):
|
||
defaults = dict(lr=lr, momentum=momentum)
|
||
super(Muon, self).__init__(params, defaults)
|
||
|
||
@torch.no_grad()
|
||
def step(self):
|
||
for group in self.param_groups:
|
||
lr = group['lr']
|
||
momentum = group['momentum']
|
||
|
||
for p in group['params']:
|
||
if p.grad is None:
|
||
continue
|
||
|
||
g = p.grad # Gradient
|
||
|
||
# Get state
|
||
state = self.state[p]
|
||
if 'momentum_buffer' not in state:
|
||
state['momentum_buffer'] = torch.zeros_like(g)
|
||
|
||
buf = state['momentum_buffer']
|
||
|
||
# Handle matrix vs non-matrix parameters
|
||
if g.ndim == 2 and g.size(0) >= 16 and g.size(1) >= 16:
|
||
# Matrix parameter: use Muon update
|
||
g = newton_schulz_orthogonalize(g, steps=5)
|
||
|
||
# Momentum update
|
||
buf.mul_(momentum).add_(g)
|
||
|
||
# Parameter update
|
||
p.data.add_(buf, alpha=-lr)
|
||
```
|
||
|
||
### Newton-Schulz Orthogonalization: `nanochat/muon.py:16`
|
||
|
||
```python
|
||
def newton_schulz_orthogonalize(G, steps=5, eps=1e-7):
|
||
"""
|
||
Orthogonalize gradient matrix using Newton-Schulz iteration
|
||
"""
|
||
# Make square by padding or cropping
|
||
a, b = G.size()
|
||
if a > b:
|
||
G = G[:b, :]
|
||
elif a < b:
|
||
G = G[:, :a]
|
||
|
||
# Initialize
|
||
# Normalization factor
|
||
t = G.size(0)
|
||
|
||
# X_0 = G / ||G||_F
|
||
A = G / (G.norm() + eps)
|
||
|
||
# Newton-Schulz iteration: X_{k+1} = X_k * (3I - X_k^T X_k) / 2
|
||
for _ in range(steps):
|
||
A_T_A = A.t() @ A
|
||
A = A @ (1.5 * torch.eye(t, device=A.device, dtype=A.dtype) - 0.5 * A_T_A)
|
||
|
||
# Restore original shape
|
||
if a > b:
|
||
A = torch.cat([A, torch.zeros(a - b, b, device=A.device, dtype=A.dtype)], dim=0)
|
||
elif a < b:
|
||
A = torch.cat([A, torch.zeros(a, b - a, device=A.device, dtype=A.dtype)], dim=1)
|
||
|
||
return A
|
||
```
|
||
|
||
**What does this do?**
|
||
|
||
For a matrix $G$, find orthogonal matrix $Q$ closest to $G$:
|
||
$$Q = \arg\min_{\tilde{Q}^T\tilde{Q}=I} \|G - \tilde{Q}\|_F$$
|
||
|
||
Uses iterative formula:
|
||
$$X_{k+1} = X_k \left(\frac{3I - X_k^TX_k}{2}\right)$$
|
||
|
||
Converges to $Q = G(G^TG)^{-1/2}$ (the orthogonal component of $G$).
|
||
|
||
**Why orthogonalize?**
|
||
- Keeps gradients on Stiefel manifold
|
||
- Better geometry for optimization
|
||
- Prevents gradient explosion/vanishing
|
||
- Faster convergence
|
||
|
||
### Distributed Muon: `nanochat/muon.py:155`
|
||
|
||
For multi-GPU training:
|
||
|
||
```python
|
||
class DistMuon(Muon):
|
||
def step(self):
|
||
# First, average gradients across all GPUs
|
||
for group in self.param_groups:
|
||
for p in group['params']:
|
||
if p.grad is not None:
|
||
torch.distributed.all_reduce(p.grad, op=torch.distributed.ReduceOp.AVG)
|
||
|
||
# Then apply standard Muon update
|
||
super().step()
|
||
```
|
||
|
||
**Key:** All-reduce gradients before Muon update ensures synchronization.
|
||
|
||
### Muon Learning Rate Scaling
|
||
|
||
```python
|
||
# From scripts/base_train.py:238
|
||
dmodel_lr_scale = (model_dim / 768) ** -0.5
|
||
lr_scaled = matrix_lr # No scaling for Muon (handles it internally)
|
||
```
|
||
|
||
Muon is **scale-invariant**, so no need to scale LR by model dimension!
|
||
|
||
### Momentum Schedule for Muon: `scripts/base_train.py:160`
|
||
|
||
```python
|
||
def get_muon_momentum(it):
|
||
"""Warmup momentum from 0.85 to 0.95"""
|
||
frac = min(it / 300, 1)
|
||
momentum = (1 - frac) * 0.85 + frac * 0.95
|
||
return momentum
|
||
```
|
||
|
||
Start with lower momentum (more responsive), increase to higher momentum (more stable).
|
||
|
||
## 2. AdamW Optimizer
|
||
|
||
AdamW is used for embedding and language model head parameters.
|
||
|
||
### Standard Adam
|
||
|
||
Combines **momentum** and **adaptive learning rates**:
|
||
|
||
$$m_t = \beta_1 m_{t-1} + (1-\beta_1) g_t \quad \text{(first moment)}$$
|
||
$$v_t = \beta_2 v_{t-1} + (1-\beta_2) g_t^2 \quad \text{(second moment)}$$
|
||
|
||
Bias correction:
|
||
$$\hat{m}_t = \frac{m_t}{1-\beta_1^t}, \quad \hat{v}_t = \frac{v_t}{1-\beta_2^t}$$
|
||
|
||
Update:
|
||
$$\theta_t = \theta_{t-1} - \eta \frac{\hat{m}_t}{\sqrt{\hat{v}_t} + \epsilon}$$
|
||
|
||
### AdamW: Decoupled Weight Decay
|
||
|
||
**Adam with L2 regularization:**
|
||
$$\mathcal{L}' = \mathcal{L} + \frac{\lambda}{2}\|\theta\|^2$$
|
||
|
||
**Problem:** Weight decay interacts with adaptive learning rate in weird ways.
|
||
|
||
**AdamW solution:** Decouple weight decay from gradient:
|
||
$$\theta_t = (1 - \lambda \eta) \theta_{t-1} - \eta \frac{\hat{m}_t}{\sqrt{\hat{v}_t} + \epsilon}$$
|
||
|
||
**Benefits:**
|
||
- Cleaner regularization
|
||
- Better generalization
|
||
- Less hyperparameter interaction
|
||
|
||
### Implementation: `nanochat/adamw.py:53`
|
||
|
||
```python
|
||
class DistAdamW(torch.optim.Optimizer):
|
||
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0.0):
|
||
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
|
||
super().__init__(params, defaults)
|
||
|
||
@torch.no_grad()
|
||
def step(self):
|
||
# First, all-reduce gradients across GPUs
|
||
for group in self.param_groups:
|
||
for p in group['params']:
|
||
if p.grad is not None:
|
||
torch.distributed.all_reduce(p.grad, op=torch.distributed.ReduceOp.AVG)
|
||
|
||
# Then apply AdamW update
|
||
for group in self.param_groups:
|
||
lr = group['lr']
|
||
beta1, beta2 = group['betas']
|
||
eps = group['eps']
|
||
weight_decay = group['weight_decay']
|
||
|
||
for p in group['params']:
|
||
if p.grad is None:
|
||
continue
|
||
|
||
grad = p.grad
|
||
state = self.state[p]
|
||
|
||
# Initialize state
|
||
if len(state) == 0:
|
||
state['step'] = 0
|
||
state['exp_avg'] = torch.zeros_like(p) # m_t
|
||
state['exp_avg_sq'] = torch.zeros_like(p) # v_t
|
||
|
||
state['step'] += 1
|
||
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
|
||
step = state['step']
|
||
|
||
# Update biased first and second moments
|
||
exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
|
||
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
|
||
|
||
# Bias correction
|
||
bias_correction1 = 1 - beta1 ** step
|
||
bias_correction2 = 1 - beta2 ** step
|
||
step_size = lr / bias_correction1
|
||
|
||
# Compute denominator (with bias correction)
|
||
denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(eps)
|
||
|
||
# Weight decay (decoupled)
|
||
if weight_decay != 0:
|
||
p.data.mul_(1 - lr * weight_decay)
|
||
|
||
# Update parameters
|
||
p.data.addcdiv_(exp_avg, denom, value=-step_size)
|
||
```
|
||
|
||
### AdamW Hyperparameters in nanochat
|
||
|
||
```python
|
||
# From scripts/base_train.py:228
|
||
adam_groups = [
|
||
dict(params=lm_head_params, lr=unembedding_lr * dmodel_lr_scale), # 0.004
|
||
dict(params=embedding_params, lr=embedding_lr * dmodel_lr_scale), # 0.2
|
||
]
|
||
|
||
adamw_kwargs = dict(
|
||
betas=(0.8, 0.95), # Instead of default (0.9, 0.999)
|
||
eps=1e-10,
|
||
weight_decay=weight_decay # Usually 0.0 for small models
|
||
)
|
||
```
|
||
|
||
**Why different betas?**
|
||
- $\beta_1 = 0.8$: Slightly less momentum (more responsive)
|
||
- $\beta_2 = 0.95$: Much less variance accumulation (adapts faster)
|
||
|
||
This is better tuned for LLM training than defaults.
|
||
|
||
### Learning Rate Scaling by Model Dimension
|
||
|
||
```python
|
||
dmodel_lr_scale = (model_dim / 768) ** -0.5
|
||
|
||
# Example:
|
||
# model_dim = 1280 → scale = (1280/768)^{-0.5} ≈ 0.77
|
||
# model_dim = 384 → scale = (384/768)^{-0.5} ≈ 1.41
|
||
```
|
||
|
||
**Why $\propto 1/\sqrt{d_{model}}$?**
|
||
|
||
Larger models have larger gradients (sum over more dimensions). Scaling LR prevents instability.
|
||
|
||
## 3. Hybrid Optimizer Setup: `nanochat/gpt.py:228`
|
||
|
||
```python
|
||
def setup_optimizers(self, unembedding_lr=0.004, embedding_lr=0.2, matrix_lr=0.02, weight_decay=0.0):
|
||
model_dim = self.config.n_embd
|
||
ddp, rank, _, _ = get_dist_info()
|
||
|
||
# Separate parameters into groups
|
||
matrix_params = list(self.transformer.h.parameters()) # All transformer blocks
|
||
embedding_params = list(self.transformer.wte.parameters()) # Token embeddings
|
||
lm_head_params = list(self.lm_head.parameters()) # Output layer
|
||
|
||
# Scale learning rates
|
||
dmodel_lr_scale = (model_dim / 768) ** -0.5
|
||
|
||
# AdamW for embeddings and LM head
|
||
adam_groups = [
|
||
dict(params=lm_head_params, lr=unembedding_lr * dmodel_lr_scale),
|
||
dict(params=embedding_params, lr=embedding_lr * dmodel_lr_scale),
|
||
]
|
||
AdamWFactory = DistAdamW if ddp else partial(torch.optim.AdamW, fused=True)
|
||
adamw_optimizer = AdamWFactory(adam_groups, betas=(0.8, 0.95), eps=1e-10, weight_decay=weight_decay)
|
||
|
||
# Muon for transformer matrices
|
||
MuonFactory = DistMuon if ddp else Muon
|
||
muon_optimizer = MuonFactory(matrix_params, lr=matrix_lr, momentum=0.95)
|
||
|
||
# Return both optimizers
|
||
optimizers = [adamw_optimizer, muon_optimizer]
|
||
return optimizers
|
||
```
|
||
|
||
**Why different learning rates?**
|
||
|
||
| Parameter | LR | Reasoning |
|
||
|-----------|-----|-----------|
|
||
| Embeddings | 0.2 | Sparse updates, can handle high LR |
|
||
| LM head | 0.004 | Dense gradients, needs lower LR |
|
||
| Matrices | 0.02 | Muon handles geometry, moderate LR |
|
||
|
||
### Stepping Multiple Optimizers: `scripts/base_train.py:269`
|
||
|
||
```python
|
||
# Update learning rates for all optimizers
|
||
lrm = get_lr_multiplier(step)
|
||
for opt in optimizers:
|
||
for group in opt.param_groups:
|
||
group["lr"] = group["initial_lr"] * lrm
|
||
|
||
# Update Muon momentum
|
||
muon_momentum = get_muon_momentum(step)
|
||
for group in muon_optimizer.param_groups:
|
||
group["momentum"] = muon_momentum
|
||
|
||
# Step all optimizers
|
||
for opt in optimizers:
|
||
opt.step()
|
||
|
||
# Clear gradients
|
||
model.zero_grad(set_to_none=True)
|
||
```
|
||
|
||
**Important:** `set_to_none=True` saves memory compared to zeroing.
|
||
|
||
## 4. Gradient Clipping
|
||
|
||
Prevents exploding gradients during training.
|
||
|
||
### Global Norm Clipping: `scripts/base_train.py:265`
|
||
|
||
```python
|
||
if grad_clip > 0.0:
|
||
torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
|
||
```
|
||
|
||
**How it works:**
|
||
|
||
1. Compute global gradient norm:
|
||
$$\|\mathbf{g}\|_{global} = \sqrt{\sum_{\theta \in \Theta} \|\nabla_\theta \mathcal{L}\|^2}$$
|
||
|
||
2. If too large, scale all gradients:
|
||
$$\mathbf{g}_\theta \leftarrow \frac{\text{max\_norm}}{\|\mathbf{g}\|_{global}} \mathbf{g}_\theta$$
|
||
|
||
**Effect:** Limits maximum gradient magnitude without changing direction.
|
||
|
||
### Implementation Details
|
||
|
||
```python
|
||
def clip_grad_norm_(parameters, max_norm, norm_type=2):
|
||
parameters = list(filter(lambda p: p.grad is not None, parameters))
|
||
|
||
# Compute total norm
|
||
total_norm = torch.norm(
|
||
torch.stack([torch.norm(p.grad.detach(), norm_type) for p in parameters]),
|
||
norm_type
|
||
)
|
||
|
||
# Compute clipping coefficient
|
||
clip_coef = max_norm / (total_norm + 1e-6)
|
||
|
||
# Clip if necessary
|
||
if clip_coef < 1:
|
||
for p in parameters:
|
||
p.grad.detach().mul_(clip_coef)
|
||
|
||
return total_norm
|
||
```
|
||
|
||
## 5. Warmup and Decay Schedules
|
||
|
||
### Why Warmup?
|
||
|
||
At initialization:
|
||
- Weights are random
|
||
- Gradients can be very large
|
||
- Adam's second moment estimate is inaccurate
|
||
|
||
**Solution:** Start with low LR, gradually increase.
|
||
|
||
### Why Decay?
|
||
|
||
Near end of training:
|
||
- Model is close to optimum
|
||
- Small steps refine solution
|
||
- Prevents oscillation
|
||
|
||
**Solution:** Gradually decrease LR to 0.
|
||
|
||
### Schedule Implementation: `scripts/base_train.py:148`
|
||
|
||
```python
|
||
warmup_ratio = 0.0 # Skip warmup for simplicity
|
||
warmdown_ratio = 0.2 # Last 20% of training
|
||
final_lr_frac = 0.0 # Decay to 0
|
||
|
||
def get_lr_multiplier(it):
|
||
warmup_iters = round(warmup_ratio * num_iterations)
|
||
warmdown_iters = round(warmdown_ratio * num_iterations)
|
||
|
||
if it < warmup_iters:
|
||
# Linear warmup
|
||
return (it + 1) / warmup_iters
|
||
elif it <= num_iterations - warmdown_iters:
|
||
# Constant LR
|
||
return 1.0
|
||
else:
|
||
# Linear warmdown
|
||
progress = (num_iterations - it) / warmdown_iters
|
||
return progress * 1.0 + (1 - progress) * final_lr_frac
|
||
```
|
||
|
||
**Alternative schedules:**
|
||
- Cosine decay: Smoother than linear
|
||
- Exponential decay: Aggressive reduction
|
||
- Step decay: Discrete jumps
|
||
|
||
## 6. Optimization Best Practices
|
||
|
||
### Learning Rate Tuning
|
||
|
||
**Too high:**
|
||
- Training unstable
|
||
- Loss oscillates or diverges
|
||
- NaN loss
|
||
|
||
**Too low:**
|
||
- Training very slow
|
||
- Gets stuck in local minima
|
||
- Underfits
|
||
|
||
**Good LR:**
|
||
- Steady loss decrease
|
||
- Occasional small oscillations
|
||
- Converges smoothly
|
||
|
||
### Finding Good LR: Learning Rate Range Test
|
||
|
||
```python
|
||
# Start with very low LR, gradually increase
|
||
lrs = []
|
||
losses = []
|
||
|
||
lr = 1e-8
|
||
for step in range(1000):
|
||
loss = train_step(lr)
|
||
lrs.append(lr)
|
||
losses.append(loss)
|
||
lr *= 1.01 # Increase by 1%
|
||
|
||
# Plot losses vs LRs
|
||
# Good LR is where loss decreases fastest
|
||
```
|
||
|
||
### Batch Size Effects
|
||
|
||
**Larger batch size:**
|
||
- More stable gradients
|
||
- Better GPU utilization
|
||
- Can use higher LR
|
||
- Slower wall-clock time per iteration
|
||
- May generalize worse
|
||
|
||
**Smaller batch size:**
|
||
- Noisier gradients (implicit regularization)
|
||
- Less GPU efficient
|
||
- Lower LR needed
|
||
- Faster iterations
|
||
|
||
**nanochat choice:** 524K tokens/batch (very large for stability)
|
||
|
||
## 7. Comparison: Different Optimization Strategies
|
||
|
||
| Strategy | Training Speed | Final Loss | Complexity |
|
||
|----------|----------------|------------|------------|
|
||
| SGD | Slow | Good | Simple |
|
||
| Adam | Fast | Good | Medium |
|
||
| AdamW | Fast | Better | Medium |
|
||
| Muon (matrices only) | Very Fast | Best | High |
|
||
| **Hybrid (AdamW + Muon)** | **Very Fast** | **Best** | **High** |
|
||
|
||
nanochat's hybrid approach is cutting-edge!
|
||
|
||
## 8. Memory Optimization
|
||
|
||
### Gradient Checkpointing (Not used in nanochat)
|
||
|
||
Trade compute for memory:
|
||
- Don't store intermediate activations
|
||
- Recompute during backward pass
|
||
- 2× slower, but 10× less memory
|
||
|
||
### Optimizer State Management
|
||
|
||
AdamW stores:
|
||
- First moment (m): same size as parameters
|
||
- Second moment (v): same size as parameters
|
||
|
||
**Memory:** ~2× parameter size
|
||
|
||
For 270M param model:
|
||
- Parameters: 270M × 2 bytes (BF16) = 540 MB
|
||
- AdamW states: 270M × 8 bytes (FP32) = 2.16 GB
|
||
- Total: ~2.7 GB
|
||
|
||
### Fused Optimizers
|
||
|
||
```python
|
||
AdamW(..., fused=True) # Uses fused CUDA kernel
|
||
```
|
||
|
||
**Benefits:**
|
||
- Faster updates (single kernel launch)
|
||
- Less memory traffic
|
||
- ~10-20% speedup
|
||
|
||
## Next Steps
|
||
|
||
We've covered optimization! Next, we'll explore **Implementation Details** - practical coding techniques used throughout nanochat.
|