mirror of
https://github.com/karpathy/nanochat.git
synced 2025-12-06 12:22:18 +00:00
fuse qkv linear and qk rotary + norm
This commit is contained in:
parent
a1de1f46ad
commit
7e87fa8a71
|
|
@ -14,14 +14,17 @@ Notable features:
|
||||||
import math
|
import math
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
from typing import Iterator
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
from nanochat.common import get_dist_info, print0
|
|
||||||
from nanochat.muon import Muon, DistMuon
|
|
||||||
from nanochat.adamw import DistAdamW
|
from nanochat.adamw import DistAdamW
|
||||||
|
from nanochat.common import get_dist_info
|
||||||
|
from nanochat.engine import KVCache
|
||||||
|
from nanochat.muon import Muon, DistMuon
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class GPTConfig:
|
class GPTConfig:
|
||||||
|
|
@ -33,12 +36,12 @@ class GPTConfig:
|
||||||
n_embd: int = 768
|
n_embd: int = 768
|
||||||
|
|
||||||
|
|
||||||
def norm(x):
|
def norm(x: torch.Tensor) -> torch.Tensor:
|
||||||
# Purely functional rmsnorm with no learnable params
|
# Purely functional rmsnorm with no learnable params
|
||||||
return F.rms_norm(x, (x.size(-1),))
|
return F.rms_norm(x, (x.size(-1),))
|
||||||
|
|
||||||
|
|
||||||
def apply_rotary_emb(x, cos, sin):
|
def apply_rotary_emb(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor:
|
||||||
assert x.ndim == 4 # multihead attention
|
assert x.ndim == 4 # multihead attention
|
||||||
d = x.shape[3] // 2
|
d = x.shape[3] // 2
|
||||||
x1, x2 = x[..., :d], x[..., d:] # split up last time into two halves
|
x1, x2 = x[..., :d], x[..., d:] # split up last time into two halves
|
||||||
|
|
@ -48,8 +51,9 @@ def apply_rotary_emb(x, cos, sin):
|
||||||
out = out.to(x.dtype) # ensure input/output dtypes match
|
out = out.to(x.dtype) # ensure input/output dtypes match
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
class CausalSelfAttention(nn.Module):
|
class CausalSelfAttention(nn.Module):
|
||||||
def __init__(self, config, layer_idx):
|
def __init__(self, config: GPTConfig, layer_idx: int):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.layer_idx = layer_idx
|
self.layer_idx = layer_idx
|
||||||
self.n_head = config.n_head
|
self.n_head = config.n_head
|
||||||
|
|
@ -58,24 +62,36 @@ class CausalSelfAttention(nn.Module):
|
||||||
self.head_dim = self.n_embd // self.n_head
|
self.head_dim = self.n_embd // self.n_head
|
||||||
assert self.n_embd % self.n_head == 0
|
assert self.n_embd % self.n_head == 0
|
||||||
assert self.n_kv_head <= self.n_head and self.n_head % self.n_kv_head == 0
|
assert self.n_kv_head <= self.n_head and self.n_head % self.n_kv_head == 0
|
||||||
self.c_q = nn.Linear(self.n_embd, self.n_head * self.head_dim, bias=False)
|
self.c_qkv = nn.Linear(
|
||||||
self.c_k = nn.Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False)
|
self.n_embd,
|
||||||
self.c_v = nn.Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False)
|
(self.n_head + 2 * self.n_kv_head) * self.head_dim,
|
||||||
|
bias=False,
|
||||||
|
)
|
||||||
self.c_proj = nn.Linear(self.n_embd, self.n_embd, bias=False)
|
self.c_proj = nn.Linear(self.n_embd, self.n_embd, bias=False)
|
||||||
|
|
||||||
def forward(self, x, cos_sin, kv_cache):
|
def forward(
|
||||||
B, T, C = x.size()
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
cos_sin: tuple[torch.Tensor, torch.Tensor],
|
||||||
|
kv_cache: KVCache,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
B, T, _ = x.size()
|
||||||
|
|
||||||
# Project the input to get queries, keys, and values
|
# Project the input to get queries, keys, and values
|
||||||
q = self.c_q(x).view(B, T, self.n_head, self.head_dim)
|
qk, v = (
|
||||||
k = self.c_k(x).view(B, T, self.n_kv_head, self.head_dim)
|
self.c_qkv(x)
|
||||||
v = self.c_v(x).view(B, T, self.n_kv_head, self.head_dim)
|
.view(B, T, self.n_head + 2 * self.n_kv_head, self.head_dim)
|
||||||
|
.split([self.n_head + self.n_kv_head, self.n_kv_head], dim=2)
|
||||||
|
)
|
||||||
|
|
||||||
# Apply Rotary Embeddings to queries and keys to get relative positional encoding
|
# Apply Rotary Embeddings to queries and keys to get relative positional encoding
|
||||||
cos, sin = cos_sin
|
cos, sin = cos_sin
|
||||||
q, k = apply_rotary_emb(q, cos, sin), apply_rotary_emb(k, cos, sin) # QK rotary embedding
|
qk = apply_rotary_emb(qk, cos, sin)
|
||||||
q, k = norm(q), norm(k) # QK norm
|
qk = norm(qk)
|
||||||
q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2) # make head be batch dim, i.e. (B, T, H, D) -> (B, H, T, D)
|
|
||||||
|
# make head be batch dim, i.e. (B, T, H, D) -> (B, H, T, D)
|
||||||
|
q, k = qk.transpose(1, 2).split([self.n_head, self.n_kv_head], dim=1)
|
||||||
|
v = v.transpose(1, 2)
|
||||||
|
|
||||||
# Apply KV cache: insert current k,v into cache, get the full view so far
|
# Apply KV cache: insert current k,v into cache, get the full view so far
|
||||||
if kv_cache is not None:
|
if kv_cache is not None:
|
||||||
|
|
@ -111,12 +127,12 @@ class CausalSelfAttention(nn.Module):
|
||||||
|
|
||||||
|
|
||||||
class MLP(nn.Module):
|
class MLP(nn.Module):
|
||||||
def __init__(self, config):
|
def __init__(self, config: GPTConfig):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=False)
|
self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=False)
|
||||||
self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=False)
|
self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=False)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
x = self.c_fc(x)
|
x = self.c_fc(x)
|
||||||
x = F.relu(x).square()
|
x = F.relu(x).square()
|
||||||
x = self.c_proj(x)
|
x = self.c_proj(x)
|
||||||
|
|
@ -124,19 +140,24 @@ class MLP(nn.Module):
|
||||||
|
|
||||||
|
|
||||||
class Block(nn.Module):
|
class Block(nn.Module):
|
||||||
def __init__(self, config, layer_idx):
|
def __init__(self, config: GPTConfig, layer_idx: int):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.attn = CausalSelfAttention(config, layer_idx)
|
self.attn = CausalSelfAttention(config, layer_idx)
|
||||||
self.mlp = MLP(config)
|
self.mlp = MLP(config)
|
||||||
|
|
||||||
def forward(self, x, cos_sin, kv_cache):
|
def forward(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
cos_sin: tuple[torch.Tensor, torch.Tensor],
|
||||||
|
kv_cache: KVCache,
|
||||||
|
) -> torch.Tensor:
|
||||||
x = x + self.attn(norm(x), cos_sin, kv_cache)
|
x = x + self.attn(norm(x), cos_sin, kv_cache)
|
||||||
x = x + self.mlp(norm(x))
|
x = x + self.mlp(norm(x))
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
class GPT(nn.Module):
|
class GPT(nn.Module):
|
||||||
def __init__(self, config):
|
def __init__(self, config: GPTConfig):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
self.transformer = nn.ModuleDict({
|
self.transformer = nn.ModuleDict({
|
||||||
|
|
@ -170,7 +191,7 @@ class GPT(nn.Module):
|
||||||
if self.transformer.wte.weight.device.type == "cuda":
|
if self.transformer.wte.weight.device.type == "cuda":
|
||||||
self.transformer.wte.to(dtype=torch.bfloat16)
|
self.transformer.wte.to(dtype=torch.bfloat16)
|
||||||
|
|
||||||
def _init_weights(self, module):
|
def _init_weights(self, module: nn.Module):
|
||||||
if isinstance(module, nn.Linear):
|
if isinstance(module, nn.Linear):
|
||||||
# https://arxiv.org/pdf/2310.17813
|
# https://arxiv.org/pdf/2310.17813
|
||||||
fan_out = module.weight.size(0)
|
fan_out = module.weight.size(0)
|
||||||
|
|
@ -183,7 +204,13 @@ class GPT(nn.Module):
|
||||||
torch.nn.init.normal_(module.weight, mean=0.0, std=1.0)
|
torch.nn.init.normal_(module.weight, mean=0.0, std=1.0)
|
||||||
|
|
||||||
# TODO: bump base theta more, e.g. 100K is more common more recently
|
# TODO: bump base theta more, e.g. 100K is more common more recently
|
||||||
def _precompute_rotary_embeddings(self, seq_len, head_dim, base=10000, device=None):
|
def _precompute_rotary_embeddings(
|
||||||
|
self,
|
||||||
|
seq_len: int,
|
||||||
|
head_dim: int,
|
||||||
|
base: float = 10000,
|
||||||
|
device: torch.device | None = None,
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
# autodetect the device from model embeddings
|
# autodetect the device from model embeddings
|
||||||
if device is None:
|
if device is None:
|
||||||
device = self.transformer.wte.weight.device
|
device = self.transformer.wte.weight.device
|
||||||
|
|
@ -199,10 +226,10 @@ class GPT(nn.Module):
|
||||||
cos, sin = cos[None, :, None, :], sin[None, :, None, :] # add batch and head dims for later broadcasting
|
cos, sin = cos[None, :, None, :], sin[None, :, None, :] # add batch and head dims for later broadcasting
|
||||||
return cos, sin
|
return cos, sin
|
||||||
|
|
||||||
def get_device(self):
|
def get_device(self) -> torch.device:
|
||||||
return self.transformer.wte.weight.device
|
return self.transformer.wte.weight.device
|
||||||
|
|
||||||
def estimate_flops(self):
|
def estimate_flops(self) -> float:
|
||||||
""" Return the estimated FLOPs per token for the model. Ref: https://arxiv.org/abs/2204.02311 """
|
""" Return the estimated FLOPs per token for the model. Ref: https://arxiv.org/abs/2204.02311 """
|
||||||
nparams = sum(p.numel() for p in self.parameters())
|
nparams = sum(p.numel() for p in self.parameters())
|
||||||
nparams_embedding = self.transformer.wte.weight.numel()
|
nparams_embedding = self.transformer.wte.weight.numel()
|
||||||
|
|
@ -210,9 +237,9 @@ class GPT(nn.Module):
|
||||||
num_flops_per_token = 6 * (nparams - nparams_embedding) + 12 * l * h * q * t
|
num_flops_per_token = 6 * (nparams - nparams_embedding) + 12 * l * h * q * t
|
||||||
return num_flops_per_token
|
return num_flops_per_token
|
||||||
|
|
||||||
def setup_optimizers(self, unembedding_lr=0.004, embedding_lr=0.2, matrix_lr=0.02, weight_decay=0.0):
|
def setup_optimizers(self, unembedding_lr: float = 0.004, embedding_lr: float = 0.2, matrix_lr: float = 0.02, weight_decay: float = 0.0) -> list[torch.optim.Optimizer]:
|
||||||
model_dim = self.config.n_embd
|
model_dim = self.config.n_embd
|
||||||
ddp, rank, local_rank, world_size = get_dist_info()
|
ddp, rank, _, _ = get_dist_info()
|
||||||
# Separate out all parameters into 3 groups (matrix, embedding, lm_head)
|
# Separate out all parameters into 3 groups (matrix, embedding, lm_head)
|
||||||
matrix_params = list(self.transformer.h.parameters())
|
matrix_params = list(self.transformer.h.parameters())
|
||||||
embedding_params = list(self.transformer.wte.parameters())
|
embedding_params = list(self.transformer.wte.parameters())
|
||||||
|
|
@ -241,8 +268,8 @@ class GPT(nn.Module):
|
||||||
group["initial_lr"] = group["lr"]
|
group["initial_lr"] = group["lr"]
|
||||||
return optimizers
|
return optimizers
|
||||||
|
|
||||||
def forward(self, idx, targets=None, kv_cache=None, loss_reduction='mean'):
|
def forward(self, idx: torch.Tensor, targets: torch.Tensor | None = None, kv_cache: KVCache | None = None, loss_reduction: str = 'mean') -> torch.Tensor | None:
|
||||||
B, T = idx.size()
|
_, T = idx.size()
|
||||||
|
|
||||||
# Grab the rotary embeddings for the current sequence length (they are of shape (1, seq_len, 1, head_dim))
|
# Grab the rotary embeddings for the current sequence length (they are of shape (1, seq_len, 1, head_dim))
|
||||||
assert T <= self.cos.size(1), f"Sequence length grew beyond the rotary embeddings cache: {T} > {self.cos.size(1)}"
|
assert T <= self.cos.size(1), f"Sequence length grew beyond the rotary embeddings cache: {T} > {self.cos.size(1)}"
|
||||||
|
|
@ -276,7 +303,14 @@ class GPT(nn.Module):
|
||||||
return logits
|
return logits
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def generate(self, tokens, max_tokens, temperature=1.0, top_k=None, seed=42):
|
def generate(
|
||||||
|
self,
|
||||||
|
tokens: list[int],
|
||||||
|
max_tokens: int,
|
||||||
|
temperature: float = 1.0,
|
||||||
|
top_k: int | None = None,
|
||||||
|
seed: int = 42,
|
||||||
|
) -> Iterator[int]:
|
||||||
"""
|
"""
|
||||||
Naive autoregressive streaming inference.
|
Naive autoregressive streaming inference.
|
||||||
To make it super simple, let's assume:
|
To make it super simple, let's assume:
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user