mirror of
https://github.com/karpathy/nanochat.git
synced 2026-04-02 05:35:19 +00:00
Two rounds of WeCo-guided D12 optimization, validated on D24. Key changes: smaller sliding windows (seq/8), VE every 3rd layer, RoPE 200K, smear removed, exponential residual decay, optimizer buffer pre-allocation. Mean CORE=0.2591 across 3 D24 runs.
510 lines
25 KiB
Python
510 lines
25 KiB
Python
"""
|
|
GPT model (rewrite, a lot simpler)
|
|
Notable features:
|
|
- rotary embeddings (and no positional embeddings)
|
|
- QK norm
|
|
- untied weights for token embedding and lm_head
|
|
- relu^2 activation in MLP
|
|
- norm after token embedding
|
|
- no learnable params in rmsnorm
|
|
- no bias in linear layers
|
|
- Group-Query Attention (GQA) support for more efficient inference
|
|
- Flash Attention 3 integration
|
|
"""
|
|
|
|
from functools import partial
|
|
from dataclasses import dataclass
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
|
|
from nanochat.common import get_dist_info, print0, COMPUTE_DTYPE
|
|
from nanochat.optim import MuonAdamW, DistMuonAdamW
|
|
|
|
# Our custom Flash Attention module that automatically uses FA3 on Hopper+ and SDPA fallback elsewhere
|
|
from nanochat.flash_attention import flash_attn
|
|
|
|
@dataclass
|
|
class GPTConfig:
|
|
sequence_len: int = 2048
|
|
vocab_size: int = 32768
|
|
n_layer: int = 12
|
|
n_head: int = 6 # number of query heads
|
|
n_kv_head: int = 6 # number of key/value heads (GQA)
|
|
n_embd: int = 768
|
|
# Sliding window attention pattern string, tiled across layers. Final layer always L.
|
|
# Characters: L=long (full context), S=short (quarter context)
|
|
# Examples: "L"=all full context, "SL"=alternating, "SSL"=two short then one long
|
|
window_pattern: str = "SSSL"
|
|
|
|
|
|
def norm(x):
|
|
return F.rms_norm(x, (x.size(-1),)) # note that this will run in bf16, seems ok
|
|
|
|
class Linear(nn.Linear):
|
|
"""nn.Linear that casts weights to match input dtype in forward.
|
|
Replaces autocast: master weights stay fp32 for optimizer precision,
|
|
but matmuls run in the activation dtype (typically bf16 from embeddings)."""
|
|
def forward(self, x):
|
|
w = self.weight
|
|
if w.dtype != x.dtype:
|
|
w = w.to(dtype=x.dtype)
|
|
return F.linear(x, w)
|
|
|
|
|
|
class EmbeddingLinear(nn.Module):
|
|
"""Lightweight linear layer for lm_head without redundant dtype casting."""
|
|
def __init__(self, in_features, out_features, bias=False, device=None, dtype=None):
|
|
super().__init__()
|
|
assert not bias
|
|
self.in_features = in_features
|
|
self.out_features = out_features
|
|
self.weight = nn.Parameter(torch.empty(out_features, in_features, device=device, dtype=dtype))
|
|
def forward(self, x):
|
|
return F.linear(x, self.weight)
|
|
|
|
|
|
def has_ve(layer_idx, n_layer):
|
|
"""Returns True if GPT layer should have Value Embedding (every 3rd layer, last layer always included)."""
|
|
return layer_idx % 3 == (n_layer - 1) % 3
|
|
|
|
def apply_rotary_emb(x, cos, sin):
|
|
assert x.ndim == 4 # multihead attention
|
|
d = x.shape[3] // 2
|
|
x1, x2 = x[..., :d], x[..., d:] # split up last dim into two halves
|
|
y1 = x1 * cos + x2 * sin # rotate pairs of dims
|
|
y2 = x1 * (-sin) + x2 * cos
|
|
return torch.cat([y1, y2], 3)
|
|
|
|
class CausalSelfAttention(nn.Module):
|
|
def __init__(self, config, layer_idx):
|
|
super().__init__()
|
|
self.layer_idx = layer_idx
|
|
self.n_head = config.n_head
|
|
self.n_kv_head = config.n_kv_head
|
|
self.n_embd = config.n_embd
|
|
self.head_dim = self.n_embd // self.n_head
|
|
assert self.n_embd % self.n_head == 0
|
|
assert self.n_kv_head <= self.n_head and self.n_head % self.n_kv_head == 0
|
|
self.c_q = Linear(self.n_embd, self.n_head * self.head_dim, bias=False)
|
|
self.c_k = Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False)
|
|
self.c_v = Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False)
|
|
self.c_proj = Linear(self.n_embd, self.n_embd, bias=False)
|
|
self.ve_gate_channels = 12
|
|
self.ve_gate = Linear(self.ve_gate_channels, self.n_kv_head, bias=False) if has_ve(layer_idx, config.n_layer) else None
|
|
|
|
def forward(self, x, ve, cos_sin, window_size, kv_cache):
|
|
B, T, C = x.size()
|
|
|
|
# Project the input to get queries, keys, and values
|
|
# Shape: (B, T, H, D) - FA3's native layout, no transpose needed!
|
|
q = self.c_q(x).view(B, T, self.n_head, self.head_dim)
|
|
k = self.c_k(x).view(B, T, self.n_kv_head, self.head_dim)
|
|
v = self.c_v(x).view(B, T, self.n_kv_head, self.head_dim)
|
|
|
|
# Value residual (ResFormer): mix in value embedding with input-dependent gate per head
|
|
if ve is not None:
|
|
ve = ve.view(B, T, self.n_kv_head, self.head_dim)
|
|
gate = 3 * torch.sigmoid(self.ve_gate(x[..., :self.ve_gate_channels])) # (B, T, n_kv_head), range (0, 3)
|
|
v = v + gate.unsqueeze(-1) * ve
|
|
|
|
# Apply Rotary Embeddings to queries and keys to get relative positional encoding
|
|
cos, sin = cos_sin
|
|
q, k = apply_rotary_emb(q, cos, sin), apply_rotary_emb(k, cos, sin)
|
|
q, k = norm(q), norm(k) # QK norm
|
|
q = q * 1.2 # sharper attention (split scale between Q and K), TODO think through better
|
|
k = k * 1.2
|
|
|
|
# Flash Attention (FA3 on Hopper+, PyTorch SDPA fallback elsewhere)
|
|
# window_size is (left, right) tuple: (N, 0) for causal, (-1, 0) for full context
|
|
if kv_cache is None:
|
|
# Training: causal attention with optional sliding window
|
|
y = flash_attn.flash_attn_func(q, k, v, causal=True, window_size=window_size)
|
|
else:
|
|
# Inference: use flash_attn_with_kvcache which handles cache management
|
|
k_cache, v_cache = kv_cache.get_layer_cache(self.layer_idx)
|
|
y = flash_attn.flash_attn_with_kvcache(
|
|
q, k_cache, v_cache,
|
|
k=k, v=v,
|
|
cache_seqlens=kv_cache.cache_seqlens,
|
|
causal=True,
|
|
window_size=window_size,
|
|
)
|
|
# Advance position after last layer processes
|
|
if self.layer_idx == kv_cache.n_layers - 1:
|
|
kv_cache.advance(T)
|
|
|
|
# Re-assemble the heads and project back to residual stream
|
|
y = y.contiguous().view(B, T, -1)
|
|
y = self.c_proj(y)
|
|
return y
|
|
|
|
|
|
class MLP(nn.Module):
|
|
def __init__(self, config):
|
|
super().__init__()
|
|
self.c_fc = Linear(config.n_embd, 4 * config.n_embd, bias=False)
|
|
self.c_proj = Linear(4 * config.n_embd, config.n_embd, bias=False)
|
|
|
|
def forward(self, x):
|
|
x = self.c_fc(x)
|
|
x = F.relu(x).square()
|
|
x = self.c_proj(x)
|
|
return x
|
|
|
|
|
|
class Block(nn.Module):
|
|
def __init__(self, config, layer_idx):
|
|
super().__init__()
|
|
self.attn = CausalSelfAttention(config, layer_idx)
|
|
self.mlp = MLP(config)
|
|
|
|
def forward(self, x, ve, cos_sin, window_size, kv_cache):
|
|
x = x + self.attn(norm(x), ve, cos_sin, window_size, kv_cache)
|
|
x = x + self.mlp(norm(x))
|
|
return x
|
|
|
|
|
|
class GPT(nn.Module):
|
|
def __init__(self, config, pad_vocab_size_to=64):
|
|
"""
|
|
NOTE a major footgun: this __init__ function runs in meta device context (!!)
|
|
Therefore, any calculations inside here are shapes and dtypes only, no actual data.
|
|
=> We actually initialize all data (parameters, buffers, etc.) in init_weights() instead.
|
|
"""
|
|
super().__init__()
|
|
self.config = config
|
|
# Compute per-layer window sizes for sliding window attention
|
|
# window_size is (left, right) tuple: (-1, 0) for full context, (N, 0) for sliding window
|
|
self.window_sizes = self._compute_window_sizes(config)
|
|
# Pad vocab for efficiency (DDP, tensor cores). This is just an optimization - outputs are cropped in forward().
|
|
# https://huggingface.co/docs/transformers/main_classes/model#transformers.PreTrainedModel.resize_token_embeddings
|
|
padded_vocab_size = ((config.vocab_size + pad_vocab_size_to - 1) // pad_vocab_size_to) * pad_vocab_size_to
|
|
if padded_vocab_size != config.vocab_size:
|
|
print0(f"Padding vocab_size from {config.vocab_size} to {padded_vocab_size} for efficiency")
|
|
self.transformer = nn.ModuleDict({
|
|
"wte": nn.Embedding(padded_vocab_size, config.n_embd),
|
|
"h": nn.ModuleList([Block(config, layer_idx) for layer_idx in range(config.n_layer)]),
|
|
})
|
|
self.lm_head = EmbeddingLinear(config.n_embd, padded_vocab_size, bias=False)
|
|
# Per-layer learnable scalars (inspired by modded-nanogpt)
|
|
# resid_lambdas: scales the residual stream at each layer (init 1.0 = neutral)
|
|
# x0_lambdas: blends initial embedding back in at each layer (init 0.0 = disabled)
|
|
# Separate parameters so they can have different optimizer treatment
|
|
self.resid_lambdas = nn.Parameter(torch.ones(config.n_layer)) # fake init, real init in init_weights()
|
|
self.x0_lambdas = nn.Parameter(torch.zeros(config.n_layer)) # fake init, real init in init_weights()
|
|
# Backout: subtract cached mid-layer residual before final norm to remove low-level features
|
|
self.backout_lambda = nn.Parameter(0.2 * torch.ones(1))
|
|
# Value embeddings (ResFormer-style): every 3rd layer, last layer always included
|
|
head_dim = config.n_embd // config.n_head
|
|
kv_dim = config.n_kv_head * head_dim
|
|
self.value_embeds = nn.ModuleDict({str(i): nn.Embedding(padded_vocab_size, kv_dim) for i in range(config.n_layer) if has_ve(i, config.n_layer)})
|
|
# To support meta device initialization, we init the rotary embeddings here, but it's just "fake" meta tensors only.
|
|
# As for rotary_seq_len, these rotary embeddings are pretty small/cheap in memory,
|
|
# so let's just over-compute them by 10X, but assert fail if we ever reach that amount.
|
|
# In the future we can dynamically grow the cache, for now it's fine.
|
|
self.rotary_seq_len = config.sequence_len * 10 # 10X over-compute should be enough, TODO make nicer?
|
|
head_dim = config.n_embd // config.n_head
|
|
cos, sin = self._precompute_rotary_embeddings(self.rotary_seq_len, head_dim)
|
|
self.register_buffer("cos", cos, persistent=False) # persistent=False means it's not saved to the checkpoint
|
|
self.register_buffer("sin", sin, persistent=False)
|
|
|
|
@torch.no_grad()
|
|
def init_weights(self):
|
|
"""
|
|
Initialize the full model in this one function for maximum clarity.
|
|
|
|
wte (embedding): normal, std=1.0
|
|
lm_head: normal, std=0.001
|
|
for each block:
|
|
attn.c_q: uniform, std=1/sqrt(n_embd)
|
|
attn.c_k: uniform, std=1/sqrt(n_embd)
|
|
attn.c_v: uniform, std=1/sqrt(n_embd)
|
|
attn.c_proj: zeros
|
|
mlp.c_fc: uniform, std=1/sqrt(n_embd)
|
|
mlp.c_proj: zeros
|
|
"""
|
|
|
|
# Embedding and unembedding
|
|
torch.nn.init.normal_(self.transformer.wte.weight, mean=0.0, std=0.8)
|
|
torch.nn.init.normal_(self.lm_head.weight, mean=0.0, std=0.001)
|
|
|
|
# Transformer blocks: uniform init with bound = sqrt(3) * std (same standard deviation as normal)
|
|
n_embd = self.config.n_embd
|
|
s = 3**0.5 * n_embd**-0.5 # sqrt(3) multiplier makes sure Uniform achieves the same std as Normal
|
|
for block in self.transformer.h:
|
|
torch.nn.init.uniform_(block.attn.c_q.weight, -s, s) # weights use Uniform to avoid outliers
|
|
torch.nn.init.uniform_(block.attn.c_k.weight, -s, s)
|
|
torch.nn.init.uniform_(block.attn.c_v.weight, -0.85 * s, 0.85 * s)
|
|
torch.nn.init.uniform_(block.attn.c_proj.weight, -0.008, 0.008) # small nonzero init
|
|
torch.nn.init.uniform_(block.mlp.c_fc.weight, -s * 0.4, s * 0.4) # 0.4x init scale for c_fc
|
|
torch.nn.init.zeros_(block.mlp.c_proj.weight)
|
|
|
|
# Per-layer scalars
|
|
# Per-layer resid init: exponential decay, stronger at early layers
|
|
import math
|
|
n_layer = self.config.n_layer
|
|
resid_start, resid_end = 1.18, 1.06
|
|
resid_decay = math.log(resid_start / resid_end) / max(n_layer - 1, 1)
|
|
for i in range(n_layer):
|
|
self.resid_lambdas.data[i] = resid_start * math.exp(-resid_decay * i)
|
|
# x0 init: first-half only, linearly decaying, zero for deep layers
|
|
half_depth = max(1, n_layer // 2)
|
|
for i in range(n_layer):
|
|
if i < half_depth:
|
|
frac = i / max(half_depth - 1, 1)
|
|
self.x0_lambdas.data[i] = 0.24 * (1.0 - frac) + 0.08 * frac
|
|
else:
|
|
self.x0_lambdas.data[i] = 0.0
|
|
|
|
# Value embeddings (init like c_v: uniform with same std)
|
|
for ve in self.value_embeds.values():
|
|
torch.nn.init.uniform_(ve.weight, -s, s)
|
|
|
|
# Gate weights init with small positive values so gates start slightly above neutral
|
|
for block in self.transformer.h:
|
|
if block.attn.ve_gate is not None:
|
|
torch.nn.init.uniform_(block.attn.ve_gate.weight, 0.0, 0.02)
|
|
|
|
# Rotary embeddings
|
|
head_dim = self.config.n_embd // self.config.n_head
|
|
cos, sin = self._precompute_rotary_embeddings(self.rotary_seq_len, head_dim)
|
|
self.cos, self.sin = cos, sin
|
|
|
|
# Cast embeddings to COMPUTE_DTYPE: optimizer can tolerate reduced-precision
|
|
# embeddings and it saves memory. Exception: fp16 requires fp32 embeddings
|
|
# because GradScaler cannot unscale fp16 gradients.
|
|
if COMPUTE_DTYPE != torch.float16:
|
|
self.transformer.wte.to(dtype=COMPUTE_DTYPE)
|
|
self.lm_head.to(dtype=COMPUTE_DTYPE)
|
|
for ve in self.value_embeds.values():
|
|
ve.to(dtype=COMPUTE_DTYPE)
|
|
|
|
def _precompute_rotary_embeddings(self, seq_len, head_dim, base=200000, device=None):
|
|
# TODO: bump base theta more? e.g. 100K is more common more recently
|
|
# autodetect the device from model embeddings
|
|
if device is None:
|
|
device = self.transformer.wte.weight.device
|
|
# stride the channels
|
|
channel_range = torch.arange(0, head_dim, 2, dtype=torch.float32, device=device)
|
|
inv_freq = 1.0 / (base ** (channel_range / head_dim))
|
|
# stride the time steps
|
|
t = torch.arange(seq_len, dtype=torch.float32, device=device)
|
|
# calculate the rotation frequencies at each (time, channel) pair
|
|
freqs = torch.outer(t, inv_freq)
|
|
cos, sin = freqs.cos(), freqs.sin()
|
|
cos, sin = cos.to(COMPUTE_DTYPE), sin.to(COMPUTE_DTYPE)
|
|
cos, sin = cos[None, :, None, :], sin[None, :, None, :] # add batch and head dims for later broadcasting
|
|
return cos, sin
|
|
|
|
def _compute_window_sizes(self, config):
|
|
"""
|
|
Compute per-layer window sizes for sliding window attention.
|
|
|
|
Returns list of (left, right) tuples for FA3's window_size parameter:
|
|
- left: how many tokens before current position to attend to (-1 = unlimited)
|
|
- right: how many tokens after current position to attend to (0 for causal)
|
|
|
|
Pattern string is tiled across layers. Final layer always gets L (full context).
|
|
Characters: L=long (full context), S=short (quarter context)
|
|
"""
|
|
pattern = config.window_pattern.upper()
|
|
assert all(c in "SL" for c in pattern), f"Invalid window_pattern: {pattern}. Use only S and L."
|
|
# Map characters to window sizes
|
|
long_window = config.sequence_len
|
|
short_window = max(256, -(-long_window // 8 // 128) * 128) # ceil to FA3 tile size (2048 -> 256)
|
|
char_to_window = {
|
|
"L": (long_window, 0),
|
|
"S": (short_window, 0),
|
|
}
|
|
# Tile pattern across layers
|
|
window_sizes = []
|
|
for layer_idx in range(config.n_layer):
|
|
char = pattern[layer_idx % len(pattern)]
|
|
window_sizes.append(char_to_window[char])
|
|
# Final layer always gets full context
|
|
window_sizes[-1] = (long_window, 0)
|
|
return window_sizes
|
|
|
|
def get_device(self):
|
|
return self.transformer.wte.weight.device
|
|
|
|
def estimate_flops(self):
|
|
"""
|
|
Return the estimated FLOPs per token for the model (forward + backward).
|
|
Each matmul weight parameter contributes 2 FLOPs (multiply *, accumulate +) in forward, and 2X that in backward => 2+4=6.
|
|
Cleanest explanation of this: https://medium.com/@dzmitrybahdanau/the-flops-calculus-of-language-model-training-3b19c1f025e4
|
|
On top of that, 12 * h * q * effective_seq_len accounts for key @ query matmul flops inside attention.
|
|
With sliding windows, effective_seq_len varies per layer (capped by window size).
|
|
Ref: https://arxiv.org/abs/2204.02311 (PaLM paper).
|
|
This is ~1% off from the exact formulas of Chinchilla paper, the difference is:
|
|
- Chinchilla counts the embedding layer as flops (? weird, it's just a lookup => we ignore)
|
|
- Chinchilla counts exp/sum/divide in attention softmax as flops (a little sus and very tiny => we ignore)
|
|
"""
|
|
nparams = sum(p.numel() for p in self.parameters())
|
|
# Exclude non-matmul params: embeddings and per-layer scalars
|
|
value_embeds_numel = sum(ve.weight.numel() for ve in self.value_embeds.values())
|
|
nparams_exclude = (self.transformer.wte.weight.numel() + value_embeds_numel +
|
|
self.resid_lambdas.numel() + self.x0_lambdas.numel() +
|
|
self.backout_lambda.numel())
|
|
h, q, t = self.config.n_head, self.config.n_embd // self.config.n_head, self.config.sequence_len
|
|
# Sum attention FLOPs per layer, accounting for sliding window
|
|
attn_flops = 0
|
|
for window_size in self.window_sizes:
|
|
window = window_size[0] # (left, right) tuple, we use left
|
|
effective_seq = t if window < 0 else min(window, t)
|
|
attn_flops += 12 * h * q * effective_seq
|
|
num_flops_per_token = 6 * (nparams - nparams_exclude) + attn_flops
|
|
return num_flops_per_token
|
|
|
|
def num_scaling_params(self):
|
|
"""
|
|
Return detailed parameter counts for scaling law analysis.
|
|
Different papers use different conventions:
|
|
- Kaplan et al. excluded embedding parameters
|
|
- Chinchilla included all parameters
|
|
Ref: https://arxiv.org/abs/2203.15556 (Chinchilla paper)
|
|
Ref: https://arxiv.org/abs/2001.08361 (Kaplan et al. original scaling laws paper)
|
|
|
|
Returns a dict with counts for each parameter group, so downstream analysis
|
|
can experiment with which combination gives the cleanest scaling laws.
|
|
"""
|
|
# Count each group separately (mirrors the grouping in setup_optimizers)
|
|
wte = sum(p.numel() for p in self.transformer.wte.parameters())
|
|
value_embeds = sum(p.numel() for p in self.value_embeds.parameters())
|
|
lm_head = sum(p.numel() for p in self.lm_head.parameters())
|
|
transformer_matrices = sum(p.numel() for p in self.transformer.h.parameters())
|
|
scalars = self.resid_lambdas.numel() + self.x0_lambdas.numel() + self.backout_lambda.numel()
|
|
total = wte + value_embeds + lm_head + transformer_matrices + scalars
|
|
assert total == sum(p.numel() for p in self.parameters()), "Parameter count mismatch"
|
|
return {
|
|
'wte': wte,
|
|
'value_embeds': value_embeds,
|
|
'lm_head': lm_head,
|
|
'transformer_matrices': transformer_matrices,
|
|
'scalars': scalars,
|
|
'total': total,
|
|
}
|
|
|
|
def setup_optimizer(self, unembedding_lr=0.004, embedding_lr=0.2, matrix_lr=0.02, weight_decay=0.0, scalar_lr=0.5):
|
|
model_dim = self.config.n_embd
|
|
ddp, rank, local_rank, world_size = get_dist_info()
|
|
|
|
# Separate out all parameters into groups
|
|
matrix_params = list(self.transformer.h.parameters())
|
|
value_embeds_params = list(self.value_embeds.parameters())
|
|
embedding_params = list(self.transformer.wte.parameters())
|
|
lm_head_params = list(self.lm_head.parameters())
|
|
resid_params = [self.resid_lambdas]
|
|
x0_params = [self.x0_lambdas]
|
|
backout_params = [self.backout_lambda]
|
|
assert len(list(self.parameters())) == len(matrix_params) + len(embedding_params) + len(lm_head_params) + len(value_embeds_params) + len(resid_params) + len(x0_params) + len(backout_params)
|
|
|
|
# Scale the LR for the AdamW parameters by ∝1/√dmodel (tuned for 768 dim model)
|
|
dmodel_lr_scale = (model_dim / 768) ** -0.5
|
|
print0(f"Scaling the LR for the AdamW parameters ∝1/√({model_dim}/768) = {dmodel_lr_scale:.6f}")
|
|
|
|
# Build param_groups with all required fields explicit
|
|
param_groups = [
|
|
# AdamW groups (embeddings, lm_head, scalars)
|
|
dict(kind='adamw', params=lm_head_params, lr=unembedding_lr * dmodel_lr_scale, betas=(0.8, 0.96), eps=1e-10, weight_decay=0.01),
|
|
dict(kind='adamw', params=embedding_params, lr=embedding_lr * dmodel_lr_scale, betas=(0.8, 0.995), eps=1e-10, weight_decay=0.001),
|
|
dict(kind='adamw', params=value_embeds_params, lr=embedding_lr * dmodel_lr_scale * 0.5, betas=(0.8, 0.995), eps=1e-10, weight_decay=0.01),
|
|
dict(kind='adamw', params=resid_params, lr=scalar_lr * 0.01, betas=(0.8, 0.95), eps=1e-10, weight_decay=0.05),
|
|
dict(kind='adamw', params=x0_params, lr=scalar_lr, betas=(0.96, 0.95), eps=1e-10, weight_decay=0.0), # higher beta1 for x0
|
|
dict(kind='adamw', params=backout_params, lr=0.15, betas=(0.8, 0.95), eps=1e-10, weight_decay=0.0),
|
|
]
|
|
# Muon groups (matrix params, grouped by shape for stacking)
|
|
for shape in sorted({p.shape for p in matrix_params}):
|
|
group_params = [p for p in matrix_params if p.shape == shape]
|
|
param_groups.append(dict(
|
|
kind='muon', params=group_params, lr=matrix_lr,
|
|
momentum=0.95, ns_steps=5, beta2=0.9, weight_decay=weight_decay,
|
|
))
|
|
|
|
Factory = DistMuonAdamW if ddp else MuonAdamW
|
|
optimizer = Factory(param_groups)
|
|
for group in optimizer.param_groups:
|
|
group["initial_lr"] = group["lr"]
|
|
return optimizer
|
|
|
|
def forward(self, idx, targets=None, kv_cache=None, loss_reduction='mean'):
|
|
B, T = idx.size()
|
|
|
|
# Grab the rotary embeddings for the current sequence length (they are of shape (1, seq_len, 1, head_dim/2))
|
|
assert T <= self.cos.size(1), f"Sequence length grew beyond the rotary embeddings cache: {T} > {self.cos.size(1)}"
|
|
assert idx.device == self.cos.device, f"Rotary embeddings and idx are on different devices: {idx.device} != {self.cos.device}"
|
|
assert self.cos.dtype == COMPUTE_DTYPE, f"Rotary embeddings must be in {COMPUTE_DTYPE}, got {self.cos.dtype}"
|
|
# if kv cache exists, we need to offset the rotary embeddings to the current position in the cache
|
|
T0 = 0 if kv_cache is None else kv_cache.get_pos()
|
|
cos_sin = self.cos[:, T0:T0+T], self.sin[:, T0:T0+T] # truncate cache to current sequence length
|
|
|
|
# Embed the tokens
|
|
x = self.transformer.wte(idx) # embed current token
|
|
x = x.to(COMPUTE_DTYPE) # ensure activations are in compute dtype (no-op usually, but active for fp16 code path)
|
|
x = norm(x)
|
|
|
|
# Forward the trunk of the Transformer
|
|
x0 = x # save initial normalized embedding for x0 residual
|
|
n_layer = self.config.n_layer
|
|
backout_layer = n_layer // 2 # cache at halfway point
|
|
x_backout = None
|
|
for i, block in enumerate(self.transformer.h):
|
|
x = self.resid_lambdas[i] * x + self.x0_lambdas[i] * x0
|
|
ve = self.value_embeds[str(i)](idx).to(x.dtype) if str(i) in self.value_embeds else None
|
|
x = block(x, ve, cos_sin, self.window_sizes[i], kv_cache)
|
|
if i == backout_layer:
|
|
x_backout = x
|
|
# Subtract mid-layer residual to remove low-level features before logit projection
|
|
if x_backout is not None:
|
|
x = x - self.backout_lambda.to(x.dtype) * x_backout
|
|
x = norm(x)
|
|
|
|
# Forward the lm_head (compute logits)
|
|
softcap = 15 # smoothly cap the logits to the range [-softcap, softcap]
|
|
logits = self.lm_head(x) # (B, T, padded_vocab_size) <- very big tensor, large amount of memory
|
|
logits = logits[..., :self.config.vocab_size] # slice to remove padding
|
|
logits = logits.float() # switch to fp32 for logit softcap and loss computation
|
|
logits = softcap * torch.tanh(logits / softcap) # squash the logits
|
|
|
|
if targets is not None:
|
|
# training: given the targets, compute and return the loss
|
|
# TODO experiment with chunked cross-entropy?
|
|
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1, reduction=loss_reduction)
|
|
return loss
|
|
else:
|
|
# inference: just return the logits directly
|
|
return logits
|
|
|
|
@torch.inference_mode()
|
|
def generate(self, tokens, max_tokens, temperature=1.0, top_k=None, seed=42):
|
|
"""
|
|
Naive autoregressive streaming inference.
|
|
To make it super simple, let's assume:
|
|
- batch size is 1
|
|
- ids and the yielded tokens are simple Python lists and ints
|
|
"""
|
|
assert isinstance(tokens, list)
|
|
device = self.get_device()
|
|
rng = None
|
|
if temperature > 0:
|
|
rng = torch.Generator(device=device)
|
|
rng.manual_seed(seed)
|
|
ids = torch.tensor([tokens], dtype=torch.long, device=device) # add batch dim
|
|
for _ in range(max_tokens):
|
|
logits = self.forward(ids) # (B, T, vocab_size)
|
|
logits = logits[:, -1, :] # (B, vocab_size)
|
|
if top_k is not None and top_k > 0:
|
|
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
|
|
logits[logits < v[:, [-1]]] = -float('Inf')
|
|
if temperature > 0:
|
|
logits = logits / temperature
|
|
probs = F.softmax(logits, dim=-1)
|
|
next_ids = torch.multinomial(probs, num_samples=1, generator=rng)
|
|
else:
|
|
next_ids = torch.argmax(logits, dim=-1, keepdim=True)
|
|
ids = torch.cat((ids, next_ids), dim=1)
|
|
token = next_ids.item()
|
|
yield token
|