mirror of
https://github.com/karpathy/nanochat.git
synced 2026-05-11 18:30:27 +00:00
try and fail relu^2 -> swiglu
This commit is contained in:
parent
d63b7ab9ac
commit
1144d186ed
19
dev/LOG.md
19
dev/LOG.md
|
|
@ -6,11 +6,24 @@ A running summary documenting some experiments and findings. Started ~Jan 7 2026
|
|||
|
||||
## 2026-02-05: SwiGLU Activation (Negative Result)
|
||||
|
||||
Replaced ReLU² MLP activation with SwiGLU (inspired by [twitter](https://x.com/_xjdr/status/2019141521690567058)). Implementation uses three projections (w1, w2, w3) with hidden_dim scaled to 8/3×n_embd to preserve both parameter count and FLOPs exactly (1.00x match on both).
|
||||
Replaced ReLU² MLP activation with SwiGLU (inspired by [twitter](https://x.com/_xjdr/status/2019141521690567058)). SwiGLU uses three projections instead of two, so to match parameters and FLOPs we scale hidden_dim from 4× to 8/3×:
|
||||
|
||||
```python
|
||||
# Old: x = c_proj(relu(c_fc(x)).square())
|
||||
# New: x = w3(silu(w1(x)) * w2(x))
|
||||
# Old ReLU²: 2 matrices, 4x expansion
|
||||
# params: 2 × n × 4n = 8n²
|
||||
# flops: 2 × 2n × 4n = 16n² per token
|
||||
self.c_fc = Linear(n_embd, 4 * n_embd)
|
||||
self.c_proj = Linear(4 * n_embd, n_embd)
|
||||
x = c_proj(relu(c_fc(x)).square())
|
||||
|
||||
# New SwiGLU: 3 matrices, 8/3x expansion
|
||||
# params: 2 × n × (8n/3) + (8n/3) × n = 8n² ✓ matches
|
||||
# flops: 3 × 2n × (8n/3) = 16n² per token ✓ matches
|
||||
hidden_dim = (8 * n_embd) // 3
|
||||
self.w1 = Linear(n_embd, hidden_dim) # gate
|
||||
self.w2 = Linear(n_embd, hidden_dim) # up
|
||||
self.w3 = Linear(hidden_dim, n_embd) # down
|
||||
x = w3(silu(w1(x)) * w2(x))
|
||||
```
|
||||
|
||||
Tested at both d12 and d24 (GPT-2 scale). Worse on all measures — step efficiency, wall clock time, and FLOPs. ReLU² remains superior for nanochat. **Not adopted.**
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user