nanochat/educational/07_optimization.md
Matt Suiche c5ef68cea2 Add comprehensive educational guide for nanochat
Created a complete educational resource covering the implementation of
nanochat from scratch, including:

- Mathematical foundations (linear algebra, optimization, attention)
- Tokenization with detailed BPE algorithm explanation
- Transformer architecture and GPT model implementation
- Self-attention mechanism with RoPE and Multi-Query Attention
- Training process, data loading, and distributed training
- Advanced optimization techniques (Muon + AdamW)
- Practical implementation guide with debugging tips
- Automated PDF compilation script

The guide includes deep code walkthroughs with line-by-line explanations
of key components, making it accessible for beginners while covering
advanced techniques used in modern LLMs.

Total content: ~4,300 lines across 8 chapters plus README and tooling.
PDF compilation available via compile_to_pdf.py script.
2025-10-21 18:36:26 +04:00

16 KiB
Raw Blame History

Advanced Optimization Techniques

nanochat uses a hybrid optimization strategy: combining Muon for matrix parameters and AdamW for embeddings. This is more sophisticated than standard approaches.

Why Different Optimizers?

Different parameter types have different optimization needs:

Parameter Type Examples Characteristics Best Optimizer
Matrices Attention, MLP Dense, high-dimensional Muon
Embeddings Token embeddings Sparse updates, embedding-specific AdamW
Vectors LM head Output layer, sparse AdamW

Traditional approach: Use AdamW for everything nanochat approach: Use Muon for matrices, AdamW for embeddings/head

Result: Faster training, better convergence

1. Muon Optimizer

Muon is a novel optimizer designed specifically for matrix parameters in neural networks.

Core Idea

Standard optimizers (SGD, Adam) treat matrices as flat vectors:

Matrix [3×4] → Flatten to vector [12] → Update

Muon exploits matrix structure:

Matrix [3×4] → Update using matrix operations → Keep matrix shape

Mathematical Formulation

For weight matrix W \in \mathbb{R}^{m \times n}:

Standard momentum:

v_t = \beta v_{t-1} + (1-\beta) g_t
W_t = W_{t-1} - \eta v_t

Muon:

  1. Compute gradient G_t = \nabla_W \mathcal{L}
  2. Orthogonalize using Newton-Schulz iteration
  3. Apply momentum in tangent space
  4. Update with adaptive step size

Implementation: nanochat/muon.py:53

class Muon(torch.optim.Optimizer):
    def __init__(self, params, lr=0.02, momentum=0.95):
        defaults = dict(lr=lr, momentum=momentum)
        super(Muon, self).__init__(params, defaults)

    @torch.no_grad()
    def step(self):
        for group in self.param_groups:
            lr = group['lr']
            momentum = group['momentum']

            for p in group['params']:
                if p.grad is None:
                    continue

                g = p.grad  # Gradient

                # Get state
                state = self.state[p]
                if 'momentum_buffer' not in state:
                    state['momentum_buffer'] = torch.zeros_like(g)

                buf = state['momentum_buffer']

                # Handle matrix vs non-matrix parameters
                if g.ndim == 2 and g.size(0) >= 16 and g.size(1) >= 16:
                    # Matrix parameter: use Muon update
                    g = newton_schulz_orthogonalize(g, steps=5)

                # Momentum update
                buf.mul_(momentum).add_(g)

                # Parameter update
                p.data.add_(buf, alpha=-lr)

Newton-Schulz Orthogonalization: nanochat/muon.py:16

def newton_schulz_orthogonalize(G, steps=5, eps=1e-7):
    """
    Orthogonalize gradient matrix using Newton-Schulz iteration
    """
    # Make square by padding or cropping
    a, b = G.size()
    if a > b:
        G = G[:b, :]
    elif a < b:
        G = G[:, :a]

    # Initialize
    # Normalization factor
    t = G.size(0)

    # X_0 = G / ||G||_F
    A = G / (G.norm() + eps)

    # Newton-Schulz iteration: X_{k+1} = X_k * (3I - X_k^T X_k) / 2
    for _ in range(steps):
        A_T_A = A.t() @ A
        A = A @ (1.5 * torch.eye(t, device=A.device, dtype=A.dtype) - 0.5 * A_T_A)

    # Restore original shape
    if a > b:
        A = torch.cat([A, torch.zeros(a - b, b, device=A.device, dtype=A.dtype)], dim=0)
    elif a < b:
        A = torch.cat([A, torch.zeros(a, b - a, device=A.device, dtype=A.dtype)], dim=1)

    return A

