mirror of
https://github.com/karpathy/nanochat.git
synced 2025-12-06 12:22:18 +00:00
wire up fused qkv and glu toggles
This commit is contained in:
parent
5fd0b13886
commit
47f7ffa25d
|
|
@ -4,16 +4,18 @@ Notable features:
|
|||
- rotary embeddings (and no positional embeddings)
|
||||
- QK norm
|
||||
- untied weights for token embedding and lm_head
|
||||
- relu^2 activation in MLP
|
||||
- relu^2 / gated MLPs with width scaling
|
||||
- norm after token embedding
|
||||
- no learnable params in rmsnorm
|
||||
- no bias in linear layers
|
||||
- Multi-Query Attention (MQA) support for more efficient inference
|
||||
- Optional fused QKV projection for fewer matmuls
|
||||
"""
|
||||
|
||||
import math
|
||||
from functools import partial
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
|
@ -31,6 +33,10 @@ class GPTConfig:
|
|||
n_head: int = 6 # number of query heads
|
||||
n_kv_head: int = 6 # number of key/value heads (MQA)
|
||||
n_embd: int = 768
|
||||
use_fused_qkv: bool = True
|
||||
mlp_type: str = "relu2" # choices: relu2, swiglu, geglu
|
||||
mlp_width_mult: float = 4.0
|
||||
mlp_glu_width_mult: Optional[float] = None
|
||||
|
||||
|
||||
def norm(x):
|
||||
|
|
@ -69,8 +75,14 @@ class CausalSelfAttention(nn.Module):
|
|||
self.n_kv_head = config.n_kv_head
|
||||
self.n_embd = config.n_embd
|
||||
self.head_dim = self.n_embd // self.n_head
|
||||
self.total_qkv_heads = self.n_head + 2 * self.n_kv_head
|
||||
assert self.n_embd % self.n_head == 0
|
||||
assert self.n_kv_head <= self.n_head and self.n_head % self.n_kv_head == 0
|
||||
self.use_fused_qkv = config.use_fused_qkv
|
||||
if self.use_fused_qkv:
|
||||
# Single projection keeps matmul count minimal; split happens in forward pass.
|
||||
self.c_attn = nn.Linear(self.n_embd, self.total_qkv_heads * self.head_dim, bias=False)
|
||||
else:
|
||||
self.c_q = nn.Linear(self.n_embd, self.n_head * self.head_dim, bias=False)
|
||||
self.c_k = nn.Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False)
|
||||
self.c_v = nn.Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False)
|
||||
|
|
@ -80,6 +92,11 @@ class CausalSelfAttention(nn.Module):
|
|||
B, T, C = x.size()
|
||||
|
||||
# Project the input to get queries, keys, and values
|
||||
if self.use_fused_qkv:
|
||||
qkv = self.c_attn(x)
|
||||
qkv = qkv.view(B, T, self.total_qkv_heads, self.head_dim)
|
||||
q, k, v = torch.split(qkv, [self.n_head, self.n_kv_head, self.n_kv_head], dim=2)
|
||||
else:
|
||||
q = self.c_q(x).view(B, T, self.n_head, self.head_dim)
|
||||
k = self.c_k(x).view(B, T, self.n_kv_head, self.head_dim)
|
||||
v = self.c_v(x).view(B, T, self.n_kv_head, self.head_dim)
|
||||
|
|
@ -125,15 +142,62 @@ class CausalSelfAttention(nn.Module):
|
|||
y = self.c_proj(y)
|
||||
return y
|
||||
|
||||
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
|
||||
missing_keys, unexpected_keys, error_msgs):
|
||||
fused_key = prefix + "c_attn.weight"
|
||||
q_key = prefix + "c_q.weight"
|
||||
k_key = prefix + "c_k.weight"
|
||||
v_key = prefix + "c_v.weight"
|
||||
has_fused = fused_key in state_dict
|
||||
has_split = all(key in state_dict for key in (q_key, k_key, v_key))
|
||||
|
||||
if self.use_fused_qkv and not has_fused and has_split:
|
||||
# Convert legacy split weights into the fused layout expected by this module.
|
||||
q = state_dict.pop(q_key)
|
||||
k = state_dict.pop(k_key)
|
||||
v = state_dict.pop(v_key)
|
||||
state_dict[fused_key] = torch.cat([q, k, v], dim=0)
|
||||
elif not self.use_fused_qkv and has_fused and not has_split:
|
||||
fused = state_dict.pop(fused_key)
|
||||
q_out = self.n_head * self.head_dim
|
||||
kv_out = self.n_kv_head * self.head_dim
|
||||
state_dict[q_key] = fused[:q_out]
|
||||
state_dict[k_key] = fused[q_out:q_out + kv_out]
|
||||
state_dict[v_key] = fused[q_out + kv_out:]
|
||||
|
||||
super()._load_from_state_dict(state_dict, prefix, local_metadata, strict,
|
||||
missing_keys, unexpected_keys, error_msgs)
|
||||
|
||||
|
||||
class MLP(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=False)
|
||||
self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=False)
|
||||
self.mlp_type = config.mlp_type.lower()
|
||||
base_hidden = max(1, int(config.mlp_width_mult * config.n_embd))
|
||||
if self.mlp_type in {"swiglu", "geglu"}:
|
||||
width_mult = config.mlp_glu_width_mult
|
||||
if width_mult is None:
|
||||
# 2/3 factor keeps parameter count aligned with the ReLU^2 baseline.
|
||||
width_mult = config.mlp_width_mult * (2.0 / 3.0)
|
||||
hidden_dim = max(1, int(width_mult * config.n_embd))
|
||||
self.c_fc = nn.Linear(config.n_embd, 2 * hidden_dim, bias=False)
|
||||
self.c_proj = nn.Linear(hidden_dim, config.n_embd, bias=False)
|
||||
self.hidden_dim = hidden_dim
|
||||
else:
|
||||
hidden_dim = base_hidden
|
||||
self.c_fc = nn.Linear(config.n_embd, hidden_dim, bias=False)
|
||||
self.c_proj = nn.Linear(hidden_dim, config.n_embd, bias=False)
|
||||
self.hidden_dim = hidden_dim
|
||||
|
||||
def forward(self, x):
|
||||
x = self.c_fc(x)
|
||||
if self.mlp_type == "swiglu":
|
||||
x1, x2 = torch.chunk(x, 2, dim=-1)
|
||||
x = F.silu(x1) * x2
|
||||
elif self.mlp_type == "geglu":
|
||||
x1, x2 = torch.chunk(x, 2, dim=-1)
|
||||
x = F.gelu(x1, approximate="tanh") * x2
|
||||
else:
|
||||
x = F.relu(x).square()
|
||||
x = self.c_proj(x)
|
||||
return x
|
||||
|
|
@ -169,7 +233,8 @@ class GPT(nn.Module):
|
|||
cos, sin = self._precompute_rotary_embeddings(self.rotary_seq_len, head_dim)
|
||||
self.register_buffer("cos", cos, persistent=False) # persistent=False means it's not saved to the checkpoint
|
||||
self.register_buffer("sin", sin, persistent=False)
|
||||
# Cast the embeddings from fp32 to bf16: optim can tolerate it and it saves memory: both in the model and the activations
|
||||
# Cast embeddings to bf16 only when CUDA is available to keep CPU inference/tests simple.
|
||||
if torch.cuda.is_available():
|
||||
self.transformer.wte.to(dtype=torch.bfloat16)
|
||||
|
||||
def init_weights(self):
|
||||
|
|
|
|||
59
tests/test_attention_fusion.py
Normal file
59
tests/test_attention_fusion.py
Normal file
|
|
@ -0,0 +1,59 @@
|
|||
import torch
|
||||
from dataclasses import replace
|
||||
|
||||
from nanochat.gpt import GPT, GPTConfig
|
||||
|
||||
|
||||
def _tiny_config(**kwargs):
|
||||
base = GPTConfig(
|
||||
sequence_len=16,
|
||||
vocab_size=128,
|
||||
n_layer=1,
|
||||
n_head=4,
|
||||
n_kv_head=2,
|
||||
n_embd=64,
|
||||
use_fused_qkv=False,
|
||||
mlp_type="relu2",
|
||||
)
|
||||
return replace(base, **kwargs)
|
||||
|
||||
|
||||
def test_fused_qkv_matches_legacy_split_projection():
|
||||
torch.manual_seed(0)
|
||||
split_cfg = _tiny_config(use_fused_qkv=False)
|
||||
fused_cfg = replace(split_cfg, use_fused_qkv=True)
|
||||
|
||||
split_model = GPT(split_cfg)
|
||||
split_model.init_weights()
|
||||
state = split_model.state_dict()
|
||||
|
||||
fused_model = GPT(fused_cfg)
|
||||
fused_model.init_weights()
|
||||
fused_model.load_state_dict(state, strict=True)
|
||||
|
||||
tokens = torch.randint(0, split_cfg.vocab_size, (2, 5))
|
||||
with torch.no_grad():
|
||||
logits_split = split_model(tokens)
|
||||
logits_fused = fused_model(tokens)
|
||||
|
||||
assert torch.allclose(logits_split, logits_fused, atol=1e-5)
|
||||
|
||||
|
||||
def test_split_loads_from_fused_state_dict():
|
||||
torch.manual_seed(1)
|
||||
fused_cfg = _tiny_config(use_fused_qkv=True)
|
||||
fused_model = GPT(fused_cfg)
|
||||
fused_model.init_weights()
|
||||
state = fused_model.state_dict()
|
||||
|
||||
split_cfg = replace(fused_cfg, use_fused_qkv=False)
|
||||
split_model = GPT(split_cfg)
|
||||
split_model.init_weights()
|
||||
split_model.load_state_dict(state, strict=True)
|
||||
|
||||
tokens = torch.randint(0, split_cfg.vocab_size, (1, 7))
|
||||
with torch.no_grad():
|
||||
logits_split = split_model(tokens)
|
||||
logits_fused = fused_model(tokens)
|
||||
|
||||
assert torch.allclose(logits_split, logits_fused, atol=1e-5)
|
||||
48
tests/test_mlp_variants.py
Normal file
48
tests/test_mlp_variants.py
Normal file
|
|
@ -0,0 +1,48 @@
|
|||
import torch
|
||||
from dataclasses import replace
|
||||
|
||||
from nanochat.gpt import GPTConfig, MLP
|
||||
|
||||
|
||||
def _mlp_config(**kwargs):
|
||||
cfg = GPTConfig(
|
||||
n_embd=32,
|
||||
mlp_type="relu2",
|
||||
mlp_width_mult=4.0,
|
||||
mlp_glu_width_mult=None,
|
||||
n_layer=1,
|
||||
n_head=4,
|
||||
n_kv_head=4,
|
||||
sequence_len=8,
|
||||
vocab_size=128,
|
||||
)
|
||||
return replace(cfg, **kwargs)
|
||||
|
||||
|
||||
def test_relu2_mlp_shape():
|
||||
cfg = _mlp_config(mlp_type="relu2")
|
||||
mlp = MLP(cfg)
|
||||
assert mlp.c_fc.weight.shape[0] == int(cfg.mlp_width_mult * cfg.n_embd)
|
||||
x = torch.randn(2, 5, cfg.n_embd)
|
||||
out = mlp(x)
|
||||
assert out.shape == x.shape
|
||||
|
||||
|
||||
def test_swiglu_width_scaling_defaults_to_two_thirds():
|
||||
cfg = _mlp_config(mlp_type="swiglu")
|
||||
mlp = MLP(cfg)
|
||||
expected_hidden = int(cfg.n_embd * cfg.mlp_width_mult * (2.0 / 3.0))
|
||||
assert mlp.c_fc.weight.shape[0] == 2 * expected_hidden
|
||||
x = torch.randn(4, 3, cfg.n_embd)
|
||||
out = mlp(x)
|
||||
assert out.shape == x.shape
|
||||
|
||||
|
||||
def test_geglu_respects_custom_width():
|
||||
cfg = _mlp_config(mlp_type="geglu", mlp_glu_width_mult=0.5)
|
||||
mlp = MLP(cfg)
|
||||
expected_hidden = int(cfg.n_embd * 0.5)
|
||||
assert mlp.c_fc.weight.shape[0] == 2 * expected_hidden
|
||||
x = torch.randn(3, 2, cfg.n_embd)
|
||||
out = mlp(x)
|
||||
assert out.shape == x.shape
|
||||
Loading…
Reference in New Issue
Block a user