From fbc1484e8c2582325e8daa1c1a5000f17aed69e7 Mon Sep 17 00:00:00 2001 From: Andrej Karpathy Date: Sun, 11 Jan 2026 21:49:54 +0000 Subject: [PATCH] add alternating window size patterns for the GPT layers, following GPT-3. Experimented a bit and found the pattern SSSL to work well - 3 short, 1 long alternating. This is now the new default and the plots look quite a bit better on flops vs. bpb --- dev/LOG.md | 16 ++++++++ nanochat/checkpoint_manager.py | 7 ++++ nanochat/gpt.py | 70 ++++++++++++++++++++++++++++------ scripts/base_train.py | 3 +- 4 files changed, 83 insertions(+), 13 deletions(-) diff --git a/dev/LOG.md b/dev/LOG.md index f2322de..902c1e0 100644 --- a/dev/LOG.md +++ b/dev/LOG.md @@ -4,6 +4,22 @@ A running summary documenting some experiments and findings. Started ~Jan 7 2026 --- +## 2026-01-11: Sliding Window Attention + +Added configurable sliding window attention, inspired by GPT-3's alternating short/long pattern. + +**Pattern string configuration:** +- New `--window_pattern` CLI arg and `GPTConfig.window_pattern` field +- Pattern is tiled across layers (e.g., `SSSL` for 20 layers → `SSSLSSSLSSSLSSSLSSSL`) +- Final layer always forced to L (full context) regardless of pattern +- Short window = `sequence_len // 2` +- Long window = `sequence_len` (full context) +- All previous models so far have been simply `L` and checkpoint loading is modified accordingly to fill in this param for old models, see `_patch_missing_config_keys` + +Quick experiments showed `SSSL` (every 4th layer is long) works well - provides a good balance between compute savings and model quality. This is now the default. + +--- + ## 2026-01-11: Flash Attention 3 Integration Replaced PyTorch's `scaled_dot_product_attention` (FA2) with Flash Attention 3 for training and inference. diff --git a/nanochat/checkpoint_manager.py b/nanochat/checkpoint_manager.py index 79ba998..cca6294 100644 --- a/nanochat/checkpoint_manager.py +++ b/nanochat/checkpoint_manager.py @@ -20,6 +20,12 @@ def log0(message): if int(os.environ.get('RANK', 0)) == 0: logger.info(message) +def _patch_missing_config_keys(model_config_kwargs): + """Add default values for new config keys missing in old checkpoints.""" + # Old models were trained with full context (no sliding window) + if "window_pattern" not in model_config_kwargs: + model_config_kwargs["window_pattern"] = "L" + def _patch_missing_keys(model_data, model_config): """Add default values for new parameters that may be missing in old checkpoints.""" n_layer = model_config.n_layer @@ -84,6 +90,7 @@ def build_model(checkpoint_dir, step, device, phase): # Hack: fix torch compile issue, which prepends all keys with _orig_mod. model_data = {k.removeprefix("_orig_mod."): v for k, v in model_data.items()} model_config_kwargs = meta_data["model_config"] + _patch_missing_config_keys(model_config_kwargs) log0(f"Building model with config: {model_config_kwargs}") model_config = GPTConfig(**model_config_kwargs) _patch_missing_keys(model_data, model_config) diff --git a/nanochat/gpt.py b/nanochat/gpt.py index f22ec07..81ccb0c 100644 --- a/nanochat/gpt.py +++ b/nanochat/gpt.py @@ -39,6 +39,10 @@ class GPTConfig: n_head: int = 6 # number of query heads n_kv_head: int = 6 # number of key/value heads (GQA) n_embd: int = 768 + # Sliding window attention pattern string, tiled across layers. Final layer always L. + # Characters: L=long (full context), S=short (half context) + # Examples: "L"=all full context, "SL"=alternating, "SSL"=two short then one long + window_pattern: str = "L" def norm(x): @@ -69,7 +73,7 @@ class CausalSelfAttention(nn.Module): self.c_v = nn.Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False) self.c_proj = nn.Linear(self.n_embd, self.n_embd, bias=False) - def forward(self, x, cos_sin, kv_cache): + def forward(self, x, cos_sin, window_size, kv_cache): B, T, C = x.size() # Project the input to get queries, keys, and values @@ -85,9 +89,10 @@ class CausalSelfAttention(nn.Module): # Attention with Flash Attention 3 # FA3 handles GQA automatically when n_kv_heads < n_heads + # window_size is (left, right) tuple: (N, 0) for causal, (-1, 0) for full context if kv_cache is None: - # Training: simple causal attention - y = flash_attn.flash_attn_func(q, k, v, causal=True) + # Training: causal attention with optional sliding window + y = flash_attn.flash_attn_func(q, k, v, causal=True, window_size=window_size) else: # Inference: use flash_attn_with_kvcache which handles cache management k_cache, v_cache = kv_cache.get_layer_cache(self.layer_idx) @@ -96,6 +101,7 @@ class CausalSelfAttention(nn.Module): k=k, v=v, cache_seqlens=kv_cache.cache_seqlens, causal=True, + window_size=window_size, ) # Advance position after last layer processes if self.layer_idx == kv_cache.n_layers - 1: @@ -126,8 +132,8 @@ class Block(nn.Module): self.attn = CausalSelfAttention(config, layer_idx) self.mlp = MLP(config) - def forward(self, x, cos_sin, kv_cache): - x = x + self.attn(norm(x), cos_sin, kv_cache) + def forward(self, x, cos_sin, window_size, kv_cache): + x = x + self.attn(norm(x), cos_sin, window_size, kv_cache) x = x + self.mlp(norm(x)) return x @@ -141,11 +147,14 @@ class GPT(nn.Module): """ super().__init__() self.config = config - # For DDP, we want vocab_size divisible by world_size. Also, there are potential performance benefits, see: + # Compute per-layer window sizes for sliding window attention + # window_size is (left, right) tuple: (-1, 0) for full context, (N, 0) for sliding window + self.window_sizes = self._compute_window_sizes(config) + # Pad vocab for efficiency (DDP, tensor cores). This is just an optimization - outputs are cropped in forward(). # https://huggingface.co/docs/transformers/main_classes/model#transformers.PreTrainedModel.resize_token_embeddings padded_vocab_size = ((config.vocab_size + pad_vocab_size_to - 1) // pad_vocab_size_to) * pad_vocab_size_to if padded_vocab_size != config.vocab_size: - print0(f"Padding vocab_size from {config.vocab_size} to {padded_vocab_size} to be divisible by {pad_vocab_size_to}") + print0(f"Padding vocab_size from {config.vocab_size} to {padded_vocab_size} for efficiency") self.transformer = nn.ModuleDict({ "wte": nn.Embedding(padded_vocab_size, config.n_embd), "h": nn.ModuleList([Block(config, layer_idx) for layer_idx in range(config.n_layer)]), @@ -228,6 +237,35 @@ class GPT(nn.Module): cos, sin = cos[None, :, None, :], sin[None, :, None, :] # add batch and head dims for later broadcasting return cos, sin + def _compute_window_sizes(self, config): + """ + Compute per-layer window sizes for sliding window attention. + + Returns list of (left, right) tuples for FA3's window_size parameter: + - left: how many tokens before current position to attend to (-1 = unlimited) + - right: how many tokens after current position to attend to (0 for causal) + + Pattern string is tiled across layers. Final layer always gets L (full context). + Characters: L=long (full context), S=short (half context) + """ + pattern = config.window_pattern.upper() + assert all(c in "SL" for c in pattern), f"Invalid window_pattern: {pattern}. Use only S and L." + # Map characters to window sizes + long_window = config.sequence_len + short_window = long_window // 2 + char_to_window = { + "L": (long_window, 0), + "S": (short_window, 0), + } + # Tile pattern across layers + window_sizes = [] + for layer_idx in range(config.n_layer): + char = pattern[layer_idx % len(pattern)] + window_sizes.append(char_to_window[char]) + # Final layer always gets full context + window_sizes[-1] = (long_window, 0) + return window_sizes + def get_device(self): return self.transformer.wte.weight.device @@ -236,16 +274,24 @@ class GPT(nn.Module): Return the estimated FLOPs per token for the model (forward + backward). Each matmul weight parameter contributes 2 FLOPs (multiply *, accumulate +) in forward, and 2X that in backward => 2+4=6. Cleanest explanation of this: https://medium.com/@dzmitrybahdanau/the-flops-calculus-of-language-model-training-3b19c1f025e4 - On top of that, the term 12 * l * h * q * t accounts for key @ query matmul flops inside attention. + On top of that, 12 * h * q * effective_seq_len accounts for key @ query matmul flops inside attention. + With sliding windows, effective_seq_len varies per layer (capped by window size). Ref: https://arxiv.org/abs/2204.02311 (PaLM paper). This is ~1% off from the exact formulas of Chinchilla paper, the difference is: - Chinchilla counts the embedding layer as flops (? weird, it's just a lookup => we ignore) - Chinchilla counts exp/sum/divide in attention softmax as flops (a little sus and very tiny => we ignore) """ nparams = sum(p.numel() for p in self.parameters()) - nparams_embedding = self.transformer.wte.weight.numel() - l, h, q, t = self.config.n_layer, self.config.n_head, self.config.n_embd // self.config.n_head, self.config.sequence_len - num_flops_per_token = 6 * (nparams - nparams_embedding) + 12 * l * h * q * t + # Exclude non-matmul params: embeddings and per-layer scalars + nparams_exclude = self.transformer.wte.weight.numel() + self.resid_lambdas.numel() + self.x0_lambdas.numel() + h, q, t = self.config.n_head, self.config.n_embd // self.config.n_head, self.config.sequence_len + # Sum attention FLOPs per layer, accounting for sliding window + attn_flops = 0 + for window_size in self.window_sizes: + window = window_size[0] # (left, right) tuple, we use left + effective_seq = t if window < 0 else min(window, t) + attn_flops += 12 * h * q * effective_seq + num_flops_per_token = 6 * (nparams - nparams_exclude) + attn_flops return num_flops_per_token def num_scaling_params(self): @@ -311,7 +357,7 @@ class GPT(nn.Module): x0 = x # save initial normalized embedding for x0 residual for i, block in enumerate(self.transformer.h): x = self.resid_lambdas[i] * x + self.x0_lambdas[i] * x0 - x = block(x, cos_sin, kv_cache) + x = block(x, cos_sin, self.window_sizes[i], kv_cache) x = norm(x) # Forward the lm_head (compute logits) diff --git a/scripts/base_train.py b/scripts/base_train.py index 3327451..9d8ac16 100644 --- a/scripts/base_train.py +++ b/scripts/base_train.py @@ -42,6 +42,7 @@ parser.add_argument("--depth", type=int, default=20, help="depth of the Transfor parser.add_argument("--aspect_ratio", type=int, default=64, help="model_dim = depth * aspect_ratio") parser.add_argument("--head_dim", type=int, default=128, help="target head dimension for attention") parser.add_argument("--max_seq_len", type=int, default=2048, help="max context length") +parser.add_argument("--window_pattern", type=str, default="L", help="sliding window pattern tiled across layers: L=full, S=half context (e.g. 'SSL')") # Training horizon (only one used, in order of precedence) parser.add_argument("--num_iterations", type=int, default=-1, help="explicit number of optimization steps (-1 = disable)") parser.add_argument("--target_flops", type=float, default=-1.0, help="calculate num_iterations to reach target_flops (-1 = disable)") @@ -139,7 +140,7 @@ if args.depth != 12: # Initialize the Model # Create a new model with random weights -model_config_kwargs = dict(sequence_len=args.max_seq_len, vocab_size=vocab_size, n_layer=num_layers, n_head=num_heads, n_kv_head=num_kv_heads, n_embd=model_dim) +model_config_kwargs = dict(sequence_len=args.max_seq_len, vocab_size=vocab_size, n_layer=num_layers, n_head=num_heads, n_kv_head=num_kv_heads, n_embd=model_dim, window_pattern=args.window_pattern) with torch.device("meta"): # All tensors are created as meta tensors (they have shape/dtype but no data) model_config = GPTConfig(**model_config_kwargs)