diff --git a/dev/attention_residuals_d12_0_500.png b/dev/attention_residuals_d12_0_500.png new file mode 100644 index 0000000..84b3271 Binary files /dev/null and b/dev/attention_residuals_d12_0_500.png differ diff --git a/dev/attention_residuals_d12_500_1000.png b/dev/attention_residuals_d12_500_1000.png new file mode 100644 index 0000000..9d5510f Binary files /dev/null and b/dev/attention_residuals_d12_500_1000.png differ diff --git a/dev/attention_residuals_local_results.md b/dev/attention_residuals_local_results.md new file mode 100644 index 0000000..4ccf69f --- /dev/null +++ b/dev/attention_residuals_local_results.md @@ -0,0 +1,42 @@ +# Attention Residuals local d12 results + +This change adds a gated AttnRes path on top of the current nanochat residual path instead of replacing it. + +Key model change: +- Baseline path stays `base = resid_lambdas[i] * x + x0_lambdas[i] * x0` +- AttnRes path uses the same `base`, then applies a zero-init correction `base + alpha * (depth - base)` +- `alpha` starts at `0`, so AttnRes is exactly equal to baseline at initialization + +Local experiment setup: +- `depth=12` +- `aspect_ratio=32` +- `head_dim=64` +- `max_seq_len=256` +- `device_batch_size=16` +- `total_batch_size=4096` +- `device_type=mps` +- `window_pattern=L` + +## 0-500 steps + +![d12 AttnRes vs baseline, 0-500](attention_residuals_d12_0_500.png) + +Checkpoint values: +- `100`: base `1.979283`, attnres `1.960904` +- `200`: base `1.855803`, attnres `1.844442` +- `300`: base `1.786364`, attnres `1.770760` +- `400`: base `1.735797`, attnres `1.720742` +- `500`: base `1.714498`, attnres `1.701136` + +## 500-1000 steps + +This figure is a checkpoint continuation from step `500` to step `1000`. It is not identical to a fresh `0-1000` run planned from the start, because the training horizon was extended at resume time. + +![d12 AttnRes vs baseline, 500-1000](attention_residuals_d12_500_1000.png) + +Checkpoint values: +- `600`: base `1.689065`, attnres `1.681095` +- `700`: base `1.652204`, attnres `1.648315` +- `800`: base `1.627924`, attnres `1.622582` +- `900`: base `1.609432`, attnres `1.602892` +- `1000`: base `1.602437`, attnres `1.597053` diff --git a/nanochat/checkpoint_manager.py b/nanochat/checkpoint_manager.py index f71524e..3c79e32 100644 --- a/nanochat/checkpoint_manager.py +++ b/nanochat/checkpoint_manager.py @@ -38,6 +38,12 @@ def _patch_missing_keys(model_data, model_config): if "x0_lambdas" not in model_data: model_data["x0_lambdas"] = torch.zeros(n_layer) log0(f"Patching missing x0_lambdas in model data to 0.0") + if model_config.residual_mode == "attnres_block" and "attnres_queries" not in model_data: + model_data["attnres_queries"] = torch.zeros(n_layer, model_config.n_embd) + log0(f"Patching missing attnres_queries in model data to 0.0") + if model_config.residual_mode == "attnres_block" and "attnres_alphas" not in model_data: + model_data["attnres_alphas"] = torch.zeros(n_layer) + log0(f"Patching missing attnres_alphas in model data to 0.0") def save_checkpoint(checkpoint_dir, step, model_data, optimizer_data, meta_data, rank=0): if rank == 0: diff --git a/nanochat/gpt.py b/nanochat/gpt.py index 0b822e4..0d299a5 100644 --- a/nanochat/gpt.py +++ b/nanochat/gpt.py @@ -37,11 +37,25 @@ class GPTConfig: # Characters: L=long (full context), S=short (quarter context) # Examples: "L"=all full context, "SL"=alternating, "SSL"=two short then one long window_pattern: str = "SSSL" + residual_mode: str = "baseline" + attnres_block_size: int = 0 def norm(x): return F.rms_norm(x, (x.size(-1),)) # note that this will run in bf16, seems ok +def attnres_mix(sources, query): + """Attention Residuals over depth using a single learned pseudo-query.""" + if len(sources) == 1: + return sources[0] + stacked = torch.stack(sources, dim=0) + keys = norm(stacked) + query = F.rms_norm(query.to(dtype=stacked.dtype), (query.numel(),)) + logits = torch.einsum("d, n b t d -> n b t", query, keys) + logits = logits * (stacked.size(-1) ** -0.5) + weights = logits.softmax(0) + return torch.einsum("n b t, n b t d -> b t d", weights, stacked) + class Linear(nn.Linear): """nn.Linear that casts weights to match input dtype in forward. Replaces autocast: master weights stay fp32 for optimizer precision, @@ -160,6 +174,7 @@ class GPT(nn.Module): """ super().__init__() self.config = config + assert config.residual_mode in {"baseline", "attnres_block"}, f"Invalid residual_mode: {config.residual_mode}" # Compute per-layer window sizes for sliding window attention # window_size is (left, right) tuple: (-1, 0) for full context, (N, 0) for sliding window self.window_sizes = self._compute_window_sizes(config) @@ -188,6 +203,12 @@ class GPT(nn.Module): head_dim = config.n_embd // config.n_head kv_dim = config.n_kv_head * head_dim self.value_embeds = nn.ModuleDict({str(i): nn.Embedding(padded_vocab_size, kv_dim) for i in range(config.n_layer) if has_ve(i, config.n_layer)}) + self.attnres_queries = None + self.attnres_alphas = None + if config.residual_mode == "attnres_block": + self.attnres_queries = nn.Parameter(torch.zeros(config.n_layer, config.n_embd)) + self.attnres_alphas = nn.Parameter(torch.zeros(config.n_layer)) + self.attnres_block_size = config.attnres_block_size if config.attnres_block_size > 0 else (config.n_layer + 7) // 8 # To support meta device initialization, we init the rotary embeddings here, but it's just "fake" meta tensors only. # As for rotary_seq_len, these rotary embeddings are pretty small/cheap in memory, # so let's just over-compute them by 10X, but assert fail if we ever reach that amount. @@ -241,6 +262,10 @@ class GPT(nn.Module): # Value embeddings (init like c_v: uniform with same std) for ve in self.value_embeds.values(): torch.nn.init.uniform_(ve.weight, -s, s) + if self.attnres_queries is not None: + torch.nn.init.zeros_(self.attnres_queries) + if self.attnres_alphas is not None: + torch.nn.init.zeros_(self.attnres_alphas) # Gate weights init with small positive values so gates start slightly above neutral for block in self.transformer.h: @@ -326,7 +351,9 @@ class GPT(nn.Module): value_embeds_numel = sum(ve.weight.numel() for ve in self.value_embeds.values()) nparams_exclude = (self.transformer.wte.weight.numel() + value_embeds_numel + self.resid_lambdas.numel() + self.x0_lambdas.numel() + - self.smear_gate.weight.numel() + self.smear_lambda.numel() + self.backout_lambda.numel()) + self.smear_gate.weight.numel() + self.smear_lambda.numel() + self.backout_lambda.numel() + + (0 if self.attnres_queries is None else self.attnres_queries.numel()) + + (0 if self.attnres_alphas is None else self.attnres_alphas.numel())) h, q, t = self.config.n_head, self.config.n_embd // self.config.n_head, self.config.sequence_len # Sum attention FLOPs per layer, accounting for sliding window attn_flops = 0 @@ -354,7 +381,9 @@ class GPT(nn.Module): value_embeds = sum(p.numel() for p in self.value_embeds.parameters()) lm_head = sum(p.numel() for p in self.lm_head.parameters()) transformer_matrices = sum(p.numel() for p in self.transformer.h.parameters()) - scalars = self.resid_lambdas.numel() + self.x0_lambdas.numel() + self.smear_gate.weight.numel() + self.smear_lambda.numel() + self.backout_lambda.numel() + attnres = 0 if self.attnres_queries is None else self.attnres_queries.numel() + attnres_alphas = 0 if self.attnres_alphas is None else self.attnres_alphas.numel() + scalars = self.resid_lambdas.numel() + self.x0_lambdas.numel() + self.smear_gate.weight.numel() + self.smear_lambda.numel() + self.backout_lambda.numel() + attnres + attnres_alphas total = wte + value_embeds + lm_head + transformer_matrices + scalars assert total == sum(p.numel() for p in self.parameters()), "Parameter count mismatch" return { @@ -366,7 +395,7 @@ class GPT(nn.Module): 'total': total, } - def setup_optimizer(self, unembedding_lr=0.004, embedding_lr=0.2, matrix_lr=0.02, weight_decay=0.0, scalar_lr=0.5): + def setup_optimizer(self, unembedding_lr=0.004, embedding_lr=0.2, matrix_lr=0.02, weight_decay=0.0, scalar_lr=0.5, attnres_query_lr_mult=1.0): model_dim = self.config.n_embd ddp, rank, local_rank, world_size = get_dist_info() @@ -378,7 +407,9 @@ class GPT(nn.Module): resid_params = [self.resid_lambdas] x0_params = [self.x0_lambdas] smear_params = [self.smear_gate.weight, self.smear_lambda, self.backout_lambda] - assert len(list(self.parameters())) == len(matrix_params) + len(embedding_params) + len(lm_head_params) + len(value_embeds_params) + len(resid_params) + len(x0_params) + len(smear_params) + attnres_params = [] if self.attnres_queries is None else [self.attnres_queries] + attnres_alpha_params = [] if self.attnres_alphas is None else [self.attnres_alphas] + assert len(list(self.parameters())) == len(matrix_params) + len(embedding_params) + len(lm_head_params) + len(value_embeds_params) + len(resid_params) + len(x0_params) + len(smear_params) + len(attnres_params) + len(attnres_alpha_params) # Scale the LR for the AdamW parameters by ∝1/√dmodel (tuned for 768 dim model) dmodel_lr_scale = (model_dim / 768) ** -0.5 @@ -394,6 +425,16 @@ class GPT(nn.Module): dict(kind='adamw', params=x0_params, lr=scalar_lr, betas=(0.96, 0.95), eps=1e-10, weight_decay=0.0), # higher beta1 for x0 dict(kind='adamw', params=smear_params, lr=0.2, betas=(0.8, 0.95), eps=1e-10, weight_decay=0.0), ] + if attnres_alpha_params: + param_groups.append(dict( + kind='adamw', params=attnres_alpha_params, lr=scalar_lr, + betas=(0.9, 0.95), eps=1e-10, weight_decay=0.0, + )) + if attnres_params: + param_groups.append(dict( + kind='adamw', params=attnres_params, lr=scalar_lr * 0.1 * attnres_query_lr_mult, + betas=(0.8, 0.95), eps=1e-10, weight_decay=0.0, + )) # Muon groups (matrix params, grouped by shape for stacking) for shape in sorted({p.shape for p in matrix_params}): group_params = [p for p in matrix_params if p.shape == shape] @@ -448,10 +489,23 @@ class GPT(nn.Module): n_layer = self.config.n_layer backout_layer = n_layer // 2 # cache at halfway point x_backout = None + use_attnres = self.config.residual_mode == "attnres_block" + completed_blocks = [x] if use_attnres else None + partial_block = x if use_attnres else None for i, block in enumerate(self.transformer.h): - x = self.resid_lambdas[i] * x + self.x0_lambdas[i] * x0 ve = self.value_embeds[str(i)](idx).to(x.dtype) if str(i) in self.value_embeds else None - x = block(x, ve, cos_sin, self.window_sizes[i], kv_cache) + base = self.resid_lambdas[i] * x + self.x0_lambdas[i] * x0 + if use_attnres: + block_start = (i % self.attnres_block_size) == 0 + sources = completed_blocks + [base] if block_start else completed_blocks + [partial_block, base] + depth = attnres_mix(sources, self.attnres_queries[i]) + alpha = torch.tanh(self.attnres_alphas[i]).to(dtype=x.dtype) + x = block(base + alpha * (depth - base), ve, cos_sin, self.window_sizes[i], kv_cache) + partial_block = x + if (i + 1) % self.attnres_block_size == 0 or i == n_layer - 1: + completed_blocks.append(partial_block) + else: + x = block(base, ve, cos_sin, self.window_sizes[i], kv_cache) if i == backout_layer: x_backout = x # Subtract mid-layer residual to remove low-level features before logit projection diff --git a/scripts/base_train.py b/scripts/base_train.py index c7683c9..40c4e24 100644 --- a/scripts/base_train.py +++ b/scripts/base_train.py @@ -52,6 +52,9 @@ parser.add_argument("--aspect-ratio", type=int, default=64, help="model_dim = de parser.add_argument("--head-dim", type=int, default=128, help="target head dimension for attention") parser.add_argument("--max-seq-len", type=int, default=2048, help="max context length") parser.add_argument("--window-pattern", type=str, default="SSSL", help="sliding window pattern tiled across layers: L=full, S=half context (e.g. 'SSL')") +parser.add_argument("--residual-mode", type=str, default="baseline", choices=["baseline", "attnres_block"], help="residual stream implementation") +parser.add_argument("--attnres-block-size", type=int, default=0, help="AttnRes block size in transformer blocks (0 = auto, roughly 8 blocks total)") +parser.add_argument("--attnres-query-lr-mult", type=float, default=1.0, help="LR multiplier for AttnRes pseudo-query vectors") # Training horizon (only one used, in order of precedence) parser.add_argument("--num-iterations", type=int, default=-1, help="explicit number of optimization steps (-1 = disable)") parser.add_argument("--target-flops", type=float, default=-1.0, help="calculate num_iterations to reach target_flops (-1 = disable)") @@ -136,7 +139,8 @@ def build_model_meta(depth): config = GPTConfig( sequence_len=args.max_seq_len, vocab_size=vocab_size, n_layer=depth, n_head=num_heads, n_kv_head=num_heads, n_embd=model_dim, - window_pattern=args.window_pattern, + window_pattern=args.window_pattern, residual_mode=args.residual_mode, + attnres_block_size=args.attnres_block_size, ) with torch.device("meta"): model_meta = GPT(config) @@ -309,6 +313,7 @@ optimizer = model.setup_optimizer( unembedding_lr=args.unembedding_lr * batch_lr_scale, embedding_lr=args.embedding_lr * batch_lr_scale, scalar_lr=args.scalar_lr * batch_lr_scale, + attnres_query_lr_mult=args.attnres_query_lr_mult, # Muon hyperparameters matrix_lr=args.matrix_lr * batch_lr_scale, weight_decay=weight_decay_scaled,