fuse qkv linear and qk rotary + norm

This commit is contained in:
Matthew Murphy 2025-10-13 22:55:55 -07:00 committed by Matt Murphy
parent a1de1f46ad
commit 7e87fa8a71

View File

@ -14,14 +14,17 @@ Notable features:
import math import math
from functools import partial from functools import partial
from dataclasses import dataclass from dataclasses import dataclass
from typing import Iterator
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from nanochat.common import get_dist_info, print0
from nanochat.muon import Muon, DistMuon
from nanochat.adamw import DistAdamW from nanochat.adamw import DistAdamW
from nanochat.common import get_dist_info
from nanochat.engine import KVCache
from nanochat.muon import Muon, DistMuon
@dataclass @dataclass
class GPTConfig: class GPTConfig:
@ -33,12 +36,12 @@ class GPTConfig:
n_embd: int = 768 n_embd: int = 768
def norm(x): def norm(x: torch.Tensor) -> torch.Tensor:
# Purely functional rmsnorm with no learnable params # Purely functional rmsnorm with no learnable params
return F.rms_norm(x, (x.size(-1),)) return F.rms_norm(x, (x.size(-1),))
def apply_rotary_emb(x, cos, sin): def apply_rotary_emb(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor:
assert x.ndim == 4 # multihead attention assert x.ndim == 4 # multihead attention
d = x.shape[3] // 2 d = x.shape[3] // 2
x1, x2 = x[..., :d], x[..., d:] # split up last time into two halves x1, x2 = x[..., :d], x[..., d:] # split up last time into two halves
@ -48,8 +51,9 @@ def apply_rotary_emb(x, cos, sin):
out = out.to(x.dtype) # ensure input/output dtypes match out = out.to(x.dtype) # ensure input/output dtypes match
return out return out
class CausalSelfAttention(nn.Module): class CausalSelfAttention(nn.Module):
def __init__(self, config, layer_idx): def __init__(self, config: GPTConfig, layer_idx: int):
super().__init__() super().__init__()
self.layer_idx = layer_idx self.layer_idx = layer_idx
self.n_head = config.n_head self.n_head = config.n_head
@ -58,24 +62,36 @@ class CausalSelfAttention(nn.Module):
self.head_dim = self.n_embd // self.n_head self.head_dim = self.n_embd // self.n_head
assert self.n_embd % self.n_head == 0 assert self.n_embd % self.n_head == 0
assert self.n_kv_head <= self.n_head and self.n_head % self.n_kv_head == 0 assert self.n_kv_head <= self.n_head and self.n_head % self.n_kv_head == 0
self.c_q = nn.Linear(self.n_embd, self.n_head * self.head_dim, bias=False) self.c_qkv = nn.Linear(
self.c_k = nn.Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False) self.n_embd,
self.c_v = nn.Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False) (self.n_head + 2 * self.n_kv_head) * self.head_dim,
bias=False,
)
self.c_proj = nn.Linear(self.n_embd, self.n_embd, bias=False) self.c_proj = nn.Linear(self.n_embd, self.n_embd, bias=False)
def forward(self, x, cos_sin, kv_cache): def forward(
B, T, C = x.size() self,
x: torch.Tensor,
cos_sin: tuple[torch.Tensor, torch.Tensor],
kv_cache: KVCache,
) -> torch.Tensor:
B, T, _ = x.size()
# Project the input to get queries, keys, and values # Project the input to get queries, keys, and values
q = self.c_q(x).view(B, T, self.n_head, self.head_dim) qk, v = (
k = self.c_k(x).view(B, T, self.n_kv_head, self.head_dim) self.c_qkv(x)
v = self.c_v(x).view(B, T, self.n_kv_head, self.head_dim) .view(B, T, self.n_head + 2 * self.n_kv_head, self.head_dim)
.split([self.n_head + self.n_kv_head, self.n_kv_head], dim=2)
)
# Apply Rotary Embeddings to queries and keys to get relative positional encoding # Apply Rotary Embeddings to queries and keys to get relative positional encoding
cos, sin = cos_sin cos, sin = cos_sin
q, k = apply_rotary_emb(q, cos, sin), apply_rotary_emb(k, cos, sin) # QK rotary embedding qk = apply_rotary_emb(qk, cos, sin)
q, k = norm(q), norm(k) # QK norm qk = norm(qk)
q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2) # make head be batch dim, i.e. (B, T, H, D) -> (B, H, T, D)
# make head be batch dim, i.e. (B, T, H, D) -> (B, H, T, D)
q, k = qk.transpose(1, 2).split([self.n_head, self.n_kv_head], dim=1)
v = v.transpose(1, 2)
# Apply KV cache: insert current k,v into cache, get the full view so far # Apply KV cache: insert current k,v into cache, get the full view so far
if kv_cache is not None: if kv_cache is not None:
@ -111,12 +127,12 @@ class CausalSelfAttention(nn.Module):
class MLP(nn.Module): class MLP(nn.Module):
def __init__(self, config): def __init__(self, config: GPTConfig):
super().__init__() super().__init__()
self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=False) self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=False)
self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=False) self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=False)
def forward(self, x): def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.c_fc(x) x = self.c_fc(x)
x = F.relu(x).square() x = F.relu(x).square()
x = self.c_proj(x) x = self.c_proj(x)
@ -124,19 +140,24 @@ class MLP(nn.Module):
class Block(nn.Module): class Block(nn.Module):
def __init__(self, config, layer_idx): def __init__(self, config: GPTConfig, layer_idx: int):
super().__init__() super().__init__()
self.attn = CausalSelfAttention(config, layer_idx) self.attn = CausalSelfAttention(config, layer_idx)
self.mlp = MLP(config) self.mlp = MLP(config)
def forward(self, x, cos_sin, kv_cache): def forward(
self,
x: torch.Tensor,
cos_sin: tuple[torch.Tensor, torch.Tensor],
kv_cache: KVCache,
) -> torch.Tensor:
x = x + self.attn(norm(x), cos_sin, kv_cache) x = x + self.attn(norm(x), cos_sin, kv_cache)
x = x + self.mlp(norm(x)) x = x + self.mlp(norm(x))
return x return x
class GPT(nn.Module): class GPT(nn.Module):
def __init__(self, config): def __init__(self, config: GPTConfig):
super().__init__() super().__init__()
self.config = config self.config = config
self.transformer = nn.ModuleDict({ self.transformer = nn.ModuleDict({
@ -170,7 +191,7 @@ class GPT(nn.Module):
if self.transformer.wte.weight.device.type == "cuda": if self.transformer.wte.weight.device.type == "cuda":
self.transformer.wte.to(dtype=torch.bfloat16) self.transformer.wte.to(dtype=torch.bfloat16)
def _init_weights(self, module): def _init_weights(self, module: nn.Module):
if isinstance(module, nn.Linear): if isinstance(module, nn.Linear):
# https://arxiv.org/pdf/2310.17813 # https://arxiv.org/pdf/2310.17813
fan_out = module.weight.size(0) fan_out = module.weight.size(0)
@ -183,7 +204,13 @@ class GPT(nn.Module):
torch.nn.init.normal_(module.weight, mean=0.0, std=1.0) torch.nn.init.normal_(module.weight, mean=0.0, std=1.0)
# TODO: bump base theta more, e.g. 100K is more common more recently # TODO: bump base theta more, e.g. 100K is more common more recently
def _precompute_rotary_embeddings(self, seq_len, head_dim, base=10000, device=None): def _precompute_rotary_embeddings(
self,
seq_len: int,
head_dim: int,
base: float = 10000,
device: torch.device | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
# autodetect the device from model embeddings # autodetect the device from model embeddings
if device is None: if device is None:
device = self.transformer.wte.weight.device device = self.transformer.wte.weight.device
@ -199,10 +226,10 @@ class GPT(nn.Module):
cos, sin = cos[None, :, None, :], sin[None, :, None, :] # add batch and head dims for later broadcasting cos, sin = cos[None, :, None, :], sin[None, :, None, :] # add batch and head dims for later broadcasting
return cos, sin return cos, sin
def get_device(self): def get_device(self) -> torch.device:
return self.transformer.wte.weight.device return self.transformer.wte.weight.device
def estimate_flops(self): def estimate_flops(self) -> float:
""" Return the estimated FLOPs per token for the model. Ref: https://arxiv.org/abs/2204.02311 """ """ Return the estimated FLOPs per token for the model. Ref: https://arxiv.org/abs/2204.02311 """
nparams = sum(p.numel() for p in self.parameters()) nparams = sum(p.numel() for p in self.parameters())
nparams_embedding = self.transformer.wte.weight.numel() nparams_embedding = self.transformer.wte.weight.numel()
@ -210,9 +237,9 @@ class GPT(nn.Module):
num_flops_per_token = 6 * (nparams - nparams_embedding) + 12 * l * h * q * t num_flops_per_token = 6 * (nparams - nparams_embedding) + 12 * l * h * q * t
return num_flops_per_token return num_flops_per_token
def setup_optimizers(self, unembedding_lr=0.004, embedding_lr=0.2, matrix_lr=0.02, weight_decay=0.0): def setup_optimizers(self, unembedding_lr: float = 0.004, embedding_lr: float = 0.2, matrix_lr: float = 0.02, weight_decay: float = 0.0) -> list[torch.optim.Optimizer]:
model_dim = self.config.n_embd model_dim = self.config.n_embd
ddp, rank, local_rank, world_size = get_dist_info() ddp, rank, _, _ = get_dist_info()
# Separate out all parameters into 3 groups (matrix, embedding, lm_head) # Separate out all parameters into 3 groups (matrix, embedding, lm_head)
matrix_params = list(self.transformer.h.parameters()) matrix_params = list(self.transformer.h.parameters())
embedding_params = list(self.transformer.wte.parameters()) embedding_params = list(self.transformer.wte.parameters())
@ -241,8 +268,8 @@ class GPT(nn.Module):
group["initial_lr"] = group["lr"] group["initial_lr"] = group["lr"]
return optimizers return optimizers
def forward(self, idx, targets=None, kv_cache=None, loss_reduction='mean'): def forward(self, idx: torch.Tensor, targets: torch.Tensor | None = None, kv_cache: KVCache | None = None, loss_reduction: str = 'mean') -> torch.Tensor | None:
B, T = idx.size() _, T = idx.size()
# Grab the rotary embeddings for the current sequence length (they are of shape (1, seq_len, 1, head_dim)) # Grab the rotary embeddings for the current sequence length (they are of shape (1, seq_len, 1, head_dim))
assert T <= self.cos.size(1), f"Sequence length grew beyond the rotary embeddings cache: {T} > {self.cos.size(1)}" assert T <= self.cos.size(1), f"Sequence length grew beyond the rotary embeddings cache: {T} > {self.cos.size(1)}"
@ -276,7 +303,14 @@ class GPT(nn.Module):
return logits return logits
@torch.inference_mode() @torch.inference_mode()
def generate(self, tokens, max_tokens, temperature=1.0, top_k=None, seed=42): def generate(
self,
tokens: list[int],
max_tokens: int,
temperature: float = 1.0,
top_k: int | None = None,
seed: int = 42,
) -> Iterator[int]:
""" """
Naive autoregressive streaming inference. Naive autoregressive streaming inference.
To make it super simple, let's assume: To make it super simple, let's assume: