wire up fused qkv and glu toggles

This commit is contained in:
Vilhelm Toivonen 2025-10-13 22:05:53 +03:00
parent 5fd0b13886
commit 47f7ffa25d
No known key found for this signature in database
GPG Key ID: 587AD4B7CF588708
3 changed files with 184 additions and 12 deletions

View File

@ -4,16 +4,18 @@ Notable features:
- rotary embeddings (and no positional embeddings) - rotary embeddings (and no positional embeddings)
- QK norm - QK norm
- untied weights for token embedding and lm_head - untied weights for token embedding and lm_head
- relu^2 activation in MLP - relu^2 / gated MLPs with width scaling
- norm after token embedding - norm after token embedding
- no learnable params in rmsnorm - no learnable params in rmsnorm
- no bias in linear layers - no bias in linear layers
- Multi-Query Attention (MQA) support for more efficient inference - Multi-Query Attention (MQA) support for more efficient inference
- Optional fused QKV projection for fewer matmuls
""" """
import math import math
from functools import partial from functools import partial
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional
import torch import torch
import torch.nn as nn import torch.nn as nn
@ -31,6 +33,10 @@ class GPTConfig:
n_head: int = 6 # number of query heads n_head: int = 6 # number of query heads
n_kv_head: int = 6 # number of key/value heads (MQA) n_kv_head: int = 6 # number of key/value heads (MQA)
n_embd: int = 768 n_embd: int = 768
use_fused_qkv: bool = True
mlp_type: str = "relu2" # choices: relu2, swiglu, geglu
mlp_width_mult: float = 4.0
mlp_glu_width_mult: Optional[float] = None
def norm(x): def norm(x):
@ -69,8 +75,14 @@ class CausalSelfAttention(nn.Module):
self.n_kv_head = config.n_kv_head self.n_kv_head = config.n_kv_head
self.n_embd = config.n_embd self.n_embd = config.n_embd
self.head_dim = self.n_embd // self.n_head self.head_dim = self.n_embd // self.n_head
self.total_qkv_heads = self.n_head + 2 * self.n_kv_head
assert self.n_embd % self.n_head == 0 assert self.n_embd % self.n_head == 0
assert self.n_kv_head <= self.n_head and self.n_head % self.n_kv_head == 0 assert self.n_kv_head <= self.n_head and self.n_head % self.n_kv_head == 0
self.use_fused_qkv = config.use_fused_qkv
if self.use_fused_qkv:
# Single projection keeps matmul count minimal; split happens in forward pass.
self.c_attn = nn.Linear(self.n_embd, self.total_qkv_heads * self.head_dim, bias=False)
else:
self.c_q = nn.Linear(self.n_embd, self.n_head * self.head_dim, bias=False) self.c_q = nn.Linear(self.n_embd, self.n_head * self.head_dim, bias=False)
self.c_k = nn.Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False) self.c_k = nn.Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False)
self.c_v = nn.Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False) self.c_v = nn.Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False)
@ -80,6 +92,11 @@ class CausalSelfAttention(nn.Module):
B, T, C = x.size() B, T, C = x.size()
# Project the input to get queries, keys, and values # Project the input to get queries, keys, and values
if self.use_fused_qkv:
qkv = self.c_attn(x)
qkv = qkv.view(B, T, self.total_qkv_heads, self.head_dim)
q, k, v = torch.split(qkv, [self.n_head, self.n_kv_head, self.n_kv_head], dim=2)
else:
q = self.c_q(x).view(B, T, self.n_head, self.head_dim) q = self.c_q(x).view(B, T, self.n_head, self.head_dim)
k = self.c_k(x).view(B, T, self.n_kv_head, self.head_dim) k = self.c_k(x).view(B, T, self.n_kv_head, self.head_dim)
v = self.c_v(x).view(B, T, self.n_kv_head, self.head_dim) v = self.c_v(x).view(B, T, self.n_kv_head, self.head_dim)
@ -125,15 +142,62 @@ class CausalSelfAttention(nn.Module):
y = self.c_proj(y) y = self.c_proj(y)
return y return y
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs):
fused_key = prefix + "c_attn.weight"
q_key = prefix + "c_q.weight"
k_key = prefix + "c_k.weight"
v_key = prefix + "c_v.weight"
has_fused = fused_key in state_dict
has_split = all(key in state_dict for key in (q_key, k_key, v_key))
if self.use_fused_qkv and not has_fused and has_split:
# Convert legacy split weights into the fused layout expected by this module.
q = state_dict.pop(q_key)
k = state_dict.pop(k_key)
v = state_dict.pop(v_key)
state_dict[fused_key] = torch.cat([q, k, v], dim=0)
elif not self.use_fused_qkv and has_fused and not has_split:
fused = state_dict.pop(fused_key)
q_out = self.n_head * self.head_dim
kv_out = self.n_kv_head * self.head_dim
state_dict[q_key] = fused[:q_out]
state_dict[k_key] = fused[q_out:q_out + kv_out]
state_dict[v_key] = fused[q_out + kv_out:]
super()._load_from_state_dict(state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs)
class MLP(nn.Module): class MLP(nn.Module):
def __init__(self, config): def __init__(self, config):
super().__init__() super().__init__()
self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=False) self.mlp_type = config.mlp_type.lower()
self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=False) base_hidden = max(1, int(config.mlp_width_mult * config.n_embd))
if self.mlp_type in {"swiglu", "geglu"}:
width_mult = config.mlp_glu_width_mult
if width_mult is None:
# 2/3 factor keeps parameter count aligned with the ReLU^2 baseline.
width_mult = config.mlp_width_mult * (2.0 / 3.0)
hidden_dim = max(1, int(width_mult * config.n_embd))
self.c_fc = nn.Linear(config.n_embd, 2 * hidden_dim, bias=False)
self.c_proj = nn.Linear(hidden_dim, config.n_embd, bias=False)
self.hidden_dim = hidden_dim
else:
hidden_dim = base_hidden
self.c_fc = nn.Linear(config.n_embd, hidden_dim, bias=False)
self.c_proj = nn.Linear(hidden_dim, config.n_embd, bias=False)
self.hidden_dim = hidden_dim
def forward(self, x): def forward(self, x):
x = self.c_fc(x) x = self.c_fc(x)
if self.mlp_type == "swiglu":
x1, x2 = torch.chunk(x, 2, dim=-1)
x = F.silu(x1) * x2
elif self.mlp_type == "geglu":
x1, x2 = torch.chunk(x, 2, dim=-1)
x = F.gelu(x1, approximate="tanh") * x2
else:
x = F.relu(x).square() x = F.relu(x).square()
x = self.c_proj(x) x = self.c_proj(x)
return x return x
@ -169,7 +233,8 @@ class GPT(nn.Module):
cos, sin = self._precompute_rotary_embeddings(self.rotary_seq_len, head_dim) cos, sin = self._precompute_rotary_embeddings(self.rotary_seq_len, head_dim)
self.register_buffer("cos", cos, persistent=False) # persistent=False means it's not saved to the checkpoint self.register_buffer("cos", cos, persistent=False) # persistent=False means it's not saved to the checkpoint
self.register_buffer("sin", sin, persistent=False) self.register_buffer("sin", sin, persistent=False)
# Cast the embeddings from fp32 to bf16: optim can tolerate it and it saves memory: both in the model and the activations # Cast embeddings to bf16 only when CUDA is available to keep CPU inference/tests simple.
if torch.cuda.is_available():
self.transformer.wte.to(dtype=torch.bfloat16) self.transformer.wte.to(dtype=torch.bfloat16)
def init_weights(self): def init_weights(self):

