This commit is contained in:
Vilhelm Toivonen 2025-10-15 13:36:07 -04:00 committed by GitHub
commit 8deb27996f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 280 additions and 29 deletions

View File

@ -4,16 +4,18 @@ Notable features:
- rotary embeddings (and no positional embeddings)
- QK norm
- untied weights for token embedding and lm_head
- relu^2 activation in MLP
- relu^2 / gated MLPs with width scaling
- norm after token embedding
- no learnable params in rmsnorm
- no bias in linear layers
- Multi-Query Attention (MQA) support for more efficient inference
- Optional fused QKV projection for fewer matmuls
"""
import math
from functools import partial
from dataclasses import dataclass
from typing import Optional
import torch
import torch.nn as nn
@ -31,6 +33,10 @@ class GPTConfig:
n_head: int = 6 # number of query heads
n_kv_head: int = 6 # number of key/value heads (MQA)
n_embd: int = 768
use_fused_qkv: bool = True
mlp_type: str = "relu2" # choices: relu2, swiglu, geglu
mlp_width_mult: float = 4.0
mlp_glu_width_mult: Optional[float] = None
def norm(x):
@ -69,20 +75,31 @@ class CausalSelfAttention(nn.Module):
self.n_kv_head = config.n_kv_head
self.n_embd = config.n_embd
self.head_dim = self.n_embd // self.n_head
self.total_qkv_heads = self.n_head + 2 * self.n_kv_head
assert self.n_embd % self.n_head == 0
assert self.n_kv_head <= self.n_head and self.n_head % self.n_kv_head == 0
self.c_q = nn.Linear(self.n_embd, self.n_head * self.head_dim, bias=False)
self.c_k = nn.Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False)
self.c_v = nn.Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False)
self.use_fused_qkv = config.use_fused_qkv
if self.use_fused_qkv:
# Single projection keeps matmul count minimal; split happens in forward pass.
self.c_attn = nn.Linear(self.n_embd, self.total_qkv_heads * self.head_dim, bias=False)
else:
self.c_q = nn.Linear(self.n_embd, self.n_head * self.head_dim, bias=False)
self.c_k = nn.Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False)
self.c_v = nn.Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False)
self.c_proj = nn.Linear(self.n_embd, self.n_embd, bias=False)
def forward(self, x, cos_sin, kv_cache):
B, T, C = x.size()
# Project the input to get queries, keys, and values
q = self.c_q(x).view(B, T, self.n_head, self.head_dim)
k = self.c_k(x).view(B, T, self.n_kv_head, self.head_dim)
v = self.c_v(x).view(B, T, self.n_kv_head, self.head_dim)
if self.use_fused_qkv:
qkv = self.c_attn(x)
qkv = qkv.view(B, T, self.total_qkv_heads, self.head_dim)
q, k, v = torch.split(qkv, [self.n_head, self.n_kv_head, self.n_kv_head], dim=2)
else:
q = self.c_q(x).view(B, T, self.n_head, self.head_dim)
k = self.c_k(x).view(B, T, self.n_kv_head, self.head_dim)
v = self.c_v(x).view(B, T, self.n_kv_head, self.head_dim)
# Apply Rotary Embeddings to queries and keys to get relative positional encoding
cos, sin = cos_sin
@ -125,16 +142,63 @@ class CausalSelfAttention(nn.Module):
y = self.c_proj(y)
return y
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs):
fused_key = prefix + "c_attn.weight"
q_key = prefix + "c_q.weight"
k_key = prefix + "c_k.weight"
v_key = prefix + "c_v.weight"
has_fused = fused_key in state_dict
has_split = all(key in state_dict for key in (q_key, k_key, v_key))
if self.use_fused_qkv and not has_fused and has_split:
# Convert legacy split weights into the fused layout expected by this module.
q = state_dict.pop(q_key)
k = state_dict.pop(k_key)
v = state_dict.pop(v_key)
state_dict[fused_key] = torch.cat([q, k, v], dim=0)
elif not self.use_fused_qkv and has_fused and not has_split:
fused = state_dict.pop(fused_key)
q_out = self.n_head * self.head_dim
kv_out = self.n_kv_head * self.head_dim
state_dict[q_key] = fused[:q_out]
state_dict[k_key] = fused[q_out:q_out + kv_out]
state_dict[v_key] = fused[q_out + kv_out:]
super()._load_from_state_dict(state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs)
class MLP(nn.Module):
def __init__(self, config):
super().__init__()
self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=False)
self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=False)
self.mlp_type = config.mlp_type.lower()
base_hidden = max(1, int(config.mlp_width_mult * config.n_embd))
if self.mlp_type in {"swiglu", "geglu"}:
width_mult = config.mlp_glu_width_mult
if width_mult is None:
# 2/3 factor keeps parameter count aligned with the ReLU^2 baseline.
width_mult = config.mlp_width_mult * (2.0 / 3.0)
hidden_dim = max(1, int(width_mult * config.n_embd))
self.c_fc = nn.Linear(config.n_embd, 2 * hidden_dim, bias=False)
self.c_proj = nn.Linear(hidden_dim, config.n_embd, bias=False)
self.hidden_dim = hidden_dim
else:
hidden_dim = base_hidden
self.c_fc = nn.Linear(config.n_embd, hidden_dim, bias=False)
self.c_proj = nn.Linear(hidden_dim, config.n_embd, bias=False)
self.hidden_dim = hidden_dim
def forward(self, x):
x = self.c_fc(x)
x = F.relu(x).square()
if self.mlp_type == "swiglu":
x1, x2 = torch.chunk(x, 2, dim=-1)
x = F.silu(x1) * x2
elif self.mlp_type == "geglu":
x1, x2 = torch.chunk(x, 2, dim=-1)
x = F.gelu(x1, approximate="tanh") * x2
else:
x = F.relu(x).square()
x = self.c_proj(x)
return x
@ -169,8 +233,9 @@ class GPT(nn.Module):
cos, sin = self._precompute_rotary_embeddings(self.rotary_seq_len, head_dim)
self.register_buffer("cos", cos, persistent=False) # persistent=False means it's not saved to the checkpoint
self.register_buffer("sin", sin, persistent=False)
# Cast the embeddings from fp32 to bf16: optim can tolerate it and it saves memory: both in the model and the activations
self.transformer.wte.to(dtype=torch.bfloat16)
# Cast embeddings to bf16 only when CUDA is available to keep CPU inference/tests simple.
if torch.cuda.is_available():
self.transformer.wte.to(dtype=torch.bfloat16)
def init_weights(self):
self.apply(self._init_weights)

50
nanochat/schedules.py Normal file
View File

@ -0,0 +1,50 @@
"""Learning rate schedule utilities."""
def compute_lr_multiplier(
step: int,
total_steps: int,
*,
warmup_ratio: float = 0.0,
warmdown_ratio: float = 0.0,
final_lr_frac: float = 0.0,
) -> float:
"""Compute LR multiplier with linear warmup and warmdown phases.
The multiplier ramps linearly from 0 -> 1 during warmup, stays at 1, then
decays linearly to ``final_lr_frac`` during warmdown. Ratios are expressed
as fractions of ``total_steps``.
"""
if total_steps <= 0:
raise ValueError("total_steps must be positive")
step = min(step, total_steps)
warmup_steps = int(round(warmup_ratio * total_steps))
warmdown_steps = int(round(warmdown_ratio * total_steps))
if warmup_steps > 0 and step < warmup_steps:
return (step + 1) / warmup_steps
if warmdown_steps > 0 and step >= total_steps - warmdown_steps:
progress = (total_steps - step) / max(1, warmdown_steps)
return progress + (1 - progress) * final_lr_frac
return 1.0
def apply_lr_multiplier(
optimizer,
multiplier: float,
*,
base_key: str = "initial_lr",
) -> float:
"""Apply ``multiplier`` to an optimizer in-place using ``base_key`` as base LR."""
for group in optimizer.param_groups:
base_lr = group.get(base_key)
if base_lr is None:
base_lr = group["lr"]
group[base_key] = base_lr
group["lr"] = base_lr * multiplier
return multiplier

View File

@ -21,6 +21,7 @@ from nanochat.tokenizer import get_tokenizer, get_token_bytes
from nanochat.checkpoint_manager import save_checkpoint
from nanochat.loss_eval import evaluate_bpb
from nanochat.engine import Engine
from nanochat.schedules import compute_lr_multiplier, apply_lr_multiplier
from scripts.base_eval import evaluate_model
print_banner()
@ -142,19 +143,12 @@ x, y = next(train_loader) # kick off load of the very first batch of data
# Learning rate scheduler
# TODO: experiment with a short warmup for the AdamW params (expecting slight improvement)
warmup_ratio = 0.0 # ratio of iterations for LR warmup
adamw_use_lr_warmup = False
adamw_warmup_ratio = 0.0
muon_use_lr_warmup = False
muon_warmup_ratio = 0.0
warmdown_ratio = 0.2 # ratio of iterations for LR warmdown
final_lr_frac = 0.0 # final LR is this fraction of the initial LR
def get_lr_multiplier(it):
warmup_iters = round(warmup_ratio * num_iterations)
warmdown_iters = round(warmdown_ratio * num_iterations)
if it < warmup_iters:
return (it + 1) / warmup_iters
elif it <= num_iterations - warmdown_iters:
return 1.0
else:
progress = (num_iterations - it) / warmdown_iters
return progress * 1.0 + (1 - progress) * final_lr_frac
# Momentum scheduler for Muon optimizer
def get_muon_momentum(it):
@ -265,10 +259,22 @@ for step in range(num_iterations + 1):
if grad_clip > 0.0:
torch.nn.utils.clip_grad_norm_(orig_model.parameters(), grad_clip)
# step the optimizers
lrm = get_lr_multiplier(step)
for opt in optimizers:
for group in opt.param_groups:
group["lr"] = group["initial_lr"] * lrm
adamw_lrm = compute_lr_multiplier(
step,
num_iterations,
warmup_ratio=adamw_warmup_ratio if adamw_use_lr_warmup else 0.0,
warmdown_ratio=warmdown_ratio,
final_lr_frac=final_lr_frac,
)
muon_lrm = compute_lr_multiplier(
step,
num_iterations,
warmup_ratio=muon_warmup_ratio if muon_use_lr_warmup else 0.0,
warmdown_ratio=warmdown_ratio,
final_lr_frac=final_lr_frac,
)
apply_lr_multiplier(adamw_optimizer, adamw_lrm)
apply_lr_multiplier(muon_optimizer, muon_lrm)
muon_momentum = get_muon_momentum(step)
for group in muon_optimizer.param_groups:
group["momentum"] = muon_momentum
@ -290,14 +296,15 @@ for step in range(num_iterations + 1):
mfu = 100 * flops_per_sec / promised_flops_per_sec_h100 # in %
if step > 10:
total_training_time += dt # only count the time after the first 10 steps
print0(f"step {step:05d}/{num_iterations:05d} ({pct_done:.2f}%) | loss: {debiased_smooth_loss:.6f} | lrm: {lrm:.2f} | dt: {dt * 1000:.2f}ms | tok/sec: {tok_per_sec:,} | mfu: {mfu:.2f} | total time: {total_training_time/60:.2f}m")
print0(f"step {step:05d}/{num_iterations:05d} ({pct_done:.2f}%) | loss: {debiased_smooth_loss:.6f} | lrm_adamw: {adamw_lrm:.2f} | lrm_muon: {muon_lrm:.2f} | dt: {dt * 1000:.2f}ms | tok/sec: {tok_per_sec:,} | mfu: {mfu:.2f} | total time: {total_training_time/60:.2f}m")
if step % 100 == 0:
wandb_run.log({
"step": step,
"total_training_flops": flops_so_far,
"total_training_time": total_training_time,
"train/loss": debiased_smooth_loss,
"train/lrm": lrm,
"train/adamw_lrm": adamw_lrm,
"train/muon_lrm": muon_lrm,
"train/dt": dt,
"train/tok_per_sec": tok_per_sec,
"train/mfu": mfu,

View File

@ -0,0 +1,59 @@
import torch
from dataclasses import replace
from nanochat.gpt import GPT, GPTConfig
def _tiny_config(**kwargs):
base = GPTConfig(
sequence_len=16,
vocab_size=128,
n_layer=1,
n_head=4,
n_kv_head=2,
n_embd=64,
use_fused_qkv=False,
mlp_type="relu2",
)
return replace(base, **kwargs)
def test_fused_qkv_matches_legacy_split_projection():
torch.manual_seed(0)
split_cfg = _tiny_config(use_fused_qkv=False)
fused_cfg = replace(split_cfg, use_fused_qkv=True)
split_model = GPT(split_cfg)
split_model.init_weights()
state = split_model.state_dict()
fused_model = GPT(fused_cfg)
fused_model.init_weights()
fused_model.load_state_dict(state, strict=True)
tokens = torch.randint(0, split_cfg.vocab_size, (2, 5))
with torch.no_grad():
logits_split = split_model(tokens)
logits_fused = fused_model(tokens)
assert torch.allclose(logits_split, logits_fused, atol=1e-5)
def test_split_loads_from_fused_state_dict():
torch.manual_seed(1)
fused_cfg = _tiny_config(use_fused_qkv=True)
fused_model = GPT(fused_cfg)
fused_model.init_weights()
state = fused_model.state_dict()
split_cfg = replace(fused_cfg, use_fused_qkv=False)
split_model = GPT(split_cfg)
split_model.init_weights()
split_model.load_state_dict(state, strict=True)
tokens = torch.randint(0, split_cfg.vocab_size, (1, 7))
with torch.no_grad():
logits_split = split_model(tokens)
logits_fused = fused_model(tokens)
assert torch.allclose(logits_split, logits_fused, atol=1e-5)

View File

@ -0,0 +1,48 @@
import torch
from dataclasses import replace
from nanochat.gpt import GPTConfig, MLP
def _mlp_config(**kwargs):
cfg = GPTConfig(
n_embd=32,
mlp_type="relu2",
mlp_width_mult=4.0,
mlp_glu_width_mult=None,
n_layer=1,
n_head=4,
n_kv_head=4,
sequence_len=8,
vocab_size=128,
)
return replace(cfg, **kwargs)
def test_relu2_mlp_shape():
cfg = _mlp_config(mlp_type="relu2")
mlp = MLP(cfg)
assert mlp.c_fc.weight.shape[0] == int(cfg.mlp_width_mult * cfg.n_embd)
x = torch.randn(2, 5, cfg.n_embd)
out = mlp(x)
assert out.shape == x.shape
def test_swiglu_width_scaling_defaults_to_two_thirds():
cfg = _mlp_config(mlp_type="swiglu")
mlp = MLP(cfg)
expected_hidden = int(cfg.n_embd * cfg.mlp_width_mult * (2.0 / 3.0))
assert mlp.c_fc.weight.shape[0] == 2 * expected_hidden
x = torch.randn(4, 3, cfg.n_embd)
out = mlp(x)
assert out.shape == x.shape
def test_geglu_respects_custom_width():
cfg = _mlp_config(mlp_type="geglu", mlp_glu_width_mult=0.5)
mlp = MLP(cfg)
expected_hidden = int(cfg.n_embd * 0.5)
assert mlp.c_fc.weight.shape[0] == 2 * expected_hidden
x = torch.randn(3, 2, cfg.n_embd)
out = mlp(x)
assert out.shape == x.shape

22
tests/test_schedules.py Normal file
View File

@ -0,0 +1,22 @@
import torch
import pytest
from nanochat.schedules import compute_lr_multiplier, apply_lr_multiplier
def test_compute_lr_multiplier_handles_warmup():
multiplier = compute_lr_multiplier(0, 100, warmup_ratio=0.1)
assert multiplier == pytest.approx(0.1)
def test_compute_lr_multiplier_handles_warmdown():
multiplier = compute_lr_multiplier(95, 100, warmdown_ratio=0.1, final_lr_frac=0.1)
# progress = (100-95)/10 = 0.5 -> 0.5 + 0.5*0.1
assert multiplier == pytest.approx(0.55)
def test_apply_lr_multiplier_uses_initial_lr():
param = torch.nn.Parameter(torch.ones(()))
opt = torch.optim.SGD([param], lr=0.2)
apply_lr_multiplier(opt, 0.5)
assert opt.param_groups[0]["lr"] == pytest.approx(0.1)
assert opt.param_groups[0]["initial_lr"] == pytest.approx(0.2)
apply_lr_multiplier(opt, 1.0)
assert opt.param_groups[0]["lr"] == pytest.approx(0.2)