mirror of
https://github.com/karpathy/nanochat.git
synced 2026-02-24 20:40:23 +00:00
remove type annotations
This commit is contained in:
parent
69e3cd410d
commit
964d459d9b
|
|
@ -14,15 +14,14 @@ Notable features:
|
|||
import math
|
||||
from functools import partial
|
||||
from dataclasses import dataclass
|
||||
from typing import Iterator
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from nanochat.adamw import DistAdamW
|
||||
from nanochat.common import get_dist_info
|
||||
from nanochat.muon import Muon, DistMuon
|
||||
from nanochat.adamw import DistAdamW
|
||||
|
||||
|
||||
@dataclass
|
||||
|
|
@ -35,12 +34,12 @@ class GPTConfig:
|
|||
n_embd: int = 768
|
||||
|
||||
|
||||
def norm(x: torch.Tensor) -> torch.Tensor:
|
||||
def norm(x):
|
||||
# Purely functional rmsnorm with no learnable params
|
||||
return F.rms_norm(x, (x.size(-1),))
|
||||
|
||||
|
||||
def apply_rotary_emb(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor:
|
||||
def apply_rotary_emb(x, cos, sin):
|
||||
assert x.ndim == 4 # multihead attention
|
||||
d = x.shape[3] // 2
|
||||
x1, x2 = x[..., :d], x[..., d:] # split up last time into two halves
|
||||
|
|
@ -52,7 +51,7 @@ def apply_rotary_emb(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> t
|
|||
|
||||
|
||||
class CausalSelfAttention(nn.Module):
|
||||
def __init__(self, config: GPTConfig, layer_idx: int):
|
||||
def __init__(self, config, layer_idx):
|
||||
super().__init__()
|
||||
self.layer_idx = layer_idx
|
||||
self.n_head = config.n_head
|
||||
|
|
@ -68,12 +67,7 @@ class CausalSelfAttention(nn.Module):
|
|||
)
|
||||
self.c_proj = nn.Linear(self.n_embd, self.n_embd, bias=False)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
cos_sin: tuple[torch.Tensor, torch.Tensor],
|
||||
kv_cache,
|
||||
) -> torch.Tensor:
|
||||
def forward(self, x, cos_sin, kv_cache):
|
||||
B, T, _ = x.size()
|
||||
|
||||
# Project the input to get queries, keys, and values
|
||||
|
|
@ -126,12 +120,12 @@ class CausalSelfAttention(nn.Module):
|
|||
|
||||
|
||||
class MLP(nn.Module):
|
||||
def __init__(self, config: GPTConfig):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=False)
|
||||
self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=False)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
def forward(self, x):
|
||||
x = self.c_fc(x)
|
||||
x = F.relu(x).square()
|
||||
x = self.c_proj(x)
|
||||
|
|
@ -139,24 +133,19 @@ class MLP(nn.Module):
|
|||
|
||||
|
||||
class Block(nn.Module):
|
||||
def __init__(self, config: GPTConfig, layer_idx: int):
|
||||
def __init__(self, config, layer_idx):
|
||||
super().__init__()
|
||||
self.attn = CausalSelfAttention(config, layer_idx)
|
||||
self.mlp = MLP(config)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
cos_sin: tuple[torch.Tensor, torch.Tensor],
|
||||
kv_cache,
|
||||
) -> torch.Tensor:
|
||||
def forward(self, x, cos_sin, kv_cache):
|
||||
x = x + self.attn(norm(x), cos_sin, kv_cache)
|
||||
x = x + self.mlp(norm(x))
|
||||
return x
|
||||
|
||||
|
||||
class GPT(nn.Module):
|
||||
def __init__(self, config: GPTConfig):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.transformer = nn.ModuleDict({
|
||||
|
|
@ -190,7 +179,7 @@ class GPT(nn.Module):
|
|||
if self.transformer.wte.weight.device.type == "cuda":
|
||||
self.transformer.wte.to(dtype=torch.bfloat16)
|
||||
|
||||
def _init_weights(self, module: nn.Module):
|
||||
def _init_weights(self, module):
|
||||
if isinstance(module, nn.Linear):
|
||||
# https://arxiv.org/pdf/2310.17813
|
||||
fan_out = module.weight.size(0)
|
||||
|
|
@ -203,13 +192,7 @@ class GPT(nn.Module):
|
|||
torch.nn.init.normal_(module.weight, mean=0.0, std=1.0)
|
||||
|
||||
# TODO: bump base theta more, e.g. 100K is more common more recently
|
||||
def _precompute_rotary_embeddings(
|
||||
self,
|
||||
seq_len: int,
|
||||
head_dim: int,
|
||||
base: float = 10000,
|
||||
device: torch.device | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
def _precompute_rotary_embeddings(self, seq_len, head_dim, base=10000, device=None):
|
||||
# autodetect the device from model embeddings
|
||||
if device is None:
|
||||
device = self.transformer.wte.weight.device
|
||||
|
|
@ -225,10 +208,10 @@ class GPT(nn.Module):
|
|||
cos, sin = cos[None, :, None, :], sin[None, :, None, :] # add batch and head dims for later broadcasting
|
||||
return cos, sin
|
||||
|
||||
def get_device(self) -> torch.device:
|
||||
def get_device(self):
|
||||
return self.transformer.wte.weight.device
|
||||
|
||||
def estimate_flops(self) -> float:
|
||||
def estimate_flops(self):
|
||||
""" Return the estimated FLOPs per token for the model. Ref: https://arxiv.org/abs/2204.02311 """
|
||||
nparams = sum(p.numel() for p in self.parameters())
|
||||
nparams_embedding = self.transformer.wte.weight.numel()
|
||||
|
|
@ -236,7 +219,7 @@ class GPT(nn.Module):
|
|||
num_flops_per_token = 6 * (nparams - nparams_embedding) + 12 * l * h * q * t
|
||||
return num_flops_per_token
|
||||
|
||||
def setup_optimizers(self, unembedding_lr: float = 0.004, embedding_lr: float = 0.2, matrix_lr: float = 0.02, weight_decay: float = 0.0) -> list[torch.optim.Optimizer]:
|
||||
def setup_optimizers(self, unembedding_lr=0.004, embedding_lr=0.2, matrix_lr=0.02, weight_decay=0.0):
|
||||
model_dim = self.config.n_embd
|
||||
ddp, rank, _, _ = get_dist_info()
|
||||
# Separate out all parameters into 3 groups (matrix, embedding, lm_head)
|
||||
|
|
@ -267,13 +250,7 @@ class GPT(nn.Module):
|
|||
group["initial_lr"] = group["lr"]
|
||||
return optimizers
|
||||
|
||||
def forward(
|
||||
self,
|
||||
idx: torch.Tensor,
|
||||
targets: torch.Tensor | None = None,
|
||||
kv_cache = None,
|
||||
loss_reduction: str = 'mean',
|
||||
) -> torch.Tensor | None:
|
||||
def forward(self, idx, targets=None, kv_cache=None, loss_reduction='mean'):
|
||||
_, T = idx.size()
|
||||
|
||||
# Grab the rotary embeddings for the current sequence length (they are of shape (1, seq_len, 1, head_dim))
|
||||
|
|
@ -308,14 +285,7 @@ class GPT(nn.Module):
|
|||
return logits
|
||||
|
||||
@torch.inference_mode()
|
||||
def generate(
|
||||
self,
|
||||
tokens: list[int],
|
||||
max_tokens: int,
|
||||
temperature: float = 1.0,
|
||||
top_k: int | None = None,
|
||||
seed: int = 42,
|
||||
) -> Iterator[int]:
|
||||
def generate(self, tokens, max_tokens, temperature=1.0, top_k=None, seed=42):
|
||||
"""
|
||||
Naive autoregressive streaming inference.
|
||||
To make it super simple, let's assume:
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user