diff --git a/dev/LOG.md b/dev/LOG.md index 13fc08e..ee1e82e 100644 --- a/dev/LOG.md +++ b/dev/LOG.md @@ -4,6 +4,62 @@ A running summary documenting some experiments and findings. Started ~Jan 7 2026 --- +## 2026-01-11: Per-Layer Residual Scalars (x0 & resid lambdas) + +Cherry-picked an idea from modded-nanogpt around learnable per-layer residual connections. + +### Changes Made + +**1. x0_lambdas (x0 residual connections)** +- Save initial normalized embedding as `x0` after `norm(wte(idx))` +- At each layer, blend x0 back in: `x = resid_lambdas[i] * x + x0_lambdas[i] * x0` +- Zero-initialized, so disabled at start; model learns which layers benefit from the shortcut +- Provides direct path from embedding to deep layers, helps preserve token information + +**2. resid_lambdas (residual stream scaling)** +- Per-layer multiplicative scaling of the residual stream +- Initialized to 1.0 (neutral, standard transformer behavior) +- Allows model to learn to amplify/dampen residual at each layer + +**3. DistAdamW small parameter handling** +- Added support for parameters with < 1024 elements (like the scalar lambdas) +- Small params use `all_reduce` instead of `reduce_scatter`/`all_gather` +- Fixes crash when param shape isn't divisible by world_size + +### Key Finding: Different LR Sensitivity + +The two scalar types need very different learning rates: +- **x0_lambdas (additive)**: Can use normal LR (~0.5). Adding a fraction of x0 is forgiving. +- **resid_lambdas (multiplicative)**: Needs ~100x smaller LR (~0.005). Multiplying the residual compounds through layers. + +Implementation: `resid_params` gets `scalar_lr * 0.01`, `x0_params` gets full `scalar_lr`. + +### Experiment Results + +Swept `--scalar_lr` (controlling x0_lambdas) at multiple depths: + +| Depth | Baseline (disabled) | Best scalar_lr | Best val_bpb | Δ bpb | +|-------|---------------------|----------------|--------------|-------| +| d8 | 1.0885 | 0.20 | 1.0782 | -0.0103 | +| d12 | 0.9770 | 0.60 | 0.9693 | -0.0077 | +| d16 | 0.9059 | 0.20 | 0.9002 | -0.0057 | +| d20 | 0.8565 | 0.10 | 0.8526 | -0.0039 | + +**Observations:** +- Consistent improvement across all model sizes +- Optimal LR varies by depth; default of 0.5 is reasonable, but 0.6 is better for d12 +- Adding resid_lambdas (with 0.01x LR) gives small additional improvement over x0 alone + +### Meta Device Footgun + +Important lesson: `__init__` runs in meta device context, so any tensor values set there are fake. Must initialize actual values in `init_weights()`. Added docstring warning to `__init__`. + +### Summary + +Added `--scalar_lr` (default 0.5) controlling learnable per-layer scalars. The formula `x = resid_lambdas[i] * x + x0_lambdas[i] * x0` gives the model control over residual scaling and direct shortcuts to the initial embedding. Solid improvement with essentially no compute overhead. + +--- + ## 2026-01-10: Muon Optimizer Upgrades & Cautious Weight Decay Cherry-picked improvements from NorMuon (modded-nanogpt) into our simpler Muon implementation. Decided against using NorMuon directly due to hard-coded architecture assumptions (expects 32 params split 10 attn + 22 mlp), parameter labeling requirements, and complexity. diff --git a/nanochat/adamw.py b/nanochat/adamw.py index 0b97ae2..48945b3 100644 --- a/nanochat/adamw.py +++ b/nanochat/adamw.py @@ -16,23 +16,31 @@ class DistAdamW(torch.optim.Optimizer): defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) super().__init__(param_groups, defaults) - @torch.compile @torch.no_grad() def step(self): rank = dist.get_rank() world_size = dist.get_world_size() - reduce_scatter_futures: list[torch.Future] = [] - all_reduce_futures: list[torch.Future] = [] + reduce_futures: list[torch.Future] = [] + gather_futures: list[torch.Future] = [] grad_slices = [] + is_small = [] # track which params are small (use all_reduce) vs large (use reduce_scatter) + for group in self.param_groups: params: list[Tensor] = group["params"] - for base_i in range(len(params)): - assert params[base_i].shape[0] % world_size == 0, f"First dim of parameter shape {params[base_i].shape} must be divisible by world size {world_size}" - grad = params[base_i].grad - rank_size = grad.shape[0] // world_size - grad_slice = torch.empty_like(grad[:rank_size]) - reduce_scatter_futures.append(dist.reduce_scatter_tensor(grad_slice, grad, op=dist.ReduceOp.AVG, async_op=True).get_future()) - grad_slices.append(grad_slice) + for p in params: + grad = p.grad + # Small params: use all_reduce (no scatter/gather needed) + if p.numel() < 1024: + is_small.append(True) + reduce_futures.append(dist.all_reduce(grad, op=dist.ReduceOp.AVG, async_op=True).get_future()) + grad_slices.append(grad) + else: + is_small.append(False) + assert p.shape[0] % world_size == 0, f"First dim of parameter shape {p.shape} must be divisible by world size {world_size}" + rank_size = grad.shape[0] // world_size + grad_slice = torch.empty_like(grad[:rank_size]) + reduce_futures.append(dist.reduce_scatter_tensor(grad_slice, grad, op=dist.ReduceOp.AVG, async_op=True).get_future()) + grad_slices.append(grad_slice) idx = 0 for group in self.param_groups: @@ -40,14 +48,19 @@ class DistAdamW(torch.optim.Optimizer): eps = group['eps'] wd = group['weight_decay'] params = group['params'] - for base in range(len(params)): - reduce_scatter_futures[idx].wait() - p = params[base] - rank_size = p.shape[0] // world_size - p_slice = p[rank * rank_size:(rank + 1) * rank_size] + for p in params: + reduce_futures[idx].wait() + g_slice = grad_slices[idx] lr = group['lr'] * getattr(p, "lr_mul", 1.0) state = self.state[p] - g_slice = grad_slices[idx] + + # For small params, operate on full param; for large, operate on slice + if is_small[idx]: + p_slice = p + else: + rank_size = p.shape[0] // world_size + p_slice = p[rank * rank_size:(rank + 1) * rank_size] + # State init if not state: state['step'] = torch.tensor(0, dtype=torch.int64, device=p.device) @@ -72,6 +85,11 @@ class DistAdamW(torch.optim.Optimizer): step_size = lr / bias1 update = exp_avg.div(denom).mul_(step_size) p_slice.add_(other=update, alpha=-1.0) + + # Only large params need all_gather + if not is_small[idx]: + gather_futures.append(dist.all_gather_into_tensor(p, p_slice, async_op=True).get_future()) idx += 1 - all_reduce_futures.append(dist.all_gather_into_tensor(p, p_slice, async_op=True).get_future()) - torch.futures.collect_all(all_reduce_futures).wait() + + if gather_futures: + torch.futures.collect_all(gather_futures).wait() diff --git a/nanochat/gpt.py b/nanochat/gpt.py index 2ffdc50..6f4556a 100644 --- a/nanochat/gpt.py +++ b/nanochat/gpt.py @@ -134,6 +134,11 @@ class Block(nn.Module): class GPT(nn.Module): def __init__(self, config, pad_vocab_size_to=64): + """ + NOTE a major footgun: this __init__ function runs in meta device context (!!) + Therefore, any calculations inside here are shapes and dtypes only, no actual data. + => We actually initialize all data (parameters, buffers, etc.) in init_weights() instead. + """ super().__init__() self.config = config # For DDP, we want vocab_size divisible by world_size. Also, there are potential performance benefits, see: @@ -146,6 +151,12 @@ class GPT(nn.Module): "h": nn.ModuleList([Block(config, layer_idx) for layer_idx in range(config.n_layer)]), }) self.lm_head = nn.Linear(config.n_embd, padded_vocab_size, bias=False) + # Per-layer learnable scalars (inspired by modded-nanogpt) + # resid_lambdas: scales the residual stream at each layer (init 1.0 = neutral) + # x0_lambdas: blends initial embedding back in at each layer (init 0.0 = disabled) + # Separate parameters so they can have different optimizer treatment + self.resid_lambdas = nn.Parameter(torch.ones(config.n_layer)) # fake init, real init in init_weights() + self.x0_lambdas = nn.Parameter(torch.zeros(config.n_layer)) # fake init, real init in init_weights() # To support meta device initialization, we init the rotary embeddings here, but it's just "fake" meta tensors only. # As for rotary_seq_len, these rotary embeddings are pretty small/cheap in memory, # so let's just over-compute them by 10X, but assert fail if we ever reach that amount. @@ -186,6 +197,11 @@ class GPT(nn.Module): torch.nn.init.uniform_(block.mlp.c_fc.weight, -s, s) torch.nn.init.zeros_(block.mlp.c_proj.weight) + # Per-layer scalars + with torch.no_grad(): + self.resid_lambdas.fill_(1.0) # 1.0 => typical residual connections at init + self.x0_lambdas.fill_(0.0) # 0.0 => skip connection to input is disabled at init + # Rotary embeddings head_dim = self.config.n_embd // self.config.n_head cos, sin = self._precompute_rotary_embeddings(self.rotary_seq_len, head_dim) @@ -244,21 +260,25 @@ class GPT(nn.Module): nparams = sum(p.numel() for p in self.parameters()) return nparams - def setup_optimizers(self, unembedding_lr=0.004, embedding_lr=0.2, matrix_lr=0.02, weight_decay=0.0, adam_betas=(0.8, 0.95)): + def setup_optimizers(self, unembedding_lr=0.004, embedding_lr=0.2, matrix_lr=0.02, weight_decay=0.0, adam_betas=(0.8, 0.95), scalar_lr=0.5): model_dim = self.config.n_embd ddp, rank, local_rank, world_size = get_dist_info() - # Separate out all parameters into 3 groups (matrix, embedding, lm_head) + # Separate out all parameters into 5 groups (matrix, embedding, lm_head, resid_lambdas, x0_lambdas) matrix_params = list(self.transformer.h.parameters()) embedding_params = list(self.transformer.wte.parameters()) lm_head_params = list(self.lm_head.parameters()) - assert len(list(self.parameters())) == len(matrix_params) + len(embedding_params) + len(lm_head_params) - # Create the AdamW optimizer for the embedding and lm_head + resid_params = [self.resid_lambdas] + x0_params = [self.x0_lambdas] + assert len(list(self.parameters())) == len(matrix_params) + len(embedding_params) + len(lm_head_params) + len(resid_params) + len(x0_params) + # Create the AdamW optimizer for the embedding, lm_head, and per-layer scalars # Scale the LR for the AdamW parameters by ∝1/√dmodel (having tuned the LRs for 768 dim model) dmodel_lr_scale = (model_dim / 768) ** -0.5 print0(f"Scaling the LR for the AdamW parameters ∝1/√({model_dim}/768) = {dmodel_lr_scale:.6f}") adam_groups = [ dict(params=lm_head_params, lr=unembedding_lr * dmodel_lr_scale), dict(params=embedding_params, lr=embedding_lr * dmodel_lr_scale), + dict(params=resid_params, lr=scalar_lr * 0.01), # these are a lot more sensitive because they accumulate in the residual stream + dict(params=x0_params, lr=scalar_lr), ] adamw_kwargs = dict(betas=adam_betas, eps=1e-10, weight_decay=0.0) # NOTE: weight decay is hardcoded to 0.0 for AdamW, only used in Muon AdamWFactory = DistAdamW if ddp else partial(torch.optim.AdamW, fused=True) @@ -288,7 +308,9 @@ class GPT(nn.Module): # Forward the trunk of the Transformer x = self.transformer.wte(idx) x = norm(x) - for block in self.transformer.h: + x0 = x # save initial normalized embedding for x0 residual + for i, block in enumerate(self.transformer.h): + x = self.resid_lambdas[i] * x + self.x0_lambdas[i] * x0 x = block(x, cos_sin, kv_cache) x = norm(x) diff --git a/scripts/base_train.py b/scripts/base_train.py index 84d44bf..3327451 100644 --- a/scripts/base_train.py +++ b/scripts/base_train.py @@ -53,6 +53,7 @@ parser.add_argument("--embedding_lr", type=float, default=0.3, help="learning ra parser.add_argument("--unembedding_lr", type=float, default=0.004, help="learning rate for unembedding parameters (Adam)") parser.add_argument("--weight_decay", type=float, default=0.2, help="cautious weight decay for the Muon optimizer (for weights)") parser.add_argument("--matrix_lr", type=float, default=0.02, help="learning rate for matrix parameters (Muon)") +parser.add_argument("--scalar_lr", type=float, default=0.5, help="learning rate for scalars (resid_lambdas, x0_lambdas)") parser.add_argument("--adam_beta1", type=float, default=0.8, help="Adam beta1 for embedding/unembedding") parser.add_argument("--adam_beta2", type=float, default=0.95, help="Adam beta2 for embedding/unembedding") parser.add_argument("--warmup_ratio", type=float, default=0.0, help="ratio of iterations for LR warmup") @@ -195,6 +196,7 @@ optimizers = model.setup_optimizers( matrix_lr=args.matrix_lr * batch_lr_scale, weight_decay=weight_decay_scaled, adam_betas=adam_betas, + scalar_lr=args.scalar_lr * batch_lr_scale, ) adamw_optimizer, muon_optimizer = optimizers