This commit is contained in:
Jesse Clark 2026-04-28 04:14:37 +00:00 committed by GitHub
commit cde6560bee
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 286 additions and 79 deletions

View File

@ -102,11 +102,14 @@ class KVCache:
self.cache_seqlens = torch.zeros(batch_size, dtype=torch.int32, device=device) self.cache_seqlens = torch.zeros(batch_size, dtype=torch.int32, device=device)
# Previous token's normalized embedding for smear (set by model forward pass) # Previous token's normalized embedding for smear (set by model forward pass)
self.prev_embedding = None self.prev_embedding = None
# Previous token ids for bigram hash features during decoding
self.prev_token_ids = None
def reset(self): def reset(self):
"""Reset cache to empty state.""" """Reset cache to empty state."""
self.cache_seqlens.zero_() self.cache_seqlens.zero_()
self.prev_embedding = None self.prev_embedding = None
self.prev_token_ids = None
def get_pos(self): def get_pos(self):
"""Get current position (assumes all batch elements at same position).""" """Get current position (assumes all batch elements at same position)."""
@ -135,6 +138,8 @@ class KVCache:
# Copy smear state: expand batch=1 prev_embedding to num_samples # Copy smear state: expand batch=1 prev_embedding to num_samples
if other.prev_embedding is not None: if other.prev_embedding is not None:
self.prev_embedding = other.prev_embedding.expand(self.batch_size, -1, -1).clone() self.prev_embedding = other.prev_embedding.expand(self.batch_size, -1, -1).clone()
if other.prev_token_ids is not None:
self.prev_token_ids = other.prev_token_ids.expand(self.batch_size).clone()
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@torch.inference_mode() @torch.inference_mode()

View File

