mirror of
https://github.com/karpathy/nanochat.git
synced 2026-01-22 19:34:17 +00:00
Add learnable lambdas that gate the residual connection and a skip connection to the input embeddings, solid bump to val_bpb
This commit is contained in:
parent
2c4473dd1b
commit
aa530cdad5
56
dev/LOG.md
56
dev/LOG.md
|
|
@ -4,6 +4,62 @@ A running summary documenting some experiments and findings. Started ~Jan 7 2026
|
|||
|
||||
---
|
||||
|
||||
## 2026-01-11: Per-Layer Residual Scalars (x0 & resid lambdas)
|
||||
|
||||
Cherry-picked an idea from modded-nanogpt around learnable per-layer residual connections.
|
||||
|
||||
### Changes Made
|
||||
|
||||
**1. x0_lambdas (x0 residual connections)**
|
||||
- Save initial normalized embedding as `x0` after `norm(wte(idx))`
|
||||
- At each layer, blend x0 back in: `x = resid_lambdas[i] * x + x0_lambdas[i] * x0`
|
||||
- Zero-initialized, so disabled at start; model learns which layers benefit from the shortcut
|
||||
- Provides direct path from embedding to deep layers, helps preserve token information
|
||||
|
||||
**2. resid_lambdas (residual stream scaling)**
|
||||
- Per-layer multiplicative scaling of the residual stream
|
||||
- Initialized to 1.0 (neutral, standard transformer behavior)
|
||||
- Allows model to learn to amplify/dampen residual at each layer
|
||||
|
||||
**3. DistAdamW small parameter handling**
|
||||
- Added support for parameters with < 1024 elements (like the scalar lambdas)
|
||||
- Small params use `all_reduce` instead of `reduce_scatter`/`all_gather`
|
||||
- Fixes crash when param shape isn't divisible by world_size
|
||||
|
||||
### Key Finding: Different LR Sensitivity
|
||||
|
||||
The two scalar types need very different learning rates:
|
||||
- **x0_lambdas (additive)**: Can use normal LR (~0.5). Adding a fraction of x0 is forgiving.
|
||||
- **resid_lambdas (multiplicative)**: Needs ~100x smaller LR (~0.005). Multiplying the residual compounds through layers.
|
||||
|
||||
Implementation: `resid_params` gets `scalar_lr * 0.01`, `x0_params` gets full `scalar_lr`.
|
||||
|
||||
### Experiment Results
|
||||
|
||||
Swept `--scalar_lr` (controlling x0_lambdas) at multiple depths:
|
||||
|
||||
| Depth | Baseline (disabled) | Best scalar_lr | Best val_bpb | Δ bpb |
|
||||
|-------|---------------------|----------------|--------------|-------|
|
||||
| d8 | 1.0885 | 0.20 | 1.0782 | -0.0103 |
|
||||
| d12 | 0.9770 | 0.60 | 0.9693 | -0.0077 |
|
||||
| d16 | 0.9059 | 0.20 | 0.9002 | -0.0057 |
|
||||
| d20 | 0.8565 | 0.10 | 0.8526 | -0.0039 |
|
||||
|
||||
**Observations:**
|
||||
- Consistent improvement across all model sizes
|
||||
- Optimal LR varies by depth; default of 0.5 is reasonable, but 0.6 is better for d12
|
||||
- Adding resid_lambdas (with 0.01x LR) gives small additional improvement over x0 alone
|
||||
|
||||
### Meta Device Footgun
|
||||
|
||||
Important lesson: `__init__` runs in meta device context, so any tensor values set there are fake. Must initialize actual values in `init_weights()`. Added docstring warning to `__init__`.
|
||||
|
||||
### Summary
|
||||
|
||||
Added `--scalar_lr` (default 0.5) controlling learnable per-layer scalars. The formula `x = resid_lambdas[i] * x + x0_lambdas[i] * x0` gives the model control over residual scaling and direct shortcuts to the initial embedding. Solid improvement with essentially no compute overhead.
|
||||
|
||||
---
|
||||
|
||||
## 2026-01-10: Muon Optimizer Upgrades & Cautious Weight Decay
|
||||
|
||||
Cherry-picked improvements from NorMuon (modded-nanogpt) into our simpler Muon implementation. Decided against using NorMuon directly due to hard-coded architecture assumptions (expects 32 params split 10 attn + 22 mlp), parameter labeling requirements, and complexity.
|
||||
|
|
|
|||
|
|
@ -16,23 +16,31 @@ class DistAdamW(torch.optim.Optimizer):
|
|||
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
|
||||
super().__init__(param_groups, defaults)
|
||||
|
||||
@torch.compile
|
||||
@torch.no_grad()
|
||||
def step(self):
|
||||
rank = dist.get_rank()
|
||||
world_size = dist.get_world_size()
|
||||
reduce_scatter_futures: list[torch.Future] = []
|
||||
all_reduce_futures: list[torch.Future] = []
|
||||
reduce_futures: list[torch.Future] = []
|
||||
gather_futures: list[torch.Future] = []
|
||||
grad_slices = []
|
||||
is_small = [] # track which params are small (use all_reduce) vs large (use reduce_scatter)
|
||||
|
||||
for group in self.param_groups:
|
||||
params: list[Tensor] = group["params"]
|
||||
for base_i in range(len(params)):
|
||||
assert params[base_i].shape[0] % world_size == 0, f"First dim of parameter shape {params[base_i].shape} must be divisible by world size {world_size}"
|
||||
grad = params[base_i].grad
|
||||
rank_size = grad.shape[0] // world_size
|
||||
grad_slice = torch.empty_like(grad[:rank_size])
|
||||
reduce_scatter_futures.append(dist.reduce_scatter_tensor(grad_slice, grad, op=dist.ReduceOp.AVG, async_op=True).get_future())
|
||||
grad_slices.append(grad_slice)
|
||||
for p in params:
|
||||
grad = p.grad
|
||||
# Small params: use all_reduce (no scatter/gather needed)
|
||||
if p.numel() < 1024:
|
||||
is_small.append(True)
|
||||
reduce_futures.append(dist.all_reduce(grad, op=dist.ReduceOp.AVG, async_op=True).get_future())
|
||||
grad_slices.append(grad)
|
||||
else:
|
||||
is_small.append(False)
|
||||
assert p.shape[0] % world_size == 0, f"First dim of parameter shape {p.shape} must be divisible by world size {world_size}"
|
||||
rank_size = grad.shape[0] // world_size
|
||||
grad_slice = torch.empty_like(grad[:rank_size])
|
||||
reduce_futures.append(dist.reduce_scatter_tensor(grad_slice, grad, op=dist.ReduceOp.AVG, async_op=True).get_future())
|
||||
grad_slices.append(grad_slice)
|
||||
|
||||
idx = 0
|
||||
for group in self.param_groups:
|
||||
|
|
@ -40,14 +48,19 @@ class DistAdamW(torch.optim.Optimizer):
|
|||
eps = group['eps']
|
||||
wd = group['weight_decay']
|
||||
params = group['params']
|
||||
for base in range(len(params)):
|
||||
reduce_scatter_futures[idx].wait()
|
||||
p = params[base]
|
||||
rank_size = p.shape[0] // world_size
|
||||
p_slice = p[rank * rank_size:(rank + 1) * rank_size]
|
||||
for p in params:
|
||||
reduce_futures[idx].wait()
|
||||
g_slice = grad_slices[idx]
|
||||
lr = group['lr'] * getattr(p, "lr_mul", 1.0)
|
||||
state = self.state[p]
|
||||
g_slice = grad_slices[idx]
|
||||
|
||||
# For small params, operate on full param; for large, operate on slice
|
||||
if is_small[idx]:
|
||||
p_slice = p
|
||||
else:
|
||||
rank_size = p.shape[0] // world_size
|
||||
p_slice = p[rank * rank_size:(rank + 1) * rank_size]
|
||||
|
||||
# State init
|
||||
if not state:
|
||||
state['step'] = torch.tensor(0, dtype=torch.int64, device=p.device)
|
||||
|
|
@ -72,6 +85,11 @@ class DistAdamW(torch.optim.Optimizer):
|
|||
step_size = lr / bias1
|
||||
update = exp_avg.div(denom).mul_(step_size)
|
||||
p_slice.add_(other=update, alpha=-1.0)
|
||||
|
||||
# Only large params need all_gather
|
||||
if not is_small[idx]:
|
||||
gather_futures.append(dist.all_gather_into_tensor(p, p_slice, async_op=True).get_future())
|
||||
idx += 1
|
||||
all_reduce_futures.append(dist.all_gather_into_tensor(p, p_slice, async_op=True).get_future())
|
||||
torch.futures.collect_all(all_reduce_futures).wait()
|
||||
|
||||
if gather_futures:
|
||||
torch.futures.collect_all(gather_futures).wait()
|
||||
|
|
|
|||
|
|
@ -134,6 +134,11 @@ class Block(nn.Module):
|
|||
|
||||
class GPT(nn.Module):
|
||||
def __init__(self, config, pad_vocab_size_to=64):
|
||||
"""
|
||||
NOTE a major footgun: this __init__ function runs in meta device context (!!)
|
||||
Therefore, any calculations inside here are shapes and dtypes only, no actual data.
|
||||
=> We actually initialize all data (parameters, buffers, etc.) in init_weights() instead.
|
||||
"""
|
||||
super().__init__()
|
||||
self.config = config
|
||||
# For DDP, we want vocab_size divisible by world_size. Also, there are potential performance benefits, see:
|
||||
|
|
@ -146,6 +151,12 @@ class GPT(nn.Module):
|
|||
"h": nn.ModuleList([Block(config, layer_idx) for layer_idx in range(config.n_layer)]),
|
||||
})
|
||||
self.lm_head = nn.Linear(config.n_embd, padded_vocab_size, bias=False)
|
||||
# Per-layer learnable scalars (inspired by modded-nanogpt)
|
||||
# resid_lambdas: scales the residual stream at each layer (init 1.0 = neutral)
|
||||
# x0_lambdas: blends initial embedding back in at each layer (init 0.0 = disabled)
|
||||
# Separate parameters so they can have different optimizer treatment
|
||||
self.resid_lambdas = nn.Parameter(torch.ones(config.n_layer)) # fake init, real init in init_weights()
|
||||
self.x0_lambdas = nn.Parameter(torch.zeros(config.n_layer)) # fake init, real init in init_weights()
|
||||
# To support meta device initialization, we init the rotary embeddings here, but it's just "fake" meta tensors only.
|
||||
# As for rotary_seq_len, these rotary embeddings are pretty small/cheap in memory,
|
||||
# so let's just over-compute them by 10X, but assert fail if we ever reach that amount.
|
||||
|
|
@ -186,6 +197,11 @@ class GPT(nn.Module):
|
|||
torch.nn.init.uniform_(block.mlp.c_fc.weight, -s, s)
|
||||
torch.nn.init.zeros_(block.mlp.c_proj.weight)
|
||||
|
||||
# Per-layer scalars
|
||||
with torch.no_grad():
|
||||
self.resid_lambdas.fill_(1.0) # 1.0 => typical residual connections at init
|
||||
self.x0_lambdas.fill_(0.0) # 0.0 => skip connection to input is disabled at init
|
||||
|
||||
# Rotary embeddings
|
||||
head_dim = self.config.n_embd // self.config.n_head
|
||||
cos, sin = self._precompute_rotary_embeddings(self.rotary_seq_len, head_dim)
|
||||
|
|
@ -244,21 +260,25 @@ class GPT(nn.Module):
|
|||
nparams = sum(p.numel() for p in self.parameters())
|
||||
return nparams
|
||||
|
||||
def setup_optimizers(self, unembedding_lr=0.004, embedding_lr=0.2, matrix_lr=0.02, weight_decay=0.0, adam_betas=(0.8, 0.95)):
|
||||
def setup_optimizers(self, unembedding_lr=0.004, embedding_lr=0.2, matrix_lr=0.02, weight_decay=0.0, adam_betas=(0.8, 0.95), scalar_lr=0.5):
|
||||
model_dim = self.config.n_embd
|
||||
ddp, rank, local_rank, world_size = get_dist_info()
|
||||
# Separate out all parameters into 3 groups (matrix, embedding, lm_head)
|
||||
# Separate out all parameters into 5 groups (matrix, embedding, lm_head, resid_lambdas, x0_lambdas)
|
||||
matrix_params = list(self.transformer.h.parameters())
|
||||
embedding_params = list(self.transformer.wte.parameters())
|
||||
lm_head_params = list(self.lm_head.parameters())
|
||||
assert len(list(self.parameters())) == len(matrix_params) + len(embedding_params) + len(lm_head_params)
|
||||
# Create the AdamW optimizer for the embedding and lm_head
|
||||
resid_params = [self.resid_lambdas]
|
||||
x0_params = [self.x0_lambdas]
|
||||
assert len(list(self.parameters())) == len(matrix_params) + len(embedding_params) + len(lm_head_params) + len(resid_params) + len(x0_params)
|
||||
# Create the AdamW optimizer for the embedding, lm_head, and per-layer scalars
|
||||
# Scale the LR for the AdamW parameters by ∝1/√dmodel (having tuned the LRs for 768 dim model)
|
||||
dmodel_lr_scale = (model_dim / 768) ** -0.5
|
||||
print0(f"Scaling the LR for the AdamW parameters ∝1/√({model_dim}/768) = {dmodel_lr_scale:.6f}")
|
||||
adam_groups = [
|
||||
dict(params=lm_head_params, lr=unembedding_lr * dmodel_lr_scale),
|
||||
dict(params=embedding_params, lr=embedding_lr * dmodel_lr_scale),
|
||||
dict(params=resid_params, lr=scalar_lr * 0.01), # these are a lot more sensitive because they accumulate in the residual stream
|
||||
dict(params=x0_params, lr=scalar_lr),
|
||||
]
|
||||
adamw_kwargs = dict(betas=adam_betas, eps=1e-10, weight_decay=0.0) # NOTE: weight decay is hardcoded to 0.0 for AdamW, only used in Muon
|
||||
AdamWFactory = DistAdamW if ddp else partial(torch.optim.AdamW, fused=True)
|
||||
|
|
@ -288,7 +308,9 @@ class GPT(nn.Module):
|
|||
# Forward the trunk of the Transformer
|
||||
x = self.transformer.wte(idx)
|
||||
x = norm(x)
|
||||
for block in self.transformer.h:
|
||||
x0 = x # save initial normalized embedding for x0 residual
|
||||
for i, block in enumerate(self.transformer.h):
|
||||
x = self.resid_lambdas[i] * x + self.x0_lambdas[i] * x0
|
||||
x = block(x, cos_sin, kv_cache)
|
||||
x = norm(x)
|
||||
|
||||
|
|
|
|||
|
|
@ -53,6 +53,7 @@ parser.add_argument("--embedding_lr", type=float, default=0.3, help="learning ra
|
|||
parser.add_argument("--unembedding_lr", type=float, default=0.004, help="learning rate for unembedding parameters (Adam)")
|
||||
parser.add_argument("--weight_decay", type=float, default=0.2, help="cautious weight decay for the Muon optimizer (for weights)")
|
||||
parser.add_argument("--matrix_lr", type=float, default=0.02, help="learning rate for matrix parameters (Muon)")
|
||||
parser.add_argument("--scalar_lr", type=float, default=0.5, help="learning rate for scalars (resid_lambdas, x0_lambdas)")
|
||||
parser.add_argument("--adam_beta1", type=float, default=0.8, help="Adam beta1 for embedding/unembedding")
|
||||
parser.add_argument("--adam_beta2", type=float, default=0.95, help="Adam beta2 for embedding/unembedding")
|
||||
parser.add_argument("--warmup_ratio", type=float, default=0.0, help="ratio of iterations for LR warmup")
|
||||
|
|
@ -195,6 +196,7 @@ optimizers = model.setup_optimizers(
|
|||
matrix_lr=args.matrix_lr * batch_lr_scale,
|
||||
weight_decay=weight_decay_scaled,
|
||||
adam_betas=adam_betas,
|
||||
scalar_lr=args.scalar_lr * batch_lr_scale,
|
||||
)
|
||||
adamw_optimizer, muon_optimizer = optimizers
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user