What does this do?

For a matrix G, find orthogonal matrix Q closest to G:

Q = \arg\min_{\tilde{Q}^T\tilde{Q}=I} \|G - \tilde{Q}\|_F

Uses iterative formula:

X_{k+1} = X_k \left(\frac{3I - X_k^TX_k}{2}\right)

Converges to Q = G(G^TG)^{-1/2} (the orthogonal component of $G$).

Why orthogonalize?

  • Keeps gradients on Stiefel manifold
  • Better geometry for optimization
  • Prevents gradient explosion/vanishing
  • Faster convergence

Distributed Muon: nanochat/muon.py:155

For multi-GPU training:

class DistMuon(Muon):
    def step(self):
        # First, average gradients across all GPUs
        for group in self.param_groups:
            for p in group['params']:
                if p.grad is not None:
                    torch.distributed.all_reduce(p.grad, op=torch.distributed.ReduceOp.AVG)

        # Then apply standard Muon update
        super().step()

Key: All-reduce gradients before Muon update ensures synchronization.

Muon Learning Rate Scaling

# From scripts/base_train.py:238
dmodel_lr_scale = (model_dim / 768) ** -0.5
lr_scaled = matrix_lr  # No scaling for Muon (handles it internally)

Muon is scale-invariant, so no need to scale LR by model dimension!

Momentum Schedule for Muon: scripts/base_train.py:160

def get_muon_momentum(it):
    """Warmup momentum from 0.85 to 0.95"""
    frac = min(it / 300, 1)
    momentum = (1 - frac) * 0.85 + frac * 0.95
    return momentum

Start with lower momentum (more responsive), increase to higher momentum (more stable).

2. AdamW Optimizer

AdamW is used for embedding and language model head parameters.

Standard Adam

Combines momentum and adaptive learning rates:

m_t = \beta_1 m_{t-1} + (1-\beta_1) g_t \quad \text{(first moment)}
v_t = \beta_2 v_{t-1} + (1-\beta_2) g_t^2 \quad \text{(second moment)}

Bias correction:

\hat{m}_t = \frac{m_t}{1-\beta_1^t}, \quad \hat{v}_t = \frac{v_t}{1-\beta_2^t}

Update:

\theta_t = \theta_{t-1} - \eta \frac{\hat{m}_t}{\sqrt{\hat{v}_t} + \epsilon}

AdamW: Decoupled Weight Decay

Adam with L2 regularization:

\mathcal{L}' = \mathcal{L} + \frac{\lambda}{2}\|\theta\|^2

Problem: Weight decay interacts with adaptive learning rate in weird ways.

AdamW solution: Decouple weight decay from gradient:

\theta_t = (1 - \lambda \eta) \theta_{t-1} - \eta \frac{\hat{m}_t}{\sqrt{\hat{v}_t} + \epsilon}

Benefits:

  • Cleaner regularization
  • Better generalization
  • Less hyperparameter interaction

Implementation: nanochat/adamw.py:53

class DistAdamW(torch.optim.Optimizer):
    def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0.0):
        defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
        super().__init__(params, defaults)

    @torch.no_grad()
    def step(self):
        # First, all-reduce gradients across GPUs
        for group in self.param_groups:
            for p in group['params']:
                if p.grad is not None:
                    torch.distributed.all_reduce(p.grad, op=torch.distributed.ReduceOp.AVG)

        # Then apply AdamW update
        for group in self.param_groups:
            lr = group['lr']
            beta1, beta2 = group['betas']
            eps = group['eps']
            weight_decay = group['weight_decay']

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad
                state = self.state[p]

                # Initialize state
                if len(state) == 0:
                    state['step'] = 0
                    state['exp_avg'] = torch.zeros_like(p)       # m_t
                    state['exp_avg_sq'] = torch.zeros_like(p)    # v_t

                state['step'] += 1
                exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
                step = state['step']

                # Update biased first and second moments
                exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
                exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)

                # Bias correction
                bias_correction1 = 1 - beta1 ** step
                bias_correction2 = 1 - beta2 ** step
                step_size = lr / bias_correction1

                # Compute denominator (with bias correction)
                denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(eps)

                # Weight decay (decoupled)
                if weight_decay != 0:
                    p.data.mul_(1 - lr * weight_decay)

                # Update parameters
                p.data.addcdiv_(exp_avg, denom, value=-step_size)

