fuse qkv linear and qk rotary + norm

This commit is contained in:
Matthew Murphy 2025-10-13 22:55:55 -07:00
parent dd6ff9a1cc
commit 24ed569055

View File

@ -14,14 +14,17 @@ Notable features:
import math
from functools import partial
from dataclasses import dataclass
from typing import Iterator
import torch
import torch.nn as nn
import torch.nn.functional as F
from nanochat.common import get_dist_info, print0
from nanochat.muon import Muon, DistMuon
from nanochat.adamw import DistAdamW
from nanochat.common import get_dist_info
from nanochat.engine import KVCache
from nanochat.muon import Muon, DistMuon
@dataclass
class GPTConfig:
@ -33,12 +36,12 @@ class GPTConfig:
n_embd: int = 768
def norm(x):
def norm(x: torch.Tensor) -> torch.Tensor:
# Purely functional rmsnorm with no learnable params
return F.rms_norm(x, (x.size(-1),))
def apply_rotary_emb(x, cos, sin):
def apply_rotary_emb(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor:
assert x.ndim == 4 # multihead attention
d = x.shape[3] // 2
x1, x2 = x[..., :d], x[..., d:] # split up last time into two halves
@ -49,7 +52,7 @@ def apply_rotary_emb(x, cos, sin):
return out
def repeat_kv(x, n_rep):
def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
"""torch.repeat_interleave(x, dim=1, repeats=n_rep)"""
if n_rep == 1:
return x
@ -62,7 +65,7 @@ def repeat_kv(x, n_rep):
class CausalSelfAttention(nn.Module):
def __init__(self, config, layer_idx):
def __init__(self, config: GPTConfig, layer_idx: int):
super().__init__()
self.layer_idx = layer_idx
self.n_head = config.n_head
@ -71,24 +74,36 @@ class CausalSelfAttention(nn.Module):
self.head_dim = self.n_embd // self.n_head
assert self.n_embd % self.n_head == 0
assert self.n_kv_head <= self.n_head and self.n_head % self.n_kv_head == 0
self.c_q = nn.Linear(self.n_embd, self.n_head * self.head_dim, bias=False)
self.c_k = nn.Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False)
self.c_v = nn.Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False)
self.c_qkv = nn.Linear(
self.n_embd,
(self.n_head + 2 * self.n_kv_head) * self.head_dim,
bias=False,
)
self.c_proj = nn.Linear(self.n_embd, self.n_embd, bias=False)
def forward(self, x, cos_sin, kv_cache):
B, T, C = x.size()
def forward(
self,
x: torch.Tensor,
cos_sin: tuple[torch.Tensor, torch.Tensor],
kv_cache: KVCache,
) -> torch.Tensor:
B, T, _ = x.size()
# Project the input to get queries, keys, and values
q = self.c_q(x).view(B, T, self.n_head, self.head_dim)
k = self.c_k(x).view(B, T, self.n_kv_head, self.head_dim)
v = self.c_v(x).view(B, T, self.n_kv_head, self.head_dim)
qk, v = (
self.c_qkv(x)
.view(B, T, self.n_head + 2 * self.n_kv_head, self.head_dim)
.split([self.n_head + self.n_kv_head, self.n_kv_head], dim=2)
)
# Apply Rotary Embeddings to queries and keys to get relative positional encoding
cos, sin = cos_sin
q, k = apply_rotary_emb(q, cos, sin), apply_rotary_emb(k, cos, sin) # QK rotary embedding
q, k = norm(q), norm(k) # QK norm
q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2) # make head be batch dim, i.e. (B, T, H, D) -> (B, H, T, D)
qk = apply_rotary_emb(qk, cos, sin)
qk = norm(qk)
# make head be batch dim, i.e. (B, T, H, D) -> (B, H, T, D)
q, k = qk.transpose(1, 2).split([self.n_head, self.n_kv_head], dim=1)
v = v.transpose(1, 2)
# Apply KV cache: insert current k,v into cache, get the full view so far
if kv_cache is not None:
@ -127,12 +142,12 @@ class CausalSelfAttention(nn.Module):
class MLP(nn.Module):
def __init__(self, config):
def __init__(self, config: GPTConfig):
super().__init__()
self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=False)
self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=False)
def forward(self, x):
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.c_fc(x)
x = F.relu(x).square()
x = self.c_proj(x)
@ -140,19 +155,24 @@ class MLP(nn.Module):
class Block(nn.Module):
def __init__(self, config, layer_idx):
def __init__(self, config: GPTConfig, layer_idx: int):
super().__init__()
self.attn = CausalSelfAttention(config, layer_idx)
self.mlp = MLP(config)
def forward(self, x, cos_sin, kv_cache):
def forward(
self,
x: torch.Tensor,
cos_sin: tuple[torch.Tensor, torch.Tensor],
kv_cache: KVCache,
) -> torch.Tensor:
x = x + self.attn(norm(x), cos_sin, kv_cache)
x = x + self.mlp(norm(x))
return x
class GPT(nn.Module):
def __init__(self, config):
def __init__(self, config: GPTConfig):
super().__init__()
self.config = config
self.transformer = nn.ModuleDict({
@ -185,7 +205,7 @@ class GPT(nn.Module):
cos, sin = self._precompute_rotary_embeddings(self.rotary_seq_len, head_dim)
self.cos, self.sin = cos, sin
def _init_weights(self, module):
def _init_weights(self, module: nn.Module):
if isinstance(module, nn.Linear):
# https://arxiv.org/pdf/2310.17813
fan_out = module.weight.size(0)
@ -198,7 +218,13 @@ class GPT(nn.Module):
torch.nn.init.normal_(module.weight, mean=0.0, std=1.0)
# TODO: bump base theta more, e.g. 100K is more common more recently
def _precompute_rotary_embeddings(self, seq_len, head_dim, base=10000, device=None):
def _precompute_rotary_embeddings(
self,
seq_len: int,
head_dim: int,
base: float = 10000,
device: torch.device | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
# autodetect the device from model embeddings
if device is None:
device = self.transformer.wte.weight.device
@ -214,10 +240,10 @@ class GPT(nn.Module):
cos, sin = cos[None, :, None, :], sin[None, :, None, :] # add batch and head dims for later broadcasting
return cos, sin
def get_device(self):
def get_device(self) -> torch.device:
return self.transformer.wte.weight.device
def estimate_flops(self):
def estimate_flops(self) -> float:
""" Return the estimated FLOPs per token for the model. Ref: https://arxiv.org/abs/2204.02311 """
nparams = sum(p.numel() for p in self.parameters())
nparams_embedding = self.transformer.wte.weight.numel()
@ -225,9 +251,9 @@ class GPT(nn.Module):
num_flops_per_token = 6 * (nparams - nparams_embedding) + 12 * l * h * q * t
return num_flops_per_token
def setup_optimizers(self, unembedding_lr=0.004, embedding_lr=0.2, matrix_lr=0.02, weight_decay=0.0):
def setup_optimizers(self, unembedding_lr: float = 0.004, embedding_lr: float = 0.2, matrix_lr: float = 0.02, weight_decay: float = 0.0) -> list[torch.optim.Optimizer]:
model_dim = self.config.n_embd
ddp, rank, local_rank, world_size = get_dist_info()
ddp, rank, _, _ = get_dist_info()
# Separate out all parameters into 3 groups (matrix, embedding, lm_head)
matrix_params = list(self.transformer.h.parameters())
embedding_params = list(self.transformer.wte.parameters())
@ -256,8 +282,8 @@ class GPT(nn.Module):
group["initial_lr"] = group["lr"]
return optimizers
def forward(self, idx, targets=None, kv_cache=None, loss_reduction='mean'):
B, T = idx.size()
def forward(self, idx: torch.Tensor, targets: torch.Tensor | None = None, kv_cache: KVCache | None = None, loss_reduction: str = 'mean') -> torch.Tensor | None:
_, T = idx.size()
# Grab the rotary embeddings for the current sequence length (they are of shape (1, seq_len, 1, head_dim))
assert T <= self.cos.size(1), f"Sequence length grew beyond the rotary embeddings cache: {T} > {self.cos.size(1)}"
@ -291,7 +317,14 @@ class GPT(nn.Module):
return logits
@torch.inference_mode()
def generate(self, tokens, max_tokens, temperature=1.0, top_k=None, seed=42):
def generate(
self,
tokens: list[int],
max_tokens: int,
temperature: float = 1.0,
top_k: int | None = None,
seed: int = 42,
) -> Iterator[int]:
"""
Naive autoregressive streaming inference.
To make it super simple, let's assume: