diff --git a/nanochat/engine.py b/nanochat/engine.py index aa2e6a98..b8538291 100644 --- a/nanochat/engine.py +++ b/nanochat/engine.py @@ -102,11 +102,14 @@ class KVCache: self.cache_seqlens = torch.zeros(batch_size, dtype=torch.int32, device=device) # Previous token's normalized embedding for smear (set by model forward pass) self.prev_embedding = None + # Previous token ids for bigram hash features during decoding + self.prev_token_ids = None def reset(self): """Reset cache to empty state.""" self.cache_seqlens.zero_() self.prev_embedding = None + self.prev_token_ids = None def get_pos(self): """Get current position (assumes all batch elements at same position).""" @@ -135,6 +138,8 @@ class KVCache: # Copy smear state: expand batch=1 prev_embedding to num_samples if other.prev_embedding is not None: self.prev_embedding = other.prev_embedding.expand(self.batch_size, -1, -1).clone() + if other.prev_token_ids is not None: + self.prev_token_ids = other.prev_token_ids.expand(self.batch_size).clone() # ----------------------------------------------------------------------------- @torch.inference_mode() diff --git a/nanochat/gpt.py b/nanochat/gpt.py index 07a1eae8..358821a3 100644 --- a/nanochat/gpt.py +++ b/nanochat/gpt.py @@ -37,6 +37,37 @@ class GPTConfig: # Characters: L=long (full context), S=short (quarter context) # Examples: "L"=all full context, "SL"=alternating, "SSL"=two short then one long window_pattern: str = "SSSL" + # Sparse funnel attention: local layers use a short window, global layers use full context. + use_sparse_funnel: bool = True + n_global: int = 0 # number of full-context (global) layers, 0 = auto depth//4 + chirp_gamma: float = 0.7 # chirp exponent for global placement (deeper-biased) + local_window: int = 128 # local attention window size + global_layer_override: tuple = () # explicit global layer indices (overrides chirp) + rope_base: int = 200000 # RoPE base frequency + use_smear: bool = True # cheap bigram-like embedding mixing + smear_channels: int = 24 # number of early embedding channels used to predict smear gate + bigram_vocab_size: int = 4096 # hashed bigram embedding buckets (0 disables) + bigram_dim: int = 128 # bigram embedding dim before projection + use_gated_attn: bool = True # learn a lightweight per-head gate on attention output + attn_gate_channels: int = 12 # channels used to predict attention gate + ve_layers: tuple = () # explicit layer indices for value residuals (empty = auto global layers under sparse funnel) + + +def default_sparse_n_global(n_layer): + return max(1, n_layer // 4) + + +def compute_sparse_global_layers(n_layer, n_global, chirp_gamma, global_layer_override=()): + if global_layer_override: + global_layers = set(global_layer_override) + else: + G = n_global if n_global > 0 else default_sparse_n_global(n_layer) + global_layers = set() + for i in range(1, G + 1): + idx = max(0, min(n_layer - 1, int(n_layer * (i / G) ** chirp_gamma) - 1)) + global_layers.add(idx) + global_layers.add(n_layer - 1) + return tuple(sorted(global_layers)) def norm(x): @@ -47,11 +78,58 @@ class Linear(nn.Linear): Replaces autocast: master weights stay fp32 for optimizer precision, but matmuls run in the activation dtype (typically bf16 from embeddings).""" def forward(self, x): - return F.linear(x, self.weight.to(dtype=x.dtype)) + w = self.weight + if w.dtype != x.dtype: + w = w.to(dtype=x.dtype) + return F.linear(x, w) -def has_ve(layer_idx, n_layer): - """Returns True if GPT layer should have Value Embedding (alternating, last layer always included).""" +class EmbeddingLinear(nn.Module): + """Lightweight linear layer for lm_head without redundant dtype casting.""" + def __init__(self, in_features, out_features, bias=False, device=None, dtype=None): + super().__init__() + assert not bias + self.in_features = in_features + self.out_features = out_features + self.weight = nn.Parameter(torch.empty(out_features, in_features, device=device, dtype=dtype)) + def forward(self, x): + return F.linear(x, self.weight) + + +class BigramHashEmbedding(nn.Module): + """Hash causal token pairs into a compact learned embedding.""" + def __init__(self, bigram_vocab_size, bigram_dim, model_dim): + super().__init__() + self.bigram_vocab_size = bigram_vocab_size + self.embed = nn.Embedding(bigram_vocab_size, bigram_dim) + self.proj = Linear(bigram_dim, model_dim, bias=False) if bigram_dim != model_dim else None + self.scale = nn.Parameter(torch.tensor(0.05, dtype=torch.float32)) + + def bigram_hash(self, tokens, prev_tokens=None): + t = tokens.to(torch.int32) + mod = self.bigram_vocab_size - 1 + out = torch.full_like(t, mod) + if prev_tokens is None: + if t.size(1) > 1: + out[..., 1:] = torch.bitwise_xor(36313 * t[..., 1:], 27191 * t[..., :-1]) % mod + else: + prev = prev_tokens.to(torch.int32).view(t.size(0), 1) + out[..., 0] = torch.bitwise_xor(36313 * t[..., 0], 27191 * prev[..., 0]) % mod + if t.size(1) > 1: + out[..., 1:] = torch.bitwise_xor(36313 * t[..., 1:], 27191 * t[..., :-1]) % mod + return out.long() + + def forward(self, token_ids, prev_tokens=None): + h = self.embed(self.bigram_hash(token_ids, prev_tokens=prev_tokens)) + if self.proj is not None: + h = self.proj(h) + return h * self.scale.to(dtype=h.dtype) + + +def has_ve(layer_idx, n_layer, layer_set=()): + """Returns True if GPT layer should have Value Embedding.""" + if layer_set: + return layer_idx in layer_set return layer_idx % 2 == (n_layer - 1) % 2 def apply_rotary_emb(x, cos, sin): @@ -77,7 +155,13 @@ class CausalSelfAttention(nn.Module): self.c_v = Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False) self.c_proj = Linear(self.n_embd, self.n_embd, bias=False) self.ve_gate_channels = 12 - self.ve_gate = Linear(self.ve_gate_channels, self.n_kv_head, bias=False) if has_ve(layer_idx, config.n_layer) else None + self.ve_gate = Linear( + self.ve_gate_channels, + self.n_kv_head, + bias=False, + ) if has_ve(layer_idx, config.n_layer, layer_set=config.ve_layers) else None + self.attn_gate_channels = config.attn_gate_channels + self.attn_gate = Linear(self.attn_gate_channels, self.n_head, bias=False) if config.use_gated_attn else None def forward(self, x, ve, cos_sin, window_size, kv_cache): B, T, C = x.size() @@ -91,7 +175,7 @@ class CausalSelfAttention(nn.Module): # Value residual (ResFormer): mix in value embedding with input-dependent gate per head if ve is not None: ve = ve.view(B, T, self.n_kv_head, self.head_dim) - gate = 3 * torch.sigmoid(self.ve_gate(x[..., :self.ve_gate_channels])) # (B, T, n_kv_head), range (0, 3) + gate = 3 * torch.sigmoid(self.ve_gate(x[..., :self.ve_gate_channels])) v = v + gate.unsqueeze(-1) * ve # Apply Rotary Embeddings to queries and keys to get relative positional encoding @@ -120,6 +204,10 @@ class CausalSelfAttention(nn.Module): if self.layer_idx == kv_cache.n_layers - 1: kv_cache.advance(T) + if self.attn_gate is not None: + gate = 2 * torch.sigmoid(self.attn_gate(x[..., :self.attn_gate_channels])) + y = y * gate.unsqueeze(-1).to(dtype=y.dtype) + # Re-assemble the heads and project back to residual stream y = y.contiguous().view(B, T, -1) y = self.c_proj(y) @@ -160,11 +248,16 @@ class GPT(nn.Module): """ super().__init__() self.config = config + if config.use_sparse_funnel and config.n_global <= 0: + config.n_global = default_sparse_n_global(config.n_layer) + if config.use_sparse_funnel and not config.ve_layers: + config.ve_layers = compute_sparse_global_layers( + config.n_layer, config.n_global, config.chirp_gamma, config.global_layer_override + ) # Compute per-layer window sizes for sliding window attention # window_size is (left, right) tuple: (-1, 0) for full context, (N, 0) for sliding window self.window_sizes = self._compute_window_sizes(config) # Pad vocab for efficiency (DDP, tensor cores). This is just an optimization - outputs are cropped in forward(). - # https://huggingface.co/docs/transformers/main_classes/model#transformers.PreTrainedModel.resize_token_embeddings padded_vocab_size = ((config.vocab_size + pad_vocab_size_to - 1) // pad_vocab_size_to) * pad_vocab_size_to if padded_vocab_size != config.vocab_size: print0(f"Padding vocab_size from {config.vocab_size} to {padded_vocab_size} for efficiency") @@ -172,22 +265,31 @@ class GPT(nn.Module): "wte": nn.Embedding(padded_vocab_size, config.n_embd), "h": nn.ModuleList([Block(config, layer_idx) for layer_idx in range(config.n_layer)]), }) - self.lm_head = Linear(config.n_embd, padded_vocab_size, bias=False) + self.bigram = BigramHashEmbedding(config.bigram_vocab_size, config.bigram_dim, config.n_embd) if config.bigram_vocab_size > 0 else None + self.lm_head = EmbeddingLinear(config.n_embd, padded_vocab_size, bias=False) # Per-layer learnable scalars (inspired by modded-nanogpt) # resid_lambdas: scales the residual stream at each layer (init 1.0 = neutral) # x0_lambdas: blends initial embedding back in at each layer (init 0.0 = disabled) # Separate parameters so they can have different optimizer treatment self.resid_lambdas = nn.Parameter(torch.ones(config.n_layer)) # fake init, real init in init_weights() self.x0_lambdas = nn.Parameter(torch.zeros(config.n_layer)) # fake init, real init in init_weights() - # Smear: mix previous token's embedding into current token (cheap bigram-like info) - self.smear_gate = Linear(24, 1, bias=False) - self.smear_lambda = nn.Parameter(torch.zeros(1)) + # Optional smear: mix previous token's embedding into current token before the trunk + if config.use_smear: + self.smear_gate = Linear(config.smear_channels, 1, bias=False) + self.smear_lambda = nn.Parameter(torch.zeros(1)) + else: + self.smear_gate = None + self.smear_lambda = None # Backout: subtract cached mid-layer residual before final norm to remove low-level features self.backout_lambda = nn.Parameter(0.2 * torch.ones(1)) # Value embeddings (ResFormer-style): alternating layers, last layer always included head_dim = config.n_embd // config.n_head kv_dim = config.n_kv_head * head_dim - self.value_embeds = nn.ModuleDict({str(i): nn.Embedding(padded_vocab_size, kv_dim) for i in range(config.n_layer) if has_ve(i, config.n_layer)}) + self.value_embeds = nn.ModuleDict({ + str(i): nn.Embedding(padded_vocab_size, kv_dim) + for i in range(config.n_layer) + if has_ve(i, config.n_layer, layer_set=config.ve_layers) + }) # To support meta device initialization, we init the rotary embeddings here, but it's just "fake" meta tensors only. # As for rotary_seq_len, these rotary embeddings are pretty small/cheap in memory, # so let's just over-compute them by 10X, but assert fail if we ever reach that amount. @@ -217,31 +319,41 @@ class GPT(nn.Module): # Embedding and unembedding torch.nn.init.normal_(self.transformer.wte.weight, mean=0.0, std=0.8) torch.nn.init.normal_(self.lm_head.weight, mean=0.0, std=0.001) + if self.bigram is not None: + torch.nn.init.zeros_(self.bigram.embed.weight) + if self.bigram.proj is not None: + bigram_s = 3**0.5 * self.bigram.embed.embedding_dim ** -0.5 + torch.nn.init.uniform_(self.bigram.proj.weight, -bigram_s, bigram_s) + self.bigram.scale.data.fill_(0.05) + if self.smear_gate is not None: + torch.nn.init.zeros_(self.smear_gate.weight) + self.smear_lambda.data.zero_() + self.backout_lambda.data.fill_(0.2) # Transformer blocks: uniform init with bound = sqrt(3) * std (same standard deviation as normal) n_embd = self.config.n_embd s = 3**0.5 * n_embd**-0.5 # sqrt(3) multiplier makes sure Uniform achieves the same std as Normal for block in self.transformer.h: - torch.nn.init.uniform_(block.attn.c_q.weight, -s, s) # weights use Uniform to avoid outliers + torch.nn.init.uniform_(block.attn.c_q.weight, -s, s) torch.nn.init.uniform_(block.attn.c_k.weight, -s, s) - torch.nn.init.uniform_(block.attn.c_v.weight, -s, s) - torch.nn.init.zeros_(block.attn.c_proj.weight) # projections are zero - torch.nn.init.uniform_(block.mlp.c_fc.weight, -s * 0.4, s * 0.4) # 0.4x init scale for c_fc + torch.nn.init.uniform_(block.attn.c_v.weight, -0.85 * s, 0.85 * s) + torch.nn.init.uniform_(block.attn.c_proj.weight, -0.008, 0.008) + torch.nn.init.uniform_(block.mlp.c_fc.weight, -s * 0.4, s * 0.4) torch.nn.init.zeros_(block.mlp.c_proj.weight) # Per-layer scalars - # Per-layer resid init: stronger residual at early layers, weaker at deep layers + import math n_layer = self.config.n_layer + resid_start, resid_end = 1.18, 1.06 + resid_decay = math.log(resid_start / resid_end) / max(n_layer - 1, 1) + half_depth = max(1, n_layer // 2) for i in range(n_layer): - self.resid_lambdas.data[i] = 1.15 - (0.10 * i / max(n_layer - 1, 1)) - # Decaying x0 init: earlier layers get more input embedding blending - for i in range(n_layer): - self.x0_lambdas.data[i] = 0.20 - (0.15 * i / max(n_layer - 1, 1)) - - # Smear/backout scalars and smear gate must be explicitly initialized - torch.nn.init.zeros_(self.smear_lambda) - torch.nn.init.constant_(self.backout_lambda, 0.2) - torch.nn.init.uniform_(self.smear_gate.weight, 0.0, 0.02) + self.resid_lambdas.data[i] = resid_start * math.exp(-resid_decay * i) + if i < half_depth: + frac = i / max(half_depth - 1, 1) + self.x0_lambdas.data[i] = 0.24 * (1.0 - frac) + 0.08 * frac + else: + self.x0_lambdas.data[i] = 0.0 # Value embeddings (init like c_v: uniform with same std) for ve in self.value_embeds.values(): @@ -251,6 +363,8 @@ class GPT(nn.Module): for block in self.transformer.h: if block.attn.ve_gate is not None: torch.nn.init.uniform_(block.attn.ve_gate.weight, 0.0, 0.02) + if block.attn.attn_gate is not None: + torch.nn.init.uniform_(block.attn.attn_gate.weight, 0.0, 0.02) # Rotary embeddings head_dim = self.config.n_embd // self.config.n_head @@ -262,11 +376,17 @@ class GPT(nn.Module): # because GradScaler cannot unscale fp16 gradients. if COMPUTE_DTYPE != torch.float16: self.transformer.wte.to(dtype=COMPUTE_DTYPE) + self.lm_head.to(dtype=COMPUTE_DTYPE) for ve in self.value_embeds.values(): ve.to(dtype=COMPUTE_DTYPE) + if self.bigram is not None: + self.bigram.embed.to(dtype=COMPUTE_DTYPE) + if self.bigram.proj is not None: + self.bigram.proj.to(dtype=COMPUTE_DTYPE) - def _precompute_rotary_embeddings(self, seq_len, head_dim, base=100000, device=None): - # TODO: bump base theta more? e.g. 100K is more common more recently + def _precompute_rotary_embeddings(self, seq_len, head_dim, base=None, device=None): + if base is None: + base = self.config.rope_base # autodetect the device from model embeddings if device is None: device = self.transformer.wte.weight.device @@ -285,31 +405,42 @@ class GPT(nn.Module): def _compute_window_sizes(self, config): """ Compute per-layer window sizes for sliding window attention. - - Returns list of (left, right) tuples for FA3's window_size parameter: - - left: how many tokens before current position to attend to (-1 = unlimited) - - right: how many tokens after current position to attend to (0 for causal) - - Pattern string is tiled across layers. Final layer always gets L (full context). - Characters: L=long (full context), S=short (quarter context) + Supports two modes: + 1. Pattern-based (original): window_pattern string tiled across layers + 2. Sparse funnel: chirped global placement + fixed local window """ - pattern = config.window_pattern.upper() - assert all(c in "SL" for c in pattern), f"Invalid window_pattern: {pattern}. Use only S and L." - # Map characters to window sizes - long_window = config.sequence_len - short_window = -(-long_window // 4 // 128) * 128 # ceil to FA3 tile size (2048 -> 768) - char_to_window = { - "L": (long_window, 0), - "S": (short_window, 0), - } - # Tile pattern across layers - window_sizes = [] - for layer_idx in range(config.n_layer): - char = pattern[layer_idx % len(pattern)] - window_sizes.append(char_to_window[char]) - # Final layer always gets full context - window_sizes[-1] = (long_window, 0) - return window_sizes + L = config.n_layer + full = config.sequence_len + + if config.use_sparse_funnel: + # Sparse funnel: chirped global placement + global_layers = set(compute_sparse_global_layers( + L, config.n_global, config.chirp_gamma, config.global_layer_override + )) + self._global_layers = sorted(global_layers) + + w = config.local_window + window_sizes = [] + for layer_idx in range(L): + if layer_idx in global_layers: + window_sizes.append((full, 0)) + else: + window_sizes.append((w, 0)) + return window_sizes + else: + # Original pattern-based mode + self._global_layers = set() + pattern = config.window_pattern.upper() + assert all(c in "SL" for c in pattern), f"Invalid window_pattern: {pattern}. Use only S and L." + long_window = full + short_window = -(-long_window // 4 // 128) * 128 + char_to_window = {"L": (long_window, 0), "S": (short_window, 0)} + window_sizes = [] + for layer_idx in range(L): + char = pattern[layer_idx % len(pattern)] + window_sizes.append(char_to_window[char]) + window_sizes[-1] = (long_window, 0) + return window_sizes def get_device(self): return self.transformer.wte.weight.device @@ -329,9 +460,14 @@ class GPT(nn.Module): nparams = sum(p.numel() for p in self.parameters()) # Exclude non-matmul params: embeddings and per-layer scalars value_embeds_numel = sum(ve.weight.numel() for ve in self.value_embeds.values()) + bigram_embed_numel = 0 if self.bigram is None else self.bigram.embed.weight.numel() + self.bigram.scale.numel() + smear_numel = 0 + if self.smear_gate is not None: + smear_numel += self.smear_gate.weight.numel() + self.smear_lambda.numel() nparams_exclude = (self.transformer.wte.weight.numel() + value_embeds_numel + + bigram_embed_numel + self.resid_lambdas.numel() + self.x0_lambdas.numel() + - self.smear_gate.weight.numel() + self.smear_lambda.numel() + self.backout_lambda.numel()) + smear_numel + self.backout_lambda.numel()) h, q, t = self.config.n_head, self.config.n_embd // self.config.n_head, self.config.sequence_len # Sum attention FLOPs per layer, accounting for sliding window attn_flops = 0 @@ -356,14 +492,18 @@ class GPT(nn.Module): """ # Count each group separately (mirrors the grouping in setup_optimizers) wte = sum(p.numel() for p in self.transformer.wte.parameters()) + bigram_hash = sum(p.numel() for p in self.bigram.parameters()) if self.bigram is not None else 0 value_embeds = sum(p.numel() for p in self.value_embeds.parameters()) lm_head = sum(p.numel() for p in self.lm_head.parameters()) transformer_matrices = sum(p.numel() for p in self.transformer.h.parameters()) - scalars = self.resid_lambdas.numel() + self.x0_lambdas.numel() + self.smear_gate.weight.numel() + self.smear_lambda.numel() + self.backout_lambda.numel() - total = wte + value_embeds + lm_head + transformer_matrices + scalars + scalars = self.resid_lambdas.numel() + self.x0_lambdas.numel() + self.backout_lambda.numel() + if self.smear_gate is not None: + scalars += self.smear_gate.weight.numel() + self.smear_lambda.numel() + total = wte + bigram_hash + value_embeds + lm_head + transformer_matrices + scalars assert total == sum(p.numel() for p in self.parameters()), "Parameter count mismatch" return { 'wte': wte, + 'bigram_hash': bigram_hash, 'value_embeds': value_embeds, 'lm_head': lm_head, 'transformer_matrices': transformer_matrices, @@ -376,14 +516,28 @@ class GPT(nn.Module): ddp, rank, local_rank, world_size = get_dist_info() # Separate out all parameters into groups - matrix_params = list(self.transformer.h.parameters()) + gated_attn_params = [block.attn.attn_gate.weight for block in self.transformer.h if block.attn.attn_gate is not None] + matrix_params = [p for p in self.transformer.h.parameters() if all(p is not gp for gp in gated_attn_params)] value_embeds_params = list(self.value_embeds.parameters()) embedding_params = list(self.transformer.wte.parameters()) + bigram_embed_params = [] + bigram_matrix_params = [] + bigram_scalar_params = [] + if self.bigram is not None: + bigram_embed_params.append(self.bigram.embed.weight) + if self.bigram.proj is not None: + bigram_matrix_params.append(self.bigram.proj.weight) + bigram_scalar_params.append(self.bigram.scale) lm_head_params = list(self.lm_head.parameters()) resid_params = [self.resid_lambdas] x0_params = [self.x0_lambdas] - smear_params = [self.smear_gate.weight, self.smear_lambda, self.backout_lambda] - assert len(list(self.parameters())) == len(matrix_params) + len(embedding_params) + len(lm_head_params) + len(value_embeds_params) + len(resid_params) + len(x0_params) + len(smear_params) + smear_params = [] + if self.smear_gate is not None: + smear_params.extend([self.smear_gate.weight, self.smear_lambda]) + backout_params = [self.backout_lambda] + matrix_params.extend(bigram_matrix_params) + all_grouped = matrix_params + gated_attn_params + value_embeds_params + embedding_params + bigram_embed_params + bigram_scalar_params + lm_head_params + resid_params + x0_params + smear_params + backout_params + assert len(list(self.parameters())) == len(all_grouped) # Scale the LR for the AdamW parameters by ∝1/√dmodel (tuned for 768 dim model) dmodel_lr_scale = (model_dim / 768) ** -0.5 @@ -397,8 +551,16 @@ class GPT(nn.Module): dict(kind='adamw', params=value_embeds_params, lr=embedding_lr * dmodel_lr_scale * 0.5, betas=(0.8, 0.995), eps=1e-10, weight_decay=0.01), dict(kind='adamw', params=resid_params, lr=scalar_lr * 0.01, betas=(0.8, 0.95), eps=1e-10, weight_decay=0.05), dict(kind='adamw', params=x0_params, lr=scalar_lr, betas=(0.96, 0.95), eps=1e-10, weight_decay=0.0), # higher beta1 for x0 - dict(kind='adamw', params=smear_params, lr=0.2, betas=(0.8, 0.95), eps=1e-10, weight_decay=0.0), ] + if bigram_embed_params: + param_groups.append(dict(kind='adamw', params=bigram_embed_params, lr=embedding_lr * dmodel_lr_scale, betas=(0.8, 0.995), eps=1e-10, weight_decay=0.001)) + if bigram_scalar_params: + param_groups.append(dict(kind='adamw', params=bigram_scalar_params, lr=0.1, betas=(0.8, 0.95), eps=1e-10, weight_decay=0.0)) + if gated_attn_params: + param_groups.append(dict(kind='adamw', params=gated_attn_params, lr=0.15, betas=(0.8, 0.95), eps=1e-10, weight_decay=0.0)) + if smear_params: + param_groups.append(dict(kind='adamw', params=smear_params, lr=0.2, betas=(0.8, 0.95), eps=1e-10, weight_decay=0.0)) + param_groups.append(dict(kind='adamw', params=backout_params, lr=0.15, betas=(0.8, 0.95), eps=1e-10, weight_decay=0.0)) # Muon groups (matrix params, grouped by shape for stacking) for shape in sorted({p.shape for p in matrix_params}): group_params = [p for p in matrix_params if p.shape == shape] @@ -427,26 +589,33 @@ class GPT(nn.Module): # Embed the tokens x = self.transformer.wte(idx) # embed current token x = x.to(COMPUTE_DTYPE) # ensure activations are in compute dtype (no-op usually, but active for fp16 code path) + if self.bigram is not None: + prev_tokens = None if kv_cache is None else kv_cache.prev_token_ids + x = x + self.bigram(idx, prev_tokens=prev_tokens).to(x.dtype) + if kv_cache is not None: + kv_cache.prev_token_ids = idx[:, -1].clone() x = norm(x) + x_base = x - # Smear: mix previous token's embedding into current position (cheap bigram info) - if kv_cache is None: - # Training / naive generate: full sequence available, use fast slice - assert T > 1, "Training forward pass should have T > 1" - gate = self.smear_lambda.to(x.dtype) * torch.sigmoid(self.smear_gate(x[:, 1:, :24])) - x = torch.cat([x[:, :1], x[:, 1:] + gate * x[:, :-1]], dim=1) - else: - # KV cache inference: read prev embedding from cache, store current for next step - x_pre_smear = kv_cache.prev_embedding - kv_cache.prev_embedding = x[:, -1:, :] - if T > 1: - # Prefill: apply smear to positions 1+, same as training - gate = self.smear_lambda.to(x.dtype) * torch.sigmoid(self.smear_gate(x[:, 1:, :24])) - x = torch.cat([x[:, :1], x[:, 1:] + gate * x[:, :-1]], dim=1) - elif x_pre_smear is not None: - # Decode: single token, use cached prev embedding - gate = self.smear_lambda.to(x.dtype) * torch.sigmoid(self.smear_gate(x[:, :, :24])) - x = x + gate * x_pre_smear + # Smear: inject a gated copy of the previous token embedding as cheap bigram context. + if self.smear_gate is not None: + gate_channels = self.config.smear_channels + if kv_cache is None: + assert T > 1, "Training forward pass should have T > 1" + x = x_base.clone() + gate = self.smear_lambda.to(x.dtype) * torch.sigmoid(self.smear_gate(x_base[:, 1:, :gate_channels])) + x[:, 1:] = x[:, 1:] + gate * x_base[:, :-1] + else: + prev1 = kv_cache.prev_embedding + kv_cache.prev_embedding = x_base[:, -1:, :] + x = x_base.clone() + + if prev1 is not None: + gate_first = self.smear_lambda.to(x.dtype) * torch.sigmoid(self.smear_gate(x_base[:, :1, :gate_channels])) + x[:, :1] = x[:, :1] + gate_first * prev1 + if T > 1: + gate = self.smear_lambda.to(x.dtype) * torch.sigmoid(self.smear_gate(x_base[:, 1:, :gate_channels])) + x[:, 1:] = x[:, 1:] + gate * x_base[:, :-1] # Forward the trunk of the Transformer x0 = x # save initial normalized embedding for x0 residual diff --git a/runs/speedrun.sh b/runs/speedrun.sh index 48fcc68a..306e1715 100644 --- a/runs/speedrun.sh +++ b/runs/speedrun.sh @@ -69,8 +69,8 @@ python -m scripts.tok_eval echo "Waiting for dataset download to complete..." wait $DATASET_DOWNLOAD_PID -# d24 model (slightly undertrained to beat GPT-2 => decrease data:params ratio from compute optimal 10.5 (default) to 8) -torchrun --standalone --nproc_per_node=8 -m scripts.base_train -- --depth=24 --target-param-data-ratio=8 --device-batch-size=16 --fp8 --run=$WANDB_RUN +# d24 sparse recipe (baked into the branch defaults), using the recorded 5,318-step horizon +torchrun --standalone --nproc_per_node=8 -m scripts.base_train -- --depth=24 --target-param-data-ratio=8 --num-iterations=5318 --device-batch-size=16 --fp8 --run=$WANDB_RUN # evaluate the model: CORE metric, BPB on train/val, and draw samples torchrun --standalone --nproc_per_node=8 -m scripts.base_eval -- --device-batch-size=16 diff --git a/scripts/base_train.py b/scripts/base_train.py index a161c477..9926dc9a 100644 --- a/scripts/base_train.py +++ b/scripts/base_train.py @@ -25,7 +25,7 @@ import wandb import torch import torch.distributed as dist -from nanochat.gpt import GPT, GPTConfig, Linear +from nanochat.gpt import GPT, GPTConfig, Linear, default_sparse_n_global, compute_sparse_global_layers from nanochat.dataloader import tokenizing_distributed_data_loader_bos_bestfit, tokenizing_distributed_data_loader_with_state_bos_bestfit from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, print_banner, get_base_dir, autodetect_device_type, get_peak_flops, COMPUTE_DTYPE, COMPUTE_DTYPE_REASON, is_ddp_initialized from nanochat.tokenizer import get_tokenizer, get_token_bytes @@ -52,6 +52,20 @@ parser.add_argument("--aspect-ratio", type=int, default=64, help="model_dim = de parser.add_argument("--head-dim", type=int, default=128, help="target head dimension for attention") parser.add_argument("--max-seq-len", type=int, default=2048, help="max context length") parser.add_argument("--window-pattern", type=str, default="SSSL", help="sliding window pattern tiled across layers: L=full, S=half context (e.g. 'SSL')") +# Sparse funnel architecture +parser.add_argument("--sparse-funnel", action=argparse.BooleanOptionalAction, default=True, help="use sparse funnel architecture instead of window pattern") +parser.add_argument("--n-global", type=int, default=0, help="number of global (full-context) layers, 0 = auto depth//4") +parser.add_argument("--chirp-gamma", type=float, default=0.7, help="chirp exponent for global layer placement") +parser.add_argument("--local-window", type=int, default=128, help="local attention window size") +parser.add_argument("--global-layers", type=str, default="", help="explicit global layer indices, comma-separated (overrides chirp)") +parser.add_argument("--rope-base", type=int, default=200000, help="RoPE base frequency") +parser.add_argument("--smear", action=argparse.BooleanOptionalAction, default=True, help="restore cheap bigram-like embedding mixing before the transformer trunk") +parser.add_argument("--smear-channels", type=int, default=24, help="number of embedding channels used to predict smear gate") +parser.add_argument("--bigram-vocab-size", type=int, default=4096, help="hashed bigram embedding buckets (0 disables)") +parser.add_argument("--bigram-dim", type=int, default=128, help="bigram embedding dim before projection") +parser.add_argument("--gated-attn", action=argparse.BooleanOptionalAction, default=True, help="learn lightweight per-head attention output gates") +parser.add_argument("--attn-gate-channels", type=int, default=12, help="embedding channels used to predict attention gates") +parser.add_argument("--ve-layers", type=str, default="", help="explicit value-residual layer indices, comma-separated") # Training horizon (only one used, in order of precedence) parser.add_argument("--num-iterations", type=int, default=-1, help="explicit number of optimization steps (-1 = disable)") parser.add_argument("--target-flops", type=float, default=-1.0, help="calculate num_iterations to reach target_flops (-1 = disable)") @@ -78,6 +92,14 @@ parser.add_argument("--save-every", type=int, default=-1, help="save checkpoints # Output parser.add_argument("--model-tag", type=str, default=None, help="override model tag for checkpoint directory name") args = parser.parse_args() + +if args.sparse_funnel and args.n_global <= 0: + args.n_global = default_sparse_n_global(args.depth) +if args.sparse_funnel and not args.ve_layers: + args.ve_layers = ",".join( + str(x) for x in compute_sparse_global_layers(args.depth, args.n_global, args.chirp_gamma) + ) + user_config = vars(args).copy() # for logging # ----------------------------------------------------------------------------- # Compute init and wandb logging @@ -133,10 +155,21 @@ def build_model_meta(depth): base_dim = depth * args.aspect_ratio model_dim = ((base_dim + args.head_dim - 1) // args.head_dim) * args.head_dim num_heads = model_dim // args.head_dim + # Parse explicit layer lists + global_layer_override = tuple(int(x) for x in args.global_layers.split(",") if x.strip()) if args.global_layers else () + ve_layers_parsed = tuple(int(x) for x in args.ve_layers.split(",") if x.strip()) if args.ve_layers else () config = GPTConfig( sequence_len=args.max_seq_len, vocab_size=vocab_size, n_layer=depth, n_head=num_heads, n_kv_head=num_heads, n_embd=model_dim, window_pattern=args.window_pattern, + use_sparse_funnel=args.sparse_funnel, + n_global=args.n_global, chirp_gamma=args.chirp_gamma, + local_window=args.local_window, global_layer_override=global_layer_override, + rope_base=args.rope_base, + use_smear=args.smear, smear_channels=args.smear_channels, + bigram_vocab_size=args.bigram_vocab_size, bigram_dim=args.bigram_dim, + use_gated_attn=args.gated_attn, attn_gate_channels=args.attn_gate_channels, + ve_layers=ve_layers_parsed, ) with torch.device("meta"): model_meta = GPT(config)