This commit is contained in:
AaronXPeng 2026-03-17 11:27:55 -07:00 committed by GitHub
commit ff8942ac95
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 343 additions and 11 deletions

View File

@ -26,6 +26,11 @@ def _patch_missing_config_keys(model_config_kwargs):
if "window_pattern" not in model_config_kwargs:
model_config_kwargs["window_pattern"] = "L"
log0(f"Patching missing window_pattern in model config to 'L'")
# AttnRes defaults to disabled
if "attn_res" not in model_config_kwargs:
model_config_kwargs["attn_res"] = False
if "attn_res_block_size" not in model_config_kwargs:
model_config_kwargs["attn_res_block_size"] = 4
def _patch_missing_keys(model_data, model_config):
"""Add default values for new parameters that may be missing in old checkpoints."""

View File

@ -10,6 +10,7 @@ Notable features:
- no bias in linear layers
- Group-Query Attention (GQA) support for more efficient inference
- Flash Attention 3 integration
- Optional Block Attention Residuals (AttnRes): https://github.com/MoonshotAI/Attention-Residuals
"""
from functools import partial
@ -37,11 +38,23 @@ class GPTConfig:
# Characters: L=long (full context), S=short (quarter context)
# Examples: "L"=all full context, "SL"=alternating, "SSL"=two short then one long
window_pattern: str = "SSSL"
# Attention Residuals (Block AttnRes): replaces resid_lambdas/x0_lambdas with learned
# depth-attention over block-level representations. See: https://github.com/MoonshotAI/Attention-Residuals
attn_res: bool = False
attn_res_block_size: int = 4 # layers per block (must be >= 2 and even)
def norm(x):
return F.rms_norm(x, (x.size(-1),)) # note that this will run in bf16, seems ok
def block_attn_res(blocks, partial_block, w):
"""Depth-attention over block representations. w is a (D,) pseudo-query vector."""
V = torch.stack(blocks + [partial_block]) # (N+1, B, T, D)
K = norm(V)
logits = torch.einsum('d, n b t d -> n b t', w, K)
h = torch.einsum('n b t, n b t d -> b t d', logits.softmax(0), V)
return h
class Linear(nn.Linear):
"""nn.Linear that casts weights to match input dtype in forward.
Replaces autocast: master weights stay fp32 for optimizer precision,
@ -160,6 +173,13 @@ class GPT(nn.Module):
"""
super().__init__()
self.config = config
# Validate AttnRes config
if config.attn_res:
assert config.attn_res_block_size >= 2 and config.attn_res_block_size % 2 == 0, \
f"attn_res_block_size must be >= 2 and even, got {config.attn_res_block_size}"
# Per-layer pseudo-query vectors for depth-attention (one before attn, one before mlp)
self.attn_res_proj = nn.Parameter(torch.zeros(config.n_layer, config.n_embd))
self.mlp_res_proj = nn.Parameter(torch.zeros(config.n_layer, config.n_embd))
# Compute per-layer window sizes for sliding window attention
# window_size is (left, right) tuple: (-1, 0) for full context, (N, 0) for sliding window
self.window_sizes = self._compute_window_sizes(config)
@ -177,6 +197,7 @@ class GPT(nn.Module):
# resid_lambdas: scales the residual stream at each layer (init 1.0 = neutral)
# x0_lambdas: blends initial embedding back in at each layer (init 0.0 = disabled)
# Separate parameters so they can have different optimizer treatment
# When AttnRes is enabled, these are still created (for checkpoint compat) but unused in forward
self.resid_lambdas = nn.Parameter(torch.ones(config.n_layer)) # fake init, real init in init_weights()
self.x0_lambdas = nn.Parameter(torch.zeros(config.n_layer)) # fake init, real init in init_weights()
# Smear: mix previous token's embedding into current token (cheap bigram-like info)
@ -247,6 +268,11 @@ class GPT(nn.Module):
if block.attn.ve_gate is not None:
torch.nn.init.uniform_(block.attn.ve_gate.weight, 0.0, 0.02)
# AttnRes pseudo-query vectors: small init
if self.config.attn_res:
torch.nn.init.normal_(self.attn_res_proj, mean=0.0, std=0.02)
torch.nn.init.normal_(self.mlp_res_proj, mean=0.0, std=0.02)
# Rotary embeddings
head_dim = self.config.n_embd // self.config.n_head
cos, sin = self._precompute_rotary_embeddings(self.rotary_seq_len, head_dim)
@ -324,9 +350,13 @@ class GPT(nn.Module):
nparams = sum(p.numel() for p in self.parameters())
# Exclude non-matmul params: embeddings and per-layer scalars
value_embeds_numel = sum(ve.weight.numel() for ve in self.value_embeds.values())
attn_res_numel = 0
if self.config.attn_res:
attn_res_numel = self.attn_res_proj.numel() + self.mlp_res_proj.numel()
nparams_exclude = (self.transformer.wte.weight.numel() + value_embeds_numel +
self.resid_lambdas.numel() + self.x0_lambdas.numel() +
self.smear_gate.weight.numel() + self.smear_lambda.numel() + self.backout_lambda.numel())
self.smear_gate.weight.numel() + self.smear_lambda.numel() + self.backout_lambda.numel() +
attn_res_numel)
h, q, t = self.config.n_head, self.config.n_embd // self.config.n_head, self.config.sequence_len
# Sum attention FLOPs per layer, accounting for sliding window
attn_flops = 0
@ -354,7 +384,8 @@ class GPT(nn.Module):
value_embeds = sum(p.numel() for p in self.value_embeds.parameters())
lm_head = sum(p.numel() for p in self.lm_head.parameters())
transformer_matrices = sum(p.numel() for p in self.transformer.h.parameters())
scalars = self.resid_lambdas.numel() + self.x0_lambdas.numel() + self.smear_gate.weight.numel() + self.smear_lambda.numel() + self.backout_lambda.numel()
attn_res = (self.attn_res_proj.numel() + self.mlp_res_proj.numel()) if self.config.attn_res else 0
scalars = self.resid_lambdas.numel() + self.x0_lambdas.numel() + self.smear_gate.weight.numel() + self.smear_lambda.numel() + self.backout_lambda.numel() + attn_res
total = wte + value_embeds + lm_head + transformer_matrices + scalars
assert total == sum(p.numel() for p in self.parameters()), "Parameter count mismatch"
return {
@ -372,13 +403,14 @@ class GPT(nn.Module):
# Separate out all parameters into groups
matrix_params = list(self.transformer.h.parameters())
attn_res_params = [self.attn_res_proj, self.mlp_res_proj] if self.config.attn_res else []
value_embeds_params = list(self.value_embeds.parameters())
embedding_params = list(self.transformer.wte.parameters())
lm_head_params = list(self.lm_head.parameters())
resid_params = [self.resid_lambdas]
x0_params = [self.x0_lambdas]
smear_params = [self.smear_gate.weight, self.smear_lambda, self.backout_lambda]
assert len(list(self.parameters())) == len(matrix_params) + len(embedding_params) + len(lm_head_params) + len(value_embeds_params) + len(resid_params) + len(x0_params) + len(smear_params)
assert len(list(self.parameters())) == len(matrix_params) + len(embedding_params) + len(lm_head_params) + len(value_embeds_params) + len(resid_params) + len(x0_params) + len(smear_params) + len(attn_res_params)
# Scale the LR for the AdamW parameters by ∝1/√dmodel (tuned for 768 dim model)
dmodel_lr_scale = (model_dim / 768) ** -0.5
@ -393,7 +425,7 @@ class GPT(nn.Module):
dict(kind='adamw', params=resid_params, lr=scalar_lr * 0.01, betas=(0.8, 0.95), eps=1e-10, weight_decay=0.05),
dict(kind='adamw', params=x0_params, lr=scalar_lr, betas=(0.96, 0.95), eps=1e-10, weight_decay=0.0), # higher beta1 for x0
dict(kind='adamw', params=smear_params, lr=0.2, betas=(0.8, 0.95), eps=1e-10, weight_decay=0.0),
]
] + ([dict(kind='adamw', params=attn_res_params, lr=0.02, betas=(0.8, 0.95), eps=1e-10, weight_decay=0.0)] if attn_res_params else [])
# Muon groups (matrix params, grouped by shape for stacking)
for shape in sorted({p.shape for p in matrix_params}):
group_params = [p for p in matrix_params if p.shape == shape]
@ -444,16 +476,41 @@ class GPT(nn.Module):
x = x + gate * x_pre_smear
# Forward the trunk of the Transformer
x0 = x # save initial normalized embedding for x0 residual
n_layer = self.config.n_layer
backout_layer = n_layer // 2 # cache at halfway point
x_backout = None
for i, block in enumerate(self.transformer.h):
x = self.resid_lambdas[i] * x + self.x0_lambdas[i] * x0
ve = self.value_embeds[str(i)](idx).to(x.dtype) if str(i) in self.value_embeds else None
x = block(x, ve, cos_sin, self.window_sizes[i], kv_cache)
if i == backout_layer:
x_backout = x
if not self.config.attn_res:
# Standard residual path with resid_lambdas / x0_lambdas
x0 = x # save initial normalized embedding for x0 residual
for i, block in enumerate(self.transformer.h):
x = self.resid_lambdas[i] * x + self.x0_lambdas[i] * x0
ve = self.value_embeds[str(i)](idx).to(x.dtype) if str(i) in self.value_embeds else None
x = block(x, ve, cos_sin, self.window_sizes[i], kv_cache)
if i == backout_layer:
x_backout = x
else:
# AttnRes path: depth-attention over block-level representations
# replaces resid_lambdas/x0_lambdas with learned depth-attention
block_size = self.config.attn_res_block_size
ar_blocks = [] # completed block representations
partial = x
for i, block in enumerate(self.transformer.h):
ve = self.value_embeds[str(i)](idx).to(partial.dtype) if str(i) in self.value_embeds else None
# Depth-attention before attention sublayer
h = block_attn_res(ar_blocks, partial, self.attn_res_proj[i].to(partial.dtype))
# Block boundary: save partial, start fresh
if i > 0 and i % (block_size // 2) == 0:
ar_blocks = ar_blocks + [partial]
partial = None
attn_out = block.attn(norm(h), ve, cos_sin, self.window_sizes[i], kv_cache)
partial = partial + attn_out if partial is not None else attn_out
# Depth-attention before MLP sublayer
h = block_attn_res(ar_blocks, partial, self.mlp_res_proj[i].to(partial.dtype))
mlp_out = block.mlp(norm(h))
partial = partial + mlp_out
if i == backout_layer:
x_backout = partial
x = partial
# Subtract mid-layer residual to remove low-level features before logit projection
if x_backout is not None:
x = x - self.backout_lambda.to(x.dtype) * x_backout

View File

@ -52,6 +52,8 @@ parser.add_argument("--aspect-ratio", type=int, default=64, help="model_dim = de
parser.add_argument("--head-dim", type=int, default=128, help="target head dimension for attention")
parser.add_argument("--max-seq-len", type=int, default=2048, help="max context length")
parser.add_argument("--window-pattern", type=str, default="SSSL", help="sliding window pattern tiled across layers: L=full, S=half context (e.g. 'SSL')")
parser.add_argument("--attn-res", action="store_true", help="enable Attention Residuals (Block AttnRes)")
parser.add_argument("--attn-res-block-size", type=int, default=4, help="layers per AttnRes block (must be >= 2 and even)")
# Training horizon (only one used, in order of precedence)
parser.add_argument("--num-iterations", type=int, default=-1, help="explicit number of optimization steps (-1 = disable)")
parser.add_argument("--target-flops", type=float, default=-1.0, help="calculate num_iterations to reach target_flops (-1 = disable)")
@ -137,6 +139,7 @@ def build_model_meta(depth):
sequence_len=args.max_seq_len, vocab_size=vocab_size,
n_layer=depth, n_head=num_heads, n_kv_head=num_heads, n_embd=model_dim,
window_pattern=args.window_pattern,
attn_res=args.attn_res, attn_res_block_size=args.attn_res_block_size,
)
with torch.device("meta"):
model_meta = GPT(config)

267
tests/test_attn_res.py Normal file
View File

@ -0,0 +1,267 @@
"""
Test Attention Residuals (Block AttnRes) integration.
python -m pytest tests/test_attn_res.py -v
"""
import torch
# Force float32 compute for CPU test stability (avoids bf16 NaN in small models
# and dtype mismatches between model activations and KV cache on CUDA machines)
import nanochat.gpt as _gpt_mod
import nanochat.common as _common_mod
_gpt_mod.COMPUTE_DTYPE = torch.float32
_common_mod.COMPUTE_DTYPE = torch.float32
from nanochat.gpt import GPT, GPTConfig, block_attn_res
def _make_config(attn_res=True, **kwargs):
"""Create a small test config."""
defaults = dict(
sequence_len=64, vocab_size=256, n_layer=8, n_head=4, n_kv_head=4, n_embd=64,
window_pattern="L", attn_res=attn_res, attn_res_block_size=4,
)
defaults.update(kwargs)
return GPTConfig(**defaults)
def _build_model(config):
"""Build a model on meta device, move to CPU, init weights."""
with torch.device("meta"):
model = GPT(config)
model.to_empty(device="cpu")
model.init_weights()
return model
# ---- Unit tests for building blocks ----
def test_block_attn_res_function():
"""block_attn_res produces correct shape and valid output."""
B, T, D = 2, 8, 16
blocks = [torch.randn(B, T, D) for _ in range(3)]
partial = torch.randn(B, T, D)
w = torch.randn(D)
h = block_attn_res(blocks, partial, w)
assert h.shape == (B, T, D)
assert torch.isfinite(h).all()
def test_block_attn_res_single_block():
"""With no prior blocks, output should equal the partial block itself."""
B, T, D = 1, 4, 8
partial = torch.randn(B, T, D)
w = torch.randn(D)
h = block_attn_res([], partial, w)
# With only one element, softmax over dim 0 gives 1.0, so h == partial
assert torch.allclose(h, partial, atol=1e-6)
# ---- Model construction tests ----
def test_model_creates_attn_res_params():
"""When attn_res=True, GPT has attn_res_proj and mlp_res_proj parameters."""
config = _make_config(attn_res=True)
model = _build_model(config)
assert hasattr(model, 'attn_res_proj')
assert hasattr(model, 'mlp_res_proj')
assert model.attn_res_proj.shape == (config.n_layer, config.n_embd)
assert model.mlp_res_proj.shape == (config.n_layer, config.n_embd)
def test_model_no_attn_res_params_when_disabled():
"""When attn_res=False, GPT does NOT have AttnRes parameters."""
config = _make_config(attn_res=False)
model = _build_model(config)
assert not hasattr(model, 'attn_res_proj')
def test_attn_res_param_count_overhead():
"""AttnRes adds minimal parameter overhead."""
config_base = _make_config(attn_res=False)
config_ar = _make_config(attn_res=True)
model_base = _build_model(config_base)
model_ar = _build_model(config_ar)
n_base = sum(p.numel() for p in model_base.parameters())
n_ar = sum(p.numel() for p in model_ar.parameters())
overhead = (n_ar - n_base) / n_base
# AttnRes adds 2 * n_layer * D params. For D=64, n_layer=8: 2*8*64 = 1024.
assert overhead < 0.05, f"AttnRes overhead {overhead:.2%} is too high"
assert n_ar > n_base, "AttnRes model should have more params"
# ---- Forward pass tests ----
def test_forward_training_attn_res():
"""Forward pass with attn_res=True produces valid loss."""
config = _make_config(attn_res=True)
model = _build_model(config)
B, T = 2, 32
idx = torch.randint(0, config.vocab_size, (B, T))
targets = torch.randint(0, config.vocab_size, (B, T))
loss = model(idx, targets=targets)
assert loss.shape == ()
assert torch.isfinite(loss)
assert loss.item() > 0
def test_forward_inference_attn_res():
"""Forward pass without targets returns logits."""
config = _make_config(attn_res=True)
model = _build_model(config)
model.eval()
B, T = 1, 16
idx = torch.randint(0, config.vocab_size, (B, T))
logits = model(idx)
assert logits.shape == (B, T, config.vocab_size)
assert torch.isfinite(logits).all()
def test_forward_standard_path_unchanged():
"""Standard path (attn_res=False) still works correctly."""
config = _make_config(attn_res=False)
model = _build_model(config)
# Disable backout for small test model (pre-existing NaN with n_embd=64)
model.backout_lambda.data.fill_(0.0)
B, T = 2, 32
idx = torch.randint(0, config.vocab_size, (B, T))
targets = torch.randint(0, config.vocab_size, (B, T))
loss = model(idx, targets=targets)
assert torch.isfinite(loss)
def test_backward_attn_res():
"""Gradients flow through AttnRes path."""
config = _make_config(attn_res=True)
model = _build_model(config)
B, T = 2, 16
idx = torch.randint(0, config.vocab_size, (B, T))
targets = torch.randint(0, config.vocab_size, (B, T))
loss = model(idx, targets=targets)
loss.backward()
# Check that AttnRes projection gradients are non-zero
assert model.attn_res_proj.grad is not None
assert model.mlp_res_proj.grad is not None
# ---- Config validation tests ----
def test_invalid_block_size_odd():
"""Odd block size should fail validation."""
try:
config = _make_config(attn_res=True, attn_res_block_size=3)
_build_model(config)
assert False, "Should have raised AssertionError"
except AssertionError:
pass
def test_invalid_block_size_one():
"""block_size=1 should fail validation."""
try:
config = _make_config(attn_res=True, attn_res_block_size=1)
_build_model(config)
assert False, "Should have raised AssertionError"
except AssertionError:
pass
def test_valid_block_sizes():
"""Various valid block sizes should work."""
for bs in [2, 4, 6, 8]:
config = _make_config(attn_res=True, attn_res_block_size=bs, n_layer=8)
model = _build_model(config)
idx = torch.randint(0, 256, (1, 16))
logits = model(idx)
assert torch.isfinite(logits).all(), f"NaN/Inf with block_size={bs}"
# ---- Optimizer tests ----
def test_optimizer_setup_attn_res():
"""Optimizer setup accounts for all AttnRes params."""
config = _make_config(attn_res=True)
model = _build_model(config)
optimizer = model.setup_optimizer()
assert optimizer is not None
def test_optimizer_setup_standard():
"""Optimizer setup works for standard path too."""
config = _make_config(attn_res=False)
model = _build_model(config)
optimizer = model.setup_optimizer()
assert optimizer is not None
# ---- FLOPs and param counting tests ----
def test_estimate_flops_attn_res():
"""FLOPs estimation works with AttnRes."""
config = _make_config(attn_res=True)
model = _build_model(config)
flops = model.estimate_flops()
assert flops > 0
def test_num_scaling_params_attn_res():
"""Parameter count is consistent with AttnRes."""
config = _make_config(attn_res=True)
model = _build_model(config)
counts = model.num_scaling_params()
assert counts['total'] == sum(p.numel() for p in model.parameters())
# ---- KV cache inference tests ----
def test_kv_cache_inference_attn_res():
"""AttnRes works with KV cache for inference."""
from nanochat.engine import KVCache
config = _make_config(attn_res=True, n_layer=4, n_embd=64, n_head=4, n_kv_head=4)
model = _build_model(config)
model.eval()
B, T = 1, 8
idx = torch.randint(0, config.vocab_size, (B, T))
# Prefill
kv_cache = KVCache(
batch_size=B, num_heads=config.n_kv_head,
seq_len=32, head_dim=config.n_embd // config.n_head,
num_layers=config.n_layer, device="cpu", dtype=torch.float32,
)
logits_prefill = model(idx, kv_cache=kv_cache)
assert logits_prefill.shape == (B, T, config.vocab_size)
# Decode one token
next_idx = torch.randint(0, config.vocab_size, (B, 1))
logits_decode = model(next_idx, kv_cache=kv_cache)
assert logits_decode.shape == (B, 1, config.vocab_size)
assert torch.isfinite(logits_decode).all()
def test_generate_attn_res():
"""Full generation pipeline works with AttnRes."""
from nanochat.engine import Engine
from tests.test_engine import MockModel, ByteTokenizer
# We need a real model for this test, not mock
config = _make_config(attn_res=True, n_layer=4, n_embd=64, n_head=4, n_kv_head=4)
model = _build_model(config)
model.eval()
tokenizer = ByteTokenizer()
# Override model methods that Engine needs
model.get_device = lambda: torch.device("cpu")
engine = Engine(model, tokenizer)
prompt = [0, 72, 101, 108, 108, 111] # some tokens
results, masks = engine.generate_batch(prompt, max_tokens=5, temperature=0.0, seed=42)
assert len(results) == 1
assert len(results[0]) >= len(prompt) # at least prompt tokens
assert len(results[0]) <= len(prompt) + 5 # at most prompt + max_tokens