AdamW Hyperparameters in nanochat

# From scripts/base_train.py:228
adam_groups = [
    dict(params=lm_head_params, lr=unembedding_lr * dmodel_lr_scale),  # 0.004
    dict(params=embedding_params, lr=embedding_lr * dmodel_lr_scale),   # 0.2
]

adamw_kwargs = dict(
    betas=(0.8, 0.95),  # Instead of default (0.9, 0.999)
    eps=1e-10,
    weight_decay=weight_decay  # Usually 0.0 for small models
)

Why different betas?

  • \beta_1 = 0.8: Slightly less momentum (more responsive)
  • \beta_2 = 0.95: Much less variance accumulation (adapts faster)

This is better tuned for LLM training than defaults.

Learning Rate Scaling by Model Dimension

dmodel_lr_scale = (model_dim / 768) ** -0.5

# Example:
# model_dim = 1280 → scale = (1280/768)^{-0.5} ≈ 0.77
# model_dim = 384  → scale = (384/768)^{-0.5} ≈ 1.41

Why \propto 1/\sqrt{d_{model}}?

Larger models have larger gradients (sum over more dimensions). Scaling LR prevents instability.

3. Hybrid Optimizer Setup: nanochat/gpt.py:228

def setup_optimizers(self, unembedding_lr=0.004, embedding_lr=0.2, matrix_lr=0.02, weight_decay=0.0):
    model_dim = self.config.n_embd
    ddp, rank, _, _ = get_dist_info()

    # Separate parameters into groups
    matrix_params = list(self.transformer.h.parameters())        # All transformer blocks
    embedding_params = list(self.transformer.wte.parameters())   # Token embeddings
    lm_head_params = list(self.lm_head.parameters())            # Output layer

    # Scale learning rates
    dmodel_lr_scale = (model_dim / 768) ** -0.5

    # AdamW for embeddings and LM head
    adam_groups = [
        dict(params=lm_head_params, lr=unembedding_lr * dmodel_lr_scale),
        dict(params=embedding_params, lr=embedding_lr * dmodel_lr_scale),
    ]
    AdamWFactory = DistAdamW if ddp else partial(torch.optim.AdamW, fused=True)
    adamw_optimizer = AdamWFactory(adam_groups, betas=(0.8, 0.95), eps=1e-10, weight_decay=weight_decay)

    # Muon for transformer matrices
    MuonFactory = DistMuon if ddp else Muon
    muon_optimizer = MuonFactory(matrix_params, lr=matrix_lr, momentum=0.95)

    # Return both optimizers
    optimizers = [adamw_optimizer, muon_optimizer]
    return optimizers

Why different learning rates?

Parameter LR Reasoning
Embeddings 0.2 Sparse updates, can handle high LR
LM head 0.004 Dense gradients, needs lower LR
Matrices 0.02 Muon handles geometry, moderate LR

Stepping Multiple Optimizers: scripts/base_train.py:269

# Update learning rates for all optimizers
lrm = get_lr_multiplier(step)
for opt in optimizers:
    for group in opt.param_groups:
        group["lr"] = group["initial_lr"] * lrm

# Update Muon momentum
muon_momentum = get_muon_momentum(step)
for group in muon_optimizer.param_groups:
    group["momentum"] = muon_momentum

# Step all optimizers
for opt in optimizers:
    opt.step()

# Clear gradients
model.zero_grad(set_to_none=True)

Important: set_to_none=True saves memory compared to zeroing.

4. Gradient Clipping

Prevents exploding gradients during training.

Global Norm Clipping: scripts/base_train.py:265

if grad_clip > 0.0:
    torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)

How it works:

  1. Compute global gradient norm:
\|\mathbf{g}\|_{global} = \sqrt{\sum_{\theta \in \Theta} \|\nabla_\theta \mathcal{L}\|^2}
  1. If too large, scale all gradients:
\mathbf{g}_\theta \leftarrow \frac{\text{max\_norm}}{\|\mathbf{g}\|_{global}} \mathbf{g}_\theta

Effect: Limits maximum gradient magnitude without changing direction.

Implementation Details

def clip_grad_norm_(parameters, max_norm, norm_type=2):
    parameters = list(filter(lambda p: p.grad is not None, parameters))

    # Compute total norm
    total_norm = torch.norm(
        torch.stack([torch.norm(p.grad.detach(), norm_type) for p in parameters]),
        norm_type
    )

    # Compute clipping coefficient
    clip_coef = max_norm / (total_norm + 1e-6)

    # Clip if necessary
    if clip_coef < 1:
        for p in parameters:
            p.grad.detach().mul_(clip_coef)

    return total_norm

5. Warmup and Decay Schedules

Why Warmup?

At initialization:

  • Weights are random
  • Gradients can be very large
  • Adam's second moment estimate is inaccurate

Solution: Start with low LR, gradually increase.

Why Decay?

Near end of training:

  • Model is close to optimum
  • Small steps refine solution
  • Prevents oscillation

Solution: Gradually decrease LR to 0.

Schedule Implementation: scripts/base_train.py:148

warmup_ratio = 0.0      # Skip warmup for simplicity
warmdown_ratio = 0.2    # Last 20% of training
final_lr_frac = 0.0     # Decay to 0

def get_lr_multiplier(it):
    warmup_iters = round(warmup_ratio * num_iterations)
    warmdown_iters = round(warmdown_ratio * num_iterations)

    if it < warmup_iters:
        # Linear warmup
        return (it + 1) / warmup_iters
    elif it <= num_iterations - warmdown_iters:
        # Constant LR
        return 1.0
    else:
        # Linear warmdown
        progress = (num_iterations - it) / warmdown_iters
        return progress * 1.0 + (1 - progress) * final_lr_frac

Alternative schedules:

  • Cosine decay: Smoother than linear
  • Exponential decay: Aggressive reduction
  • Step decay: Discrete jumps

6. Optimization Best Practices

Learning Rate Tuning

Too high:

  • Training unstable
  • Loss oscillates or diverges
  • NaN loss

Too low:

  • Training very slow
  • Gets stuck in local minima
  • Underfits

Good LR:

  • Steady loss decrease
  • Occasional small oscillations
  • Converges smoothly

Finding Good LR: Learning Rate Range Test

# Start with very low LR, gradually increase
lrs = []
losses = []

lr = 1e-8
for step in range(1000):
    loss = train_step(lr)
    lrs.append(lr)
    losses.append(loss)
    lr *= 1.01  # Increase by 1%

# Plot losses vs LRs
# Good LR is where loss decreases fastest

Batch Size Effects

Larger batch size:

  • More stable gradients
  • Better GPU utilization
  • Can use higher LR
  • Slower wall-clock time per iteration
  • May generalize worse

Smaller batch size:

  • Noisier gradients (implicit regularization)
  • Less GPU efficient
  • Lower LR needed
  • Faster iterations

nanochat choice: 524K tokens/batch (very large for stability)

7. Comparison: Different Optimization Strategies

Strategy Training Speed Final Loss Complexity
SGD Slow Good Simple
Adam Fast Good Medium
AdamW Fast Better Medium
Muon (matrices only) Very Fast Best High
Hybrid (AdamW + Muon) Very Fast Best High

nanochat's hybrid approach is cutting-edge!

8. Memory Optimization

Gradient Checkpointing (Not used in nanochat)

Trade compute for memory:

  • Don't store intermediate activations
  • Recompute during backward pass
  • 2× slower, but 10× less memory

Optimizer State Management

AdamW stores:

  • First moment (m): same size as parameters
  • Second moment (v): same size as parameters

Memory: ~2× parameter size

For 270M param model:

  • Parameters: 270M × 2 bytes (BF16) = 540 MB
  • AdamW states: 270M × 8 bytes (FP32) = 2.16 GB
  • Total: ~2.7 GB

Fused Optimizers

AdamW(..., fused=True)  # Uses fused CUDA kernel

Benefits:

  • Faster updates (single kernel launch)
  • Less memory traffic
  • ~10-20% speedup

Next Steps

We've covered optimization! Next, we'll explore Implementation Details - practical coding techniques used throughout nanochat.