This commit is contained in:
Node 2026-03-25 13:21:35 -07:00 committed by GitHub
commit 31f1f0e04b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 114 additions and 7 deletions

Binary file not shown.

After

Width:  |  Height:  |  Size: 84 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 130 KiB

View File

@ -0,0 +1,42 @@
# Attention Residuals local d12 results
This change adds a gated AttnRes path on top of the current nanochat residual path instead of replacing it.
Key model change:
- Baseline path stays `base = resid_lambdas[i] * x + x0_lambdas[i] * x0`
- AttnRes path uses the same `base`, then applies a zero-init correction `base + alpha * (depth - base)`
- `alpha` starts at `0`, so AttnRes is exactly equal to baseline at initialization
Local experiment setup:
- `depth=12`
- `aspect_ratio=32`
- `head_dim=64`
- `max_seq_len=256`
- `device_batch_size=16`
- `total_batch_size=4096`
- `device_type=mps`
- `window_pattern=L`
## 0-500 steps
![d12 AttnRes vs baseline, 0-500](attention_residuals_d12_0_500.png)
Checkpoint values:
- `100`: base `1.979283`, attnres `1.960904`
- `200`: base `1.855803`, attnres `1.844442`
- `300`: base `1.786364`, attnres `1.770760`
- `400`: base `1.735797`, attnres `1.720742`
- `500`: base `1.714498`, attnres `1.701136`
## 500-1000 steps
This figure is a checkpoint continuation from step `500` to step `1000`. It is not identical to a fresh `0-1000` run planned from the start, because the training horizon was extended at resume time.
![d12 AttnRes vs baseline, 500-1000](attention_residuals_d12_500_1000.png)
Checkpoint values:
- `600`: base `1.689065`, attnres `1.681095`
- `700`: base `1.652204`, attnres `1.648315`
- `800`: base `1.627924`, attnres `1.622582`
- `900`: base `1.609432`, attnres `1.602892`
- `1000`: base `1.602437`, attnres `1.597053`

View File

@ -38,6 +38,12 @@ def _patch_missing_keys(model_data, model_config):
if "x0_lambdas" not in model_data:
model_data["x0_lambdas"] = torch.zeros(n_layer)
log0(f"Patching missing x0_lambdas in model data to 0.0")
if model_config.residual_mode == "attnres_block" and "attnres_queries" not in model_data:
model_data["attnres_queries"] = torch.zeros(n_layer, model_config.n_embd)
log0(f"Patching missing attnres_queries in model data to 0.0")
if model_config.residual_mode == "attnres_block" and "attnres_alphas" not in model_data:
model_data["attnres_alphas"] = torch.zeros(n_layer)
log0(f"Patching missing attnres_alphas in model data to 0.0")
def save_checkpoint(checkpoint_dir, step, model_data, optimizer_data, meta_data, rank=0):
if rank == 0:

View File

@ -37,11 +37,25 @@ class GPTConfig:
# Characters: L=long (full context), S=short (quarter context)
# Examples: "L"=all full context, "SL"=alternating, "SSL"=two short then one long
window_pattern: str = "SSSL"
residual_mode: str = "baseline"
attnres_block_size: int = 0
def norm(x):
return F.rms_norm(x, (x.size(-1),)) # note that this will run in bf16, seems ok
def attnres_mix(sources, query):
"""Attention Residuals over depth using a single learned pseudo-query."""
if len(sources) == 1:
return sources[0]
stacked = torch.stack(sources, dim=0)
keys = norm(stacked)
query = F.rms_norm(query.to(dtype=stacked.dtype), (query.numel(),))
logits = torch.einsum("d, n b t d -> n b t", query, keys)
logits = logits * (stacked.size(-1) ** -0.5)
weights = logits.softmax(0)
return torch.einsum("n b t, n b t d -> b t d", weights, stacked)
class Linear(nn.Linear):
"""nn.Linear that casts weights to match input dtype in forward.
Replaces autocast: master weights stay fp32 for optimizer precision,
@ -160,6 +174,7 @@ class GPT(nn.Module):
"""
super().__init__()
self.config = config
assert config.residual_mode in {"baseline", "attnres_block"}, f"Invalid residual_mode: {config.residual_mode}"
# Compute per-layer window sizes for sliding window attention
# window_size is (left, right) tuple: (-1, 0) for full context, (N, 0) for sliding window
self.window_sizes = self._compute_window_sizes(config)
@ -188,6 +203,12 @@ class GPT(nn.Module):
head_dim = config.n_embd // config.n_head
kv_dim = config.n_kv_head * head_dim
self.value_embeds = nn.ModuleDict({str(i): nn.Embedding(padded_vocab_size, kv_dim) for i in range(config.n_layer) if has_ve(i, config.n_layer)})
self.attnres_queries = None
self.attnres_alphas = None
if config.residual_mode == "attnres_block":
self.attnres_queries = nn.Parameter(torch.zeros(config.n_layer, config.n_embd))
self.attnres_alphas = nn.Parameter(torch.zeros(config.n_layer))
self.attnres_block_size = config.attnres_block_size if config.attnres_block_size > 0 else (config.n_layer + 7) // 8
# To support meta device initialization, we init the rotary embeddings here, but it's just "fake" meta tensors only.
# As for rotary_seq_len, these rotary embeddings are pretty small/cheap in memory,
# so let's just over-compute them by 10X, but assert fail if we ever reach that amount.
@ -241,6 +262,10 @@ class GPT(nn.Module):
# Value embeddings (init like c_v: uniform with same std)
for ve in self.value_embeds.values():
torch.nn.init.uniform_(ve.weight, -s, s)
if self.attnres_queries is not None:
torch.nn.init.zeros_(self.attnres_queries)
if self.attnres_alphas is not None:
torch.nn.init.zeros_(self.attnres_alphas)
# Gate weights init with small positive values so gates start slightly above neutral
for block in self.transformer.h:
@ -326,7 +351,9 @@ class GPT(nn.Module):
value_embeds_numel = sum(ve.weight.numel() for ve in self.value_embeds.values())
nparams_exclude = (self.transformer.wte.weight.numel() + value_embeds_numel +
self.resid_lambdas.numel() + self.x0_lambdas.numel() +
self.smear_gate.weight.numel() + self.smear_lambda.numel() + self.backout_lambda.numel())
self.smear_gate.weight.numel() + self.smear_lambda.numel() + self.backout_lambda.numel() +
(0 if self.attnres_queries is None else self.attnres_queries.numel()) +
(0 if self.attnres_alphas is None else self.attnres_alphas.numel()))
h, q, t = self.config.n_head, self.config.n_embd // self.config.n_head, self.config.sequence_len
# Sum attention FLOPs per layer, accounting for sliding window
attn_flops = 0
@ -354,7 +381,9 @@ class GPT(nn.Module):
value_embeds = sum(p.numel() for p in self.value_embeds.parameters())
lm_head = sum(p.numel() for p in self.lm_head.parameters())
transformer_matrices = sum(p.numel() for p in self.transformer.h.parameters())
scalars = self.resid_lambdas.numel() + self.x0_lambdas.numel() + self.smear_gate.weight.numel() + self.smear_lambda.numel() + self.backout_lambda.numel()
attnres = 0 if self.attnres_queries is None else self.attnres_queries.numel()
attnres_alphas = 0 if self.attnres_alphas is None else self.attnres_alphas.numel()
scalars = self.resid_lambdas.numel() + self.x0_lambdas.numel() + self.smear_gate.weight.numel() + self.smear_lambda.numel() + self.backout_lambda.numel() + attnres + attnres_alphas
total = wte + value_embeds + lm_head + transformer_matrices + scalars
assert total == sum(p.numel() for p in self.parameters()), "Parameter count mismatch"
return {
@ -366,7 +395,7 @@ class GPT(nn.Module):
'total': total,
}
def setup_optimizer(self, unembedding_lr=0.004, embedding_lr=0.2, matrix_lr=0.02, weight_decay=0.0, scalar_lr=0.5):
def setup_optimizer(self, unembedding_lr=0.004, embedding_lr=0.2, matrix_lr=0.02, weight_decay=0.0, scalar_lr=0.5, attnres_query_lr_mult=1.0):
model_dim = self.config.n_embd
ddp, rank, local_rank, world_size = get_dist_info()
@ -378,7 +407,9 @@ class GPT(nn.Module):
resid_params = [self.resid_lambdas]
x0_params = [self.x0_lambdas]
smear_params = [self.smear_gate.weight, self.smear_lambda, self.backout_lambda]
assert len(list(self.parameters())) == len(matrix_params) + len(embedding_params) + len(lm_head_params) + len(value_embeds_params) + len(resid_params) + len(x0_params) + len(smear_params)
attnres_params = [] if self.attnres_queries is None else [self.attnres_queries]
attnres_alpha_params = [] if self.attnres_alphas is None else [self.attnres_alphas]
assert len(list(self.parameters())) == len(matrix_params) + len(embedding_params) + len(lm_head_params) + len(value_embeds_params) + len(resid_params) + len(x0_params) + len(smear_params) + len(attnres_params) + len(attnres_alpha_params)
# Scale the LR for the AdamW parameters by ∝1/√dmodel (tuned for 768 dim model)
dmodel_lr_scale = (model_dim / 768) ** -0.5
@ -394,6 +425,16 @@ class GPT(nn.Module):
dict(kind='adamw', params=x0_params, lr=scalar_lr, betas=(0.96, 0.95), eps=1e-10, weight_decay=0.0), # higher beta1 for x0
dict(kind='adamw', params=smear_params, lr=0.2, betas=(0.8, 0.95), eps=1e-10, weight_decay=0.0),
]
if attnres_alpha_params:
param_groups.append(dict(
kind='adamw', params=attnres_alpha_params, lr=scalar_lr,
betas=(0.9, 0.95), eps=1e-10, weight_decay=0.0,
))
if attnres_params:
param_groups.append(dict(
kind='adamw', params=attnres_params, lr=scalar_lr * 0.1 * attnres_query_lr_mult,
betas=(0.8, 0.95), eps=1e-10, weight_decay=0.0,
))
# Muon groups (matrix params, grouped by shape for stacking)
for shape in sorted({p.shape for p in matrix_params}):
group_params = [p for p in matrix_params if p.shape == shape]
@ -448,10 +489,23 @@ class GPT(nn.Module):
n_layer = self.config.n_layer
backout_layer = n_layer // 2 # cache at halfway point
x_backout = None
use_attnres = self.config.residual_mode == "attnres_block"
completed_blocks = [x] if use_attnres else None
partial_block = x if use_attnres else None
for i, block in enumerate(self.transformer.h):
x = self.resid_lambdas[i] * x + self.x0_lambdas[i] * x0
ve = self.value_embeds[str(i)](idx).to(x.dtype) if str(i) in self.value_embeds else None
x = block(x, ve, cos_sin, self.window_sizes[i], kv_cache)
base = self.resid_lambdas[i] * x + self.x0_lambdas[i] * x0
if use_attnres:
block_start = (i % self.attnres_block_size) == 0
sources = completed_blocks + [base] if block_start else completed_blocks + [partial_block, base]
depth = attnres_mix(sources, self.attnres_queries[i])
alpha = torch.tanh(self.attnres_alphas[i]).to(dtype=x.dtype)
x = block(base + alpha * (depth - base), ve, cos_sin, self.window_sizes[i], kv_cache)
partial_block = x
if (i + 1) % self.attnres_block_size == 0 or i == n_layer - 1:
completed_blocks.append(partial_block)
else:
x = block(base, ve, cos_sin, self.window_sizes[i], kv_cache)
if i == backout_layer:
x_backout = x
# Subtract mid-layer residual to remove low-level features before logit projection

