diff --git a/nanochat/gpt.py b/nanochat/gpt.py index cf78adb..9466e27 100644 --- a/nanochat/gpt.py +++ b/nanochat/gpt.py @@ -14,15 +14,14 @@ Notable features: import math from functools import partial from dataclasses import dataclass -from typing import Iterator import torch import torch.nn as nn import torch.nn.functional as F -from nanochat.adamw import DistAdamW from nanochat.common import get_dist_info from nanochat.muon import Muon, DistMuon +from nanochat.adamw import DistAdamW @dataclass @@ -35,12 +34,12 @@ class GPTConfig: n_embd: int = 768 -def norm(x: torch.Tensor) -> torch.Tensor: +def norm(x): # Purely functional rmsnorm with no learnable params return F.rms_norm(x, (x.size(-1),)) -def apply_rotary_emb(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor: +def apply_rotary_emb(x, cos, sin): assert x.ndim == 4 # multihead attention d = x.shape[3] // 2 x1, x2 = x[..., :d], x[..., d:] # split up last time into two halves @@ -52,7 +51,7 @@ def apply_rotary_emb(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> t class CausalSelfAttention(nn.Module): - def __init__(self, config: GPTConfig, layer_idx: int): + def __init__(self, config, layer_idx): super().__init__() self.layer_idx = layer_idx self.n_head = config.n_head @@ -68,12 +67,7 @@ class CausalSelfAttention(nn.Module): ) self.c_proj = nn.Linear(self.n_embd, self.n_embd, bias=False) - def forward( - self, - x: torch.Tensor, - cos_sin: tuple[torch.Tensor, torch.Tensor], - kv_cache, - ) -> torch.Tensor: + def forward(self, x, cos_sin, kv_cache): B, T, _ = x.size() # Project the input to get queries, keys, and values @@ -126,12 +120,12 @@ class CausalSelfAttention(nn.Module): class MLP(nn.Module): - def __init__(self, config: GPTConfig): + def __init__(self, config): super().__init__() self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=False) self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=False) - def forward(self, x: torch.Tensor) -> torch.Tensor: + def forward(self, x): x = self.c_fc(x) x = F.relu(x).square() x = self.c_proj(x) @@ -139,24 +133,19 @@ class MLP(nn.Module): class Block(nn.Module): - def __init__(self, config: GPTConfig, layer_idx: int): + def __init__(self, config, layer_idx): super().__init__() self.attn = CausalSelfAttention(config, layer_idx) self.mlp = MLP(config) - def forward( - self, - x: torch.Tensor, - cos_sin: tuple[torch.Tensor, torch.Tensor], - kv_cache, - ) -> torch.Tensor: + def forward(self, x, cos_sin, kv_cache): x = x + self.attn(norm(x), cos_sin, kv_cache) x = x + self.mlp(norm(x)) return x class GPT(nn.Module): - def __init__(self, config: GPTConfig): + def __init__(self, config): super().__init__() self.config = config self.transformer = nn.ModuleDict({ @@ -190,7 +179,7 @@ class GPT(nn.Module): if self.transformer.wte.weight.device.type == "cuda": self.transformer.wte.to(dtype=torch.bfloat16) - def _init_weights(self, module: nn.Module): + def _init_weights(self, module): if isinstance(module, nn.Linear): # https://arxiv.org/pdf/2310.17813 fan_out = module.weight.size(0) @@ -203,13 +192,7 @@ class GPT(nn.Module): torch.nn.init.normal_(module.weight, mean=0.0, std=1.0) # TODO: bump base theta more, e.g. 100K is more common more recently - def _precompute_rotary_embeddings( - self, - seq_len: int, - head_dim: int, - base: float = 10000, - device: torch.device | None = None, - ) -> tuple[torch.Tensor, torch.Tensor]: + def _precompute_rotary_embeddings(self, seq_len, head_dim, base=10000, device=None): # autodetect the device from model embeddings if device is None: device = self.transformer.wte.weight.device @@ -225,10 +208,10 @@ class GPT(nn.Module): cos, sin = cos[None, :, None, :], sin[None, :, None, :] # add batch and head dims for later broadcasting return cos, sin - def get_device(self) -> torch.device: + def get_device(self): return self.transformer.wte.weight.device - def estimate_flops(self) -> float: + def estimate_flops(self): """ Return the estimated FLOPs per token for the model. Ref: https://arxiv.org/abs/2204.02311 """ nparams = sum(p.numel() for p in self.parameters()) nparams_embedding = self.transformer.wte.weight.numel() @@ -236,7 +219,7 @@ class GPT(nn.Module): num_flops_per_token = 6 * (nparams - nparams_embedding) + 12 * l * h * q * t return num_flops_per_token - def setup_optimizers(self, unembedding_lr: float = 0.004, embedding_lr: float = 0.2, matrix_lr: float = 0.02, weight_decay: float = 0.0) -> list[torch.optim.Optimizer]: + def setup_optimizers(self, unembedding_lr=0.004, embedding_lr=0.2, matrix_lr=0.02, weight_decay=0.0): model_dim = self.config.n_embd ddp, rank, _, _ = get_dist_info() # Separate out all parameters into 3 groups (matrix, embedding, lm_head) @@ -267,13 +250,7 @@ class GPT(nn.Module): group["initial_lr"] = group["lr"] return optimizers - def forward( - self, - idx: torch.Tensor, - targets: torch.Tensor | None = None, - kv_cache = None, - loss_reduction: str = 'mean', - ) -> torch.Tensor | None: + def forward(self, idx, targets=None, kv_cache=None, loss_reduction='mean'): _, T = idx.size() # Grab the rotary embeddings for the current sequence length (they are of shape (1, seq_len, 1, head_dim)) @@ -308,14 +285,7 @@ class GPT(nn.Module): return logits @torch.inference_mode() - def generate( - self, - tokens: list[int], - max_tokens: int, - temperature: float = 1.0, - top_k: int | None = None, - seed: int = 42, - ) -> Iterator[int]: + def generate(self, tokens, max_tokens, temperature=1.0, top_k=None, seed=42): """ Naive autoregressive streaming inference. To make it super simple, let's assume: