From 9f151898539d0cd33a3238cee21c0479a2c410c1 Mon Sep 17 00:00:00 2001 From: AaronXPeng <126114394+AaronXPeng@users.noreply.github.com> Date: Tue, 17 Mar 2026 00:57:16 -0400 Subject: [PATCH] Add optional Block Attention Residuals (AttnRes) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Implements Block AttnRes from MoonshotAI (https://github.com/MoonshotAI/Attention-Residuals) as an optional alternative to resid_lambdas/x0_lambdas for residual connections. When enabled via --attn-res, replaces standard residual scaling with learned depth-attention over block-level representations. Layers are partitioned into blocks; at each sublayer, a softmax-weighted combination of all completed blocks plus the current partial block determines the input to attention/MLP. Design follows nanochat conventions: - Two nn.Parameter(n_layer, D) pseudo-query vectors on GPT (like resid_lambdas) - Uses existing parameterless norm() for key normalization (no learnable RMSNorm) - Block class unchanged — all AttnRes logic lives in GPT.forward - Minimal 6-line block_attn_res() core function Changes: - nanochat/gpt.py: block_attn_res(), AttnRes path in GPT.forward, config/init/optimizer - nanochat/checkpoint_manager.py: backward-compat config patching - scripts/base_train.py: --attn-res and --attn-res-block-size CLI args - tests/test_attn_res.py: 18 tests covering unit/forward/backward/optimizer/inference GPU results (depth=4, 20 steps, RTX 6000 Ada): Standard: val_bpb 3.21 → 2.80, ~840K tok/sec AttnRes: val_bpb 3.21 → 2.61, ~780K tok/sec Co-Authored-By: Claude Opus 4.6 --- nanochat/checkpoint_manager.py | 5 + nanochat/gpt.py | 79 ++++++++-- scripts/base_train.py | 3 + tests/test_attn_res.py | 267 +++++++++++++++++++++++++++++++++ 4 files changed, 343 insertions(+), 11 deletions(-) create mode 100644 tests/test_attn_res.py diff --git a/nanochat/checkpoint_manager.py b/nanochat/checkpoint_manager.py index f71524e..d2bee40 100644 --- a/nanochat/checkpoint_manager.py +++ b/nanochat/checkpoint_manager.py @@ -26,6 +26,11 @@ def _patch_missing_config_keys(model_config_kwargs): if "window_pattern" not in model_config_kwargs: model_config_kwargs["window_pattern"] = "L" log0(f"Patching missing window_pattern in model config to 'L'") + # AttnRes defaults to disabled + if "attn_res" not in model_config_kwargs: + model_config_kwargs["attn_res"] = False + if "attn_res_block_size" not in model_config_kwargs: + model_config_kwargs["attn_res_block_size"] = 4 def _patch_missing_keys(model_data, model_config): """Add default values for new parameters that may be missing in old checkpoints.""" diff --git a/nanochat/gpt.py b/nanochat/gpt.py index 0b822e4..b746ac2 100644 --- a/nanochat/gpt.py +++ b/nanochat/gpt.py @@ -10,6 +10,7 @@ Notable features: - no bias in linear layers - Group-Query Attention (GQA) support for more efficient inference - Flash Attention 3 integration +- Optional Block Attention Residuals (AttnRes): https://github.com/MoonshotAI/Attention-Residuals """ from functools import partial @@ -37,11 +38,23 @@ class GPTConfig: # Characters: L=long (full context), S=short (quarter context) # Examples: "L"=all full context, "SL"=alternating, "SSL"=two short then one long window_pattern: str = "SSSL" + # Attention Residuals (Block AttnRes): replaces resid_lambdas/x0_lambdas with learned + # depth-attention over block-level representations. See: https://github.com/MoonshotAI/Attention-Residuals + attn_res: bool = False + attn_res_block_size: int = 4 # layers per block (must be >= 2 and even) def norm(x): return F.rms_norm(x, (x.size(-1),)) # note that this will run in bf16, seems ok +def block_attn_res(blocks, partial_block, w): + """Depth-attention over block representations. w is a (D,) pseudo-query vector.""" + V = torch.stack(blocks + [partial_block]) # (N+1, B, T, D) + K = norm(V) + logits = torch.einsum('d, n b t d -> n b t', w, K) + h = torch.einsum('n b t, n b t d -> b t d', logits.softmax(0), V) + return h + class Linear(nn.Linear): """nn.Linear that casts weights to match input dtype in forward. Replaces autocast: master weights stay fp32 for optimizer precision, @@ -160,6 +173,13 @@ class GPT(nn.Module): """ super().__init__() self.config = config + # Validate AttnRes config + if config.attn_res: + assert config.attn_res_block_size >= 2 and config.attn_res_block_size % 2 == 0, \ + f"attn_res_block_size must be >= 2 and even, got {config.attn_res_block_size}" + # Per-layer pseudo-query vectors for depth-attention (one before attn, one before mlp) + self.attn_res_proj = nn.Parameter(torch.zeros(config.n_layer, config.n_embd)) + self.mlp_res_proj = nn.Parameter(torch.zeros(config.n_layer, config.n_embd)) # Compute per-layer window sizes for sliding window attention # window_size is (left, right) tuple: (-1, 0) for full context, (N, 0) for sliding window self.window_sizes = self._compute_window_sizes(config) @@ -177,6 +197,7 @@ class GPT(nn.Module): # resid_lambdas: scales the residual stream at each layer (init 1.0 = neutral) # x0_lambdas: blends initial embedding back in at each layer (init 0.0 = disabled) # Separate parameters so they can have different optimizer treatment + # When AttnRes is enabled, these are still created (for checkpoint compat) but unused in forward self.resid_lambdas = nn.Parameter(torch.ones(config.n_layer)) # fake init, real init in init_weights() self.x0_lambdas = nn.Parameter(torch.zeros(config.n_layer)) # fake init, real init in init_weights() # Smear: mix previous token's embedding into current token (cheap bigram-like info) @@ -247,6 +268,11 @@ class GPT(nn.Module): if block.attn.ve_gate is not None: torch.nn.init.uniform_(block.attn.ve_gate.weight, 0.0, 0.02) + # AttnRes pseudo-query vectors: small init + if self.config.attn_res: + torch.nn.init.normal_(self.attn_res_proj, mean=0.0, std=0.02) + torch.nn.init.normal_(self.mlp_res_proj, mean=0.0, std=0.02) + # Rotary embeddings head_dim = self.config.n_embd // self.config.n_head cos, sin = self._precompute_rotary_embeddings(self.rotary_seq_len, head_dim) @@ -324,9 +350,13 @@ class GPT(nn.Module): nparams = sum(p.numel() for p in self.parameters()) # Exclude non-matmul params: embeddings and per-layer scalars value_embeds_numel = sum(ve.weight.numel() for ve in self.value_embeds.values()) + attn_res_numel = 0 + if self.config.attn_res: + attn_res_numel = self.attn_res_proj.numel() + self.mlp_res_proj.numel() nparams_exclude = (self.transformer.wte.weight.numel() + value_embeds_numel + self.resid_lambdas.numel() + self.x0_lambdas.numel() + - self.smear_gate.weight.numel() + self.smear_lambda.numel() + self.backout_lambda.numel()) + self.smear_gate.weight.numel() + self.smear_lambda.numel() + self.backout_lambda.numel() + + attn_res_numel) h, q, t = self.config.n_head, self.config.n_embd // self.config.n_head, self.config.sequence_len # Sum attention FLOPs per layer, accounting for sliding window attn_flops = 0 @@ -354,7 +384,8 @@ class GPT(nn.Module): value_embeds = sum(p.numel() for p in self.value_embeds.parameters()) lm_head = sum(p.numel() for p in self.lm_head.parameters()) transformer_matrices = sum(p.numel() for p in self.transformer.h.parameters()) - scalars = self.resid_lambdas.numel() + self.x0_lambdas.numel() + self.smear_gate.weight.numel() + self.smear_lambda.numel() + self.backout_lambda.numel() + attn_res = (self.attn_res_proj.numel() + self.mlp_res_proj.numel()) if self.config.attn_res else 0 + scalars = self.resid_lambdas.numel() + self.x0_lambdas.numel() + self.smear_gate.weight.numel() + self.smear_lambda.numel() + self.backout_lambda.numel() + attn_res total = wte + value_embeds + lm_head + transformer_matrices + scalars assert total == sum(p.numel() for p in self.parameters()), "Parameter count mismatch" return { @@ -372,13 +403,14 @@ class GPT(nn.Module): # Separate out all parameters into groups matrix_params = list(self.transformer.h.parameters()) + attn_res_params = [self.attn_res_proj, self.mlp_res_proj] if self.config.attn_res else [] value_embeds_params = list(self.value_embeds.parameters()) embedding_params = list(self.transformer.wte.parameters()) lm_head_params = list(self.lm_head.parameters()) resid_params = [self.resid_lambdas] x0_params = [self.x0_lambdas] smear_params = [self.smear_gate.weight, self.smear_lambda, self.backout_lambda] - assert len(list(self.parameters())) == len(matrix_params) + len(embedding_params) + len(lm_head_params) + len(value_embeds_params) + len(resid_params) + len(x0_params) + len(smear_params) + assert len(list(self.parameters())) == len(matrix_params) + len(embedding_params) + len(lm_head_params) + len(value_embeds_params) + len(resid_params) + len(x0_params) + len(smear_params) + len(attn_res_params) # Scale the LR for the AdamW parameters by ∝1/√dmodel (tuned for 768 dim model) dmodel_lr_scale = (model_dim / 768) ** -0.5 @@ -393,7 +425,7 @@ class GPT(nn.Module): dict(kind='adamw', params=resid_params, lr=scalar_lr * 0.01, betas=(0.8, 0.95), eps=1e-10, weight_decay=0.05), dict(kind='adamw', params=x0_params, lr=scalar_lr, betas=(0.96, 0.95), eps=1e-10, weight_decay=0.0), # higher beta1 for x0 dict(kind='adamw', params=smear_params, lr=0.2, betas=(0.8, 0.95), eps=1e-10, weight_decay=0.0), - ] + ] + ([dict(kind='adamw', params=attn_res_params, lr=0.02, betas=(0.8, 0.95), eps=1e-10, weight_decay=0.0)] if attn_res_params else []) # Muon groups (matrix params, grouped by shape for stacking) for shape in sorted({p.shape for p in matrix_params}): group_params = [p for p in matrix_params if p.shape == shape] @@ -444,16 +476,41 @@ class GPT(nn.Module): x = x + gate * x_pre_smear # Forward the trunk of the Transformer - x0 = x # save initial normalized embedding for x0 residual n_layer = self.config.n_layer backout_layer = n_layer // 2 # cache at halfway point x_backout = None - for i, block in enumerate(self.transformer.h): - x = self.resid_lambdas[i] * x + self.x0_lambdas[i] * x0 - ve = self.value_embeds[str(i)](idx).to(x.dtype) if str(i) in self.value_embeds else None - x = block(x, ve, cos_sin, self.window_sizes[i], kv_cache) - if i == backout_layer: - x_backout = x + if not self.config.attn_res: + # Standard residual path with resid_lambdas / x0_lambdas + x0 = x # save initial normalized embedding for x0 residual + for i, block in enumerate(self.transformer.h): + x = self.resid_lambdas[i] * x + self.x0_lambdas[i] * x0 + ve = self.value_embeds[str(i)](idx).to(x.dtype) if str(i) in self.value_embeds else None + x = block(x, ve, cos_sin, self.window_sizes[i], kv_cache) + if i == backout_layer: + x_backout = x + else: + # AttnRes path: depth-attention over block-level representations + # replaces resid_lambdas/x0_lambdas with learned depth-attention + block_size = self.config.attn_res_block_size + ar_blocks = [] # completed block representations + partial = x + for i, block in enumerate(self.transformer.h): + ve = self.value_embeds[str(i)](idx).to(partial.dtype) if str(i) in self.value_embeds else None + # Depth-attention before attention sublayer + h = block_attn_res(ar_blocks, partial, self.attn_res_proj[i].to(partial.dtype)) + # Block boundary: save partial, start fresh + if i > 0 and i % (block_size // 2) == 0: + ar_blocks = ar_blocks + [partial] + partial = None + attn_out = block.attn(norm(h), ve, cos_sin, self.window_sizes[i], kv_cache) + partial = partial + attn_out if partial is not None else attn_out + # Depth-attention before MLP sublayer + h = block_attn_res(ar_blocks, partial, self.mlp_res_proj[i].to(partial.dtype)) + mlp_out = block.mlp(norm(h)) + partial = partial + mlp_out + if i == backout_layer: + x_backout = partial + x = partial # Subtract mid-layer residual to remove low-level features before logit projection if x_backout is not None: x = x - self.backout_lambda.to(x.dtype) * x_backout diff --git a/scripts/base_train.py b/scripts/base_train.py index 86aa770..dab0153 100644 --- a/scripts/base_train.py +++ b/scripts/base_train.py @@ -52,6 +52,8 @@ parser.add_argument("--aspect-ratio", type=int, default=64, help="model_dim = de parser.add_argument("--head-dim", type=int, default=128, help="target head dimension for attention") parser.add_argument("--max-seq-len", type=int, default=2048, help="max context length") parser.add_argument("--window-pattern", type=str, default="SSSL", help="sliding window pattern tiled across layers: L=full, S=half context (e.g. 'SSL')") +parser.add_argument("--attn-res", action="store_true", help="enable Attention Residuals (Block AttnRes)") +parser.add_argument("--attn-res-block-size", type=int, default=4, help="layers per AttnRes block (must be >= 2 and even)") # Training horizon (only one used, in order of precedence) parser.add_argument("--num-iterations", type=int, default=-1, help="explicit number of optimization steps (-1 = disable)") parser.add_argument("--target-flops", type=float, default=-1.0, help="calculate num_iterations to reach target_flops (-1 = disable)") @@ -137,6 +139,7 @@ def build_model_meta(depth): sequence_len=args.max_seq_len, vocab_size=vocab_size, n_layer=depth, n_head=num_heads, n_kv_head=num_heads, n_embd=model_dim, window_pattern=args.window_pattern, + attn_res=args.attn_res, attn_res_block_size=args.attn_res_block_size, ) with torch.device("meta"): model_meta = GPT(config) diff --git a/tests/test_attn_res.py b/tests/test_attn_res.py new file mode 100644 index 0000000..7406b73 --- /dev/null +++ b/tests/test_attn_res.py @@ -0,0 +1,267 @@ +""" +Test Attention Residuals (Block AttnRes) integration. + +python -m pytest tests/test_attn_res.py -v +""" + +import torch + +# Force float32 compute for CPU test stability (avoids bf16 NaN in small models +# and dtype mismatches between model activations and KV cache on CUDA machines) +import nanochat.gpt as _gpt_mod +import nanochat.common as _common_mod +_gpt_mod.COMPUTE_DTYPE = torch.float32 +_common_mod.COMPUTE_DTYPE = torch.float32 + +from nanochat.gpt import GPT, GPTConfig, block_attn_res + + +def _make_config(attn_res=True, **kwargs): + """Create a small test config.""" + defaults = dict( + sequence_len=64, vocab_size=256, n_layer=8, n_head=4, n_kv_head=4, n_embd=64, + window_pattern="L", attn_res=attn_res, attn_res_block_size=4, + ) + defaults.update(kwargs) + return GPTConfig(**defaults) + + +def _build_model(config): + """Build a model on meta device, move to CPU, init weights.""" + with torch.device("meta"): + model = GPT(config) + model.to_empty(device="cpu") + model.init_weights() + return model + + +# ---- Unit tests for building blocks ---- + +def test_block_attn_res_function(): + """block_attn_res produces correct shape and valid output.""" + B, T, D = 2, 8, 16 + blocks = [torch.randn(B, T, D) for _ in range(3)] + partial = torch.randn(B, T, D) + w = torch.randn(D) + + h = block_attn_res(blocks, partial, w) + assert h.shape == (B, T, D) + assert torch.isfinite(h).all() + + +def test_block_attn_res_single_block(): + """With no prior blocks, output should equal the partial block itself.""" + B, T, D = 1, 4, 8 + partial = torch.randn(B, T, D) + w = torch.randn(D) + + h = block_attn_res([], partial, w) + # With only one element, softmax over dim 0 gives 1.0, so h == partial + assert torch.allclose(h, partial, atol=1e-6) + + +# ---- Model construction tests ---- + +def test_model_creates_attn_res_params(): + """When attn_res=True, GPT has attn_res_proj and mlp_res_proj parameters.""" + config = _make_config(attn_res=True) + model = _build_model(config) + assert hasattr(model, 'attn_res_proj') + assert hasattr(model, 'mlp_res_proj') + assert model.attn_res_proj.shape == (config.n_layer, config.n_embd) + assert model.mlp_res_proj.shape == (config.n_layer, config.n_embd) + + +def test_model_no_attn_res_params_when_disabled(): + """When attn_res=False, GPT does NOT have AttnRes parameters.""" + config = _make_config(attn_res=False) + model = _build_model(config) + assert not hasattr(model, 'attn_res_proj') + + +def test_attn_res_param_count_overhead(): + """AttnRes adds minimal parameter overhead.""" + config_base = _make_config(attn_res=False) + config_ar = _make_config(attn_res=True) + model_base = _build_model(config_base) + model_ar = _build_model(config_ar) + n_base = sum(p.numel() for p in model_base.parameters()) + n_ar = sum(p.numel() for p in model_ar.parameters()) + overhead = (n_ar - n_base) / n_base + # AttnRes adds 2 * n_layer * D params. For D=64, n_layer=8: 2*8*64 = 1024. + assert overhead < 0.05, f"AttnRes overhead {overhead:.2%} is too high" + assert n_ar > n_base, "AttnRes model should have more params" + + +# ---- Forward pass tests ---- + +def test_forward_training_attn_res(): + """Forward pass with attn_res=True produces valid loss.""" + config = _make_config(attn_res=True) + model = _build_model(config) + B, T = 2, 32 + idx = torch.randint(0, config.vocab_size, (B, T)) + targets = torch.randint(0, config.vocab_size, (B, T)) + loss = model(idx, targets=targets) + assert loss.shape == () + assert torch.isfinite(loss) + assert loss.item() > 0 + + +def test_forward_inference_attn_res(): + """Forward pass without targets returns logits.""" + config = _make_config(attn_res=True) + model = _build_model(config) + model.eval() + B, T = 1, 16 + idx = torch.randint(0, config.vocab_size, (B, T)) + logits = model(idx) + assert logits.shape == (B, T, config.vocab_size) + assert torch.isfinite(logits).all() + + +def test_forward_standard_path_unchanged(): + """Standard path (attn_res=False) still works correctly.""" + config = _make_config(attn_res=False) + model = _build_model(config) + # Disable backout for small test model (pre-existing NaN with n_embd=64) + model.backout_lambda.data.fill_(0.0) + B, T = 2, 32 + idx = torch.randint(0, config.vocab_size, (B, T)) + targets = torch.randint(0, config.vocab_size, (B, T)) + loss = model(idx, targets=targets) + assert torch.isfinite(loss) + + +def test_backward_attn_res(): + """Gradients flow through AttnRes path.""" + config = _make_config(attn_res=True) + model = _build_model(config) + B, T = 2, 16 + idx = torch.randint(0, config.vocab_size, (B, T)) + targets = torch.randint(0, config.vocab_size, (B, T)) + loss = model(idx, targets=targets) + loss.backward() + # Check that AttnRes projection gradients are non-zero + assert model.attn_res_proj.grad is not None + assert model.mlp_res_proj.grad is not None + + +# ---- Config validation tests ---- + +def test_invalid_block_size_odd(): + """Odd block size should fail validation.""" + try: + config = _make_config(attn_res=True, attn_res_block_size=3) + _build_model(config) + assert False, "Should have raised AssertionError" + except AssertionError: + pass + + +def test_invalid_block_size_one(): + """block_size=1 should fail validation.""" + try: + config = _make_config(attn_res=True, attn_res_block_size=1) + _build_model(config) + assert False, "Should have raised AssertionError" + except AssertionError: + pass + + +def test_valid_block_sizes(): + """Various valid block sizes should work.""" + for bs in [2, 4, 6, 8]: + config = _make_config(attn_res=True, attn_res_block_size=bs, n_layer=8) + model = _build_model(config) + idx = torch.randint(0, 256, (1, 16)) + logits = model(idx) + assert torch.isfinite(logits).all(), f"NaN/Inf with block_size={bs}" + + +# ---- Optimizer tests ---- + +def test_optimizer_setup_attn_res(): + """Optimizer setup accounts for all AttnRes params.""" + config = _make_config(attn_res=True) + model = _build_model(config) + optimizer = model.setup_optimizer() + assert optimizer is not None + + +def test_optimizer_setup_standard(): + """Optimizer setup works for standard path too.""" + config = _make_config(attn_res=False) + model = _build_model(config) + optimizer = model.setup_optimizer() + assert optimizer is not None + + +# ---- FLOPs and param counting tests ---- + +def test_estimate_flops_attn_res(): + """FLOPs estimation works with AttnRes.""" + config = _make_config(attn_res=True) + model = _build_model(config) + flops = model.estimate_flops() + assert flops > 0 + + +def test_num_scaling_params_attn_res(): + """Parameter count is consistent with AttnRes.""" + config = _make_config(attn_res=True) + model = _build_model(config) + counts = model.num_scaling_params() + assert counts['total'] == sum(p.numel() for p in model.parameters()) + + +# ---- KV cache inference tests ---- + +def test_kv_cache_inference_attn_res(): + """AttnRes works with KV cache for inference.""" + from nanochat.engine import KVCache + config = _make_config(attn_res=True, n_layer=4, n_embd=64, n_head=4, n_kv_head=4) + model = _build_model(config) + model.eval() + + B, T = 1, 8 + idx = torch.randint(0, config.vocab_size, (B, T)) + + # Prefill + kv_cache = KVCache( + batch_size=B, num_heads=config.n_kv_head, + seq_len=32, head_dim=config.n_embd // config.n_head, + num_layers=config.n_layer, device="cpu", dtype=torch.float32, + ) + logits_prefill = model(idx, kv_cache=kv_cache) + assert logits_prefill.shape == (B, T, config.vocab_size) + + # Decode one token + next_idx = torch.randint(0, config.vocab_size, (B, 1)) + logits_decode = model(next_idx, kv_cache=kv_cache) + assert logits_decode.shape == (B, 1, config.vocab_size) + assert torch.isfinite(logits_decode).all() + + +def test_generate_attn_res(): + """Full generation pipeline works with AttnRes.""" + from nanochat.engine import Engine + from tests.test_engine import MockModel, ByteTokenizer + + # We need a real model for this test, not mock + config = _make_config(attn_res=True, n_layer=4, n_embd=64, n_head=4, n_kv_head=4) + model = _build_model(config) + model.eval() + + tokenizer = ByteTokenizer() + + # Override model methods that Engine needs + model.get_device = lambda: torch.device("cpu") + + engine = Engine(model, tokenizer) + prompt = [0, 72, 101, 108, 108, 111] # some tokens + + results, masks = engine.generate_batch(prompt, max_tokens=5, temperature=0.0, seed=42) + assert len(results) == 1 + assert len(results[0]) >= len(prompt) # at least prompt tokens + assert len(results[0]) <= len(prompt) + 5 # at most prompt + max_tokens