View File

@ -52,6 +52,9 @@ parser.add_argument("--aspect-ratio", type=int, default=64, help="model_dim = de
parser.add_argument("--head-dim", type=int, default=128, help="target head dimension for attention")
parser.add_argument("--max-seq-len", type=int, default=2048, help="max context length")
parser.add_argument("--window-pattern", type=str, default="SSSL", help="sliding window pattern tiled across layers: L=full, S=half context (e.g. 'SSL')")
parser.add_argument("--residual-mode", type=str, default="baseline", choices=["baseline", "attnres_block"], help="residual stream implementation")
parser.add_argument("--attnres-block-size", type=int, default=0, help="AttnRes block size in transformer blocks (0 = auto, roughly 8 blocks total)")
parser.add_argument("--attnres-query-lr-mult", type=float, default=1.0, help="LR multiplier for AttnRes pseudo-query vectors")
# Training horizon (only one used, in order of precedence)
parser.add_argument("--num-iterations", type=int, default=-1, help="explicit number of optimization steps (-1 = disable)")
parser.add_argument("--target-flops", type=float, default=-1.0, help="calculate num_iterations to reach target_flops (-1 = disable)")
@ -136,7 +139,8 @@ def build_model_meta(depth):
config = GPTConfig(
sequence_len=args.max_seq_len, vocab_size=vocab_size,
n_layer=depth, n_head=num_heads, n_kv_head=num_heads, n_embd=model_dim,
window_pattern=args.window_pattern,
window_pattern=args.window_pattern, residual_mode=args.residual_mode,
attnres_block_size=args.attnres_block_size,
)
with torch.device("meta"):
model_meta = GPT(config)
@ -309,6 +313,7 @@ optimizer = model.setup_optimizer(
unembedding_lr=args.unembedding_lr * batch_lr_scale,
embedding_lr=args.embedding_lr * batch_lr_scale,
scalar_lr=args.scalar_lr * batch_lr_scale,
attnres_query_lr_mult=args.attnres_query_lr_mult,
# Muon hyperparameters
matrix_lr=args.matrix_lr * batch_lr_scale,
weight_decay=weight_decay_scaled,