From 1144d186ed4bd7ea949bddca03612922402ab198 Mon Sep 17 00:00:00 2001 From: Andrej Karpathy Date: Thu, 5 Feb 2026 02:42:46 +0000 Subject: [PATCH] try and fail relu^2 -> swiglu --- dev/LOG.md | 19 ++++++++++++++++--- 1 file changed, 16 insertions(+), 3 deletions(-) diff --git a/dev/LOG.md b/dev/LOG.md index b344b238..02561ac7 100644 --- a/dev/LOG.md +++ b/dev/LOG.md @@ -6,11 +6,24 @@ A running summary documenting some experiments and findings. Started ~Jan 7 2026 ## 2026-02-05: SwiGLU Activation (Negative Result) -Replaced ReLU² MLP activation with SwiGLU (inspired by [twitter](https://x.com/_xjdr/status/2019141521690567058)). Implementation uses three projections (w1, w2, w3) with hidden_dim scaled to 8/3×n_embd to preserve both parameter count and FLOPs exactly (1.00x match on both). +Replaced ReLU² MLP activation with SwiGLU (inspired by [twitter](https://x.com/_xjdr/status/2019141521690567058)). SwiGLU uses three projections instead of two, so to match parameters and FLOPs we scale hidden_dim from 4× to 8/3×: ```python -# Old: x = c_proj(relu(c_fc(x)).square()) -# New: x = w3(silu(w1(x)) * w2(x)) +# Old ReLU²: 2 matrices, 4x expansion +# params: 2 × n × 4n = 8n² +# flops: 2 × 2n × 4n = 16n² per token +self.c_fc = Linear(n_embd, 4 * n_embd) +self.c_proj = Linear(4 * n_embd, n_embd) +x = c_proj(relu(c_fc(x)).square()) + +# New SwiGLU: 3 matrices, 8/3x expansion +# params: 2 × n × (8n/3) + (8n/3) × n = 8n² ✓ matches +# flops: 3 × 2n × (8n/3) = 16n² per token ✓ matches +hidden_dim = (8 * n_embd) // 3 +self.w1 = Linear(n_embd, hidden_dim) # gate +self.w2 = Linear(n_embd, hidden_dim) # up +self.w3 = Linear(hidden_dim, n_embd) # down +x = w3(silu(w1(x)) * w2(x)) ``` Tested at both d12 and d24 (GPT-2 scale). Worse on all measures — step efficiency, wall clock time, and FLOPs. ReLU² remains superior for nanochat. **Not adopted.**