View File

@ -0,0 +1,59 @@
import torch
from dataclasses import replace
from nanochat.gpt import GPT, GPTConfig
def _tiny_config(**kwargs):
base = GPTConfig(
sequence_len=16,
vocab_size=128,
n_layer=1,
n_head=4,
n_kv_head=2,
n_embd=64,
use_fused_qkv=False,
mlp_type="relu2",
)
return replace(base, **kwargs)
def test_fused_qkv_matches_legacy_split_projection():
torch.manual_seed(0)
split_cfg = _tiny_config(use_fused_qkv=False)
fused_cfg = replace(split_cfg, use_fused_qkv=True)
split_model = GPT(split_cfg)
split_model.init_weights()
state = split_model.state_dict()
fused_model = GPT(fused_cfg)
fused_model.init_weights()
fused_model.load_state_dict(state, strict=True)
tokens = torch.randint(0, split_cfg.vocab_size, (2, 5))
with torch.no_grad():
logits_split = split_model(tokens)
logits_fused = fused_model(tokens)
assert torch.allclose(logits_split, logits_fused, atol=1e-5)
def test_split_loads_from_fused_state_dict():
torch.manual_seed(1)
fused_cfg = _tiny_config(use_fused_qkv=True)
fused_model = GPT(fused_cfg)
fused_model.init_weights()
state = fused_model.state_dict()
split_cfg = replace(fused_cfg, use_fused_qkv=False)
split_model = GPT(split_cfg)
split_model.init_weights()
split_model.load_state_dict(state, strict=True)
tokens = torch.randint(0, split_cfg.vocab_size, (1, 7))
with torch.no_grad():
logits_split = split_model(tokens)
logits_fused = fused_model(tokens)
assert torch.allclose(logits_split, logits_fused, atol=1e-5)

View File

@ -0,0 +1,48 @@
import torch
from dataclasses import replace
from nanochat.gpt import GPTConfig, MLP
def _mlp_config(**kwargs):
cfg = GPTConfig(
n_embd=32,
mlp_type="relu2",
mlp_width_mult=4.0,
mlp_glu_width_mult=None,
n_layer=1,
n_head=4,
n_kv_head=4,
sequence_len=8,
vocab_size=128,
)
return replace(cfg, **kwargs)
def test_relu2_mlp_shape():
cfg = _mlp_config(mlp_type="relu2")
mlp = MLP(cfg)
assert mlp.c_fc.weight.shape[0] == int(cfg.mlp_width_mult * cfg.n_embd)
x = torch.randn(2, 5, cfg.n_embd)
out = mlp(x)
assert out.shape == x.shape
def test_swiglu_width_scaling_defaults_to_two_thirds():
cfg = _mlp_config(mlp_type="swiglu")
mlp = MLP(cfg)
expected_hidden = int(cfg.n_embd * cfg.mlp_width_mult * (2.0 / 3.0))
assert mlp.c_fc.weight.shape[0] == 2 * expected_hidden
x = torch.randn(4, 3, cfg.n_embd)
out = mlp(x)
assert out.shape == x.shape
def test_geglu_respects_custom_width():
cfg = _mlp_config(mlp_type="geglu", mlp_glu_width_mult=0.5)
mlp = MLP(cfg)
expected_hidden = int(cfg.n_embd * 0.5)
assert mlp.c_fc.weight.shape[0] == 2 * expected_hidden
x = torch.randn(3, 2, cfg.n_embd)
out = mlp(x)
assert out.shape == x.shape