mirror of
https://github.com/karpathy/nanochat.git
synced 2026-04-01 13:15:21 +00:00
Merge 9d62998d3a into c0dbf1f3ff
This commit is contained in:
commit
31f1f0e04b
BIN
dev/attention_residuals_d12_0_500.png
Normal file
BIN
dev/attention_residuals_d12_0_500.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 84 KiB |
BIN
dev/attention_residuals_d12_500_1000.png
Normal file
BIN
dev/attention_residuals_d12_500_1000.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 130 KiB |
42
dev/attention_residuals_local_results.md
Normal file
42
dev/attention_residuals_local_results.md
Normal file
|
|
@ -0,0 +1,42 @@
|
|||
# Attention Residuals local d12 results
|
||||
|
||||
This change adds a gated AttnRes path on top of the current nanochat residual path instead of replacing it.
|
||||
|
||||
Key model change:
|
||||
- Baseline path stays `base = resid_lambdas[i] * x + x0_lambdas[i] * x0`
|
||||
- AttnRes path uses the same `base`, then applies a zero-init correction `base + alpha * (depth - base)`
|
||||
- `alpha` starts at `0`, so AttnRes is exactly equal to baseline at initialization
|
||||
|
||||
Local experiment setup:
|
||||
- `depth=12`
|
||||
- `aspect_ratio=32`
|
||||
- `head_dim=64`
|
||||
- `max_seq_len=256`
|
||||
- `device_batch_size=16`
|
||||
- `total_batch_size=4096`
|
||||
- `device_type=mps`
|
||||
- `window_pattern=L`
|
||||
|
||||
## 0-500 steps
|
||||
|
||||

|
||||
|
||||
Checkpoint values:
|
||||
- `100`: base `1.979283`, attnres `1.960904`
|
||||
- `200`: base `1.855803`, attnres `1.844442`
|
||||
- `300`: base `1.786364`, attnres `1.770760`
|
||||
- `400`: base `1.735797`, attnres `1.720742`
|
||||
- `500`: base `1.714498`, attnres `1.701136`
|
||||
|
||||
## 500-1000 steps
|
||||
|
||||
This figure is a checkpoint continuation from step `500` to step `1000`. It is not identical to a fresh `0-1000` run planned from the start, because the training horizon was extended at resume time.
|
||||
|
||||

|
||||
|
||||
Checkpoint values:
|
||||
- `600`: base `1.689065`, attnres `1.681095`
|
||||
- `700`: base `1.652204`, attnres `1.648315`
|
||||
- `800`: base `1.627924`, attnres `1.622582`
|
||||
- `900`: base `1.609432`, attnres `1.602892`
|
||||
- `1000`: base `1.602437`, attnres `1.597053`
|
||||
|
|
@ -38,6 +38,12 @@ def _patch_missing_keys(model_data, model_config):
|
|||
if "x0_lambdas" not in model_data:
|
||||
model_data["x0_lambdas"] = torch.zeros(n_layer)
|
||||
log0(f"Patching missing x0_lambdas in model data to 0.0")
|
||||
if model_config.residual_mode == "attnres_block" and "attnres_queries" not in model_data:
|
||||
model_data["attnres_queries"] = torch.zeros(n_layer, model_config.n_embd)
|
||||
log0(f"Patching missing attnres_queries in model data to 0.0")
|
||||
if model_config.residual_mode == "attnres_block" and "attnres_alphas" not in model_data:
|
||||
model_data["attnres_alphas"] = torch.zeros(n_layer)
|
||||
log0(f"Patching missing attnres_alphas in model data to 0.0")
|
||||
|
||||
def save_checkpoint(checkpoint_dir, step, model_data, optimizer_data, meta_data, rank=0):
|
||||
if rank == 0:
|
||||
|
|
|
|||
|
|
@ -37,11 +37,25 @@ class GPTConfig:
|
|||
# Characters: L=long (full context), S=short (quarter context)
|
||||
# Examples: "L"=all full context, "SL"=alternating, "SSL"=two short then one long
|
||||
window_pattern: str = "SSSL"
|
||||
residual_mode: str = "baseline"
|
||||
attnres_block_size: int = 0
|
||||
|
||||
|
||||
def norm(x):
|
||||
return F.rms_norm(x, (x.size(-1),)) # note that this will run in bf16, seems ok
|
||||
|
||||
def attnres_mix(sources, query):
|
||||
"""Attention Residuals over depth using a single learned pseudo-query."""
|
||||
if len(sources) == 1:
|
||||
return sources[0]
|
||||
stacked = torch.stack(sources, dim=0)
|
||||
keys = norm(stacked)
|
||||
query = F.rms_norm(query.to(dtype=stacked.dtype), (query.numel(),))
|
||||
logits = torch.einsum("d, n b t d -> n b t", query, keys)
|
||||
logits = logits * (stacked.size(-1) ** -0.5)
|
||||
weights = logits.softmax(0)
|
||||
return torch.einsum("n b t, n b t d -> b t d", weights, stacked)
|
||||
|
||||
class Linear(nn.Linear):
|
||||
"""nn.Linear that casts weights to match input dtype in forward.
|
||||
Replaces autocast: master weights stay fp32 for optimizer precision,
|
||||
|
|
@ -160,6 +174,7 @@ class GPT(nn.Module):
|
|||
"""
|
||||
super().__init__()
|
||||
self.config = config
|
||||
assert config.residual_mode in {"baseline", "attnres_block"}, f"Invalid residual_mode: {config.residual_mode}"
|
||||
# Compute per-layer window sizes for sliding window attention
|
||||
# window_size is (left, right) tuple: (-1, 0) for full context, (N, 0) for sliding window
|
||||
self.window_sizes = self._compute_window_sizes(config)
|
||||
|
|
@ -188,6 +203,12 @@ class GPT(nn.Module):
|
|||
head_dim = config.n_embd // config.n_head
|
||||
kv_dim = config.n_kv_head * head_dim
|
||||
self.value_embeds = nn.ModuleDict({str(i): nn.Embedding(padded_vocab_size, kv_dim) for i in range(config.n_layer) if has_ve(i, config.n_layer)})
|
||||
self.attnres_queries = None
|
||||
self.attnres_alphas = None
|
||||
if config.residual_mode == "attnres_block":
|
||||
self.attnres_queries = nn.Parameter(torch.zeros(config.n_layer, config.n_embd))
|
||||
self.attnres_alphas = nn.Parameter(torch.zeros(config.n_layer))
|
||||
self.attnres_block_size = config.attnres_block_size if config.attnres_block_size > 0 else (config.n_layer + 7) // 8
|
||||
# To support meta device initialization, we init the rotary embeddings here, but it's just "fake" meta tensors only.
|
||||
# As for rotary_seq_len, these rotary embeddings are pretty small/cheap in memory,
|
||||
# so let's just over-compute them by 10X, but assert fail if we ever reach that amount.
|
||||
|
|
@ -241,6 +262,10 @@ class GPT(nn.Module):
|
|||
# Value embeddings (init like c_v: uniform with same std)
|
||||
for ve in self.value_embeds.values():
|
||||
torch.nn.init.uniform_(ve.weight, -s, s)
|
||||
if self.attnres_queries is not None:
|
||||
torch.nn.init.zeros_(self.attnres_queries)
|
||||
if self.attnres_alphas is not None:
|
||||
torch.nn.init.zeros_(self.attnres_alphas)
|
||||
|
||||
# Gate weights init with small positive values so gates start slightly above neutral
|
||||
for block in self.transformer.h:
|
||||
|
|
@ -326,7 +351,9 @@ class GPT(nn.Module):
|
|||
value_embeds_numel = sum(ve.weight.numel() for ve in self.value_embeds.values())
|
||||
nparams_exclude = (self.transformer.wte.weight.numel() + value_embeds_numel +
|
||||
self.resid_lambdas.numel() + self.x0_lambdas.numel() +
|
||||
self.smear_gate.weight.numel() + self.smear_lambda.numel() + self.backout_lambda.numel())
|
||||
self.smear_gate.weight.numel() + self.smear_lambda.numel() + self.backout_lambda.numel() +
|
||||
(0 if self.attnres_queries is None else self.attnres_queries.numel()) +
|
||||
(0 if self.attnres_alphas is None else self.attnres_alphas.numel()))
|
||||
h, q, t = self.config.n_head, self.config.n_embd // self.config.n_head, self.config.sequence_len
|
||||
# Sum attention FLOPs per layer, accounting for sliding window
|
||||
attn_flops = 0
|
||||
|
|
@ -354,7 +381,9 @@ class GPT(nn.Module):
|
|||
value_embeds = sum(p.numel() for p in self.value_embeds.parameters())
|
||||
lm_head = sum(p.numel() for p in self.lm_head.parameters())
|
||||
transformer_matrices = sum(p.numel() for p in self.transformer.h.parameters())
|
||||
scalars = self.resid_lambdas.numel() + self.x0_lambdas.numel() + self.smear_gate.weight.numel() + self.smear_lambda.numel() + self.backout_lambda.numel()
|
||||
attnres = 0 if self.attnres_queries is None else self.attnres_queries.numel()
|
||||
attnres_alphas = 0 if self.attnres_alphas is None else self.attnres_alphas.numel()
|
||||
scalars = self.resid_lambdas.numel() + self.x0_lambdas.numel() + self.smear_gate.weight.numel() + self.smear_lambda.numel() + self.backout_lambda.numel() + attnres + attnres_alphas
|
||||
total = wte + value_embeds + lm_head + transformer_matrices + scalars
|
||||
assert total == sum(p.numel() for p in self.parameters()), "Parameter count mismatch"
|
||||
return {
|
||||
|
|
@ -366,7 +395,7 @@ class GPT(nn.Module):
|
|||
'total': total,
|
||||
}
|
||||
|
||||
def setup_optimizer(self, unembedding_lr=0.004, embedding_lr=0.2, matrix_lr=0.02, weight_decay=0.0, scalar_lr=0.5):
|
||||
def setup_optimizer(self, unembedding_lr=0.004, embedding_lr=0.2, matrix_lr=0.02, weight_decay=0.0, scalar_lr=0.5, attnres_query_lr_mult=1.0):
|
||||
model_dim = self.config.n_embd
|
||||
ddp, rank, local_rank, world_size = get_dist_info()
|
||||
|
||||
|
|
@ -378,7 +407,9 @@ class GPT(nn.Module):
|
|||
resid_params = [self.resid_lambdas]
|
||||
x0_params = [self.x0_lambdas]
|
||||
smear_params = [self.smear_gate.weight, self.smear_lambda, self.backout_lambda]
|
||||
assert len(list(self.parameters())) == len(matrix_params) + len(embedding_params) + len(lm_head_params) + len(value_embeds_params) + len(resid_params) + len(x0_params) + len(smear_params)
|
||||
attnres_params = [] if self.attnres_queries is None else [self.attnres_queries]
|
||||
attnres_alpha_params = [] if self.attnres_alphas is None else [self.attnres_alphas]
|
||||
assert len(list(self.parameters())) == len(matrix_params) + len(embedding_params) + len(lm_head_params) + len(value_embeds_params) + len(resid_params) + len(x0_params) + len(smear_params) + len(attnres_params) + len(attnres_alpha_params)
|
||||
|
||||
# Scale the LR for the AdamW parameters by ∝1/√dmodel (tuned for 768 dim model)
|
||||
dmodel_lr_scale = (model_dim / 768) ** -0.5
|
||||
|
|
@ -394,6 +425,16 @@ class GPT(nn.Module):
|
|||
dict(kind='adamw', params=x0_params, lr=scalar_lr, betas=(0.96, 0.95), eps=1e-10, weight_decay=0.0), # higher beta1 for x0
|
||||
dict(kind='adamw', params=smear_params, lr=0.2, betas=(0.8, 0.95), eps=1e-10, weight_decay=0.0),
|
||||
]
|
||||
if attnres_alpha_params:
|
||||
param_groups.append(dict(
|
||||
kind='adamw', params=attnres_alpha_params, lr=scalar_lr,
|
||||
betas=(0.9, 0.95), eps=1e-10, weight_decay=0.0,
|
||||
))
|
||||
if attnres_params:
|
||||
param_groups.append(dict(
|
||||
kind='adamw', params=attnres_params, lr=scalar_lr * 0.1 * attnres_query_lr_mult,
|
||||
betas=(0.8, 0.95), eps=1e-10, weight_decay=0.0,
|
||||
))
|
||||
# Muon groups (matrix params, grouped by shape for stacking)
|
||||
for shape in sorted({p.shape for p in matrix_params}):
|
||||
group_params = [p for p in matrix_params if p.shape == shape]
|
||||
|
|
@ -448,10 +489,23 @@ class GPT(nn.Module):
|
|||
n_layer = self.config.n_layer
|
||||
backout_layer = n_layer // 2 # cache at halfway point
|
||||
x_backout = None
|
||||
use_attnres = self.config.residual_mode == "attnres_block"
|
||||
completed_blocks = [x] if use_attnres else None
|
||||
partial_block = x if use_attnres else None
|
||||
for i, block in enumerate(self.transformer.h):
|
||||
x = self.resid_lambdas[i] * x + self.x0_lambdas[i] * x0
|
||||
ve = self.value_embeds[str(i)](idx).to(x.dtype) if str(i) in self.value_embeds else None
|
||||
x = block(x, ve, cos_sin, self.window_sizes[i], kv_cache)
|
||||
base = self.resid_lambdas[i] * x + self.x0_lambdas[i] * x0
|
||||
if use_attnres:
|
||||
block_start = (i % self.attnres_block_size) == 0
|
||||
sources = completed_blocks + [base] if block_start else completed_blocks + [partial_block, base]
|
||||
depth = attnres_mix(sources, self.attnres_queries[i])
|
||||
alpha = torch.tanh(self.attnres_alphas[i]).to(dtype=x.dtype)
|
||||
x = block(base + alpha * (depth - base), ve, cos_sin, self.window_sizes[i], kv_cache)
|
||||
partial_block = x
|
||||
if (i + 1) % self.attnres_block_size == 0 or i == n_layer - 1:
|
||||
completed_blocks.append(partial_block)
|
||||
else:
|
||||
x = block(base, ve, cos_sin, self.window_sizes[i], kv_cache)
|
||||
if i == backout_layer:
|
||||
x_backout = x
|
||||
# Subtract mid-layer residual to remove low-level features before logit projection
|
||||
|
|
|
|||
|
|
@ -52,6 +52,9 @@ parser.add_argument("--aspect-ratio", type=int, default=64, help="model_dim = de
|
|||
parser.add_argument("--head-dim", type=int, default=128, help="target head dimension for attention")
|
||||
parser.add_argument("--max-seq-len", type=int, default=2048, help="max context length")
|
||||
parser.add_argument("--window-pattern", type=str, default="SSSL", help="sliding window pattern tiled across layers: L=full, S=half context (e.g. 'SSL')")
|
||||
parser.add_argument("--residual-mode", type=str, default="baseline", choices=["baseline", "attnres_block"], help="residual stream implementation")
|
||||
parser.add_argument("--attnres-block-size", type=int, default=0, help="AttnRes block size in transformer blocks (0 = auto, roughly 8 blocks total)")
|
||||
parser.add_argument("--attnres-query-lr-mult", type=float, default=1.0, help="LR multiplier for AttnRes pseudo-query vectors")
|
||||
# Training horizon (only one used, in order of precedence)
|
||||
parser.add_argument("--num-iterations", type=int, default=-1, help="explicit number of optimization steps (-1 = disable)")
|
||||
parser.add_argument("--target-flops", type=float, default=-1.0, help="calculate num_iterations to reach target_flops (-1 = disable)")
|
||||
|
|
@ -136,7 +139,8 @@ def build_model_meta(depth):
|
|||
config = GPTConfig(
|
||||
sequence_len=args.max_seq_len, vocab_size=vocab_size,
|
||||
n_layer=depth, n_head=num_heads, n_kv_head=num_heads, n_embd=model_dim,
|
||||
window_pattern=args.window_pattern,
|
||||
window_pattern=args.window_pattern, residual_mode=args.residual_mode,
|
||||
attnres_block_size=args.attnres_block_size,
|
||||
)
|
||||
with torch.device("meta"):
|
||||
model_meta = GPT(config)
|
||||
|
|
@ -309,6 +313,7 @@ optimizer = model.setup_optimizer(
|
|||
unembedding_lr=args.unembedding_lr * batch_lr_scale,
|
||||
embedding_lr=args.embedding_lr * batch_lr_scale,
|
||||
scalar_lr=args.scalar_lr * batch_lr_scale,
|
||||
attnres_query_lr_mult=args.attnres_query_lr_mult,
|
||||
# Muon hyperparameters
|
||||
matrix_lr=args.matrix_lr * batch_lr_scale,
|
||||
weight_decay=weight_decay_scaled,
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user