mirror of
https://github.com/karpathy/nanochat.git
synced 2026-01-21 02:44:13 +00:00
add alternating window size patterns for the GPT layers, following GPT-3. Experimented a bit and found the pattern SSSL to work well - 3 short, 1 long alternating. This is now the new default and the plots look quite a bit better on flops vs. bpb
This commit is contained in:
parent
2ff7d51252
commit
fbc1484e8c
16
dev/LOG.md
16
dev/LOG.md
|
|
@ -4,6 +4,22 @@ A running summary documenting some experiments and findings. Started ~Jan 7 2026
|
|||
|
||||
---
|
||||
|
||||
## 2026-01-11: Sliding Window Attention
|
||||
|
||||
Added configurable sliding window attention, inspired by GPT-3's alternating short/long pattern.
|
||||
|
||||
**Pattern string configuration:**
|
||||
- New `--window_pattern` CLI arg and `GPTConfig.window_pattern` field
|
||||
- Pattern is tiled across layers (e.g., `SSSL` for 20 layers → `SSSLSSSLSSSLSSSLSSSL`)
|
||||
- Final layer always forced to L (full context) regardless of pattern
|
||||
- Short window = `sequence_len // 2`
|
||||
- Long window = `sequence_len` (full context)
|
||||
- All previous models so far have been simply `L` and checkpoint loading is modified accordingly to fill in this param for old models, see `_patch_missing_config_keys`
|
||||
|
||||
Quick experiments showed `SSSL` (every 4th layer is long) works well - provides a good balance between compute savings and model quality. This is now the default.
|
||||
|
||||
---
|
||||
|
||||
## 2026-01-11: Flash Attention 3 Integration
|
||||
|
||||
Replaced PyTorch's `scaled_dot_product_attention` (FA2) with Flash Attention 3 for training and inference.
|
||||
|
|
|
|||
|
|
@ -20,6 +20,12 @@ def log0(message):
|
|||
if int(os.environ.get('RANK', 0)) == 0:
|
||||
logger.info(message)
|
||||
|
||||
def _patch_missing_config_keys(model_config_kwargs):
|
||||
"""Add default values for new config keys missing in old checkpoints."""
|
||||
# Old models were trained with full context (no sliding window)
|
||||
if "window_pattern" not in model_config_kwargs:
|
||||
model_config_kwargs["window_pattern"] = "L"
|
||||
|
||||
def _patch_missing_keys(model_data, model_config):
|
||||
"""Add default values for new parameters that may be missing in old checkpoints."""
|
||||
n_layer = model_config.n_layer
|
||||
|
|
@ -84,6 +90,7 @@ def build_model(checkpoint_dir, step, device, phase):
|
|||
# Hack: fix torch compile issue, which prepends all keys with _orig_mod.
|
||||
model_data = {k.removeprefix("_orig_mod."): v for k, v in model_data.items()}
|
||||
model_config_kwargs = meta_data["model_config"]
|
||||
_patch_missing_config_keys(model_config_kwargs)
|
||||
log0(f"Building model with config: {model_config_kwargs}")
|
||||
model_config = GPTConfig(**model_config_kwargs)
|
||||
_patch_missing_keys(model_data, model_config)
|
||||
|
|
|
|||
|
|
@ -39,6 +39,10 @@ class GPTConfig:
|
|||
n_head: int = 6 # number of query heads
|
||||
n_kv_head: int = 6 # number of key/value heads (GQA)
|
||||
n_embd: int = 768
|
||||
# Sliding window attention pattern string, tiled across layers. Final layer always L.
|
||||
# Characters: L=long (full context), S=short (half context)
|
||||
# Examples: "L"=all full context, "SL"=alternating, "SSL"=two short then one long
|
||||
window_pattern: str = "L"
|
||||
|
||||
|
||||
def norm(x):
|
||||
|
|
@ -69,7 +73,7 @@ class CausalSelfAttention(nn.Module):
|
|||
self.c_v = nn.Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False)
|
||||
self.c_proj = nn.Linear(self.n_embd, self.n_embd, bias=False)
|
||||
|
||||
def forward(self, x, cos_sin, kv_cache):
|
||||
def forward(self, x, cos_sin, window_size, kv_cache):
|
||||
B, T, C = x.size()
|
||||
|
||||
# Project the input to get queries, keys, and values
|
||||
|
|
@ -85,9 +89,10 @@ class CausalSelfAttention(nn.Module):
|
|||
|
||||
# Attention with Flash Attention 3
|
||||
# FA3 handles GQA automatically when n_kv_heads < n_heads
|
||||
# window_size is (left, right) tuple: (N, 0) for causal, (-1, 0) for full context
|
||||
if kv_cache is None:
|
||||
# Training: simple causal attention
|
||||
y = flash_attn.flash_attn_func(q, k, v, causal=True)
|
||||
# Training: causal attention with optional sliding window
|
||||
y = flash_attn.flash_attn_func(q, k, v, causal=True, window_size=window_size)
|
||||
else:
|
||||
# Inference: use flash_attn_with_kvcache which handles cache management
|
||||
k_cache, v_cache = kv_cache.get_layer_cache(self.layer_idx)
|
||||
|
|
@ -96,6 +101,7 @@ class CausalSelfAttention(nn.Module):
|
|||
k=k, v=v,
|
||||
cache_seqlens=kv_cache.cache_seqlens,
|
||||
causal=True,
|
||||
window_size=window_size,
|
||||
)
|
||||
# Advance position after last layer processes
|
||||
if self.layer_idx == kv_cache.n_layers - 1:
|
||||
|
|
@ -126,8 +132,8 @@ class Block(nn.Module):
|
|||
self.attn = CausalSelfAttention(config, layer_idx)
|
||||
self.mlp = MLP(config)
|
||||
|
||||
def forward(self, x, cos_sin, kv_cache):
|
||||
x = x + self.attn(norm(x), cos_sin, kv_cache)
|
||||
def forward(self, x, cos_sin, window_size, kv_cache):
|
||||
x = x + self.attn(norm(x), cos_sin, window_size, kv_cache)
|
||||
x = x + self.mlp(norm(x))
|
||||
return x
|
||||
|
||||
|
|
@ -141,11 +147,14 @@ class GPT(nn.Module):
|
|||
"""
|
||||
super().__init__()
|
||||
self.config = config
|
||||
# For DDP, we want vocab_size divisible by world_size. Also, there are potential performance benefits, see:
|
||||
# Compute per-layer window sizes for sliding window attention
|
||||
# window_size is (left, right) tuple: (-1, 0) for full context, (N, 0) for sliding window
|
||||
self.window_sizes = self._compute_window_sizes(config)
|
||||
# Pad vocab for efficiency (DDP, tensor cores). This is just an optimization - outputs are cropped in forward().
|
||||
# https://huggingface.co/docs/transformers/main_classes/model#transformers.PreTrainedModel.resize_token_embeddings
|
||||
padded_vocab_size = ((config.vocab_size + pad_vocab_size_to - 1) // pad_vocab_size_to) * pad_vocab_size_to
|
||||
if padded_vocab_size != config.vocab_size:
|
||||
print0(f"Padding vocab_size from {config.vocab_size} to {padded_vocab_size} to be divisible by {pad_vocab_size_to}")
|
||||
print0(f"Padding vocab_size from {config.vocab_size} to {padded_vocab_size} for efficiency")
|
||||
self.transformer = nn.ModuleDict({
|
||||
"wte": nn.Embedding(padded_vocab_size, config.n_embd),
|
||||
"h": nn.ModuleList([Block(config, layer_idx) for layer_idx in range(config.n_layer)]),
|
||||
|
|
@ -228,6 +237,35 @@ class GPT(nn.Module):
|
|||
cos, sin = cos[None, :, None, :], sin[None, :, None, :] # add batch and head dims for later broadcasting
|
||||
return cos, sin
|
||||
|
||||
def _compute_window_sizes(self, config):
|
||||
"""
|
||||
Compute per-layer window sizes for sliding window attention.
|
||||
|
||||
Returns list of (left, right) tuples for FA3's window_size parameter:
|
||||
- left: how many tokens before current position to attend to (-1 = unlimited)
|
||||
- right: how many tokens after current position to attend to (0 for causal)
|
||||
|
||||
Pattern string is tiled across layers. Final layer always gets L (full context).
|
||||
Characters: L=long (full context), S=short (half context)
|
||||
"""
|
||||
pattern = config.window_pattern.upper()
|
||||
assert all(c in "SL" for c in pattern), f"Invalid window_pattern: {pattern}. Use only S and L."
|
||||
# Map characters to window sizes
|
||||
long_window = config.sequence_len
|
||||
short_window = long_window // 2
|
||||
char_to_window = {
|
||||
"L": (long_window, 0),
|
||||
"S": (short_window, 0),
|
||||
}
|
||||
# Tile pattern across layers
|
||||
window_sizes = []
|
||||
for layer_idx in range(config.n_layer):
|
||||
char = pattern[layer_idx % len(pattern)]
|
||||
window_sizes.append(char_to_window[char])
|
||||
# Final layer always gets full context
|
||||
window_sizes[-1] = (long_window, 0)
|
||||
return window_sizes
|
||||
|
||||
def get_device(self):
|
||||
return self.transformer.wte.weight.device
|
||||
|
||||
|
|
@ -236,16 +274,24 @@ class GPT(nn.Module):
|
|||
Return the estimated FLOPs per token for the model (forward + backward).
|
||||
Each matmul weight parameter contributes 2 FLOPs (multiply *, accumulate +) in forward, and 2X that in backward => 2+4=6.
|
||||
Cleanest explanation of this: https://medium.com/@dzmitrybahdanau/the-flops-calculus-of-language-model-training-3b19c1f025e4
|
||||
On top of that, the term 12 * l * h * q * t accounts for key @ query matmul flops inside attention.
|
||||
On top of that, 12 * h * q * effective_seq_len accounts for key @ query matmul flops inside attention.
|
||||
With sliding windows, effective_seq_len varies per layer (capped by window size).
|
||||
Ref: https://arxiv.org/abs/2204.02311 (PaLM paper).
|
||||
This is ~1% off from the exact formulas of Chinchilla paper, the difference is:
|
||||
- Chinchilla counts the embedding layer as flops (? weird, it's just a lookup => we ignore)
|
||||
- Chinchilla counts exp/sum/divide in attention softmax as flops (a little sus and very tiny => we ignore)
|
||||
"""
|
||||
nparams = sum(p.numel() for p in self.parameters())
|
||||
nparams_embedding = self.transformer.wte.weight.numel()
|
||||
l, h, q, t = self.config.n_layer, self.config.n_head, self.config.n_embd // self.config.n_head, self.config.sequence_len
|
||||
num_flops_per_token = 6 * (nparams - nparams_embedding) + 12 * l * h * q * t
|
||||
# Exclude non-matmul params: embeddings and per-layer scalars
|
||||
nparams_exclude = self.transformer.wte.weight.numel() + self.resid_lambdas.numel() + self.x0_lambdas.numel()
|
||||
h, q, t = self.config.n_head, self.config.n_embd // self.config.n_head, self.config.sequence_len
|
||||
# Sum attention FLOPs per layer, accounting for sliding window
|
||||
attn_flops = 0
|
||||
for window_size in self.window_sizes:
|
||||
window = window_size[0] # (left, right) tuple, we use left
|
||||
effective_seq = t if window < 0 else min(window, t)
|
||||
attn_flops += 12 * h * q * effective_seq
|
||||
num_flops_per_token = 6 * (nparams - nparams_exclude) + attn_flops
|
||||
return num_flops_per_token
|
||||
|
||||
def num_scaling_params(self):
|
||||
|
|
@ -311,7 +357,7 @@ class GPT(nn.Module):
|
|||
x0 = x # save initial normalized embedding for x0 residual
|
||||
for i, block in enumerate(self.transformer.h):
|
||||
x = self.resid_lambdas[i] * x + self.x0_lambdas[i] * x0
|
||||
x = block(x, cos_sin, kv_cache)
|
||||
x = block(x, cos_sin, self.window_sizes[i], kv_cache)
|
||||
x = norm(x)
|
||||
|
||||
# Forward the lm_head (compute logits)
|
||||
|
|
|
|||
|
|
@ -42,6 +42,7 @@ parser.add_argument("--depth", type=int, default=20, help="depth of the Transfor
|
|||
parser.add_argument("--aspect_ratio", type=int, default=64, help="model_dim = depth * aspect_ratio")
|
||||
parser.add_argument("--head_dim", type=int, default=128, help="target head dimension for attention")
|
||||
parser.add_argument("--max_seq_len", type=int, default=2048, help="max context length")
|
||||
parser.add_argument("--window_pattern", type=str, default="L", help="sliding window pattern tiled across layers: L=full, S=half context (e.g. 'SSL')")
|
||||
# Training horizon (only one used, in order of precedence)
|
||||
parser.add_argument("--num_iterations", type=int, default=-1, help="explicit number of optimization steps (-1 = disable)")
|
||||
parser.add_argument("--target_flops", type=float, default=-1.0, help="calculate num_iterations to reach target_flops (-1 = disable)")
|
||||
|
|
@ -139,7 +140,7 @@ if args.depth != 12:
|
|||
# Initialize the Model
|
||||
|
||||
# Create a new model with random weights
|
||||
model_config_kwargs = dict(sequence_len=args.max_seq_len, vocab_size=vocab_size, n_layer=num_layers, n_head=num_heads, n_kv_head=num_kv_heads, n_embd=model_dim)
|
||||
model_config_kwargs = dict(sequence_len=args.max_seq_len, vocab_size=vocab_size, n_layer=num_layers, n_head=num_heads, n_kv_head=num_kv_heads, n_embd=model_dim, window_pattern=args.window_pattern)
|
||||
with torch.device("meta"):
|
||||
# All tensors are created as meta tensors (they have shape/dtype but no data)
|
||||
model_config = GPTConfig(**model_config_kwargs)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user