remove type annotations

This commit is contained in:
Matthew Murphy 2025-10-29 01:28:19 -07:00
parent 69e3cd410d
commit 964d459d9b

View File

@ -14,15 +14,14 @@ Notable features:
import math
from functools import partial
from dataclasses import dataclass
from typing import Iterator
import torch
import torch.nn as nn
import torch.nn.functional as F
from nanochat.adamw import DistAdamW
from nanochat.common import get_dist_info
from nanochat.muon import Muon, DistMuon
from nanochat.adamw import DistAdamW
@dataclass
@ -35,12 +34,12 @@ class GPTConfig:
n_embd: int = 768
def norm(x: torch.Tensor) -> torch.Tensor:
def norm(x):
# Purely functional rmsnorm with no learnable params
return F.rms_norm(x, (x.size(-1),))
def apply_rotary_emb(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor:
def apply_rotary_emb(x, cos, sin):
assert x.ndim == 4 # multihead attention
d = x.shape[3] // 2
x1, x2 = x[..., :d], x[..., d:] # split up last time into two halves
@ -52,7 +51,7 @@ def apply_rotary_emb(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> t
class CausalSelfAttention(nn.Module):
def __init__(self, config: GPTConfig, layer_idx: int):
def __init__(self, config, layer_idx):
super().__init__()
self.layer_idx = layer_idx
self.n_head = config.n_head
@ -68,12 +67,7 @@ class CausalSelfAttention(nn.Module):
)
self.c_proj = nn.Linear(self.n_embd, self.n_embd, bias=False)
def forward(
self,
x: torch.Tensor,
cos_sin: tuple[torch.Tensor, torch.Tensor],
kv_cache,
) -> torch.Tensor:
def forward(self, x, cos_sin, kv_cache):
B, T, _ = x.size()
# Project the input to get queries, keys, and values
@ -126,12 +120,12 @@ class CausalSelfAttention(nn.Module):
class MLP(nn.Module):
def __init__(self, config: GPTConfig):
def __init__(self, config):
super().__init__()
self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=False)
self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=False)
def forward(self, x: torch.Tensor) -> torch.Tensor:
def forward(self, x):
x = self.c_fc(x)
x = F.relu(x).square()
x = self.c_proj(x)
@ -139,24 +133,19 @@ class MLP(nn.Module):
class Block(nn.Module):
def __init__(self, config: GPTConfig, layer_idx: int):
def __init__(self, config, layer_idx):
super().__init__()
self.attn = CausalSelfAttention(config, layer_idx)
self.mlp = MLP(config)
def forward(
self,
x: torch.Tensor,
cos_sin: tuple[torch.Tensor, torch.Tensor],
kv_cache,
) -> torch.Tensor:
def forward(self, x, cos_sin, kv_cache):
x = x + self.attn(norm(x), cos_sin, kv_cache)
x = x + self.mlp(norm(x))
return x
class GPT(nn.Module):
def __init__(self, config: GPTConfig):
def __init__(self, config):
super().__init__()
self.config = config
self.transformer = nn.ModuleDict({
@ -190,7 +179,7 @@ class GPT(nn.Module):
if self.transformer.wte.weight.device.type == "cuda":
self.transformer.wte.to(dtype=torch.bfloat16)
def _init_weights(self, module: nn.Module):
def _init_weights(self, module):
if isinstance(module, nn.Linear):
# https://arxiv.org/pdf/2310.17813
fan_out = module.weight.size(0)
@ -203,13 +192,7 @@ class GPT(nn.Module):
torch.nn.init.normal_(module.weight, mean=0.0, std=1.0)
# TODO: bump base theta more, e.g. 100K is more common more recently
def _precompute_rotary_embeddings(
self,
seq_len: int,
head_dim: int,
base: float = 10000,
device: torch.device | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
def _precompute_rotary_embeddings(self, seq_len, head_dim, base=10000, device=None):
# autodetect the device from model embeddings
if device is None:
device = self.transformer.wte.weight.device
@ -225,10 +208,10 @@ class GPT(nn.Module):
cos, sin = cos[None, :, None, :], sin[None, :, None, :] # add batch and head dims for later broadcasting
return cos, sin
def get_device(self) -> torch.device:
def get_device(self):
return self.transformer.wte.weight.device
def estimate_flops(self) -> float:
def estimate_flops(self):
""" Return the estimated FLOPs per token for the model. Ref: https://arxiv.org/abs/2204.02311 """
nparams = sum(p.numel() for p in self.parameters())
nparams_embedding = self.transformer.wte.weight.numel()
@ -236,7 +219,7 @@ class GPT(nn.Module):
num_flops_per_token = 6 * (nparams - nparams_embedding) + 12 * l * h * q * t
return num_flops_per_token
def setup_optimizers(self, unembedding_lr: float = 0.004, embedding_lr: float = 0.2, matrix_lr: float = 0.02, weight_decay: float = 0.0) -> list[torch.optim.Optimizer]:
def setup_optimizers(self, unembedding_lr=0.004, embedding_lr=0.2, matrix_lr=0.02, weight_decay=0.0):
model_dim = self.config.n_embd
ddp, rank, _, _ = get_dist_info()
# Separate out all parameters into 3 groups (matrix, embedding, lm_head)
@ -267,13 +250,7 @@ class GPT(nn.Module):
group["initial_lr"] = group["lr"]
return optimizers
def forward(
self,
idx: torch.Tensor,
targets: torch.Tensor | None = None,
kv_cache = None,
loss_reduction: str = 'mean',
) -> torch.Tensor | None:
def forward(self, idx, targets=None, kv_cache=None, loss_reduction='mean'):
_, T = idx.size()
# Grab the rotary embeddings for the current sequence length (they are of shape (1, seq_len, 1, head_dim))
@ -308,14 +285,7 @@ class GPT(nn.Module):
return logits
@torch.inference_mode()
def generate(
self,
tokens: list[int],
max_tokens: int,
temperature: float = 1.0,
top_k: int | None = None,
seed: int = 42,
) -> Iterator[int]:
def generate(self, tokens, max_tokens, temperature=1.0, top_k=None, seed=42):
"""
Naive autoregressive streaming inference.
To make it super simple, let's assume: