diff --git a/nanochat/engine.py b/nanochat/engine.py index 4724c8f..aa2e6a9 100644 --- a/nanochat/engine.py +++ b/nanochat/engine.py @@ -100,10 +100,13 @@ class KVCache: self.v_cache = torch.zeros(num_layers, batch_size, seq_len, num_heads, head_dim, device=device, dtype=dtype) # Current sequence length per batch element (FA3 needs int32) self.cache_seqlens = torch.zeros(batch_size, dtype=torch.int32, device=device) + # Previous token's normalized embedding for smear (set by model forward pass) + self.prev_embedding = None def reset(self): """Reset cache to empty state.""" self.cache_seqlens.zero_() + self.prev_embedding = None def get_pos(self): """Get current position (assumes all batch elements at same position).""" @@ -129,6 +132,9 @@ class KVCache: self.k_cache[:, :, :other_pos, :, :] = other.k_cache[:, :, :other_pos, :, :] self.v_cache[:, :, :other_pos, :, :] = other.v_cache[:, :, :other_pos, :, :] self.cache_seqlens.fill_(other_pos) + # Copy smear state: expand batch=1 prev_embedding to num_samples + if other.prev_embedding is not None: + self.prev_embedding = other.prev_embedding.expand(self.batch_size, -1, -1).clone() # ----------------------------------------------------------------------------- @torch.inference_mode() diff --git a/nanochat/gpt.py b/nanochat/gpt.py index 5e99c73..0b822e4 100644 --- a/nanochat/gpt.py +++ b/nanochat/gpt.py @@ -34,7 +34,7 @@ class GPTConfig: n_kv_head: int = 6 # number of key/value heads (GQA) n_embd: int = 768 # Sliding window attention pattern string, tiled across layers. Final layer always L. - # Characters: L=long (full context), S=short (half context) + # Characters: L=long (full context), S=short (quarter context) # Examples: "L"=all full context, "SL"=alternating, "SSL"=two short then one long window_pattern: str = "SSSL" @@ -98,8 +98,8 @@ class CausalSelfAttention(nn.Module): cos, sin = cos_sin q, k = apply_rotary_emb(q, cos, sin), apply_rotary_emb(k, cos, sin) q, k = norm(q), norm(k) # QK norm - q = q * 1.15 # sharper attention (split scale between Q and K), TODO think through better - k = k * 1.15 + q = q * 1.2 # sharper attention (split scale between Q and K), TODO think through better + k = k * 1.2 # Flash Attention (FA3 on Hopper+, PyTorch SDPA fallback elsewhere) # window_size is (left, right) tuple: (N, 0) for causal, (-1, 0) for full context @@ -179,6 +179,11 @@ class GPT(nn.Module): # Separate parameters so they can have different optimizer treatment self.resid_lambdas = nn.Parameter(torch.ones(config.n_layer)) # fake init, real init in init_weights() self.x0_lambdas = nn.Parameter(torch.zeros(config.n_layer)) # fake init, real init in init_weights() + # Smear: mix previous token's embedding into current token (cheap bigram-like info) + self.smear_gate = Linear(24, 1, bias=False) + self.smear_lambda = nn.Parameter(torch.zeros(1)) + # Backout: subtract cached mid-layer residual before final norm to remove low-level features + self.backout_lambda = nn.Parameter(0.2 * torch.ones(1)) # Value embeddings (ResFormer-style): alternating layers, last layer always included head_dim = config.n_embd // config.n_head kv_dim = config.n_kv_head * head_dim @@ -221,12 +226,17 @@ class GPT(nn.Module): torch.nn.init.uniform_(block.attn.c_k.weight, -s, s) torch.nn.init.uniform_(block.attn.c_v.weight, -s, s) torch.nn.init.zeros_(block.attn.c_proj.weight) # projections are zero - torch.nn.init.uniform_(block.mlp.c_fc.weight, -s * 0.5, s * 0.5) # 0.5x init scale for c_fc + torch.nn.init.uniform_(block.mlp.c_fc.weight, -s * 0.4, s * 0.4) # 0.4x init scale for c_fc torch.nn.init.zeros_(block.mlp.c_proj.weight) # Per-layer scalars - self.resid_lambdas.fill_(1.0) # 1.0 => typical residual connections at init - self.x0_lambdas.fill_(0.1) # 0.1 => small initial weight for skip connection to input embedding + # Per-layer resid init: stronger residual at early layers, weaker at deep layers + n_layer = self.config.n_layer + for i in range(n_layer): + self.resid_lambdas.data[i] = 1.15 - (0.10 * i / max(n_layer - 1, 1)) + # Decaying x0 init: earlier layers get more input embedding blending + for i in range(n_layer): + self.x0_lambdas.data[i] = 0.20 - (0.15 * i / max(n_layer - 1, 1)) # Value embeddings (init like c_v: uniform with same std) for ve in self.value_embeds.values(): @@ -276,13 +286,13 @@ class GPT(nn.Module): - right: how many tokens after current position to attend to (0 for causal) Pattern string is tiled across layers. Final layer always gets L (full context). - Characters: L=long (full context), S=short (half context) + Characters: L=long (full context), S=short (quarter context) """ pattern = config.window_pattern.upper() assert all(c in "SL" for c in pattern), f"Invalid window_pattern: {pattern}. Use only S and L." # Map characters to window sizes long_window = config.sequence_len - short_window = -(-long_window // 3 // 128) * 128 # ceil to FA3 tile size (2048 -> 768) + short_window = -(-long_window // 4 // 128) * 128 # ceil to FA3 tile size (2048 -> 768) char_to_window = { "L": (long_window, 0), "S": (short_window, 0), @@ -315,7 +325,8 @@ class GPT(nn.Module): # Exclude non-matmul params: embeddings and per-layer scalars value_embeds_numel = sum(ve.weight.numel() for ve in self.value_embeds.values()) nparams_exclude = (self.transformer.wte.weight.numel() + value_embeds_numel + - self.resid_lambdas.numel() + self.x0_lambdas.numel()) + self.resid_lambdas.numel() + self.x0_lambdas.numel() + + self.smear_gate.weight.numel() + self.smear_lambda.numel() + self.backout_lambda.numel()) h, q, t = self.config.n_head, self.config.n_embd // self.config.n_head, self.config.sequence_len # Sum attention FLOPs per layer, accounting for sliding window attn_flops = 0 @@ -343,7 +354,7 @@ class GPT(nn.Module): value_embeds = sum(p.numel() for p in self.value_embeds.parameters()) lm_head = sum(p.numel() for p in self.lm_head.parameters()) transformer_matrices = sum(p.numel() for p in self.transformer.h.parameters()) - scalars = self.resid_lambdas.numel() + self.x0_lambdas.numel() + scalars = self.resid_lambdas.numel() + self.x0_lambdas.numel() + self.smear_gate.weight.numel() + self.smear_lambda.numel() + self.backout_lambda.numel() total = wte + value_embeds + lm_head + transformer_matrices + scalars assert total == sum(p.numel() for p in self.parameters()), "Parameter count mismatch" return { @@ -366,7 +377,8 @@ class GPT(nn.Module): lm_head_params = list(self.lm_head.parameters()) resid_params = [self.resid_lambdas] x0_params = [self.x0_lambdas] - assert len(list(self.parameters())) == len(matrix_params) + len(embedding_params) + len(lm_head_params) + len(value_embeds_params) + len(resid_params) + len(x0_params) + smear_params = [self.smear_gate.weight, self.smear_lambda, self.backout_lambda] + assert len(list(self.parameters())) == len(matrix_params) + len(embedding_params) + len(lm_head_params) + len(value_embeds_params) + len(resid_params) + len(x0_params) + len(smear_params) # Scale the LR for the AdamW parameters by ∝1/√dmodel (tuned for 768 dim model) dmodel_lr_scale = (model_dim / 768) ** -0.5 @@ -380,6 +392,7 @@ class GPT(nn.Module): dict(kind='adamw', params=value_embeds_params, lr=embedding_lr * dmodel_lr_scale * 0.5, betas=(0.8, 0.995), eps=1e-10, weight_decay=0.01), dict(kind='adamw', params=resid_params, lr=scalar_lr * 0.01, betas=(0.8, 0.95), eps=1e-10, weight_decay=0.05), dict(kind='adamw', params=x0_params, lr=scalar_lr, betas=(0.96, 0.95), eps=1e-10, weight_decay=0.0), # higher beta1 for x0 + dict(kind='adamw', params=smear_params, lr=0.2, betas=(0.8, 0.95), eps=1e-10, weight_decay=0.0), ] # Muon groups (matrix params, grouped by shape for stacking) for shape in sorted({p.shape for p in matrix_params}): @@ -406,15 +419,44 @@ class GPT(nn.Module): T0 = 0 if kv_cache is None else kv_cache.get_pos() cos_sin = self.cos[:, T0:T0+T], self.sin[:, T0:T0+T] # truncate cache to current sequence length - # Forward the trunk of the Transformer + # Embed the tokens x = self.transformer.wte(idx) # embed current token x = x.to(COMPUTE_DTYPE) # ensure activations are in compute dtype (no-op usually, but active for fp16 code path) x = norm(x) + + # Smear: mix previous token's embedding into current position (cheap bigram info) + if kv_cache is None: + # Training / naive generate: full sequence available, use fast slice + assert T > 1, "Training forward pass should have T > 1" + gate = self.smear_lambda.to(x.dtype) * torch.sigmoid(self.smear_gate(x[:, 1:, :24])) + x = torch.cat([x[:, :1], x[:, 1:] + gate * x[:, :-1]], dim=1) + else: + # KV cache inference: read prev embedding from cache, store current for next step + x_pre_smear = kv_cache.prev_embedding + kv_cache.prev_embedding = x[:, -1:, :] + if T > 1: + # Prefill: apply smear to positions 1+, same as training + gate = self.smear_lambda.to(x.dtype) * torch.sigmoid(self.smear_gate(x[:, 1:, :24])) + x = torch.cat([x[:, :1], x[:, 1:] + gate * x[:, :-1]], dim=1) + elif x_pre_smear is not None: + # Decode: single token, use cached prev embedding + gate = self.smear_lambda.to(x.dtype) * torch.sigmoid(self.smear_gate(x[:, :, :24])) + x = x + gate * x_pre_smear + + # Forward the trunk of the Transformer x0 = x # save initial normalized embedding for x0 residual + n_layer = self.config.n_layer + backout_layer = n_layer // 2 # cache at halfway point + x_backout = None for i, block in enumerate(self.transformer.h): x = self.resid_lambdas[i] * x + self.x0_lambdas[i] * x0 ve = self.value_embeds[str(i)](idx).to(x.dtype) if str(i) in self.value_embeds else None x = block(x, ve, cos_sin, self.window_sizes[i], kv_cache) + if i == backout_layer: + x_backout = x + # Subtract mid-layer residual to remove low-level features before logit projection + if x_backout is not None: + x = x - self.backout_lambda.to(x.dtype) * x_backout x = norm(x) # Forward the lm_head (compute logits) diff --git a/runs/speedrun.sh b/runs/speedrun.sh index fa50694..48fcc68 100644 --- a/runs/speedrun.sh +++ b/runs/speedrun.sh @@ -69,8 +69,8 @@ python -m scripts.tok_eval echo "Waiting for dataset download to complete..." wait $DATASET_DOWNLOAD_PID -# d24 model (slightly undertrained to beat GPT-2 => decrease data:params ratio from compute optimal 10.5 (default) to 9.5) -torchrun --standalone --nproc_per_node=8 -m scripts.base_train -- --depth=24 --target-param-data-ratio=9.5 --device-batch-size=16 --fp8 --run=$WANDB_RUN +# d24 model (slightly undertrained to beat GPT-2 => decrease data:params ratio from compute optimal 10.5 (default) to 8) +torchrun --standalone --nproc_per_node=8 -m scripts.base_train -- --depth=24 --target-param-data-ratio=8 --device-batch-size=16 --fp8 --run=$WANDB_RUN # evaluate the model: CORE metric, BPB on train/val, and draw samples torchrun --standalone --nproc_per_node=8 -m scripts.base_eval -- --device-batch-size=16 diff --git a/scripts/base_train.py b/scripts/base_train.py index cfbfe28..86aa770 100644 --- a/scripts/base_train.py +++ b/scripts/base_train.py @@ -367,11 +367,18 @@ def get_lr_multiplier(it): progress = (num_iterations - it) / warmdown_iters return progress * 1.0 + (1 - progress) * args.final_lr_frac -# Momentum scheduler for Muon optimizer (warms up to 0.97 over the first 400 steps) +# Momentum scheduler for Muon optimizer (warms up to 0.97, warms down to 0.90 during LR warmdown) def get_muon_momentum(it): - frac = min(it / 400, 1) - momentum = (1 - frac) * 0.85 + frac * 0.97 - return momentum + warmdown_iters = round(args.warmdown_ratio * num_iterations) + warmdown_start = num_iterations - warmdown_iters + if it < 400: + frac = it / 400 + return (1 - frac) * 0.85 + frac * 0.97 + elif it >= warmdown_start: + progress = (it - warmdown_start) / warmdown_iters + return 0.97 * (1 - progress) + 0.90 * progress + else: + return 0.97 # Weight decay scheduler for Muon optimizer (cosine decay to zero over the course of training) def get_weight_decay(it):