diff --git a/nanochat/gpt.py b/nanochat/gpt.py index 5a066b2..6ab213c 100644 --- a/nanochat/gpt.py +++ b/nanochat/gpt.py @@ -4,16 +4,18 @@ Notable features: - rotary embeddings (and no positional embeddings) - QK norm - untied weights for token embedding and lm_head -- relu^2 activation in MLP +- relu^2 / gated MLPs with width scaling - norm after token embedding - no learnable params in rmsnorm - no bias in linear layers - Multi-Query Attention (MQA) support for more efficient inference +- Optional fused QKV projection for fewer matmuls """ import math from functools import partial from dataclasses import dataclass +from typing import Optional import torch import torch.nn as nn @@ -31,6 +33,10 @@ class GPTConfig: n_head: int = 6 # number of query heads n_kv_head: int = 6 # number of key/value heads (MQA) n_embd: int = 768 + use_fused_qkv: bool = True + mlp_type: str = "relu2" # choices: relu2, swiglu, geglu + mlp_width_mult: float = 4.0 + mlp_glu_width_mult: Optional[float] = None def norm(x): @@ -69,20 +75,31 @@ class CausalSelfAttention(nn.Module): self.n_kv_head = config.n_kv_head self.n_embd = config.n_embd self.head_dim = self.n_embd // self.n_head + self.total_qkv_heads = self.n_head + 2 * self.n_kv_head assert self.n_embd % self.n_head == 0 assert self.n_kv_head <= self.n_head and self.n_head % self.n_kv_head == 0 - self.c_q = nn.Linear(self.n_embd, self.n_head * self.head_dim, bias=False) - self.c_k = nn.Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False) - self.c_v = nn.Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False) + self.use_fused_qkv = config.use_fused_qkv + if self.use_fused_qkv: + # Single projection keeps matmul count minimal; split happens in forward pass. + self.c_attn = nn.Linear(self.n_embd, self.total_qkv_heads * self.head_dim, bias=False) + else: + self.c_q = nn.Linear(self.n_embd, self.n_head * self.head_dim, bias=False) + self.c_k = nn.Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False) + self.c_v = nn.Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False) self.c_proj = nn.Linear(self.n_embd, self.n_embd, bias=False) def forward(self, x, cos_sin, kv_cache): B, T, C = x.size() # Project the input to get queries, keys, and values - q = self.c_q(x).view(B, T, self.n_head, self.head_dim) - k = self.c_k(x).view(B, T, self.n_kv_head, self.head_dim) - v = self.c_v(x).view(B, T, self.n_kv_head, self.head_dim) + if self.use_fused_qkv: + qkv = self.c_attn(x) + qkv = qkv.view(B, T, self.total_qkv_heads, self.head_dim) + q, k, v = torch.split(qkv, [self.n_head, self.n_kv_head, self.n_kv_head], dim=2) + else: + q = self.c_q(x).view(B, T, self.n_head, self.head_dim) + k = self.c_k(x).view(B, T, self.n_kv_head, self.head_dim) + v = self.c_v(x).view(B, T, self.n_kv_head, self.head_dim) # Apply Rotary Embeddings to queries and keys to get relative positional encoding cos, sin = cos_sin @@ -125,16 +142,63 @@ class CausalSelfAttention(nn.Module): y = self.c_proj(y) return y + def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, + missing_keys, unexpected_keys, error_msgs): + fused_key = prefix + "c_attn.weight" + q_key = prefix + "c_q.weight" + k_key = prefix + "c_k.weight" + v_key = prefix + "c_v.weight" + has_fused = fused_key in state_dict + has_split = all(key in state_dict for key in (q_key, k_key, v_key)) + + if self.use_fused_qkv and not has_fused and has_split: + # Convert legacy split weights into the fused layout expected by this module. + q = state_dict.pop(q_key) + k = state_dict.pop(k_key) + v = state_dict.pop(v_key) + state_dict[fused_key] = torch.cat([q, k, v], dim=0) + elif not self.use_fused_qkv and has_fused and not has_split: + fused = state_dict.pop(fused_key) + q_out = self.n_head * self.head_dim + kv_out = self.n_kv_head * self.head_dim + state_dict[q_key] = fused[:q_out] + state_dict[k_key] = fused[q_out:q_out + kv_out] + state_dict[v_key] = fused[q_out + kv_out:] + + super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, + missing_keys, unexpected_keys, error_msgs) + class MLP(nn.Module): def __init__(self, config): super().__init__() - self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=False) - self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=False) + self.mlp_type = config.mlp_type.lower() + base_hidden = max(1, int(config.mlp_width_mult * config.n_embd)) + if self.mlp_type in {"swiglu", "geglu"}: + width_mult = config.mlp_glu_width_mult + if width_mult is None: + # 2/3 factor keeps parameter count aligned with the ReLU^2 baseline. + width_mult = config.mlp_width_mult * (2.0 / 3.0) + hidden_dim = max(1, int(width_mult * config.n_embd)) + self.c_fc = nn.Linear(config.n_embd, 2 * hidden_dim, bias=False) + self.c_proj = nn.Linear(hidden_dim, config.n_embd, bias=False) + self.hidden_dim = hidden_dim + else: + hidden_dim = base_hidden + self.c_fc = nn.Linear(config.n_embd, hidden_dim, bias=False) + self.c_proj = nn.Linear(hidden_dim, config.n_embd, bias=False) + self.hidden_dim = hidden_dim def forward(self, x): x = self.c_fc(x) - x = F.relu(x).square() + if self.mlp_type == "swiglu": + x1, x2 = torch.chunk(x, 2, dim=-1) + x = F.silu(x1) * x2 + elif self.mlp_type == "geglu": + x1, x2 = torch.chunk(x, 2, dim=-1) + x = F.gelu(x1, approximate="tanh") * x2 + else: + x = F.relu(x).square() x = self.c_proj(x) return x @@ -169,8 +233,9 @@ class GPT(nn.Module): cos, sin = self._precompute_rotary_embeddings(self.rotary_seq_len, head_dim) self.register_buffer("cos", cos, persistent=False) # persistent=False means it's not saved to the checkpoint self.register_buffer("sin", sin, persistent=False) - # Cast the embeddings from fp32 to bf16: optim can tolerate it and it saves memory: both in the model and the activations - self.transformer.wte.to(dtype=torch.bfloat16) + # Cast embeddings to bf16 only when CUDA is available to keep CPU inference/tests simple. + if torch.cuda.is_available(): + self.transformer.wte.to(dtype=torch.bfloat16) def init_weights(self): self.apply(self._init_weights) diff --git a/tests/test_attention_fusion.py b/tests/test_attention_fusion.py new file mode 100644 index 0000000..3209f81 --- /dev/null +++ b/tests/test_attention_fusion.py @@ -0,0 +1,59 @@ +import torch +from dataclasses import replace + +from nanochat.gpt import GPT, GPTConfig + + +def _tiny_config(**kwargs): + base = GPTConfig( + sequence_len=16, + vocab_size=128, + n_layer=1, + n_head=4, + n_kv_head=2, + n_embd=64, + use_fused_qkv=False, + mlp_type="relu2", + ) + return replace(base, **kwargs) + + +def test_fused_qkv_matches_legacy_split_projection(): + torch.manual_seed(0) + split_cfg = _tiny_config(use_fused_qkv=False) + fused_cfg = replace(split_cfg, use_fused_qkv=True) + + split_model = GPT(split_cfg) + split_model.init_weights() + state = split_model.state_dict() + + fused_model = GPT(fused_cfg) + fused_model.init_weights() + fused_model.load_state_dict(state, strict=True) + + tokens = torch.randint(0, split_cfg.vocab_size, (2, 5)) + with torch.no_grad(): + logits_split = split_model(tokens) + logits_fused = fused_model(tokens) + + assert torch.allclose(logits_split, logits_fused, atol=1e-5) + + +def test_split_loads_from_fused_state_dict(): + torch.manual_seed(1) + fused_cfg = _tiny_config(use_fused_qkv=True) + fused_model = GPT(fused_cfg) + fused_model.init_weights() + state = fused_model.state_dict() + + split_cfg = replace(fused_cfg, use_fused_qkv=False) + split_model = GPT(split_cfg) + split_model.init_weights() + split_model.load_state_dict(state, strict=True) + + tokens = torch.randint(0, split_cfg.vocab_size, (1, 7)) + with torch.no_grad(): + logits_split = split_model(tokens) + logits_fused = fused_model(tokens) + + assert torch.allclose(logits_split, logits_fused, atol=1e-5) diff --git a/tests/test_mlp_variants.py b/tests/test_mlp_variants.py new file mode 100644 index 0000000..3e407e1 --- /dev/null +++ b/tests/test_mlp_variants.py @@ -0,0 +1,48 @@ +import torch +from dataclasses import replace + +from nanochat.gpt import GPTConfig, MLP + + +def _mlp_config(**kwargs): + cfg = GPTConfig( + n_embd=32, + mlp_type="relu2", + mlp_width_mult=4.0, + mlp_glu_width_mult=None, + n_layer=1, + n_head=4, + n_kv_head=4, + sequence_len=8, + vocab_size=128, + ) + return replace(cfg, **kwargs) + + +def test_relu2_mlp_shape(): + cfg = _mlp_config(mlp_type="relu2") + mlp = MLP(cfg) + assert mlp.c_fc.weight.shape[0] == int(cfg.mlp_width_mult * cfg.n_embd) + x = torch.randn(2, 5, cfg.n_embd) + out = mlp(x) + assert out.shape == x.shape + + +def test_swiglu_width_scaling_defaults_to_two_thirds(): + cfg = _mlp_config(mlp_type="swiglu") + mlp = MLP(cfg) + expected_hidden = int(cfg.n_embd * cfg.mlp_width_mult * (2.0 / 3.0)) + assert mlp.c_fc.weight.shape[0] == 2 * expected_hidden + x = torch.randn(4, 3, cfg.n_embd) + out = mlp(x) + assert out.shape == x.shape + + +def test_geglu_respects_custom_width(): + cfg = _mlp_config(mlp_type="geglu", mlp_glu_width_mult=0.5) + mlp = MLP(cfg) + expected_hidden = int(cfg.n_embd * 0.5) + assert mlp.c_fc.weight.shape[0] == 2 * expected_hidden + x = torch.randn(3, 2, cfg.n_embd) + out = mlp(x) + assert out.shape == x.shape