mirror of
https://github.com/karpathy/nanochat.git
synced 2026-04-23 17:28:49 +00:00
173 lines
7.4 KiB
Python
173 lines
7.4 KiB
Python
|
|
"""
|
|
Minimal GPT implementation for HF export (inference-only utilities).
|
|
"""
|
|
import math
|
|
from dataclasses import dataclass
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
|
|
@dataclass
|
|
class GPTConfig:
|
|
sequence_len: int = 1024
|
|
vocab_size: int = 50304
|
|
n_layer: int = 12
|
|
n_head: int = 6 # number of query heads
|
|
n_kv_head: int = 6 # number of key/value heads (GQA)
|
|
n_embd: int = 768
|
|
|
|
def norm(x):
|
|
# Purely functional rmsnorm with no learnable params
|
|
return F.rms_norm(x, (x.size(-1),))
|
|
|
|
def apply_rotary_emb(x, cos, sin):
|
|
assert x.ndim == 4 # multihead attention
|
|
d = x.shape[3] // 2
|
|
x1, x2 = x[..., :d], x[..., d:] # split up last time into two halves
|
|
y1 = x1 * cos + x2 * sin # rotate pairs of dims
|
|
y2 = x1 * (-sin) + x2 * cos
|
|
out = torch.cat([y1, y2], 3) # re-assemble
|
|
out = out.to(x.dtype) # ensure input/output dtypes match
|
|
return out
|
|
|
|
class CausalSelfAttention(nn.Module):
|
|
def __init__(self, config, layer_idx):
|
|
super().__init__()
|
|
self.layer_idx = layer_idx
|
|
self.n_head = config.n_head
|
|
self.n_kv_head = config.n_kv_head
|
|
self.n_embd = config.n_embd
|
|
self.head_dim = self.n_embd // self.n_head
|
|
assert self.n_embd % self.n_head == 0
|
|
assert self.n_kv_head <= self.n_head and self.n_head % self.n_kv_head == 0
|
|
self.c_q = nn.Linear(self.n_embd, self.n_head * self.head_dim, bias=False)
|
|
self.c_k = nn.Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False)
|
|
self.c_v = nn.Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False)
|
|
self.c_proj = nn.Linear(self.n_embd, self.n_embd, bias=False)
|
|
|
|
def forward(self, x, cos_sin, kv_cache):
|
|
B, T, _ = x.size()
|
|
|
|
# Project the input to get queries, keys, and values
|
|
q = self.c_q(x).view(B, T, self.n_head, self.head_dim)
|
|
k = self.c_k(x).view(B, T, self.n_kv_head, self.head_dim)
|
|
v = self.c_v(x).view(B, T, self.n_kv_head, self.head_dim)
|
|
|
|
# Apply Rotary Embeddings to queries and keys to get relative positional encoding
|
|
cos, sin = cos_sin
|
|
q, k = apply_rotary_emb(q, cos, sin), apply_rotary_emb(k, cos, sin) # QK rotary embedding
|
|
q, k = norm(q), norm(k) # QK norm
|
|
q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2) # (B, T, H, D) -> (B, H, T, D)
|
|
|
|
# Apply KV cache: insert current k,v into cache, get the full view so far
|
|
if kv_cache is not None:
|
|
k, v = kv_cache.insert_kv(self.layer_idx, k, v)
|
|
Tq = q.size(2) # number of queries in this forward pass
|
|
Tk = k.size(2) # number of keys/values in total (in the cache + current forward pass)
|
|
|
|
# Attention: queries attend to keys/values autoregressively. A few cases to handle:
|
|
enable_gqa = self.n_head != self.n_kv_head # Group Query Attention (GQA): duplicate key/value heads to match query heads if desired
|
|
if kv_cache is None or Tq == Tk:
|
|
y = F.scaled_dot_product_attention(q, k, v, is_causal=True, enable_gqa=enable_gqa)
|
|
elif Tq == 1:
|
|
y = F.scaled_dot_product_attention(q, k, v, is_causal=False, enable_gqa=enable_gqa)
|
|
else:
|
|
attn_mask = torch.zeros((Tq, Tk), dtype=torch.bool, device=q.device) # True = keep, False = mask
|
|
prefix_len = Tk - Tq
|
|
if prefix_len > 0:
|
|
attn_mask[:, :prefix_len] = True
|
|
attn_mask[:, prefix_len:] = torch.tril(torch.ones((Tq, Tq), dtype=torch.bool, device=q.device))
|
|
y = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, enable_gqa=enable_gqa)
|
|
|
|
# Re-assemble the heads side by side and project back to residual stream
|
|
y = y.transpose(1, 2).contiguous().view(B, T, -1)
|
|
y = self.c_proj(y)
|
|
return y
|
|
|
|
class MLP(nn.Module):
|
|
def __init__(self, config):
|
|
super().__init__()
|
|
self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=False)
|
|
self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=False)
|
|
|
|
def forward(self, x):
|
|
x = self.c_fc(x)
|
|
x = F.relu(x).square()
|
|
x = self.c_proj(x)
|
|
return x
|
|
|
|
class Block(nn.Module):
|
|
def __init__(self, config, layer_idx):
|
|
super().__init__()
|
|
self.attn = CausalSelfAttention(config, layer_idx)
|
|
self.mlp = MLP(config)
|
|
|
|
def forward(self, x, cos_sin, kv_cache):
|
|
x = x + self.attn(norm(x), cos_sin, kv_cache)
|
|
x = x + self.mlp(norm(x))
|
|
return x
|
|
|
|
class GPT(nn.Module):
|
|
def __init__(self, config):
|
|
super().__init__()
|
|
self.config = config
|
|
self.transformer = nn.ModuleDict({
|
|
"wte": nn.Embedding(config.vocab_size, config.n_embd),
|
|
"h": nn.ModuleList([Block(config, layer_idx) for layer_idx in range(config.n_layer)]),
|
|
})
|
|
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
|
|
# Precompute rotary embeddings (small overhead, avoids realloc each forward)
|
|
self.rotary_seq_len = config.sequence_len * 10
|
|
head_dim = config.n_embd // config.n_head
|
|
cos, sin = self._precompute_rotary_embeddings(self.rotary_seq_len, head_dim)
|
|
self.register_buffer("cos", cos, persistent=False)
|
|
self.register_buffer("sin", sin, persistent=False)
|
|
|
|
def get_device(self):
|
|
return self.transformer.wte.weight.device
|
|
|
|
def _precompute_rotary_embeddings(self, seq_len, head_dim, base=10000, device=None):
|
|
if device is None:
|
|
device = self.transformer.wte.weight.device
|
|
# When transformers initializes the model under `init_empty_weights`, parameters live on the
|
|
# meta device. Buffers created on meta cannot be moved with `.to()`. To keep HF/accelerate
|
|
# happy, build these buffers on CPU in that case.
|
|
if getattr(device, "type", None) == "meta":
|
|
device = torch.device("cpu")
|
|
channel_range = torch.arange(0, head_dim, 2, dtype=torch.float32, device=device)
|
|
inv_freq = 1.0 / (base ** (channel_range / head_dim))
|
|
t = torch.arange(seq_len, dtype=torch.float32, device=device)
|
|
freqs = torch.outer(t, inv_freq)
|
|
cos, sin = freqs.cos(), freqs.sin()
|
|
cos, sin = cos.bfloat16(), sin.bfloat16()
|
|
cos, sin = cos[None, :, None, :], sin[None, :, None, :]
|
|
return cos, sin
|
|
|
|
def forward(self, idx, targets=None, kv_cache=None, loss_reduction="mean"):
|
|
B, T = idx.size()
|
|
|
|
assert T <= self.cos.size(1), f"Sequence length grew beyond the rotary embeddings cache: {T} > {self.cos.size(1)}"
|
|
assert idx.device == self.cos.device, f"Rotary embeddings and idx are on different devices: {idx.device} != {self.cos.device}"
|
|
assert self.cos.dtype == torch.bfloat16, "Rotary embeddings must be in bfloat16"
|
|
T0 = 0 if kv_cache is None else kv_cache.get_pos()
|
|
cos_sin = self.cos[:, T0:T0+T], self.sin[:, T0:T0+T]
|
|
|
|
x = self.transformer.wte(idx)
|
|
x = norm(x)
|
|
for block in self.transformer.h:
|
|
x = block(x, cos_sin, kv_cache)
|
|
x = norm(x)
|
|
|
|
softcap = 15
|
|
if targets is not None:
|
|
logits = self.lm_head(x)
|
|
logits = softcap * torch.tanh(logits / softcap)
|
|
logits = logits.float()
|
|
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1, reduction=loss_reduction)
|
|
return loss
|
|
logits = self.lm_head(x)
|
|
logits = softcap * torch.tanh(logits / softcap)
|
|
return logits
|