wire up fused qkv and glu toggles

This commit is contained in:
Vilhelm Toivonen 2025-10-13 22:05:53 +03:00
parent 5fd0b13886
commit 47f7ffa25d
No known key found for this signature in database
GPG Key ID: 587AD4B7CF588708
3 changed files with 184 additions and 12 deletions

View File

@ -4,16 +4,18 @@ Notable features:
- rotary embeddings (and no positional embeddings)
- QK norm
- untied weights for token embedding and lm_head
- relu^2 activation in MLP
- relu^2 / gated MLPs with width scaling
- norm after token embedding
- no learnable params in rmsnorm
- no bias in linear layers
- Multi-Query Attention (MQA) support for more efficient inference
- Optional fused QKV projection for fewer matmuls
"""
import math
from functools import partial
from dataclasses import dataclass
from typing import Optional
import torch
import torch.nn as nn
@ -31,6 +33,10 @@ class GPTConfig:
n_head: int = 6 # number of query heads
n_kv_head: int = 6 # number of key/value heads (MQA)
n_embd: int = 768
use_fused_qkv: bool = True
mlp_type: str = "relu2" # choices: relu2, swiglu, geglu
mlp_width_mult: float = 4.0
mlp_glu_width_mult: Optional[float] = None
def norm(x):
@ -69,20 +75,31 @@ class CausalSelfAttention(nn.Module):
self.n_kv_head = config.n_kv_head
self.n_embd = config.n_embd
self.head_dim = self.n_embd // self.n_head
self.total_qkv_heads = self.n_head + 2 * self.n_kv_head
assert self.n_embd % self.n_head == 0
assert self.n_kv_head <= self.n_head and self.n_head % self.n_kv_head == 0
self.c_q = nn.Linear(self.n_embd, self.n_head * self.head_dim, bias=False)
self.c_k = nn.Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False)
self.c_v = nn.Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False)
self.use_fused_qkv = config.use_fused_qkv
if self.use_fused_qkv:
# Single projection keeps matmul count minimal; split happens in forward pass.
self.c_attn = nn.Linear(self.n_embd, self.total_qkv_heads * self.head_dim, bias=False)
else:
self.c_q = nn.Linear(self.n_embd, self.n_head * self.head_dim, bias=False)
self.c_k = nn.Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False)
self.c_v = nn.Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False)
self.c_proj = nn.Linear(self.n_embd, self.n_embd, bias=False)
def forward(self, x, cos_sin, kv_cache):
B, T, C = x.size()
# Project the input to get queries, keys, and values
q = self.c_q(x).view(B, T, self.n_head, self.head_dim)
k = self.c_k(x).view(B, T, self.n_kv_head, self.head_dim)
v = self.c_v(x).view(B, T, self.n_kv_head, self.head_dim)
if self.use_fused_qkv:
qkv = self.c_attn(x)
qkv = qkv.view(B, T, self.total_qkv_heads, self.head_dim)
q, k, v = torch.split(qkv, [self.n_head, self.n_kv_head, self.n_kv_head], dim=2)
else:
q = self.c_q(x).view(B, T, self.n_head, self.head_dim)
k = self.c_k(x).view(B, T, self.n_kv_head, self.head_dim)
v = self.c_v(x).view(B, T, self.n_kv_head, self.head_dim)
# Apply Rotary Embeddings to queries and keys to get relative positional encoding
cos, sin = cos_sin
@ -125,16 +142,63 @@ class CausalSelfAttention(nn.Module):
y = self.c_proj(y)
return y
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs):
fused_key = prefix + "c_attn.weight"
q_key = prefix + "c_q.weight"
k_key = prefix + "c_k.weight"
v_key = prefix + "c_v.weight"
has_fused = fused_key in state_dict
has_split = all(key in state_dict for key in (q_key, k_key, v_key))
if self.use_fused_qkv and not has_fused and has_split:
# Convert legacy split weights into the fused layout expected by this module.
q = state_dict.pop(q_key)
k = state_dict.pop(k_key)
v = state_dict.pop(v_key)
state_dict[fused_key] = torch.cat([q, k, v], dim=0)
elif not self.use_fused_qkv and has_fused and not has_split:
fused = state_dict.pop(fused_key)
q_out = self.n_head * self.head_dim
kv_out = self.n_kv_head * self.head_dim
state_dict[q_key] = fused[:q_out]
state_dict[k_key] = fused[q_out:q_out + kv_out]
state_dict[v_key] = fused[q_out + kv_out:]
super()._load_from_state_dict(state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs)
class MLP(nn.Module):
def __init__(self, config):
super().__init__()
self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=False)
self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=False)
self.mlp_type = config.mlp_type.lower()
base_hidden = max(1, int(config.mlp_width_mult * config.n_embd))
if self.mlp_type in {"swiglu", "geglu"}:
width_mult = config.mlp_glu_width_mult
if width_mult is None:
# 2/3 factor keeps parameter count aligned with the ReLU^2 baseline.
width_mult = config.mlp_width_mult * (2.0 / 3.0)
hidden_dim = max(1, int(width_mult * config.n_embd))
self.c_fc = nn.Linear(config.n_embd, 2 * hidden_dim, bias=False)
self.c_proj = nn.Linear(hidden_dim, config.n_embd, bias=False)
self.hidden_dim = hidden_dim
else:
hidden_dim = base_hidden
self.c_fc = nn.Linear(config.n_embd, hidden_dim, bias=False)
self.c_proj = nn.Linear(hidden_dim, config.n_embd, bias=False)
self.hidden_dim = hidden_dim
def forward(self, x):
x = self.c_fc(x)
x = F.relu(x).square()
if self.mlp_type == "swiglu":
x1, x2 = torch.chunk(x, 2, dim=-1)
x = F.silu(x1) * x2
elif self.mlp_type == "geglu":
x1, x2 = torch.chunk(x, 2, dim=-1)
x = F.gelu(x1, approximate="tanh") * x2
else:
x = F.relu(x).square()
x = self.c_proj(x)
return x
@ -169,8 +233,9 @@ class GPT(nn.Module):
cos, sin = self._precompute_rotary_embeddings(self.rotary_seq_len, head_dim)
self.register_buffer("cos", cos, persistent=False) # persistent=False means it's not saved to the checkpoint
self.register_buffer("sin", sin, persistent=False)
# Cast the embeddings from fp32 to bf16: optim can tolerate it and it saves memory: both in the model and the activations
self.transformer.wte.to(dtype=torch.bfloat16)
# Cast embeddings to bf16 only when CUDA is available to keep CPU inference/tests simple.
if torch.cuda.is_available():
self.transformer.wte.to(dtype=torch.bfloat16)
def init_weights(self):
self.apply(self._init_weights)

View File

@ -0,0 +1,59 @@
import torch
from dataclasses import replace
from nanochat.gpt import GPT, GPTConfig
def _tiny_config(**kwargs):
base = GPTConfig(
sequence_len=16,
vocab_size=128,
n_layer=1,
n_head=4,
n_kv_head=2,
n_embd=64,
use_fused_qkv=False,
mlp_type="relu2",
)
return replace(base, **kwargs)
def test_fused_qkv_matches_legacy_split_projection():
torch.manual_seed(0)
split_cfg = _tiny_config(use_fused_qkv=False)
fused_cfg = replace(split_cfg, use_fused_qkv=True)
split_model = GPT(split_cfg)
split_model.init_weights()
state = split_model.state_dict()
fused_model = GPT(fused_cfg)
fused_model.init_weights()
fused_model.load_state_dict(state, strict=True)
tokens = torch.randint(0, split_cfg.vocab_size, (2, 5))
with torch.no_grad():
logits_split = split_model(tokens)
logits_fused = fused_model(tokens)
assert torch.allclose(logits_split, logits_fused, atol=1e-5)
def test_split_loads_from_fused_state_dict():
torch.manual_seed(1)
fused_cfg = _tiny_config(use_fused_qkv=True)
fused_model = GPT(fused_cfg)
fused_model.init_weights()
state = fused_model.state_dict()
split_cfg = replace(fused_cfg, use_fused_qkv=False)
split_model = GPT(split_cfg)
split_model.init_weights()
split_model.load_state_dict(state, strict=True)
tokens = torch.randint(0, split_cfg.vocab_size, (1, 7))
with torch.no_grad():
logits_split = split_model(tokens)
logits_fused = fused_model(tokens)
assert torch.allclose(logits_split, logits_fused, atol=1e-5)

View File

@ -0,0 +1,48 @@
import torch
from dataclasses import replace
from nanochat.gpt import GPTConfig, MLP
def _mlp_config(**kwargs):
cfg = GPTConfig(
n_embd=32,
mlp_type="relu2",
mlp_width_mult=4.0,
mlp_glu_width_mult=None,
n_layer=1,
n_head=4,
n_kv_head=4,
sequence_len=8,
vocab_size=128,
)
return replace(cfg, **kwargs)
def test_relu2_mlp_shape():
cfg = _mlp_config(mlp_type="relu2")
mlp = MLP(cfg)
assert mlp.c_fc.weight.shape[0] == int(cfg.mlp_width_mult * cfg.n_embd)
x = torch.randn(2, 5, cfg.n_embd)
out = mlp(x)
assert out.shape == x.shape
def test_swiglu_width_scaling_defaults_to_two_thirds():
cfg = _mlp_config(mlp_type="swiglu")
mlp = MLP(cfg)
expected_hidden = int(cfg.n_embd * cfg.mlp_width_mult * (2.0 / 3.0))
assert mlp.c_fc.weight.shape[0] == 2 * expected_hidden
x = torch.randn(4, 3, cfg.n_embd)
out = mlp(x)
assert out.shape == x.shape
def test_geglu_respects_custom_width():
cfg = _mlp_config(mlp_type="geglu", mlp_glu_width_mult=0.5)
mlp = MLP(cfg)
expected_hidden = int(cfg.n_embd * 0.5)
assert mlp.c_fc.weight.shape[0] == 2 * expected_hidden
x = torch.randn(3, 2, cfg.n_embd)
out = mlp(x)
assert out.shape == x.shape