@ -37,6 +37,37 @@ class GPTConfig:
# Characters: L=long (full context), S=short (quarter context) # Characters: L=long (full context), S=short (quarter context)
# Examples: "L"=all full context, "SL"=alternating, "SSL"=two short then one long # Examples: "L"=all full context, "SL"=alternating, "SSL"=two short then one long
window_pattern: str = "SSSL" window_pattern: str = "SSSL"
# Sparse funnel attention: local layers use a short window, global layers use full context.
use_sparse_funnel: bool = True
n_global: int = 0 # number of full-context (global) layers, 0 = auto depth//4
chirp_gamma: float = 0.7 # chirp exponent for global placement (deeper-biased)
local_window: int = 128 # local attention window size
global_layer_override: tuple = () # explicit global layer indices (overrides chirp)
rope_base: int = 200000 # RoPE base frequency
use_smear: bool = True # cheap bigram-like embedding mixing
smear_channels: int = 24 # number of early embedding channels used to predict smear gate
bigram_vocab_size: int = 4096 # hashed bigram embedding buckets (0 disables)
bigram_dim: int = 128 # bigram embedding dim before projection
use_gated_attn: bool = True # learn a lightweight per-head gate on attention output
attn_gate_channels: int = 12 # channels used to predict attention gate
ve_layers: tuple = () # explicit layer indices for value residuals (empty = auto global layers under sparse funnel)
def default_sparse_n_global(n_layer):
return max(1, n_layer // 4)
def compute_sparse_global_layers(n_layer, n_global, chirp_gamma, global_layer_override=()):
if global_layer_override:
global_layers = set(global_layer_override)
else:
G = n_global if n_global > 0 else default_sparse_n_global(n_layer)
global_layers = set()
for i in range(1, G + 1):
idx = max(0, min(n_layer - 1, int(n_layer * (i / G) ** chirp_gamma) - 1))
global_layers.add(idx)
global_layers.add(n_layer - 1)
return tuple(sorted(global_layers))
def norm(x): def norm(x):
@ -47,11 +78,58 @@ class Linear(nn.Linear):
Replaces autocast: master weights stay fp32 for optimizer precision, Replaces autocast: master weights stay fp32 for optimizer precision,
but matmuls run in the activation dtype (typically bf16 from embeddings).""" but matmuls run in the activation dtype (typically bf16 from embeddings)."""
def forward(self, x): def forward(self, x):
return F.linear(x, self.weight.to(dtype=x.dtype)) w = self.weight
if w.dtype != x.dtype:
w = w.to(dtype=x.dtype)
return F.linear(x, w)
def has_ve(layer_idx, n_layer): class EmbeddingLinear(nn.Module):
"""Returns True if GPT layer should have Value Embedding (alternating, last layer always included).""" """Lightweight linear layer for lm_head without redundant dtype casting."""
def __init__(self, in_features, out_features, bias=False, device=None, dtype=None):
super().__init__()
assert not bias
self.in_features = in_features
self.out_features = out_features
self.weight = nn.Parameter(torch.empty(out_features, in_features, device=device, dtype=dtype))
def forward(self, x):
return F.linear(x, self.weight)
class BigramHashEmbedding(nn.Module):
"""Hash causal token pairs into a compact learned embedding."""
def __init__(self, bigram_vocab_size, bigram_dim, model_dim):
super().__init__()
self.bigram_vocab_size = bigram_vocab_size
self.embed = nn.Embedding(bigram_vocab_size, bigram_dim)
self.proj = Linear(bigram_dim, model_dim, bias=False) if bigram_dim != model_dim else None
self.scale = nn.Parameter(torch.tensor(0.05, dtype=torch.float32))
def bigram_hash(self, tokens, prev_tokens=None):
t = tokens.to(torch.int32)
mod = self.bigram_vocab_size - 1
out = torch.full_like(t, mod)
if prev_tokens is None:
if t.size(1) > 1:
out[..., 1:] = torch.bitwise_xor(36313 * t[..., 1:], 27191 * t[..., :-1]) % mod
else:
prev = prev_tokens.to(torch.int32).view(t.size(0), 1)
out[..., 0] = torch.bitwise_xor(36313 * t[..., 0], 27191 * prev[..., 0]) % mod
if t.size(1) > 1:
out[..., 1:] = torch.bitwise_xor(36313 * t[..., 1:], 27191 * t[..., :-1]) % mod
return out.long()
def forward(self, token_ids, prev_tokens=None):
h = self.embed(self.bigram_hash(token_ids, prev_tokens=prev_tokens))
if self.proj is not None:
h = self.proj(h)
return h * self.scale.to(dtype=h.dtype)
def has_ve(layer_idx, n_layer, layer_set=()):
"""Returns True if GPT layer should have Value Embedding."""
if layer_set:
return layer_idx in layer_set
return layer_idx % 2 == (n_layer - 1) % 2 return layer_idx % 2 == (n_layer - 1) % 2
def apply_rotary_emb(x, cos, sin): def apply_rotary_emb(x, cos, sin):
@ -77,7 +155,13 @@ class CausalSelfAttention(nn.Module):
self.c_v = Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False) self.c_v = Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False)
self.c_proj = Linear(self.n_embd, self.n_embd, bias=False) self.c_proj = Linear(self.n_embd, self.n_embd, bias=False)
self.ve_gate_channels = 12 self.ve_gate_channels = 12
self.ve_gate = Linear(self.ve_gate_channels, self.n_kv_head, bias=False) if has_ve(layer_idx, config.n_layer) else None self.ve_gate = Linear(
self.ve_gate_channels,
self.n_kv_head,
bias=False,
) if has_ve(layer_idx, config.n_layer, layer_set=config.ve_layers) else None
self.attn_gate_channels = config.attn_gate_channels
self.attn_gate = Linear(self.attn_gate_channels, self.n_head, bias=False) if config.use_gated_attn else None
def forward(self, x, ve, cos_sin, window_size, kv_cache): def forward(self, x, ve, cos_sin, window_size, kv_cache):
B, T, C = x.size() B, T, C = x.size()
@ -91,7 +175,7 @@ class CausalSelfAttention(nn.Module):
# Value residual (ResFormer): mix in value embedding with input-dependent gate per head # Value residual (ResFormer): mix in value embedding with input-dependent gate per head
if ve is not None: if ve is not None:
ve = ve.view(B, T, self.n_kv_head, self.head_dim) ve = ve.view(B, T, self.n_kv_head, self.head_dim)
gate = 3 * torch.sigmoid(self.ve_gate(x[..., :self.ve_gate_channels])) # (B, T, n_kv_head), range (0, 3) gate = 3 * torch.sigmoid(self.ve_gate(x[..., :self.ve_gate_channels]))
v = v + gate.unsqueeze(-1) * ve v = v + gate.unsqueeze(-1) * ve
# Apply Rotary Embeddings to queries and keys to get relative positional encoding # Apply Rotary Embeddings to queries and keys to get relative positional encoding
@ -120,6 +204,10 @@ class CausalSelfAttention(nn.Module):
if self.layer_idx == kv_cache.n_layers - 1: if self.layer_idx == kv_cache.n_layers - 1:
kv_cache.advance(T) kv_cache.advance(T)
if self.attn_gate is not None:
gate = 2 * torch.sigmoid(self.attn_gate(x[..., :self.attn_gate_channels]))
y = y * gate.unsqueeze(-1).to(dtype=y.dtype)
# Re-assemble the heads and project back to residual stream # Re-assemble the heads and project back to residual stream
y = y.contiguous().view(B, T, -1) y = y.contiguous().view(B, T, -1)
y = self.c_proj(y) y = self.c_proj(y)
@ -160,11 +248,16 @@ class GPT(nn.Module):
""" """
super().__init__() super().__init__()
self.config = config self.config = config
if config.use_sparse_funnel and config.n_global <= 0:
config.n_global = default_sparse_n_global(config.n_layer)
if config.use_sparse_funnel and not config.ve_layers:
config.ve_layers = compute_sparse_global_layers(
config.n_layer, config.n_global, config.chirp_gamma, config.global_layer_override
)
# Compute per-layer window sizes for sliding window attention # Compute per-layer window sizes for sliding window attention
# window_size is (left, right) tuple: (-1, 0) for full context, (N, 0) for sliding window # window_size is (left, right) tuple: (-1, 0) for full context, (N, 0) for sliding window
self.window_sizes = self._compute_window_sizes(config) self.window_sizes = self._compute_window_sizes(config)
# Pad vocab for efficiency (DDP, tensor cores). This is just an optimization - outputs are cropped in forward(). # Pad vocab for efficiency (DDP, tensor cores). This is just an optimization - outputs are cropped in forward().
# https://huggingface.co/docs/transformers/main_classes/model#transformers.PreTrainedModel.resize_token_embeddings
padded_vocab_size = ((config.vocab_size + pad_vocab_size_to - 1) // pad_vocab_size_to) * pad_vocab_size_to padded_vocab_size = ((config.vocab_size + pad_vocab_size_to - 1) // pad_vocab_size_to) * pad_vocab_size_to
if padded_vocab_size != config.vocab_size: if padded_vocab_size != config.vocab_size:
print0(f"Padding vocab_size from {config.vocab_size} to {padded_vocab_size} for efficiency") print0(f"Padding vocab_size from {config.vocab_size} to {padded_vocab_size} for efficiency")
@ -172,22 +265,31 @@ class GPT(nn.Module):
"wte": nn.Embedding(padded_vocab_size, config.n_embd), "wte": nn.Embedding(padded_vocab_size, config.n_embd),
"h": nn.ModuleList([Block(config, layer_idx) for layer_idx in range(config.n_layer)]), "h": nn.ModuleList([Block(config, layer_idx) for layer_idx in range(config.n_layer)]),
}) })
self.lm_head = Linear(config.n_embd, padded_vocab_size, bias=False) self.bigram = BigramHashEmbedding(config.bigram_vocab_size, config.bigram_dim, config.n_embd) if config.bigram_vocab_size > 0 else None
self.lm_head = EmbeddingLinear(config.n_embd, padded_vocab_size, bias=False)
# Per-layer learnable scalars (inspired by modded-nanogpt) # Per-layer learnable scalars (inspired by modded-nanogpt)
# resid_lambdas: scales the residual stream at each layer (init 1.0 = neutral) # resid_lambdas: scales the residual stream at each layer (init 1.0 = neutral)
# x0_lambdas: blends initial embedding back in at each layer (init 0.0 = disabled) # x0_lambdas: blends initial embedding back in at each layer (init 0.0 = disabled)
# Separate parameters so they can have different optimizer treatment # Separate parameters so they can have different optimizer treatment
self.resid_lambdas = nn.Parameter(torch.ones(config.n_layer)) # fake init, real init in init_weights() self.resid_lambdas = nn.Parameter(torch.ones(config.n_layer)) # fake init, real init in init_weights()
self.x0_lambdas = nn.Parameter(torch.zeros(config.n_layer)) # fake init, real init in init_weights() self.x0_lambdas = nn.Parameter(torch.zeros(config.n_layer)) # fake init, real init in init_weights()
# Smear: mix previous token's embedding into current token (cheap bigram-like info) # Optional smear: mix previous token's embedding into current token before the trunk
self.smear_gate = Linear(24, 1, bias=False) if config.use_smear:
self.smear_lambda = nn.Parameter(torch.zeros(1)) self.smear_gate = Linear(config.smear_channels, 1, bias=False)
self.smear_lambda = nn.Parameter(torch.zeros(1))
else:
self.smear_gate = None
self.smear_lambda = None
# Backout: subtract cached mid-layer residual before final norm to remove low-level features # Backout: subtract cached mid-layer residual before final norm to remove low-level features
self.backout_lambda = nn.Parameter(0.2 * torch.ones(1)) self.backout_lambda = nn.Parameter(0.2 * torch.ones(1))
# Value embeddings (ResFormer-style): alternating layers, last layer always included # Value embeddings (ResFormer-style): alternating layers, last layer always included
head_dim = config.n_embd // config.n_head head_dim = config.n_embd // config.n_head
kv_dim = config.n_kv_head * head_dim kv_dim = config.n_kv_head * head_dim
self.value_embeds = nn.ModuleDict({str(i): nn.Embedding(padded_vocab_size, kv_dim) for i in range(config.n_layer) if has_ve(i, config.n_layer)}) self.value_embeds = nn.ModuleDict({
str(i): nn.Embedding(padded_vocab_size, kv_dim)
for i in range(config.n_layer)
if has_ve(i, config.n_layer, layer_set=config.ve_layers)
})
# To support meta device initialization, we init the rotary embeddings here, but it's just "fake" meta tensors only. # To support meta device initialization, we init the rotary embeddings here, but it's just "fake" meta tensors only.
# As for rotary_seq_len, these rotary embeddings are pretty small/cheap in memory, # As for rotary_seq_len, these rotary embeddings are pretty small/cheap in memory,
# so let's just over-compute them by 10X, but assert fail if we ever reach that amount. # so let's just over-compute them by 10X, but assert fail if we ever reach that amount.
@ -217,31 +319,41 @@ class GPT(nn.Module):
# Embedding and unembedding # Embedding and unembedding
torch.nn.init.normal_(self.transformer.wte.weight, mean=0.0, std=0.8) torch.nn.init.normal_(self.transformer.wte.weight, mean=0.0, std=0.8)
torch.nn.init.normal_(self.lm_head.weight, mean=0.0, std=0.001) torch.nn.init.normal_(self.lm_head.weight, mean=0.0, std=0.001)
if self.bigram is not None:
torch.nn.init.zeros_(self.bigram.embed.weight)
if self.bigram.proj is not None:
bigram_s = 3**0.5 * self.bigram.embed.embedding_dim ** -0.5
torch.nn.init.uniform_(self.bigram.proj.weight, -bigram_s, bigram_s)
self.bigram.scale.data.fill_(0.05)
if self.smear_gate is not None:
torch.nn.init.zeros_(self.smear_gate.weight)
self.smear_lambda.data.zero_()
self.backout_lambda.data.fill_(0.2)
# Transformer blocks: uniform init with bound = sqrt(3) * std (same standard deviation as normal) # Transformer blocks: uniform init with bound = sqrt(3) * std (same standard deviation as normal)
n_embd = self.config.n_embd n_embd = self.config.n_embd
s = 3**0.5 * n_embd**-0.5 # sqrt(3) multiplier makes sure Uniform achieves the same std as Normal s = 3**0.5 * n_embd**-0.5 # sqrt(3) multiplier makes sure Uniform achieves the same std as Normal
for block in self.transformer.h: for block in self.transformer.h:
torch.nn.init.uniform_(block.attn.c_q.weight, -s, s) # weights use Uniform to avoid outliers torch.nn.init.uniform_(block.attn.c_q.weight, -s, s)
torch.nn.init.uniform_(block.attn.c_k.weight, -s, s) torch.nn.init.uniform_(block.attn.c_k.weight, -s, s)
torch.nn.init.uniform_(block.attn.c_v.weight, -s, s) torch.nn.init.uniform_(block.attn.c_v.weight, -0.85 * s, 0.85 * s)
torch.nn.init.zeros_(block.attn.c_proj.weight) # projections are zero torch.nn.init.uniform_(block.attn.c_proj.weight, -0.008, 0.008)
torch.nn.init.uniform_(block.mlp.c_fc.weight, -s * 0.4, s * 0.4) # 0.4x init scale for c_fc torch.nn.init.uniform_(block.mlp.c_fc.weight, -s * 0.4, s * 0.4)
torch.nn.init.zeros_(block.mlp.c_proj.weight) torch.nn.init.zeros_(block.mlp.c_proj.weight)
# Per-layer scalars # Per-layer scalars
# Per-layer resid init: stronger residual at early layers, weaker at deep layers import math
n_layer = self.config.n_layer n_layer = self.config.n_layer
resid_start, resid_end = 1.18, 1.06
resid_decay = math.log(resid_start / resid_end) / max(n_layer - 1, 1)
half_depth = max(1, n_layer // 2)
for i in range(n_layer): for i in range(n_layer):
self.resid_lambdas.data[i] = 1.15 - (0.10 * i / max(n_layer - 1, 1)) self.resid_lambdas.data[i] = resid_start * math.exp(-resid_decay * i)
# Decaying x0 init: earlier layers get more input embedding blending if i < half_depth:
for i in range(n_layer): frac = i / max(half_depth - 1, 1)
self.x0_lambdas.data[i] = 0.20 - (0.15 * i / max(n_layer - 1, 1)) self.x0_lambdas.data[i] = 0.24 * (1.0 - frac) + 0.08 * frac
else:
# Smear/backout scalars and smear gate must be explicitly initialized self.x0_lambdas.data[i] = 0.0
torch.nn.init.zeros_(self.smear_lambda)
torch.nn.init.constant_(self.backout_lambda, 0.2)
torch.nn.init.uniform_(self.smear_gate.weight, 0.0, 0.02)
# Value embeddings (init like c_v: uniform with same std) # Value embeddings (init like c_v: uniform with same std)
for ve in self.value_embeds.values(): for ve in self.value_embeds.values():
@ -251,6 +363,8 @@ class GPT(nn.Module):
for block in self.transformer.h: for block in self.transformer.h:
if block.attn.ve_gate is not None: if block.attn.ve_gate is not None:
torch.nn.init.uniform_(block.attn.ve_gate.weight, 0.0, 0.02) torch.nn.init.uniform_(block.attn.ve_gate.weight, 0.0, 0.02)
if block.attn.attn_gate is not None:
torch.nn.init.uniform_(block.attn.attn_gate.weight, 0.0, 0.02)
# Rotary embeddings # Rotary embeddings
head_dim = self.config.n_embd // self.config.n_head head_dim = self.config.n_embd // self.config.n_head
@ -262,11 +376,17 @@ class GPT(nn.Module):
# because GradScaler cannot unscale fp16 gradients. # because GradScaler cannot unscale fp16 gradients.
if COMPUTE_DTYPE != torch.float16: if COMPUTE_DTYPE != torch.float16:
self.transformer.wte.to(dtype=COMPUTE_DTYPE) self.transformer.wte.to(dtype=COMPUTE_DTYPE)
self.lm_head.to(dtype=COMPUTE_DTYPE)
for ve in self.value_embeds.values(): for ve in self.value_embeds.values():
ve.to(dtype=COMPUTE_DTYPE) ve.to(dtype=COMPUTE_DTYPE)
if self.bigram is not None:
self.bigram.embed.to(dtype=COMPUTE_DTYPE)
if self.bigram.proj is not None:
self.bigram.proj.to(dtype=COMPUTE_DTYPE)
def _precompute_rotary_embeddings(self, seq_len, head_dim, base=100000, device=None): def _precompute_rotary_embeddings(self, seq_len, head_dim, base=None, device=None):
# TODO: bump base theta more? e.g. 100K is more common more recently if base is None:
base = self.config.rope_base
# autodetect the device from model embeddings # autodetect the device from model embeddings
if device is None: if device is None:
device = self.transformer.wte.weight.device device = self.transformer.wte.weight.device
@ -285,31 +405,42 @@ class GPT(nn.Module):
def _compute_window_sizes(self, config): def _compute_window_sizes(self, config):
""" """
Compute per-layer window sizes for sliding window attention. Compute per-layer window sizes for sliding window attention.
Supports two modes:
Returns list of (left, right) tuples for FA3's window_size parameter: 1. Pattern-based (original): window_pattern string tiled across layers
- left: how many tokens before current position to attend to (-1 = unlimited) 2. Sparse funnel: chirped global placement + fixed local window
- right: how many tokens after current position to attend to (0 for causal)
Pattern string is tiled across layers. Final layer always gets L (full context).
Characters: L=long (full context), S=short (quarter context)
""" """
pattern = config.window_pattern.upper() L = config.n_layer
assert all(c in "SL" for c in pattern), f"Invalid window_pattern: {pattern}. Use only S and L." full = config.sequence_len
# Map characters to window sizes
long_window = config.sequence_len if config.use_sparse_funnel:
short_window = -(-long_window // 4 // 128) * 128 # ceil to FA3 tile size (2048 -> 768) # Sparse funnel: chirped global placement
char_to_window = { global_layers = set(compute_sparse_global_layers(
"L": (long_window, 0), L, config.n_global, config.chirp_gamma, config.global_layer_override
"S": (short_window, 0), ))
} self._global_layers = sorted(global_layers)
# Tile pattern across layers
window_sizes = [] w = config.local_window
for layer_idx in range(config.n_layer): window_sizes = []
char = pattern[layer_idx % len(pattern)] for layer_idx in range(L):
window_sizes.append(char_to_window[char]) if layer_idx in global_layers:
# Final layer always gets full context window_sizes.append((full, 0))
window_sizes[-1] = (long_window, 0) else:
return window_sizes window_sizes.append((w, 0))
return window_sizes
else:
# Original pattern-based mode
self._global_layers = set()
pattern = config.window_pattern.upper()
assert all(c in "SL" for c in pattern), f"Invalid window_pattern: {pattern}. Use only S and L."
long_window = full
short_window = -(-long_window // 4 // 128) * 128
char_to_window = {"L": (long_window, 0), "S": (short_window, 0)}
window_sizes = []
for layer_idx in range(L):
char = pattern[layer_idx % len(pattern)]
window_sizes.append(char_to_window[char])
window_sizes[-1] = (long_window, 0)
return window_sizes
def get_device(self): def get_device(self):
return self.transformer.wte.weight.device return self.transformer.wte.weight.device
@ -329,9 +460,14 @@ class GPT(nn.Module):
nparams = sum(p.numel() for p in self.parameters()) nparams = sum(p.numel() for p in self.parameters())
# Exclude non-matmul params: embeddings and per-layer scalars # Exclude non-matmul params: embeddings and per-layer scalars
value_embeds_numel = sum(ve.weight.numel() for ve in self.value_embeds.values()) value_embeds_numel = sum(ve.weight.numel() for ve in self.value_embeds.values())
bigram_embed_numel = 0 if self.bigram is None else self.bigram.embed.weight.numel() + self.bigram.scale.numel()
smear_numel = 0
if self.smear_gate is not None:
smear_numel += self.smear_gate.weight.numel() + self.smear_lambda.numel()
nparams_exclude = (self.transformer.wte.weight.numel() + value_embeds_numel + nparams_exclude = (self.transformer.wte.weight.numel() + value_embeds_numel +
bigram_embed_numel +
self.resid_lambdas.numel() + self.x0_lambdas.numel() + self.resid_lambdas.numel() + self.x0_lambdas.numel() +
self.smear_gate.weight.numel() + self.smear_lambda.numel() + self.backout_lambda.numel()) smear_numel + self.backout_lambda.numel())
h, q, t = self.config.n_head, self.config.n_embd // self.config.n_head, self.config.sequence_len h, q, t = self.config.n_head, self.config.n_embd // self.config.n_head, self.config.sequence_len
# Sum attention FLOPs per layer, accounting for sliding window # Sum attention FLOPs per layer, accounting for sliding window
attn_flops = 0 attn_flops = 0
@ -356,14 +492,18 @@ class GPT(nn.Module):
""" """
# Count each group separately (mirrors the grouping in setup_optimizers) # Count each group separately (mirrors the grouping in setup_optimizers)
wte = sum(p.numel() for p in self.transformer.wte.parameters()) wte = sum(p.numel() for p in self.transformer.wte.parameters())
bigram_hash = sum(p.numel() for p in self.bigram.parameters()) if self.bigram is not None else 0
value_embeds = sum(p.numel() for p in self.value_embeds.parameters()) value_embeds = sum(p.numel() for p in self.value_embeds.parameters())
lm_head = sum(p.numel() for p in self.lm_head.parameters()) lm_head = sum(p.numel() for p in self.lm_head.parameters())
transformer_matrices = sum(p.numel() for p in self.transformer.h.parameters()) transformer_matrices = sum(p.numel() for p in self.transformer.h.parameters())
scalars = self.resid_lambdas.numel() + self.x0_lambdas.numel() + self.smear_gate.weight.numel() + self.smear_lambda.numel() + self.backout_lambda.numel() scalars = self.resid_lambdas.numel() + self.x0_lambdas.numel() + self.backout_lambda.numel()
total = wte + value_embeds + lm_head + transformer_matrices + scalars if self.smear_gate is not None:
scalars += self.smear_gate.weight.numel() + self.smear_lambda.numel()
total = wte + bigram_hash + value_embeds + lm_head + transformer_matrices + scalars
assert total == sum(p.numel() for p in self.parameters()), "Parameter count mismatch" assert total == sum(p.numel() for p in self.parameters()), "Parameter count mismatch"
return { return {
'wte': wte, 'wte': wte,
'bigram_hash': bigram_hash,
'value_embeds': value_embeds, 'value_embeds': value_embeds,
'lm_head': lm_head, 'lm_head': lm_head,
'transformer_matrices': transformer_matrices, 'transformer_matrices': transformer_matrices,
@ -376,14 +516,28 @@ class GPT(nn.Module):
ddp, rank, local_rank, world_size = get_dist_info() ddp, rank, local_rank, world_size = get_dist_info()
# Separate out all parameters into groups # Separate out all parameters into groups
matrix_params = list(self.transformer.h.parameters()) gated_attn_params = [block.attn.attn_gate.weight for block in self.transformer.h if block.attn.attn_gate is not None]
matrix_params = [p for p in self.transformer.h.parameters() if all(p is not gp for gp in gated_attn_params)]
value_embeds_params = list(self.value_embeds.parameters()) value_embeds_params = list(self.value_embeds.parameters())
embedding_params = list(self.transformer.wte.parameters()) embedding_params = list(self.transformer.wte.parameters())
bigram_embed_params = []
bigram_matrix_params = []
bigram_scalar_params = []
if self.bigram is not None:
bigram_embed_params.append(self.bigram.embed.weight)
if self.bigram.proj is not None:
bigram_matrix_params.append(self.bigram.proj.weight)
bigram_scalar_params.append(self.bigram.scale)
lm_head_params = list(self.lm_head.parameters()) lm_head_params = list(self.lm_head.parameters())
resid_params = [self.resid_lambdas] resid_params = [self.resid_lambdas]
x0_params = [self.x0_lambdas] x0_params = [self.x0_lambdas]
smear_params = [self.smear_gate.weight, self.smear_lambda, self.backout_lambda] smear_params = []
assert len(list(self.parameters())) == len(matrix_params) + len(embedding_params) + len(lm_head_params) + len(value_embeds_params) + len(resid_params) + len(x0_params) + len(smear_params) if self.smear_gate is not None:
smear_params.extend([self.smear_gate.weight, self.smear_lambda])
backout_params = [self.backout_lambda]
matrix_params.extend(bigram_matrix_params)
all_grouped = matrix_params + gated_attn_params + value_embeds_params + embedding_params + bigram_embed_params + bigram_scalar_params + lm_head_params + resid_params + x0_params + smear_params + backout_params
assert len(list(self.parameters())) == len(all_grouped)
# Scale the LR for the AdamW parameters by ∝1/√dmodel (tuned for 768 dim model) # Scale the LR for the AdamW parameters by ∝1/√dmodel (tuned for 768 dim model)
dmodel_lr_scale = (model_dim / 768) ** -0.5 dmodel_lr_scale = (model_dim / 768) ** -0.5
@ -397,8 +551,16 @@ class GPT(nn.Module):
dict(kind='adamw', params=value_embeds_params, lr=embedding_lr * dmodel_lr_scale * 0.5, betas=(0.8, 0.995), eps=1e-10, weight_decay=0.01), dict(kind='adamw', params=value_embeds_params, lr=embedding_lr * dmodel_lr_scale * 0.5, betas=(0.8, 0.995), eps=1e-10, weight_decay=0.01),
dict(kind='adamw', params=resid_params, lr=scalar_lr * 0.01, betas=(0.8, 0.95), eps=1e-10, weight_decay=0.05), dict(kind='adamw', params=resid_params, lr=scalar_lr * 0.01, betas=(0.8, 0.95), eps=1e-10, weight_decay=0.05),
dict(kind='adamw', params=x0_params, lr=scalar_lr, betas=(0.96, 0.95), eps=1e-10, weight_decay=0.0), # higher beta1 for x0 dict(kind='adamw', params=x0_params, lr=scalar_lr, betas=(0.96, 0.95), eps=1e-10, weight_decay=0.0), # higher beta1 for x0
dict(kind='adamw', params=smear_params, lr=0.2, betas=(0.8, 0.95), eps=1e-10, weight_decay=0.0),
] ]
if bigram_embed_params:
param_groups.append(dict(kind='adamw', params=bigram_embed_params, lr=embedding_lr * dmodel_lr_scale, betas=(0.8, 0.995), eps=1e-10, weight_decay=0.001))
if bigram_scalar_params:
param_groups.append(dict(kind='adamw', params=bigram_scalar_params, lr=0.1, betas=(0.8, 0.95), eps=1e-10, weight_decay=0.0))
if gated_attn_params:
param_groups.append(dict(kind='adamw', params=gated_attn_params, lr=0.15, betas=(0.8, 0.95), eps=1e-10, weight_decay=0.0))
if smear_params:
param_groups.append(dict(kind='adamw', params=smear_params, lr=0.2, betas=(0.8, 0.95), eps=1e-10, weight_decay=0.0))
param_groups.append(dict(kind='adamw', params=backout_params, lr=0.15, betas=(0.8, 0.95), eps=1e-10, weight_decay=0.0))
# Muon groups (matrix params, grouped by shape for stacking) # Muon groups (matrix params, grouped by shape for stacking)
for shape in sorted({p.shape for p in matrix_params}): for shape in sorted({p.shape for p in matrix_params}):
group_params = [p for p in matrix_params if p.shape == shape] group_params = [p for p in matrix_params if p.shape == shape]
@ -427,26 +589,33 @@ class GPT(nn.Module):
# Embed the tokens # Embed the tokens
x = self.transformer.wte(idx) # embed current token x = self.transformer.wte(idx) # embed current token
x = x.to(COMPUTE_DTYPE) # ensure activations are in compute dtype (no-op usually, but active for fp16 code path) x = x.to(COMPUTE_DTYPE) # ensure activations are in compute dtype (no-op usually, but active for fp16 code path)
if self.bigram is not None:
prev_tokens = None if kv_cache is None else kv_cache.prev_token_ids
x = x + self.bigram(idx, prev_tokens=prev_tokens).to(x.dtype)
if kv_cache is not None:
kv_cache.prev_token_ids = idx[:, -1].clone()
x = norm(x) x = norm(x)
x_base = x
# Smear: mix previous token's embedding into current position (cheap bigram info) # Smear: inject a gated copy of the previous token embedding as cheap bigram context.
if kv_cache is None: if self.smear_gate is not None:
# Training / naive generate: full sequence available, use fast slice gate_channels = self.config.smear_channels
assert T > 1, "Training forward pass should have T > 1" if kv_cache is None:
gate = self.smear_lambda.to(x.dtype) * torch.sigmoid(self.smear_gate(x[:, 1:, :24])) assert T > 1, "Training forward pass should have T > 1"
x = torch.cat([x[:, :1], x[:, 1:] + gate * x[:, :-1]], dim=1) x = x_base.clone()
else: gate = self.smear_lambda.to(x.dtype) * torch.sigmoid(self.smear_gate(x_base[:, 1:, :gate_channels]))
# KV cache inference: read prev embedding from cache, store current for next step x[:, 1:] = x[:, 1:] + gate * x_base[:, :-1]
x_pre_smear = kv_cache.prev_embedding else:
kv_cache.prev_embedding = x[:, -1:, :] prev1 = kv_cache.prev_embedding
if T > 1: kv_cache.prev_embedding = x_base[:, -1:, :]
# Prefill: apply smear to positions 1+, same as training x = x_base.clone()
gate = self.smear_lambda.to(x.dtype) * torch.sigmoid(self.smear_gate(x[:, 1:, :24]))
x = torch.cat([x[:, :1], x[:, 1:] + gate * x[:, :-1]], dim=1) if prev1 is not None:
elif x_pre_smear is not None: gate_first = self.smear_lambda.to(x.dtype) * torch.sigmoid(self.smear_gate(x_base[:, :1, :gate_channels]))
# Decode: single token, use cached prev embedding x[:, :1] = x[:, :1] + gate_first * prev1
gate = self.smear_lambda.to(x.dtype) * torch.sigmoid(self.smear_gate(x[:, :, :24])) if T > 1:
x = x + gate * x_pre_smear gate = self.smear_lambda.to(x.dtype) * torch.sigmoid(self.smear_gate(x_base[:, 1:, :gate_channels]))
x[:, 1:] = x[:, 1:] + gate * x_base[:, :-1]
# Forward the trunk of the Transformer # Forward the trunk of the Transformer
x0 = x # save initial normalized embedding for x0 residual x0 = x # save initial normalized embedding for x0 residual

View File

@ -69,8 +69,8 @@ python -m scripts.tok_eval
echo "Waiting for dataset download to complete..." echo "Waiting for dataset download to complete..."
wait $DATASET_DOWNLOAD_PID wait $DATASET_DOWNLOAD_PID
# d24 model (slightly undertrained to beat GPT-2 => decrease data:params ratio from compute optimal 10.5 (default) to 8) # d24 sparse recipe (baked into the branch defaults), using the recorded 5,318-step horizon
torchrun --standalone --nproc_per_node=8 -m scripts.base_train -- --depth=24 --target-param-data-ratio=8 --device-batch-size=16 --fp8 --run=$WANDB_RUN torchrun --standalone --nproc_per_node=8 -m scripts.base_train -- --depth=24 --target-param-data-ratio=8 --num-iterations=5318 --device-batch-size=16 --fp8 --run=$WANDB_RUN
# evaluate the model: CORE metric, BPB on train/val, and draw samples # evaluate the model: CORE metric, BPB on train/val, and draw samples
torchrun --standalone --nproc_per_node=8 -m scripts.base_eval -- --device-batch-size=16 torchrun --standalone --nproc_per_node=8 -m scripts.base_eval -- --device-batch-size=16

View File

@ -25,7 +25,7 @@ import wandb
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from nanochat.gpt import GPT, GPTConfig, Linear from nanochat.gpt import GPT, GPTConfig, Linear, default_sparse_n_global, compute_sparse_global_layers
from nanochat.dataloader import tokenizing_distributed_data_loader_bos_bestfit, tokenizing_distributed_data_loader_with_state_bos_bestfit from nanochat.dataloader import tokenizing_distributed_data_loader_bos_bestfit, tokenizing_distributed_data_loader_with_state_bos_bestfit
from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, print_banner, get_base_dir, autodetect_device_type, get_peak_flops, COMPUTE_DTYPE, COMPUTE_DTYPE_REASON, is_ddp_initialized from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, print_banner, get_base_dir, autodetect_device_type, get_peak_flops, COMPUTE_DTYPE, COMPUTE_DTYPE_REASON, is_ddp_initialized
from nanochat.tokenizer import get_tokenizer, get_token_bytes from nanochat.tokenizer import get_tokenizer, get_token_bytes
@ -52,6 +52,20 @@ parser.add_argument("--aspect-ratio", type=int, default=64, help="model_dim = de
parser.add_argument("--head-dim", type=int, default=128, help="target head dimension for attention") parser.add_argument("--head-dim", type=int, default=128, help="target head dimension for attention")
parser.add_argument("--max-seq-len", type=int, default=2048, help="max context length") parser.add_argument("--max-seq-len", type=int, default=2048, help="max context length")
parser.add_argument("--window-pattern", type=str, default="SSSL", help="sliding window pattern tiled across layers: L=full, S=half context (e.g. 'SSL')") parser.add_argument("--window-pattern", type=str, default="SSSL", help="sliding window pattern tiled across layers: L=full, S=half context (e.g. 'SSL')")
# Sparse funnel architecture
parser.add_argument("--sparse-funnel", action=argparse.BooleanOptionalAction, default=True, help="use sparse funnel architecture instead of window pattern")
parser.add_argument("--n-global", type=int, default=0, help="number of global (full-context) layers, 0 = auto depth//4")
parser.add_argument("--chirp-gamma", type=float, default=0.7, help="chirp exponent for global layer placement")
parser.add_argument("--local-window", type=int, default=128, help="local attention window size")
parser.add_argument("--global-layers", type=str, default="", help="explicit global layer indices, comma-separated (overrides chirp)")
parser.add_argument("--rope-base", type=int, default=200000, help="RoPE base frequency")
parser.add_argument("--smear", action=argparse.BooleanOptionalAction, default=True, help="restore cheap bigram-like embedding mixing before the transformer trunk")
parser.add_argument("--smear-channels", type=int, default=24, help="number of embedding channels used to predict smear gate")
parser.add_argument("--bigram-vocab-size", type=int, default=4096, help="hashed bigram embedding buckets (0 disables)")
parser.add_argument("--bigram-dim", type=int, default=128, help="bigram embedding dim before projection")
parser.add_argument("--gated-attn", action=argparse.BooleanOptionalAction, default=True, help="learn lightweight per-head attention output gates")
parser.add_argument("--attn-gate-channels", type=int, default=12, help="embedding channels used to predict attention gates")
parser.add_argument("--ve-layers", type=str, default="", help="explicit value-residual layer indices, comma-separated")
# Training horizon (only one used, in order of precedence) # Training horizon (only one used, in order of precedence)
parser.add_argument("--num-iterations", type=int, default=-1, help="explicit number of optimization steps (-1 = disable)") parser.add_argument("--num-iterations", type=int, default=-1, help="explicit number of optimization steps (-1 = disable)")
parser.add_argument("--target-flops", type=float, default=-1.0, help="calculate num_iterations to reach target_flops (-1 = disable)") parser.add_argument("--target-flops", type=float, default=-1.0, help="calculate num_iterations to reach target_flops (-1 = disable)")
@ -78,6 +92,14 @@ parser.add_argument("--save-every", type=int, default=-1, help="save checkpoints
# Output # Output
parser.add_argument("--model-tag", type=str, default=None, help="override model tag for checkpoint directory name") parser.add_argument("--model-tag", type=str, default=None, help="override model tag for checkpoint directory name")
args = parser.parse_args() args = parser.parse_args()
if args.sparse_funnel and args.n_global <= 0:
args.n_global = default_sparse_n_global(args.depth)
if args.sparse_funnel and not args.ve_layers:
args.ve_layers = ",".join(
str(x) for x in compute_sparse_global_layers(args.depth, args.n_global, args.chirp_gamma)
)
user_config = vars(args).copy() # for logging user_config = vars(args).copy() # for logging
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Compute init and wandb logging # Compute init and wandb logging
@ -133,10 +155,21 @@ def build_model_meta(depth):
base_dim = depth * args.aspect_ratio base_dim = depth * args.aspect_ratio
model_dim = ((base_dim + args.head_dim - 1) // args.head_dim) * args.head_dim model_dim = ((base_dim + args.head_dim - 1) // args.head_dim) * args.head_dim
num_heads = model_dim // args.head_dim num_heads = model_dim // args.head_dim
# Parse explicit layer lists
global_layer_override = tuple(int(x) for x in args.global_layers.split(",") if x.strip()) if args.global_layers else ()
ve_layers_parsed = tuple(int(x) for x in args.ve_layers.split(",") if x.strip()) if args.ve_layers else ()
config = GPTConfig( config = GPTConfig(
sequence_len=args.max_seq_len, vocab_size=vocab_size, sequence_len=args.max_seq_len, vocab_size=vocab_size,
n_layer=depth, n_head=num_heads, n_kv_head=num_heads, n_embd=model_dim, n_layer=depth, n_head=num_heads, n_kv_head=num_heads, n_embd=model_dim,
window_pattern=args.window_pattern, window_pattern=args.window_pattern,
use_sparse_funnel=args.sparse_funnel,
n_global=args.n_global, chirp_gamma=args.chirp_gamma,
local_window=args.local_window, global_layer_override=global_layer_override,
rope_base=args.rope_base,
use_smear=args.smear, smear_channels=args.smear_channels,
bigram_vocab_size=args.bigram_vocab_size, bigram_dim=args.bigram_dim,
use_gated_attn=args.gated_attn, attn_gate_channels=args.attn_gate_channels,
ve_layers=ve_layers_parsed,
) )
with torch.device("meta"): with torch.device("meta"):
model_meta = GPT(config) model_meta = GPT(config)