diff --git a/nanochat/gpt.py b/nanochat/gpt.py index b640f1e..d3f7859 100644 --- a/nanochat/gpt.py +++ b/nanochat/gpt.py @@ -14,14 +14,17 @@ Notable features: import math from functools import partial from dataclasses import dataclass +from typing import Iterator import torch import torch.nn as nn import torch.nn.functional as F -from nanochat.common import get_dist_info, print0 -from nanochat.muon import Muon, DistMuon from nanochat.adamw import DistAdamW +from nanochat.common import get_dist_info +from nanochat.engine import KVCache +from nanochat.muon import Muon, DistMuon + @dataclass class GPTConfig: @@ -33,12 +36,12 @@ class GPTConfig: n_embd: int = 768 -def norm(x): +def norm(x: torch.Tensor) -> torch.Tensor: # Purely functional rmsnorm with no learnable params return F.rms_norm(x, (x.size(-1),)) -def apply_rotary_emb(x, cos, sin): +def apply_rotary_emb(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor: assert x.ndim == 4 # multihead attention d = x.shape[3] // 2 x1, x2 = x[..., :d], x[..., d:] # split up last time into two halves @@ -48,8 +51,9 @@ def apply_rotary_emb(x, cos, sin): out = out.to(x.dtype) # ensure input/output dtypes match return out + class CausalSelfAttention(nn.Module): - def __init__(self, config, layer_idx): + def __init__(self, config: GPTConfig, layer_idx: int): super().__init__() self.layer_idx = layer_idx self.n_head = config.n_head @@ -58,24 +62,36 @@ class CausalSelfAttention(nn.Module): self.head_dim = self.n_embd // self.n_head assert self.n_embd % self.n_head == 0 assert self.n_kv_head <= self.n_head and self.n_head % self.n_kv_head == 0 - self.c_q = nn.Linear(self.n_embd, self.n_head * self.head_dim, bias=False) - self.c_k = nn.Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False) - self.c_v = nn.Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False) + self.c_qkv = nn.Linear( + self.n_embd, + (self.n_head + 2 * self.n_kv_head) * self.head_dim, + bias=False, + ) self.c_proj = nn.Linear(self.n_embd, self.n_embd, bias=False) - def forward(self, x, cos_sin, kv_cache): - B, T, C = x.size() + def forward( + self, + x: torch.Tensor, + cos_sin: tuple[torch.Tensor, torch.Tensor], + kv_cache: KVCache, + ) -> torch.Tensor: + B, T, _ = x.size() # Project the input to get queries, keys, and values - q = self.c_q(x).view(B, T, self.n_head, self.head_dim) - k = self.c_k(x).view(B, T, self.n_kv_head, self.head_dim) - v = self.c_v(x).view(B, T, self.n_kv_head, self.head_dim) + qk, v = ( + self.c_qkv(x) + .view(B, T, self.n_head + 2 * self.n_kv_head, self.head_dim) + .split([self.n_head + self.n_kv_head, self.n_kv_head], dim=2) + ) # Apply Rotary Embeddings to queries and keys to get relative positional encoding cos, sin = cos_sin - q, k = apply_rotary_emb(q, cos, sin), apply_rotary_emb(k, cos, sin) # QK rotary embedding - q, k = norm(q), norm(k) # QK norm - q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2) # make head be batch dim, i.e. (B, T, H, D) -> (B, H, T, D) + qk = apply_rotary_emb(qk, cos, sin) + qk = norm(qk) + + # make head be batch dim, i.e. (B, T, H, D) -> (B, H, T, D) + q, k = qk.transpose(1, 2).split([self.n_head, self.n_kv_head], dim=1) + v = v.transpose(1, 2) # Apply KV cache: insert current k,v into cache, get the full view so far if kv_cache is not None: @@ -111,12 +127,12 @@ class CausalSelfAttention(nn.Module): class MLP(nn.Module): - def __init__(self, config): + def __init__(self, config: GPTConfig): super().__init__() self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=False) self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=False) - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.c_fc(x) x = F.relu(x).square() x = self.c_proj(x) @@ -124,19 +140,24 @@ class MLP(nn.Module): class Block(nn.Module): - def __init__(self, config, layer_idx): + def __init__(self, config: GPTConfig, layer_idx: int): super().__init__() self.attn = CausalSelfAttention(config, layer_idx) self.mlp = MLP(config) - def forward(self, x, cos_sin, kv_cache): + def forward( + self, + x: torch.Tensor, + cos_sin: tuple[torch.Tensor, torch.Tensor], + kv_cache: KVCache, + ) -> torch.Tensor: x = x + self.attn(norm(x), cos_sin, kv_cache) x = x + self.mlp(norm(x)) return x class GPT(nn.Module): - def __init__(self, config): + def __init__(self, config: GPTConfig): super().__init__() self.config = config self.transformer = nn.ModuleDict({ @@ -170,7 +191,7 @@ class GPT(nn.Module): if self.transformer.wte.weight.device.type == "cuda": self.transformer.wte.to(dtype=torch.bfloat16) - def _init_weights(self, module): + def _init_weights(self, module: nn.Module): if isinstance(module, nn.Linear): # https://arxiv.org/pdf/2310.17813 fan_out = module.weight.size(0) @@ -183,7 +204,13 @@ class GPT(nn.Module): torch.nn.init.normal_(module.weight, mean=0.0, std=1.0) # TODO: bump base theta more, e.g. 100K is more common more recently - def _precompute_rotary_embeddings(self, seq_len, head_dim, base=10000, device=None): + def _precompute_rotary_embeddings( + self, + seq_len: int, + head_dim: int, + base: float = 10000, + device: torch.device | None = None, + ) -> tuple[torch.Tensor, torch.Tensor]: # autodetect the device from model embeddings if device is None: device = self.transformer.wte.weight.device @@ -199,10 +226,10 @@ class GPT(nn.Module): cos, sin = cos[None, :, None, :], sin[None, :, None, :] # add batch and head dims for later broadcasting return cos, sin - def get_device(self): + def get_device(self) -> torch.device: return self.transformer.wte.weight.device - def estimate_flops(self): + def estimate_flops(self) -> float: """ Return the estimated FLOPs per token for the model. Ref: https://arxiv.org/abs/2204.02311 """ nparams = sum(p.numel() for p in self.parameters()) nparams_embedding = self.transformer.wte.weight.numel() @@ -210,9 +237,9 @@ class GPT(nn.Module): num_flops_per_token = 6 * (nparams - nparams_embedding) + 12 * l * h * q * t return num_flops_per_token - def setup_optimizers(self, unembedding_lr=0.004, embedding_lr=0.2, matrix_lr=0.02, weight_decay=0.0): + def setup_optimizers(self, unembedding_lr: float = 0.004, embedding_lr: float = 0.2, matrix_lr: float = 0.02, weight_decay: float = 0.0) -> list[torch.optim.Optimizer]: model_dim = self.config.n_embd - ddp, rank, local_rank, world_size = get_dist_info() + ddp, rank, _, _ = get_dist_info() # Separate out all parameters into 3 groups (matrix, embedding, lm_head) matrix_params = list(self.transformer.h.parameters()) embedding_params = list(self.transformer.wte.parameters()) @@ -241,8 +268,8 @@ class GPT(nn.Module): group["initial_lr"] = group["lr"] return optimizers - def forward(self, idx, targets=None, kv_cache=None, loss_reduction='mean'): - B, T = idx.size() + def forward(self, idx: torch.Tensor, targets: torch.Tensor | None = None, kv_cache: KVCache | None = None, loss_reduction: str = 'mean') -> torch.Tensor | None: + _, T = idx.size() # Grab the rotary embeddings for the current sequence length (they are of shape (1, seq_len, 1, head_dim)) assert T <= self.cos.size(1), f"Sequence length grew beyond the rotary embeddings cache: {T} > {self.cos.size(1)}" @@ -276,7 +303,14 @@ class GPT(nn.Module): return logits @torch.inference_mode() - def generate(self, tokens, max_tokens, temperature=1.0, top_k=None, seed=42): + def generate( + self, + tokens: list[int], + max_tokens: int, + temperature: float = 1.0, + top_k: int | None = None, + seed: int = 42, + ) -> Iterator[int]: """ Naive autoregressive streaming inference. To make it super simple, let's assume: