Add minimal sparse d24 recipe

This commit is contained in:
Codex 2026-04-28 00:32:54 +00:00
parent 0aaca56805
commit 03b2dbe63c
4 changed files with 286 additions and 79 deletions

View File

@ -102,11 +102,14 @@ class KVCache:
self.cache_seqlens = torch.zeros(batch_size, dtype=torch.int32, device=device)
# Previous token's normalized embedding for smear (set by model forward pass)
self.prev_embedding = None
# Previous token ids for bigram hash features during decoding
self.prev_token_ids = None
def reset(self):
"""Reset cache to empty state."""
self.cache_seqlens.zero_()
self.prev_embedding = None
self.prev_token_ids = None
def get_pos(self):
"""Get current position (assumes all batch elements at same position)."""
@ -135,6 +138,8 @@ class KVCache:
# Copy smear state: expand batch=1 prev_embedding to num_samples
if other.prev_embedding is not None:
self.prev_embedding = other.prev_embedding.expand(self.batch_size, -1, -1).clone()
if other.prev_token_ids is not None:
self.prev_token_ids = other.prev_token_ids.expand(self.batch_size).clone()
# -----------------------------------------------------------------------------
@torch.inference_mode()

View File

@ -37,6 +37,37 @@ class GPTConfig:
# Characters: L=long (full context), S=short (quarter context)
# Examples: "L"=all full context, "SL"=alternating, "SSL"=two short then one long
window_pattern: str = "SSSL"
# Sparse funnel attention: local layers use a short window, global layers use full context.
use_sparse_funnel: bool = True
n_global: int = 0 # number of full-context (global) layers, 0 = auto depth//4
chirp_gamma: float = 0.7 # chirp exponent for global placement (deeper-biased)
local_window: int = 128 # local attention window size
global_layer_override: tuple = () # explicit global layer indices (overrides chirp)
rope_base: int = 200000 # RoPE base frequency
use_smear: bool = True # cheap bigram-like embedding mixing
smear_channels: int = 24 # number of early embedding channels used to predict smear gate
bigram_vocab_size: int = 4096 # hashed bigram embedding buckets (0 disables)
bigram_dim: int = 128 # bigram embedding dim before projection
use_gated_attn: bool = True # learn a lightweight per-head gate on attention output
attn_gate_channels: int = 12 # channels used to predict attention gate
ve_layers: tuple = () # explicit layer indices for value residuals (empty = auto global layers under sparse funnel)
def default_sparse_n_global(n_layer):
return max(1, n_layer // 4)
def compute_sparse_global_layers(n_layer, n_global, chirp_gamma, global_layer_override=()):
if global_layer_override:
global_layers = set(global_layer_override)
else:
G = n_global if n_global > 0 else default_sparse_n_global(n_layer)
global_layers = set()
for i in range(1, G + 1):
idx = max(0, min(n_layer - 1, int(n_layer * (i / G) ** chirp_gamma) - 1))
global_layers.add(idx)
global_layers.add(n_layer - 1)
return tuple(sorted(global_layers))
def norm(x):
@ -47,11 +78,58 @@ class Linear(nn.Linear):
Replaces autocast: master weights stay fp32 for optimizer precision,
but matmuls run in the activation dtype (typically bf16 from embeddings)."""
def forward(self, x):
return F.linear(x, self.weight.to(dtype=x.dtype))
w = self.weight
if w.dtype != x.dtype:
w = w.to(dtype=x.dtype)
return F.linear(x, w)
def has_ve(layer_idx, n_layer):
"""Returns True if GPT layer should have Value Embedding (alternating, last layer always included)."""
class EmbeddingLinear(nn.Module):
"""Lightweight linear layer for lm_head without redundant dtype casting."""
def __init__(self, in_features, out_features, bias=False, device=None, dtype=None):
super().__init__()
assert not bias
self.in_features = in_features
self.out_features = out_features
self.weight = nn.Parameter(torch.empty(out_features, in_features, device=device, dtype=dtype))
def forward(self, x):
return F.linear(x, self.weight)
class BigramHashEmbedding(nn.Module):
"""Hash causal token pairs into a compact learned embedding."""
def __init__(self, bigram_vocab_size, bigram_dim, model_dim):
super().__init__()
self.bigram_vocab_size = bigram_vocab_size
self.embed = nn.Embedding(bigram_vocab_size, bigram_dim)
self.proj = Linear(bigram_dim, model_dim, bias=False) if bigram_dim != model_dim else None
self.scale = nn.Parameter(torch.tensor(0.05, dtype=torch.float32))
def bigram_hash(self, tokens, prev_tokens=None):
t = tokens.to(torch.int32)
mod = self.bigram_vocab_size - 1
out = torch.full_like(t, mod)
if prev_tokens is None:
if t.size(1) > 1:
out[..., 1:] = torch.bitwise_xor(36313 * t[..., 1:], 27191 * t[..., :-1]) % mod
else:
prev = prev_tokens.to(torch.int32).view(t.size(0), 1)
out[..., 0] = torch.bitwise_xor(36313 * t[..., 0], 27191 * prev[..., 0]) % mod
if t.size(1) > 1:
out[..., 1:] = torch.bitwise_xor(36313 * t[..., 1:], 27191 * t[..., :-1]) % mod
return out.long()
def forward(self, token_ids, prev_tokens=None):
h = self.embed(self.bigram_hash(token_ids, prev_tokens=prev_tokens))
if self.proj is not None:
h = self.proj(h)
return h * self.scale.to(dtype=h.dtype)
def has_ve(layer_idx, n_layer, layer_set=()):
"""Returns True if GPT layer should have Value Embedding."""
if layer_set:
return layer_idx in layer_set
return layer_idx % 2 == (n_layer - 1) % 2
def apply_rotary_emb(x, cos, sin):
@ -77,7 +155,13 @@ class CausalSelfAttention(nn.Module):
self.c_v = Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False)
self.c_proj = Linear(self.n_embd, self.n_embd, bias=False)
self.ve_gate_channels = 12
self.ve_gate = Linear(self.ve_gate_channels, self.n_kv_head, bias=False) if has_ve(layer_idx, config.n_layer) else None
self.ve_gate = Linear(
self.ve_gate_channels,
self.n_kv_head,
bias=False,
) if has_ve(layer_idx, config.n_layer, layer_set=config.ve_layers) else None
self.attn_gate_channels = config.attn_gate_channels
self.attn_gate = Linear(self.attn_gate_channels, self.n_head, bias=False) if config.use_gated_attn else None
def forward(self, x, ve, cos_sin, window_size, kv_cache):
B, T, C = x.size()
@ -91,7 +175,7 @@ class CausalSelfAttention(nn.Module):
# Value residual (ResFormer): mix in value embedding with input-dependent gate per head
if ve is not None:
ve = ve.view(B, T, self.n_kv_head, self.head_dim)
gate = 3 * torch.sigmoid(self.ve_gate(x[..., :self.ve_gate_channels])) # (B, T, n_kv_head), range (0, 3)
gate = 3 * torch.sigmoid(self.ve_gate(x[..., :self.ve_gate_channels]))
v = v + gate.unsqueeze(-1) * ve
# Apply Rotary Embeddings to queries and keys to get relative positional encoding
@ -120,6 +204,10 @@ class CausalSelfAttention(nn.Module):
if self.layer_idx == kv_cache.n_layers - 1:
kv_cache.advance(T)
if self.attn_gate is not None:
gate = 2 * torch.sigmoid(self.attn_gate(x[..., :self.attn_gate_channels]))
y = y * gate.unsqueeze(-1).to(dtype=y.dtype)
# Re-assemble the heads and project back to residual stream
y = y.contiguous().view(B, T, -1)
y = self.c_proj(y)
@ -160,11 +248,16 @@ class GPT(nn.Module):
"""
super().__init__()
self.config = config
if config.use_sparse_funnel and config.n_global <= 0:
config.n_global = default_sparse_n_global(config.n_layer)
if config.use_sparse_funnel and not config.ve_layers:
config.ve_layers = compute_sparse_global_layers(
config.n_layer, config.n_global, config.chirp_gamma, config.global_layer_override
)
# Compute per-layer window sizes for sliding window attention
# window_size is (left, right) tuple: (-1, 0) for full context, (N, 0) for sliding window
self.window_sizes = self._compute_window_sizes(config)
# Pad vocab for efficiency (DDP, tensor cores). This is just an optimization - outputs are cropped in forward().
# https://huggingface.co/docs/transformers/main_classes/model#transformers.PreTrainedModel.resize_token_embeddings
padded_vocab_size = ((config.vocab_size + pad_vocab_size_to - 1) // pad_vocab_size_to) * pad_vocab_size_to
if padded_vocab_size != config.vocab_size:
print0(f"Padding vocab_size from {config.vocab_size} to {padded_vocab_size} for efficiency")
@ -172,22 +265,31 @@ class GPT(nn.Module):
"wte": nn.Embedding(padded_vocab_size, config.n_embd),
"h": nn.ModuleList([Block(config, layer_idx) for layer_idx in range(config.n_layer)]),
})
self.lm_head = Linear(config.n_embd, padded_vocab_size, bias=False)
self.bigram = BigramHashEmbedding(config.bigram_vocab_size, config.bigram_dim, config.n_embd) if config.bigram_vocab_size > 0 else None
self.lm_head = EmbeddingLinear(config.n_embd, padded_vocab_size, bias=False)
# Per-layer learnable scalars (inspired by modded-nanogpt)
# resid_lambdas: scales the residual stream at each layer (init 1.0 = neutral)
# x0_lambdas: blends initial embedding back in at each layer (init 0.0 = disabled)
# Separate parameters so they can have different optimizer treatment
self.resid_lambdas = nn.Parameter(torch.ones(config.n_layer)) # fake init, real init in init_weights()
self.x0_lambdas = nn.Parameter(torch.zeros(config.n_layer)) # fake init, real init in init_weights()
# Smear: mix previous token's embedding into current token (cheap bigram-like info)
self.smear_gate = Linear(24, 1, bias=False)
self.smear_lambda = nn.Parameter(torch.zeros(1))
# Optional smear: mix previous token's embedding into current token before the trunk
if config.use_smear:
self.smear_gate = Linear(config.smear_channels, 1, bias=False)
self.smear_lambda = nn.Parameter(torch.zeros(1))
else:
self.smear_gate = None
self.smear_lambda = None
# Backout: subtract cached mid-layer residual before final norm to remove low-level features
self.backout_lambda = nn.Parameter(0.2 * torch.ones(1))
# Value embeddings (ResFormer-style): alternating layers, last layer always included
head_dim = config.n_embd // config.n_head
kv_dim = config.n_kv_head * head_dim
self.value_embeds = nn.ModuleDict({str(i): nn.Embedding(padded_vocab_size, kv_dim) for i in range(config.n_layer) if has_ve(i, config.n_layer)})
self.value_embeds = nn.ModuleDict({
str(i): nn.Embedding(padded_vocab_size, kv_dim)
for i in range(config.n_layer)
if has_ve(i, config.n_layer, layer_set=config.ve_layers)
})
# To support meta device initialization, we init the rotary embeddings here, but it's just "fake" meta tensors only.
# As for rotary_seq_len, these rotary embeddings are pretty small/cheap in memory,
# so let's just over-compute them by 10X, but assert fail if we ever reach that amount.
@ -217,31 +319,41 @@ class GPT(nn.Module):
# Embedding and unembedding
torch.nn.init.normal_(self.transformer.wte.weight, mean=0.0, std=0.8)
torch.nn.init.normal_(self.lm_head.weight, mean=0.0, std=0.001)
if self.bigram is not None:
torch.nn.init.zeros_(self.bigram.embed.weight)
if self.bigram.proj is not None:
bigram_s = 3**0.5 * self.bigram.embed.embedding_dim ** -0.5
torch.nn.init.uniform_(self.bigram.proj.weight, -bigram_s, bigram_s)
self.bigram.scale.data.fill_(0.05)
if self.smear_gate is not None:
torch.nn.init.zeros_(self.smear_gate.weight)
self.smear_lambda.data.zero_()
self.backout_lambda.data.fill_(0.2)
# Transformer blocks: uniform init with bound = sqrt(3) * std (same standard deviation as normal)
n_embd = self.config.n_embd
s = 3**0.5 * n_embd**-0.5 # sqrt(3) multiplier makes sure Uniform achieves the same std as Normal
for block in self.transformer.h:
torch.nn.init.uniform_(block.attn.c_q.weight, -s, s) # weights use Uniform to avoid outliers
torch.nn.init.uniform_(block.attn.c_q.weight, -s, s)
torch.nn.init.uniform_(block.attn.c_k.weight, -s, s)
torch.nn.init.uniform_(block.attn.c_v.weight, -s, s)
torch.nn.init.zeros_(block.attn.c_proj.weight) # projections are zero
torch.nn.init.uniform_(block.mlp.c_fc.weight, -s * 0.4, s * 0.4) # 0.4x init scale for c_fc
torch.nn.init.uniform_(block.attn.c_v.weight, -0.85 * s, 0.85 * s)
torch.nn.init.uniform_(block.attn.c_proj.weight, -0.008, 0.008)
torch.nn.init.uniform_(block.mlp.c_fc.weight, -s * 0.4, s * 0.4)
torch.nn.init.zeros_(block.mlp.c_proj.weight)
# Per-layer scalars
# Per-layer resid init: stronger residual at early layers, weaker at deep layers
import math
n_layer = self.config.n_layer
resid_start, resid_end = 1.18, 1.06
resid_decay = math.log(resid_start / resid_end) / max(n_layer - 1, 1)
half_depth = max(1, n_layer // 2)
for i in range(n_layer):
self.resid_lambdas.data[i] = 1.15 - (0.10 * i / max(n_layer - 1, 1))
# Decaying x0 init: earlier layers get more input embedding blending
for i in range(n_layer):
self.x0_lambdas.data[i] = 0.20 - (0.15 * i / max(n_layer - 1, 1))
# Smear/backout scalars and smear gate must be explicitly initialized
torch.nn.init.zeros_(self.smear_lambda)
torch.nn.init.constant_(self.backout_lambda, 0.2)
torch.nn.init.uniform_(self.smear_gate.weight, 0.0, 0.02)
self.resid_lambdas.data[i] = resid_start * math.exp(-resid_decay * i)
if i < half_depth:
frac = i / max(half_depth - 1, 1)
self.x0_lambdas.data[i] = 0.24 * (1.0 - frac) + 0.08 * frac
else:
self.x0_lambdas.data[i] = 0.0
# Value embeddings (init like c_v: uniform with same std)
for ve in self.value_embeds.values():
@ -251,6 +363,8 @@ class GPT(nn.Module):
for block in self.transformer.h:
if block.attn.ve_gate is not None:
torch.nn.init.uniform_(block.attn.ve_gate.weight, 0.0, 0.02)
if block.attn.attn_gate is not None:
torch.nn.init.uniform_(block.attn.attn_gate.weight, 0.0, 0.02)
# Rotary embeddings
head_dim = self.config.n_embd // self.config.n_head
@ -262,11 +376,17 @@ class GPT(nn.Module):
# because GradScaler cannot unscale fp16 gradients.
if COMPUTE_DTYPE != torch.float16:
self.transformer.wte.to(dtype=COMPUTE_DTYPE)
self.lm_head.to(dtype=COMPUTE_DTYPE)
for ve in self.value_embeds.values():
ve.to(dtype=COMPUTE_DTYPE)
if self.bigram is not None:
self.bigram.embed.to(dtype=COMPUTE_DTYPE)
if self.bigram.proj is not None:
self.bigram.proj.to(dtype=COMPUTE_DTYPE)
def _precompute_rotary_embeddings(self, seq_len, head_dim, base=100000, device=None):
# TODO: bump base theta more? e.g. 100K is more common more recently
def _precompute_rotary_embeddings(self, seq_len, head_dim, base=None, device=None):
if base is None:
base = self.config.rope_base
# autodetect the device from model embeddings
if device is None:
device = self.transformer.wte.weight.device
@ -285,31 +405,42 @@ class GPT(nn.Module):
def _compute_window_sizes(self, config):
"""
Compute per-layer window sizes for sliding window attention.
Returns list of (left, right) tuples for FA3's window_size parameter:
- left: how many tokens before current position to attend to (-1 = unlimited)
- right: how many tokens after current position to attend to (0 for causal)
Pattern string is tiled across layers. Final layer always gets L (full context).
Characters: L=long (full context), S=short (quarter context)
Supports two modes:
1. Pattern-based (original): window_pattern string tiled across layers
2. Sparse funnel: chirped global placement + fixed local window
"""
pattern = config.window_pattern.upper()
assert all(c in "SL" for c in pattern), f"Invalid window_pattern: {pattern}. Use only S and L."
# Map characters to window sizes
long_window = config.sequence_len
short_window = -(-long_window // 4 // 128) * 128 # ceil to FA3 tile size (2048 -> 768)
char_to_window = {
"L": (long_window, 0),
"S": (short_window, 0),
}
# Tile pattern across layers
window_sizes = []
for layer_idx in range(config.n_layer):
char = pattern[layer_idx % len(pattern)]
window_sizes.append(char_to_window[char])
# Final layer always gets full context
window_sizes[-1] = (long_window, 0)
return window_sizes
L = config.n_layer
full = config.sequence_len
if config.use_sparse_funnel:
# Sparse funnel: chirped global placement
global_layers = set(compute_sparse_global_layers(
L, config.n_global, config.chirp_gamma, config.global_layer_override
))
self._global_layers = sorted(global_layers)
w = config.local_window
window_sizes = []
for layer_idx in range(L):
if layer_idx in global_layers:
window_sizes.append((full, 0))
else:
window_sizes.append((w, 0))
return window_sizes
else:
# Original pattern-based mode
self._global_layers = set()
pattern = config.window_pattern.upper()
assert all(c in "SL" for c in pattern), f"Invalid window_pattern: {pattern}. Use only S and L."
long_window = full
short_window = -(-long_window // 4 // 128) * 128
char_to_window = {"L": (long_window, 0), "S": (short_window, 0)}
window_sizes = []
for layer_idx in range(L):
char = pattern[layer_idx % len(pattern)]
window_sizes.append(char_to_window[char])
window_sizes[-1] = (long_window, 0)
return window_sizes
def get_device(self):
return self.transformer.wte.weight.device
@ -329,9 +460,14 @@ class GPT(nn.Module):
nparams = sum(p.numel() for p in self.parameters())
# Exclude non-matmul params: embeddings and per-layer scalars
value_embeds_numel = sum(ve.weight.numel() for ve in self.value_embeds.values())
bigram_embed_numel = 0 if self.bigram is None else self.bigram.embed.weight.numel() + self.bigram.scale.numel()
smear_numel = 0
if self.smear_gate is not None:
smear_numel += self.smear_gate.weight.numel() + self.smear_lambda.numel()
nparams_exclude = (self.transformer.wte.weight.numel() + value_embeds_numel +
bigram_embed_numel +
self.resid_lambdas.numel() + self.x0_lambdas.numel() +
self.smear_gate.weight.numel() + self.smear_lambda.numel() + self.backout_lambda.numel())
smear_numel + self.backout_lambda.numel())
h, q, t = self.config.n_head, self.config.n_embd // self.config.n_head, self.config.sequence_len
# Sum attention FLOPs per layer, accounting for sliding window
attn_flops = 0
@ -356,14 +492,18 @@ class GPT(nn.Module):
"""
# Count each group separately (mirrors the grouping in setup_optimizers)
wte = sum(p.numel() for p in self.transformer.wte.parameters())
bigram_hash = sum(p.numel() for p in self.bigram.parameters()) if self.bigram is not None else 0
value_embeds = sum(p.numel() for p in self.value_embeds.parameters())
lm_head = sum(p.numel() for p in self.lm_head.parameters())
transformer_matrices = sum(p.numel() for p in self.transformer.h.parameters())
scalars = self.resid_lambdas.numel() + self.x0_lambdas.numel() + self.smear_gate.weight.numel() + self.smear_lambda.numel() + self.backout_lambda.numel()
total = wte + value_embeds + lm_head + transformer_matrices + scalars
scalars = self.resid_lambdas.numel() + self.x0_lambdas.numel() + self.backout_lambda.numel()
if self.smear_gate is not None:
scalars += self.smear_gate.weight.numel() + self.smear_lambda.numel()
total = wte + bigram_hash + value_embeds + lm_head + transformer_matrices + scalars
assert total == sum(p.numel() for p in self.parameters()), "Parameter count mismatch"
return {
'wte': wte,
'bigram_hash': bigram_hash,
'value_embeds': value_embeds,
'lm_head': lm_head,
'transformer_matrices': transformer_matrices,
@ -376,14 +516,28 @@ class GPT(nn.Module):
ddp, rank, local_rank, world_size = get_dist_info()
# Separate out all parameters into groups
matrix_params = list(self.transformer.h.parameters())
gated_attn_params = [block.attn.attn_gate.weight for block in self.transformer.h if block.attn.attn_gate is not None]
matrix_params = [p for p in self.transformer.h.parameters() if all(p is not gp for gp in gated_attn_params)]
value_embeds_params = list(self.value_embeds.parameters())
embedding_params = list(self.transformer.wte.parameters())
bigram_embed_params = []
bigram_matrix_params = []
bigram_scalar_params = []
if self.bigram is not None:
bigram_embed_params.append(self.bigram.embed.weight)
if self.bigram.proj is not None:
bigram_matrix_params.append(self.bigram.proj.weight)
bigram_scalar_params.append(self.bigram.scale)
lm_head_params = list(self.lm_head.parameters())
resid_params = [self.resid_lambdas]
x0_params = [self.x0_lambdas]
smear_params = [self.smear_gate.weight, self.smear_lambda, self.backout_lambda]
assert len(list(self.parameters())) == len(matrix_params) + len(embedding_params) + len(lm_head_params) + len(value_embeds_params) + len(resid_params) + len(x0_params) + len(smear_params)
smear_params = []
if self.smear_gate is not None:
smear_params.extend([self.smear_gate.weight, self.smear_lambda])
backout_params = [self.backout_lambda]
matrix_params.extend(bigram_matrix_params)
all_grouped = matrix_params + gated_attn_params + value_embeds_params + embedding_params + bigram_embed_params + bigram_scalar_params + lm_head_params + resid_params + x0_params + smear_params + backout_params
assert len(list(self.parameters())) == len(all_grouped)
# Scale the LR for the AdamW parameters by ∝1/√dmodel (tuned for 768 dim model)
dmodel_lr_scale = (model_dim / 768) ** -0.5
@ -397,8 +551,16 @@ class GPT(nn.Module):
dict(kind='adamw', params=value_embeds_params, lr=embedding_lr * dmodel_lr_scale * 0.5, betas=(0.8, 0.995), eps=1e-10, weight_decay=0.01),
dict(kind='adamw', params=resid_params, lr=scalar_lr * 0.01, betas=(0.8, 0.95), eps=1e-10, weight_decay=0.05),
dict(kind='adamw', params=x0_params, lr=scalar_lr, betas=(0.96, 0.95), eps=1e-10, weight_decay=0.0), # higher beta1 for x0
dict(kind='adamw', params=smear_params, lr=0.2, betas=(0.8, 0.95), eps=1e-10, weight_decay=0.0),
]
if bigram_embed_params:
param_groups.append(dict(kind='adamw', params=bigram_embed_params, lr=embedding_lr * dmodel_lr_scale, betas=(0.8, 0.995), eps=1e-10, weight_decay=0.001))
if bigram_scalar_params:
param_groups.append(dict(kind='adamw', params=bigram_scalar_params, lr=0.1, betas=(0.8, 0.95), eps=1e-10, weight_decay=0.0))
if gated_attn_params:
param_groups.append(dict(kind='adamw', params=gated_attn_params, lr=0.15, betas=(0.8, 0.95), eps=1e-10, weight_decay=0.0))
if smear_params:
param_groups.append(dict(kind='adamw', params=smear_params, lr=0.2, betas=(0.8, 0.95), eps=1e-10, weight_decay=0.0))
param_groups.append(dict(kind='adamw', params=backout_params, lr=0.15, betas=(0.8, 0.95), eps=1e-10, weight_decay=0.0))
# Muon groups (matrix params, grouped by shape for stacking)
for shape in sorted({p.shape for p in matrix_params}):
group_params = [p for p in matrix_params if p.shape == shape]
@ -427,26 +589,33 @@ class GPT(nn.Module):
# Embed the tokens
x = self.transformer.wte(idx) # embed current token
x = x.to(COMPUTE_DTYPE) # ensure activations are in compute dtype (no-op usually, but active for fp16 code path)
if self.bigram is not None:
prev_tokens = None if kv_cache is None else kv_cache.prev_token_ids
x = x + self.bigram(idx, prev_tokens=prev_tokens).to(x.dtype)
if kv_cache is not None:
kv_cache.prev_token_ids = idx[:, -1].clone()
x = norm(x)
x_base = x
# Smear: mix previous token's embedding into current position (cheap bigram info)
if kv_cache is None:
# Training / naive generate: full sequence available, use fast slice
assert T > 1, "Training forward pass should have T > 1"
gate = self.smear_lambda.to(x.dtype) * torch.sigmoid(self.smear_gate(x[:, 1:, :24]))
x = torch.cat([x[:, :1], x[:, 1:] + gate * x[:, :-1]], dim=1)
else:
# KV cache inference: read prev embedding from cache, store current for next step
x_pre_smear = kv_cache.prev_embedding
kv_cache.prev_embedding = x[:, -1:, :]
if T > 1:
# Prefill: apply smear to positions 1+, same as training
gate = self.smear_lambda.to(x.dtype) * torch.sigmoid(self.smear_gate(x[:, 1:, :24]))
x = torch.cat([x[:, :1], x[:, 1:] + gate * x[:, :-1]], dim=1)
elif x_pre_smear is not None:
# Decode: single token, use cached prev embedding
gate = self.smear_lambda.to(x.dtype) * torch.sigmoid(self.smear_gate(x[:, :, :24]))
x = x + gate * x_pre_smear
# Smear: inject a gated copy of the previous token embedding as cheap bigram context.
if self.smear_gate is not None:
gate_channels = self.config.smear_channels
if kv_cache is None:
assert T > 1, "Training forward pass should have T > 1"
x = x_base.clone()
gate = self.smear_lambda.to(x.dtype) * torch.sigmoid(self.smear_gate(x_base[:, 1:, :gate_channels]))
x[:, 1:] = x[:, 1:] + gate * x_base[:, :-1]
else:
prev1 = kv_cache.prev_embedding
kv_cache.prev_embedding = x_base[:, -1:, :]
x = x_base.clone()
if prev1 is not None:
gate_first = self.smear_lambda.to(x.dtype) * torch.sigmoid(self.smear_gate(x_base[:, :1, :gate_channels]))
x[:, :1] = x[:, :1] + gate_first * prev1
if T > 1:
gate = self.smear_lambda.to(x.dtype) * torch.sigmoid(self.smear_gate(x_base[:, 1:, :gate_channels]))
x[:, 1:] = x[:, 1:] + gate * x_base[:, :-1]
# Forward the trunk of the Transformer
x0 = x # save initial normalized embedding for x0 residual

View File

@ -69,8 +69,8 @@ python -m scripts.tok_eval
echo "Waiting for dataset download to complete..."
wait $DATASET_DOWNLOAD_PID
# d24 model (slightly undertrained to beat GPT-2 => decrease data:params ratio from compute optimal 10.5 (default) to 8)
torchrun --standalone --nproc_per_node=8 -m scripts.base_train -- --depth=24 --target-param-data-ratio=8 --device-batch-size=16 --fp8 --run=$WANDB_RUN
# d24 sparse recipe (baked into the branch defaults), using the recorded 5,318-step horizon
torchrun --standalone --nproc_per_node=8 -m scripts.base_train -- --depth=24 --target-param-data-ratio=8 --num-iterations=5318 --device-batch-size=16 --fp8 --run=$WANDB_RUN
# evaluate the model: CORE metric, BPB on train/val, and draw samples
torchrun --standalone --nproc_per_node=8 -m scripts.base_eval -- --device-batch-size=16

View File

@ -25,7 +25,7 @@ import wandb
import torch
import torch.distributed as dist
from nanochat.gpt import GPT, GPTConfig, Linear
from nanochat.gpt import GPT, GPTConfig, Linear, default_sparse_n_global, compute_sparse_global_layers
from nanochat.dataloader import tokenizing_distributed_data_loader_bos_bestfit, tokenizing_distributed_data_loader_with_state_bos_bestfit
from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, print_banner, get_base_dir, autodetect_device_type, get_peak_flops, COMPUTE_DTYPE, COMPUTE_DTYPE_REASON, is_ddp_initialized
from nanochat.tokenizer import get_tokenizer, get_token_bytes
@ -52,6 +52,20 @@ parser.add_argument("--aspect-ratio", type=int, default=64, help="model_dim = de
parser.add_argument("--head-dim", type=int, default=128, help="target head dimension for attention")
parser.add_argument("--max-seq-len", type=int, default=2048, help="max context length")
parser.add_argument("--window-pattern", type=str, default="SSSL", help="sliding window pattern tiled across layers: L=full, S=half context (e.g. 'SSL')")
# Sparse funnel architecture
parser.add_argument("--sparse-funnel", action=argparse.BooleanOptionalAction, default=True, help="use sparse funnel architecture instead of window pattern")
parser.add_argument("--n-global", type=int, default=0, help="number of global (full-context) layers, 0 = auto depth//4")
parser.add_argument("--chirp-gamma", type=float, default=0.7, help="chirp exponent for global layer placement")
parser.add_argument("--local-window", type=int, default=128, help="local attention window size")
parser.add_argument("--global-layers", type=str, default="", help="explicit global layer indices, comma-separated (overrides chirp)")
parser.add_argument("--rope-base", type=int, default=200000, help="RoPE base frequency")
parser.add_argument("--smear", action=argparse.BooleanOptionalAction, default=True, help="restore cheap bigram-like embedding mixing before the transformer trunk")
parser.add_argument("--smear-channels", type=int, default=24, help="number of embedding channels used to predict smear gate")
parser.add_argument("--bigram-vocab-size", type=int, default=4096, help="hashed bigram embedding buckets (0 disables)")
parser.add_argument("--bigram-dim", type=int, default=128, help="bigram embedding dim before projection")
parser.add_argument("--gated-attn", action=argparse.BooleanOptionalAction, default=True, help="learn lightweight per-head attention output gates")
parser.add_argument("--attn-gate-channels", type=int, default=12, help="embedding channels used to predict attention gates")
parser.add_argument("--ve-layers", type=str, default="", help="explicit value-residual layer indices, comma-separated")
# Training horizon (only one used, in order of precedence)
parser.add_argument("--num-iterations", type=int, default=-1, help="explicit number of optimization steps (-1 = disable)")
parser.add_argument("--target-flops", type=float, default=-1.0, help="calculate num_iterations to reach target_flops (-1 = disable)")
@ -78,6 +92,14 @@ parser.add_argument("--save-every", type=int, default=-1, help="save checkpoints
# Output
parser.add_argument("--model-tag", type=str, default=None, help="override model tag for checkpoint directory name")
args = parser.parse_args()
if args.sparse_funnel and args.n_global <= 0:
args.n_global = default_sparse_n_global(args.depth)
if args.sparse_funnel and not args.ve_layers:
args.ve_layers = ",".join(
str(x) for x in compute_sparse_global_layers(args.depth, args.n_global, args.chirp_gamma)
)
user_config = vars(args).copy() # for logging
# -----------------------------------------------------------------------------
# Compute init and wandb logging
@ -133,10 +155,21 @@ def build_model_meta(depth):
base_dim = depth * args.aspect_ratio
model_dim = ((base_dim + args.head_dim - 1) // args.head_dim) * args.head_dim
num_heads = model_dim // args.head_dim
# Parse explicit layer lists
global_layer_override = tuple(int(x) for x in args.global_layers.split(",") if x.strip()) if args.global_layers else ()
ve_layers_parsed = tuple(int(x) for x in args.ve_layers.split(",") if x.strip()) if args.ve_layers else ()
config = GPTConfig(
sequence_len=args.max_seq_len, vocab_size=vocab_size,
n_layer=depth, n_head=num_heads, n_kv_head=num_heads, n_embd=model_dim,
window_pattern=args.window_pattern,
use_sparse_funnel=args.sparse_funnel,
n_global=args.n_global, chirp_gamma=args.chirp_gamma,
local_window=args.local_window, global_layer_override=global_layer_override,
rope_base=args.rope_base,
use_smear=args.smear, smear_channels=args.smear_channels,
bigram_vocab_size=args.bigram_vocab_size, bigram_dim=args.bigram_dim,
use_gated_attn=args.gated_attn, attn_gate_channels=args.attn_gate_channels,
ve_layers=ve_layers_parsed,
)
with torch.device("meta"):
model_meta = GPT(config)