Merge branch 'master' into master

This commit is contained in:
兰天游 2026-01-12 10:49:55 +08:00 committed by GitHub
commit 8d89db3195
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 586 additions and 198 deletions

View File

@ -4,6 +4,170 @@ A running summary documenting some experiments and findings. Started ~Jan 7 2026
---
## 2026-01-11: Sliding Window Attention
Added configurable sliding window attention, inspired by GPT-3's alternating short/long pattern.
**Pattern string configuration:**
- New `--window_pattern` CLI arg and `GPTConfig.window_pattern` field
- Pattern is tiled across layers (e.g., `SSSL` for 20 layers → `SSSLSSSLSSSLSSSLSSSL`)
- Final layer always forced to L (full context) regardless of pattern
- Short window = `sequence_len // 2`
- Long window = `sequence_len` (full context)
- All previous models so far have been simply `L` and checkpoint loading is modified accordingly to fill in this param for old models, see `_patch_missing_config_keys`
Quick experiments showed `SSSL` (every 4th layer is long) works well - provides a good balance between compute savings and model quality. This is now the default.
---
## 2026-01-11: Flash Attention 3 Integration
Replaced PyTorch's `scaled_dot_product_attention` (FA2) with Flash Attention 3 for training and inference.
### Changes Made
**1. FA3 via `kernels` package**
- Official FA3 is "beta" and requires building from source (painful)
- Using `kernels` package from HuggingFace Hub: `get_kernel('varunneal/flash-attention-3')`
- Loads pre-built wheels, works out of the box on H100
**2. Simplified attention code**
- FA3 uses `(B, T, H, D)` layout matching our projection output directly - no transpose needed
- Training: `flash_attn.flash_attn_func(q, k, v, causal=True)`
- Inference: `flash_attn.flash_attn_with_kvcache()` handles all cache cases in one call
- Removed 3 separate FA2 code paths (training, single-token, chunk inference)
- GQA handled automatically when n_kv_heads < n_heads
**3. Rewrote KVCache for FA3**
- Old format: `(num_layers, 2, B, H, T, D)` combined tensor
- New format: separate `k_cache` and `v_cache` of shape `(num_layers, B, T, H, D)`
- FA3 updates cache in-place during `flash_attn_with_kvcache`
- Position tracked via `cache_seqlens` tensor (int32, per batch element)
- Simpler API: `get_layer_cache()`, `advance()`, `reset()`, `prefill()`
### Results
- **~9% improvement in tok/sec** during training out of the box
- Benchmarks showed FA3 is 2x faster than FA2 at realistic training sizes (batch=32, seq=2048)
- FA3 supports sliding window via `window_size=(left, 0)`, which is huge and expected to give further improvements. This is ready to tune but keeping full context for now.
---
## 2026-01-11: Per-Layer Residual Scalars (x0 & resid lambdas)
Cherry-picked an idea from modded-nanogpt around learnable per-layer residual connections.
### Changes Made
**1. x0_lambdas (x0 residual connections)**
- Save initial normalized embedding as `x0` after `norm(wte(idx))`
- At each layer, blend x0 back in: `x = resid_lambdas[i] * x + x0_lambdas[i] * x0`
- Zero-initialized, so disabled at start; model learns which layers benefit from the shortcut
- Provides direct path from embedding to deep layers, helps preserve token information
**2. resid_lambdas (residual stream scaling)**
- Per-layer multiplicative scaling of the residual stream
- Initialized to 1.0 (neutral, standard transformer behavior)
- Allows model to learn to amplify/dampen residual at each layer
**3. DistAdamW small parameter handling**
- Added support for parameters with < 1024 elements (like the scalar lambdas)
- Small params use `all_reduce` instead of `reduce_scatter`/`all_gather`
- Fixes crash when param shape isn't divisible by world_size
### Key Finding: Different LR Sensitivity
The two scalar types need very different learning rates:
- **x0_lambdas (additive)**: Can use normal LR (~0.5). Adding a fraction of x0 is forgiving.
- **resid_lambdas (multiplicative)**: Needs ~100x smaller LR (~0.005). Multiplying the residual compounds through layers.
Implementation: `resid_params` gets `scalar_lr * 0.01`, `x0_params` gets full `scalar_lr`.
### Experiment Results
Swept `--scalar_lr` (controlling x0_lambdas) at multiple depths:
| Depth | Baseline (disabled) | Best scalar_lr | Best val_bpb | Δ bpb |
|-------|---------------------|----------------|--------------|-------|
| d8 | 1.0885 | 0.20 | 1.0782 | -0.0103 |
| d12 | 0.9770 | 0.60 | 0.9693 | -0.0077 |
| d16 | 0.9059 | 0.20 | 0.9002 | -0.0057 |
| d20 | 0.8565 | 0.10 | 0.8526 | -0.0039 |
**Observations:**
- Consistent improvement across all model sizes
- Optimal LR varies by depth; default of 0.5 is reasonable, but 0.6 is better for d12
- Adding resid_lambdas (with 0.01x LR) gives small additional improvement over x0 alone
### Meta Device Footgun
Important lesson: `__init__` runs in meta device context, so any tensor values set there are fake. Must initialize actual values in `init_weights()`. Added docstring warning to `__init__`.
### Summary
Added `--scalar_lr` (default 0.5) controlling learnable per-layer scalars. The formula `x = resid_lambdas[i] * x + x0_lambdas[i] * x0` gives the model control over residual scaling and direct shortcuts to the initial embedding. Solid improvement with essentially no compute overhead.
---
## 2026-01-10: Muon Optimizer Upgrades & Cautious Weight Decay
Cherry-picked improvements from NorMuon (modded-nanogpt) into our simpler Muon implementation. Decided against using NorMuon directly due to hard-coded architecture assumptions (expects 32 params split 10 attn + 22 mlp), parameter labeling requirements, and complexity.
### Changes Made
**1. Polar Express Orthogonalization**
- Replaced Newton-Schulz iteration with "Polar Express Sign Method" from [arxiv.org/pdf/2505.16932](https://arxiv.org/pdf/2505.16932)
- Uses 5 different coefficient tuples (one per iteration) instead of fixed coefficients
- Both methods kept in code for easy comparison (`zeropower_via_polar_express` vs `zeropower_via_newtonschulz5`)
- **Result:** No dramatic/noticeable difference in training, but keeping the new Polar Express as default.
**2. Variance Reduction (NorMuon-style)**
- Added low-rank variance estimator similar to Adafactor ([arxiv.org/pdf/2510.05491](https://arxiv.org/pdf/2510.05491))
- Maintains `second_momentum_buffer` with shape `[rows, 1]` or `[1, cols]` (whichever is smaller)
- Normalizes updates based on running per-row/col variance estimate (beta2=0.95)
- Memory overhead: ~1/max(rows, cols) per param, negligible
- **Result:** Led to a very small improvement, kept and enabled by default.
**3. Cautious Weight Decay**
- Only decays weights where `update * weight >= 0` (same sign) from [arxiv.org/abs/2411.16085](https://arxiv.org/abs/2411.16085)
- Standard WD always pulls toward zero; cautious WD skips decay when gradient is pushing weight away from zero
- **Implementation note:** Had to inline the logic rather than use a separate `@torch.compile` function. Passing changing float values (like `weight_decay` during scheduling) as function arguments triggers recompilation. Reading from `group["weight_decay"]` inside the step avoids this.
- **Result:** Solid improvements, especially the cautious version was better than standard wd.
- Now defaults to ON for Muon via the `weight_decay` param. AdamW still has no weight decay and is hardcoded to 0 weight decay, might try to re-tune this later.
**4. Weight decay schedule**
- Added a linear schedule to weight decay that is default on from 1.0 to 0.0 (i.e. start with max weight decay in the beginning of training, them ramp to 0 by the end). Worked better than a static setting in experiments. (modded-nanogpt has the same schedule but it is imlpemented in a more confusing way by multiplying twice by the learning rate, which is already wired up to a decay schedule).
### Weight Decay Scaling Experiments
Swept weight decay values at d8, d12, d16, d20 to find optimal values and scaling law.
**Optimal Values Found:**
| Depth | Width (channels) | Optimal WD |
|-------|------------------|------------|
| d8 | 512 | ~0.40 |
| d12 | 768 | ~0.22 |
| d16 | 1024 | ~0.10 |
| d20 | 1280 | ~0.08 |
**Scaling Law:**
- Fit power law: `WD = k / channels^α` in log-log space
- Found α ≈ 1.97 (approximately 2), meaning WD ∝ 1/width²
**Practical Formula:**
```
WD_target = WD_reference × (d_reference / d_target)²
```
Example: If d12 optimal is 0.22, then d20 optimal ≈ 0.22 × (12/20)² ≈ 0.08
**Reference:** Moonlight paper uses fixed WD=0.1 for their 15B MoE model. Our experiments indicated a scaling law where the optimal WD changed with depth, so we go along with the empirical scaling law.
### Summary
Muon was changed to use Polar Express, added Adafactor-style variance reduction, and cautious weight decay with schedule that ramps linearly to zero. All of these changes follow modded-nanogpt repo, but all of them were also validated piece by piece to yield improvements in nanochat with the exception of the Polar Express change which was in the noise. This is default on and configurable with `--weight_decay`, using simply 0.2 and ∝ 1/width² scaling. The kwarg `--weight_decay` is therefore changing as of this change. It used to configure AdamW via standard weight decay and now it becomes exclusively used in Muon (AdamW is hardcoded to 0.0), and it is scaled based on depth.
---
## 2026-01-08: exp_grad_clip - Gradient Clipping
**Hypothesis:** Gradient clipping may be unnecessary overhead. Tested L2 norm clipping at various thresholds (0.25, 0.5, 1.0, 2.0) and elementwise clipping.
@ -18,6 +182,4 @@ A running summary documenting some experiments and findings. Started ~Jan 7 2026
**Observartion:** modded-nanogpt does not appear to clip either right now.
**Recommendation:** Disable by default (`--grad_clip=0.0`). The code naturally produces well-behaved gradients.
---
**Summary:** Deleted all grad-clip code paths. The code naturally produces well-behaved gradients. This improves a bit of MFU because we don't have to calculate and sync grad norms.

View File

@ -16,23 +16,31 @@ class DistAdamW(torch.optim.Optimizer):
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
super().__init__(param_groups, defaults)
@torch.compile
@torch.no_grad()
def step(self):
rank = dist.get_rank()
world_size = dist.get_world_size()
reduce_scatter_futures: list[torch.Future] = []
all_reduce_futures: list[torch.Future] = []
reduce_futures: list[torch.Future] = []
gather_futures: list[torch.Future] = []
grad_slices = []
is_small = [] # track which params are small (use all_reduce) vs large (use reduce_scatter)
for group in self.param_groups:
params: list[Tensor] = group["params"]
for base_i in range(len(params)):
assert params[base_i].shape[0] % world_size == 0, f"First dim of parameter shape {params[base_i].shape} must be divisible by world size {world_size}"
grad = params[base_i].grad
rank_size = grad.shape[0] // world_size
grad_slice = torch.empty_like(grad[:rank_size])
reduce_scatter_futures.append(dist.reduce_scatter_tensor(grad_slice, grad, op=dist.ReduceOp.AVG, async_op=True).get_future())
grad_slices.append(grad_slice)
for p in params:
grad = p.grad
# Small params: use all_reduce (no scatter/gather needed)
if p.numel() < 1024:
is_small.append(True)
reduce_futures.append(dist.all_reduce(grad, op=dist.ReduceOp.AVG, async_op=True).get_future())
grad_slices.append(grad)
else:
is_small.append(False)
assert p.shape[0] % world_size == 0, f"First dim of parameter shape {p.shape} must be divisible by world size {world_size}"
rank_size = grad.shape[0] // world_size
grad_slice = torch.empty_like(grad[:rank_size])
reduce_futures.append(dist.reduce_scatter_tensor(grad_slice, grad, op=dist.ReduceOp.AVG, async_op=True).get_future())
grad_slices.append(grad_slice)
idx = 0
for group in self.param_groups:
@ -40,14 +48,19 @@ class DistAdamW(torch.optim.Optimizer):
eps = group['eps']
wd = group['weight_decay']
params = group['params']
for base in range(len(params)):
reduce_scatter_futures[idx].wait()
p = params[base]
rank_size = p.shape[0] // world_size
p_slice = p[rank * rank_size:(rank + 1) * rank_size]
for p in params:
reduce_futures[idx].wait()
g_slice = grad_slices[idx]
lr = group['lr'] * getattr(p, "lr_mul", 1.0)
state = self.state[p]
g_slice = grad_slices[idx]
# For small params, operate on full param; for large, operate on slice
if is_small[idx]:
p_slice = p
else:
rank_size = p.shape[0] // world_size
p_slice = p[rank * rank_size:(rank + 1) * rank_size]
# State init
if not state:
state['step'] = torch.tensor(0, dtype=torch.int64, device=p.device)
@ -72,6 +85,11 @@ class DistAdamW(torch.optim.Optimizer):
step_size = lr / bias1
update = exp_avg.div(denom).mul_(step_size)
p_slice.add_(other=update, alpha=-1.0)
# Only large params need all_gather
if not is_small[idx]:
gather_futures.append(dist.all_gather_into_tensor(p, p_slice, async_op=True).get_future())
idx += 1
all_reduce_futures.append(dist.all_gather_into_tensor(p, p_slice, async_op=True).get_future())
torch.futures.collect_all(all_reduce_futures).wait()
if gather_futures:
torch.futures.collect_all(gather_futures).wait()

View File

@ -20,6 +20,22 @@ def log0(message):
if int(os.environ.get('RANK', 0)) == 0:
logger.info(message)
def _patch_missing_config_keys(model_config_kwargs):
"""Add default values for new config keys missing in old checkpoints."""
# Old models were trained with full context (no sliding window)
if "window_pattern" not in model_config_kwargs:
model_config_kwargs["window_pattern"] = "L"
def _patch_missing_keys(model_data, model_config):
"""Add default values for new parameters that may be missing in old checkpoints."""
n_layer = model_config.n_layer
# resid_lambdas defaults to 1.0 (identity scaling)
if "resid_lambdas" not in model_data:
model_data["resid_lambdas"] = torch.ones(n_layer)
# x0_lambdas defaults to 0.0 (disabled)
if "x0_lambdas" not in model_data:
model_data["x0_lambdas"] = torch.zeros(n_layer)
def save_checkpoint(checkpoint_dir, step, model_data, optimizer_data, meta_data, rank=0):
if rank == 0:
os.makedirs(checkpoint_dir, exist_ok=True)
@ -74,8 +90,10 @@ def build_model(checkpoint_dir, step, device, phase):
# Hack: fix torch compile issue, which prepends all keys with _orig_mod.
model_data = {k.removeprefix("_orig_mod."): v for k, v in model_data.items()}
model_config_kwargs = meta_data["model_config"]
_patch_missing_config_keys(model_config_kwargs)
log0(f"Building model with config: {model_config_kwargs}")
model_config = GPTConfig(**model_config_kwargs)
_patch_missing_keys(model_data, model_config)
with torch.device("meta"):
model = GPT(model_config)
# Load the model state

View File

@ -82,83 +82,54 @@ def use_calculator(expr):
# -----------------------------------------------------------------------------
class KVCache:
"""
Works hand-in-hand with the GPT model to maintain the KV cache.
Note that the .pos advances automatically after the last layer of the Transformer inserts.
KV Cache designed for Flash Attention 3's flash_attn_with_kvcache API.
Key differences from FA2-style cache:
- Tensors are (B, T, H, D) not (B, H, T, D)
- FA3 updates the cache in-place during flash_attn_with_kvcache
- Position tracked per batch element via cache_seqlens tensor
"""
def __init__(self, batch_size, num_heads, seq_len, head_dim, num_layers):
# Each of K/V is of shape (B, H, T, D) and we have one per layer of the Transformer.
self.kv_shape = (num_layers, 2, batch_size, num_heads, seq_len, head_dim)
self.kv_cache = None
self.pos = 0 # current position in time in the cache
def __init__(self, batch_size, num_heads, seq_len, head_dim, num_layers, device, dtype=torch.bfloat16):
self.batch_size = batch_size
self.max_seq_len = seq_len
self.n_layers = num_layers
self.n_heads = num_heads
self.head_dim = head_dim
# Pre-allocate cache tensors: (n_layers, B, T, H, D)
self.k_cache = torch.zeros(num_layers, batch_size, seq_len, num_heads, head_dim, device=device, dtype=dtype)
self.v_cache = torch.zeros(num_layers, batch_size, seq_len, num_heads, head_dim, device=device, dtype=dtype)
# Current sequence length per batch element (FA3 needs int32)
self.cache_seqlens = torch.zeros(batch_size, dtype=torch.int32, device=device)
def reset(self):
self.pos = 0
"""Reset cache to empty state."""
self.cache_seqlens.zero_()
def get_pos(self):
return self.pos
"""Get current position (assumes all batch elements at same position)."""
return self.cache_seqlens[0].item()
def get_layer_cache(self, layer_idx):
"""Return (k_cache, v_cache) views for a specific layer."""
return self.k_cache[layer_idx], self.v_cache[layer_idx]
def advance(self, num_tokens):
"""Advance the cache position by num_tokens."""
self.cache_seqlens += num_tokens
def prefill(self, other):
"""
Prefill given another KV cache. Optionally expand along batch dim.
This is used when we do batch 1 prefill and then want to generate
multiple samples in parallel from there.
Copy cached KV from another cache into this one.
Used when we do batch=1 prefill and then want to generate multiple samples in parallel.
"""
# 1) validate the shapes
assert self.kv_cache is None, "Cannot prefill a non-empty KV cache"
assert other.kv_cache is not None, "Cannot prefill with a None KV cache"
# Extract dimensions explicitly
self_layers, self_kv, self_batch, self_heads, self_seq, self_head_dim = self.kv_shape
other_layers, other_kv, other_batch, other_heads, other_seq, other_head_dim = other.kv_shape
# Validate dimensions
assert self_layers == other_layers, f"Layer count mismatch: {self_layers} != {other_layers}"
assert self_kv == other_kv, f"K/V dimension mismatch: {self_kv} != {other_kv}"
assert self_heads == other_heads, f"Head count mismatch: {self_heads} != {other_heads}"
assert self_head_dim == other_head_dim, f"Head dim mismatch: {self_head_dim} != {other_head_dim}"
# Batch size can be expanded (other can be 1, self can be larger)
assert self_batch == other_batch or other_batch == 1, f"Batch size mismatch: {self_batch} vs {other_batch} (other must be 1 or equal)"
# Sequence length: self must be longer than other
assert self_seq >= other_seq, f"Sequence length mismatch: {self_seq} < {other_seq}"
# 2) initialize the cache
dtype, device = other.kv_cache.dtype, other.kv_cache.device
self.kv_cache = torch.empty(self.kv_shape, dtype=dtype, device=device)
# 3) copy the data over
self.kv_cache[:, :, :, :, :other.pos, :] = other.kv_cache
# 4) update the pos
self.pos = other.pos
def insert_kv(self, layer_idx, k, v):
# Lazy initialize the cache here because we need to know the dtype/device
if self.kv_cache is None:
self.kv_cache = torch.empty(self.kv_shape, dtype=k.dtype, device=k.device)
# Insert new keys/values to the cache and return the full cache so far
B, H, T_add, D = k.size()
t0, t1 = self.pos, self.pos + T_add
# Dynamically grow the cache if needed
if t1 > self.kv_cache.size(4):
t_needed = t1 + 1024 # as much as we need plus buffer of 1024
t_needed = (t_needed + 1023) & ~1023 # then round up to the nearest multiple of 1024
additional_shape = list(self.kv_cache.shape)
additional_shape[4] = t_needed - self.kv_cache.size(4)
additional_cache = torch.empty(additional_shape, dtype=k.dtype, device=k.device)
self.kv_cache = torch.cat([self.kv_cache, additional_cache], dim=4).contiguous()
self.kv_shape = self.kv_cache.shape
# Insert k, v into the cache
self.kv_cache[layer_idx, 0, :, :, t0:t1, :] = k
self.kv_cache[layer_idx, 1, :, :, t0:t1, :] = v
# Return the full cached keys/values up to current position (as a view)
key_view = self.kv_cache[layer_idx, 0, :, :, :t1, :]
value_view = self.kv_cache[layer_idx, 1, :, :, :t1, :]
# Increment pos after the last layer of the Transformer processes
if layer_idx == self.kv_cache.size(0) - 1:
self.pos = t1
return key_view, value_view
assert self.get_pos() == 0, "Cannot prefill a non-empty KV cache"
assert self.n_layers == other.n_layers and self.n_heads == other.n_heads and self.head_dim == other.head_dim
assert self.max_seq_len >= other.max_seq_len
other_pos = other.get_pos()
self.k_cache[:, :, :other_pos, :, :] = other.k_cache[:, :, :other_pos, :, :]
self.v_cache[:, :, :other_pos, :, :] = other.v_cache[:, :, :other_pos, :, :]
self.cache_seqlens.fill_(other_pos)
# -----------------------------------------------------------------------------
@torch.inference_mode()
@ -219,6 +190,7 @@ class Engine:
kv_cache_prefill = KVCache(
batch_size=1,
seq_len=len(tokens),
device=device,
**kv_model_kwargs,
)
ids = torch.tensor([tokens], dtype=torch.long, device=device)
@ -230,6 +202,7 @@ class Engine:
kv_cache_decode = KVCache(
batch_size=num_samples,
seq_len=kv_length_hint,
device=device,
**kv_model_kwargs,
)
kv_cache_decode.prefill(kv_cache_prefill)

View File

@ -9,10 +9,9 @@ Notable features:
- no learnable params in rmsnorm
- no bias in linear layers
- Group-Query Attention (GQA) support for more efficient inference
- Canon layers
- Flash Attention 3 integration
"""
import math
from functools import partial
from dataclasses import dataclass
@ -24,6 +23,14 @@ from nanochat.common import get_dist_info, print0
from nanochat.muon import Muon, DistMuon
from nanochat.adamw import DistAdamW
# Load Flash Attention 3 from HuggingFace Hub (and silence the progress bar)
import os
os.environ["HF_HUB_DISABLE_PROGRESS_BARS"] = "1"
# Official docs of FA3 label it as "beta" and want you to install FA3 from source, which is a pain.
# Wishing for official FA3 wheels soon, for now this seems to be a fast way to get them (ty varunneal)
from kernels import get_kernel
flash_attn = get_kernel('varunneal/flash-attention-3').flash_attn_interface
@dataclass
class GPTConfig:
sequence_len: int = 1024
@ -32,6 +39,10 @@ class GPTConfig:
n_head: int = 6 # number of query heads
n_kv_head: int = 6 # number of key/value heads (GQA)
n_embd: int = 768
# Sliding window attention pattern string, tiled across layers. Final layer always L.
# Characters: L=long (full context), S=short (half context)
# Examples: "L"=all full context, "SL"=alternating, "SSL"=two short then one long
window_pattern: str = "L"
def norm(x):
@ -62,48 +73,42 @@ class CausalSelfAttention(nn.Module):
self.c_v = nn.Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False)
self.c_proj = nn.Linear(self.n_embd, self.n_embd, bias=False)
def forward(self, x, cos_sin, kv_cache):
def forward(self, x, cos_sin, window_size, kv_cache):
B, T, C = x.size()
# Project the input to get queries, keys, and values
# Shape: (B, T, H, D) - FA3's native layout, no transpose needed!
q = self.c_q(x).view(B, T, self.n_head, self.head_dim)
k = self.c_k(x).view(B, T, self.n_kv_head, self.head_dim)
v = self.c_v(x).view(B, T, self.n_kv_head, self.head_dim)
# Apply Rotary Embeddings to queries and keys to get relative positional encoding
cos, sin = cos_sin
q, k = apply_rotary_emb(q, cos, sin), apply_rotary_emb(k, cos, sin) # QK rotary embedding
q, k = apply_rotary_emb(q, cos, sin), apply_rotary_emb(k, cos, sin)
q, k = norm(q), norm(k) # QK norm
q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2) # make head be batch dim, i.e. (B, T, H, D) -> (B, H, T, D)
# Apply KV cache: insert current k,v into cache, get the full view so far
if kv_cache is not None:
k, v = kv_cache.insert_kv(self.layer_idx, k, v)
Tq = q.size(2) # number of queries in this forward pass
Tk = k.size(2) # number of keys/values in total (in the cache + current forward pass)
# Attention: queries attend to keys/values autoregressively. A few cases to handle:
enable_gqa = self.n_head != self.n_kv_head # Group Query Attention (GQA): duplicate key/value heads to match query heads if desired
if kv_cache is None or Tq == Tk:
# During training (no KV cache), attend as usual with causal attention
# And even if there is KV cache, we can still use this simple version when Tq == Tk
y = F.scaled_dot_product_attention(q, k, v, is_causal=True, enable_gqa=enable_gqa)
elif Tq == 1:
# During inference but with a single query in this forward pass:
# The query has to attend to all the keys/values in the cache
y = F.scaled_dot_product_attention(q, k, v, is_causal=False, enable_gqa=enable_gqa)
# Attention with Flash Attention 3
# FA3 handles GQA automatically when n_kv_heads < n_heads
# window_size is (left, right) tuple: (N, 0) for causal, (-1, 0) for full context
if kv_cache is None:
# Training: causal attention with optional sliding window
y = flash_attn.flash_attn_func(q, k, v, causal=True, window_size=window_size)
else:
# During inference AND we have a chunk of queries in this forward pass:
# First, each query attends to all the cached keys/values (i.e. full prefix)
attn_mask = torch.zeros((Tq, Tk), dtype=torch.bool, device=q.device) # True = keep, False = mask
prefix_len = Tk - Tq
attn_mask[:, :prefix_len] = True
# Then, causal attention within this chunk
attn_mask[:, prefix_len:] = torch.tril(torch.ones((Tq, Tq), dtype=torch.bool, device=q.device))
y = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, enable_gqa=enable_gqa)
# Inference: use flash_attn_with_kvcache which handles cache management
k_cache, v_cache = kv_cache.get_layer_cache(self.layer_idx)
y = flash_attn.flash_attn_with_kvcache(
q, k_cache, v_cache,
k=k, v=v,
cache_seqlens=kv_cache.cache_seqlens,
causal=True,
window_size=window_size,
)
# Advance position after last layer processes
if self.layer_idx == kv_cache.n_layers - 1:
kv_cache.advance(T)
# Re-assemble the heads side by side and project back to residual stream
y = y.transpose(1, 2).contiguous().view(B, T, -1)
# Re-assemble the heads and project back to residual stream
y = y.contiguous().view(B, T, -1)
y = self.c_proj(y)
return y
@ -186,18 +191,32 @@ class Block(nn.Module):
class GPT(nn.Module):
def __init__(self, config, pad_vocab_size_to=64):
"""
NOTE a major footgun: this __init__ function runs in meta device context (!!)
Therefore, any calculations inside here are shapes and dtypes only, no actual data.
=> We actually initialize all data (parameters, buffers, etc.) in init_weights() instead.
"""
super().__init__()
self.config = config
# For DDP, we want vocab_size divisible by world_size. Also, there are potential performance benefits, see:
# Compute per-layer window sizes for sliding window attention
# window_size is (left, right) tuple: (-1, 0) for full context, (N, 0) for sliding window
self.window_sizes = self._compute_window_sizes(config)
# Pad vocab for efficiency (DDP, tensor cores). This is just an optimization - outputs are cropped in forward().
# https://huggingface.co/docs/transformers/main_classes/model#transformers.PreTrainedModel.resize_token_embeddings
padded_vocab_size = ((config.vocab_size + pad_vocab_size_to - 1) // pad_vocab_size_to) * pad_vocab_size_to
if padded_vocab_size != config.vocab_size:
print0(f"Padding vocab_size from {config.vocab_size} to {padded_vocab_size} to be divisible by {pad_vocab_size_to}")
print0(f"Padding vocab_size from {config.vocab_size} to {padded_vocab_size} for efficiency")
self.transformer = nn.ModuleDict({
"wte": nn.Embedding(padded_vocab_size, config.n_embd),
"h": nn.ModuleList([Block(config, layer_idx) for layer_idx in range(config.n_layer)]),
})
self.lm_head = nn.Linear(config.n_embd, padded_vocab_size, bias=False)
# Per-layer learnable scalars (inspired by modded-nanogpt)
# resid_lambdas: scales the residual stream at each layer (init 1.0 = neutral)
# x0_lambdas: blends initial embedding back in at each layer (init 0.0 = disabled)
# Separate parameters so they can have different optimizer treatment
self.resid_lambdas = nn.Parameter(torch.ones(config.n_layer)) # fake init, real init in init_weights()
self.x0_lambdas = nn.Parameter(torch.zeros(config.n_layer)) # fake init, real init in init_weights()
# To support meta device initialization, we init the rotary embeddings here, but it's just "fake" meta tensors only.
# As for rotary_seq_len, these rotary embeddings are pretty small/cheap in memory,
# so let's just over-compute them by 10X, but assert fail if we ever reach that amount.
@ -243,6 +262,11 @@ class GPT(nn.Module):
torch.nn.init.uniform_(block.canon_a.conv.weight, -s, s)
torch.nn.init.uniform_(block.canon_c.conv.weight, -s, s)
# Per-layer scalars
with torch.no_grad():
self.resid_lambdas.fill_(1.0) # 1.0 => typical residual connections at init
self.x0_lambdas.fill_(0.0) # 0.0 => skip connection to input is disabled at init
# Rotary embeddings
head_dim = self.config.n_embd // self.config.n_head
cos, sin = self._precompute_rotary_embeddings(self.rotary_seq_len, head_dim)
@ -269,6 +293,35 @@ class GPT(nn.Module):
cos, sin = cos[None, :, None, :], sin[None, :, None, :] # add batch and head dims for later broadcasting
return cos, sin
def _compute_window_sizes(self, config):
"""
Compute per-layer window sizes for sliding window attention.
Returns list of (left, right) tuples for FA3's window_size parameter:
- left: how many tokens before current position to attend to (-1 = unlimited)
- right: how many tokens after current position to attend to (0 for causal)
Pattern string is tiled across layers. Final layer always gets L (full context).
Characters: L=long (full context), S=short (half context)
"""
pattern = config.window_pattern.upper()
assert all(c in "SL" for c in pattern), f"Invalid window_pattern: {pattern}. Use only S and L."
# Map characters to window sizes
long_window = config.sequence_len
short_window = long_window // 2
char_to_window = {
"L": (long_window, 0),
"S": (short_window, 0),
}
# Tile pattern across layers
window_sizes = []
for layer_idx in range(config.n_layer):
char = pattern[layer_idx % len(pattern)]
window_sizes.append(char_to_window[char])
# Final layer always gets full context
window_sizes[-1] = (long_window, 0)
return window_sizes
def get_device(self):
return self.transformer.wte.weight.device
@ -277,16 +330,24 @@ class GPT(nn.Module):
Return the estimated FLOPs per token for the model (forward + backward).
Each matmul weight parameter contributes 2 FLOPs (multiply *, accumulate +) in forward, and 2X that in backward => 2+4=6.
Cleanest explanation of this: https://medium.com/@dzmitrybahdanau/the-flops-calculus-of-language-model-training-3b19c1f025e4
On top of that, the term 12 * l * h * q * t accounts for key @ query matmul flops inside attention.
On top of that, 12 * h * q * effective_seq_len accounts for key @ query matmul flops inside attention.
With sliding windows, effective_seq_len varies per layer (capped by window size).
Ref: https://arxiv.org/abs/2204.02311 (PaLM paper).
This is ~1% off from the exact formulas of Chinchilla paper, the difference is:
- Chinchilla counts the embedding layer as flops (? weird, it's just a lookup => we ignore)
- Chinchilla counts exp/sum/divide in attention softmax as flops (a little sus and very tiny => we ignore)
"""
nparams = sum(p.numel() for p in self.parameters())
nparams_embedding = self.transformer.wte.weight.numel()
l, h, q, t = self.config.n_layer, self.config.n_head, self.config.n_embd // self.config.n_head, self.config.sequence_len
num_flops_per_token = 6 * (nparams - nparams_embedding) + 12 * l * h * q * t
# Exclude non-matmul params: embeddings and per-layer scalars
nparams_exclude = self.transformer.wte.weight.numel() + self.resid_lambdas.numel() + self.x0_lambdas.numel()
h, q, t = self.config.n_head, self.config.n_embd // self.config.n_head, self.config.sequence_len
# Sum attention FLOPs per layer, accounting for sliding window
attn_flops = 0
for window_size in self.window_sizes:
window = window_size[0] # (left, right) tuple, we use left
effective_seq = t if window < 0 else min(window, t)
attn_flops += 12 * h * q * effective_seq
num_flops_per_token = 6 * (nparams - nparams_exclude) + attn_flops
return num_flops_per_token
def num_scaling_params(self):
@ -301,27 +362,31 @@ class GPT(nn.Module):
nparams = sum(p.numel() for p in self.parameters())
return nparams
def setup_optimizers(self, unembedding_lr=0.004, embedding_lr=0.2, matrix_lr=0.02, weight_decay=0.0, adam_betas=(0.8, 0.95)):
def setup_optimizers(self, unembedding_lr=0.004, embedding_lr=0.2, matrix_lr=0.02, weight_decay=0.0, adam_betas=(0.8, 0.95), scalar_lr=0.5):
model_dim = self.config.n_embd
ddp, rank, local_rank, world_size = get_dist_info()
# Separate out all parameters into 3 groups (matrix, embedding, lm_head)
# Separate out all parameters into 5 groups (matrix, embedding, lm_head, resid_lambdas, x0_lambdas)
matrix_params = list(self.transformer.h.parameters())
embedding_params = list(self.transformer.wte.parameters())
lm_head_params = list(self.lm_head.parameters())
assert len(list(self.parameters())) == len(matrix_params) + len(embedding_params) + len(lm_head_params)
# Create the AdamW optimizer for the embedding and lm_head
resid_params = [self.resid_lambdas]
x0_params = [self.x0_lambdas]
assert len(list(self.parameters())) == len(matrix_params) + len(embedding_params) + len(lm_head_params) + len(resid_params) + len(x0_params)
# Create the AdamW optimizer for the embedding, lm_head, and per-layer scalars
# Scale the LR for the AdamW parameters by ∝1/√dmodel (having tuned the LRs for 768 dim model)
dmodel_lr_scale = (model_dim / 768) ** -0.5
print0(f"Scaling the LR for the AdamW parameters ∝1/√({model_dim}/768) = {dmodel_lr_scale:.6f}")
adam_groups = [
dict(params=lm_head_params, lr=unembedding_lr * dmodel_lr_scale),
dict(params=embedding_params, lr=embedding_lr * dmodel_lr_scale),
dict(params=resid_params, lr=scalar_lr * 0.01), # these are a lot more sensitive because they accumulate in the residual stream
dict(params=x0_params, lr=scalar_lr),
]
adamw_kwargs = dict(betas=adam_betas, eps=1e-10, weight_decay=weight_decay)
adamw_kwargs = dict(betas=adam_betas, eps=1e-10, weight_decay=0.0) # NOTE: weight decay is hardcoded to 0.0 for AdamW, only used in Muon
AdamWFactory = DistAdamW if ddp else partial(torch.optim.AdamW, fused=True)
adamw_optimizer = AdamWFactory(adam_groups, **adamw_kwargs)
# Create the Muon optimizer for the linear layers
muon_kwargs = dict(lr=matrix_lr, momentum=0.95)
muon_kwargs = dict(lr=matrix_lr, momentum=0.95, weight_decay=weight_decay)
MuonFactory = DistMuon if ddp else Muon
muon_optimizer = MuonFactory(matrix_params, **muon_kwargs)
# Combine them the two optimizers into one list
@ -345,8 +410,10 @@ class GPT(nn.Module):
# Forward the trunk of the Transformer
x = self.transformer.wte(idx)
x = norm(x)
for block in self.transformer.h:
x = block(x, cos_sin, kv_cache)
x0 = x # save initial normalized embedding for x0 residual
for i, block in enumerate(self.transformer.h):
x = self.resid_lambdas[i] * x + self.x0_lambdas[i] * x0
x = block(x, cos_sin, self.window_sizes[i], kv_cache)
x = norm(x)
# Forward the lm_head (compute logits)

View File

@ -1,11 +1,50 @@
"""
Muon optimizer from Keller et al.
Also a lot of borrowing of ideas from modded-nanogpt.
Muon optimizer adapted (simplified) from modded-nanogpt.
https://github.com/KellerJordan/modded-nanogpt
"""
import torch
from torch import Tensor
import torch.distributed as dist
# Coefficients for Polar Express (computed for num_iters=5, safety_factor=2e-2, cushion=2)
# From https://arxiv.org/pdf/2505.16932
polar_express_coeffs = [
(8.156554524902461, -22.48329292557795, 15.878769915207462),
(4.042929935166739, -2.808917465908714, 0.5000178451051316),
(3.8916678022926607, -2.772484153217685, 0.5060648178503393),
(3.285753657755655, -2.3681294933425376, 0.46449024233003106),
(2.3465413258596377, -1.7097828382687081, 0.42323551169305323),
]
@torch.compile
def zeropower_via_polar_express(G: Tensor, steps: int = 5) -> Tensor:
"""
Polar Express Sign Method for orthogonalization.
https://arxiv.org/pdf/2505.16932
by Noah Amsel, David Persson, Christopher Musco, Robert M. Gower.
Alternative to Newton-Schulz iteration with potentially better convergence properties.
"""
assert G.ndim >= 2
X = G.bfloat16()
if G.size(-2) > G.size(-1):
X = X.mT
# Ensure spectral norm is at most 1 (with 2% safety factor)
X = X / (X.norm(dim=(-2, -1), keepdim=True) * 1.02 + 1e-6)
# Perform the iterations (cap at available coefficients)
for a, b, c in polar_express_coeffs[:min(steps, len(polar_express_coeffs))]:
A = X @ X.mT
B = b * A + c * (A @ A)
X = a * X + B @ X
if G.size(-2) > G.size(-1):
X = X.mT
return X
@torch.compile
def zeropower_via_newtonschulz5(G: Tensor, steps: int) -> Tensor:
"""
@ -35,6 +74,40 @@ def zeropower_via_newtonschulz5(G: Tensor, steps: int) -> Tensor:
X = X.mT
return X
@torch.compile
def apply_variance_reduction(v: Tensor, second_momentum_buffer: Tensor, beta2: float) -> Tensor:
"""
NorMuon-style variance reduction, similar to Adafactor's low-rank variance estimator.
https://arxiv.org/pdf/2510.05491
Normalizes updates based on a running estimate of per-row (or per-column) variance.
The reduction dimension is determined by the shape of second_momentum_buffer.
"""
# Determine reduction dimension from buffer shape
red_dim = -1 if second_momentum_buffer.size(-1) == 1 else -2
# Compute per-row/col mean of squared values
v_mean = v.float().square().mean(dim=red_dim, keepdim=True)
red_dim_size = v.size(red_dim)
# Compute current norm
v_norm_sq = v_mean.sum(dim=(-2, -1), keepdim=True) * red_dim_size
v_norm = v_norm_sq.sqrt()
# Update second momentum buffer (EMA of variance)
second_momentum_buffer.lerp_(v_mean.to(dtype=second_momentum_buffer.dtype), 1 - beta2)
# Compute scaling factor from second momentum
step_size = second_momentum_buffer.clamp_min(1e-10).rsqrt()
scaled_sq_sum = (v_mean * red_dim_size) * step_size.float().square()
v_norm_new = scaled_sq_sum.sum(dim=(-2, -1), keepdim=True).sqrt()
# Final scale preserves overall norm while adjusting per-row/col
final_scale = step_size * (v_norm / v_norm_new.clamp_min(1e-10))
return v.mul(final_scale.to(v.dtype))
class Muon(torch.optim.Optimizer):
"""
Muon - MomentUm Orthogonalized by Newton-schulz
@ -56,9 +129,11 @@ class Muon(torch.optim.Optimizer):
momentum: The momentum used by the internal SGD.
nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended)
ns_steps: The number of Newton-Schulz iteration steps to use.
beta2: The decay rate for the second moment (variance) estimate. Set to None to disable.
weight_decay: Cautious weight decay coefficient. Only decays where update and weight agree.
"""
def __init__(self, params, lr=0.02, momentum=0.95, nesterov=True, ns_steps=5):
defaults = dict(lr=lr, momentum=momentum, nesterov=nesterov, ns_steps=ns_steps)
def __init__(self, params, lr=0.02, momentum=0.95, nesterov=True, ns_steps=5, beta2=0.95, weight_decay=0.0):
defaults = dict(lr=lr, momentum=momentum, nesterov=nesterov, ns_steps=ns_steps, beta2=beta2, weight_decay=weight_decay)
params: list[Tensor] = [*params]
param_groups = []
for size in {p.numel() for p in params}:
@ -79,13 +154,29 @@ class Muon(torch.optim.Optimizer):
buf: Tensor = state["momentum_buffer"]
buf.lerp_(g, 1 - group["momentum"])
g = g.lerp_(buf, group["momentum"]) if group["nesterov"] else buf
g = zeropower_via_newtonschulz5(g, steps=group["ns_steps"])
p.add_(g, alpha=-group["lr"] * max(1, p.size(-2) / p.size(-1))**0.5)
g = zeropower_via_polar_express(g, steps=group["ns_steps"])
# Variance reduction (NorMuon-style)
if group["beta2"] is not None:
if "second_momentum_buffer" not in state:
# Buffer shape determines reduction dim: reduce along larger dimension
if p.size(-2) >= p.size(-1):
state["second_momentum_buffer"] = torch.zeros_like(g[..., :1])
else:
state["second_momentum_buffer"] = torch.zeros_like(g[..., :1, :])
g = apply_variance_reduction(g, state["second_momentum_buffer"], group["beta2"])
# Parameter update with cautious weight decay
effective_lr = group["lr"] * max(1, p.size(-2) / p.size(-1))**0.5
wd = group["weight_decay"]
if wd != 0:
mask = (g * p) >= 0
p.sub_(effective_lr * g + effective_lr * wd * p * mask)
else:
p.sub_(effective_lr * g)
class DistMuon(torch.optim.Optimizer):
"""
Muon: SGD-momentum + (optional) Nesterov, then orthogonalize the 2D update via NewtonSchulz,
Muon: SGD-momentum + (optional) Nesterov, then orthogonalize the 2D update via Polar Express,
finally apply aspect-ratio scaled step. Performs its own distributed synchronization:
- reduce_scatter(AVG) for gradient averaging
- all_gather to replicate updated weights
@ -102,11 +193,13 @@ class DistMuon(torch.optim.Optimizer):
lr: learning rate
momentum: momentum coefficient in [0,1)
nesterov: if True, Nesterov-style update (g <- lerp(g, buf, momentum)); else use buf
ns_steps: number of NewtonSchulz iterations for the orthogonalization
ns_steps: number of Newton-Schulz iterations for the orthogonalization
beta2: decay rate for second moment (variance) estimate. Set to None to disable.
weight_decay: Cautious weight decay coefficient. Only decays where update and weight agree.
"""
def __init__(self, params, lr: float = 0.02, momentum: float = 0.95,
nesterov: bool = True, ns_steps: int = 5):
defaults = dict(lr=lr, momentum=momentum, nesterov=nesterov, ns_steps=ns_steps)
nesterov: bool = True, ns_steps: int = 5, beta2: float = 0.95, weight_decay: float = 0.0):
defaults = dict(lr=lr, momentum=momentum, nesterov=nesterov, ns_steps=ns_steps, beta2=beta2, weight_decay=weight_decay)
params = list(params)
assert all(p.ndim == 2 for p in params), "Muon expects 2D parameters only"
rank = dist.get_rank()
@ -173,9 +266,24 @@ class DistMuon(torch.optim.Optimizer):
buf: Tensor = state["momentum_buffer"]
buf.lerp_(g, 1.0 - group["momentum"])
g = g.lerp_(buf, group["momentum"]) if group["nesterov"] else buf
g = zeropower_via_newtonschulz5(g, steps=group["ns_steps"])
scale = (max(1.0, p.size(-2) / p.size(-1)) ** 0.5)
p.add_(g, alpha=-group["lr"] * scale)
g = zeropower_via_polar_express(g, steps=group["ns_steps"])
# Variance reduction (NorMuon-style)
if group["beta2"] is not None:
if "second_momentum_buffer" not in state:
# Buffer shape determines reduction dim: reduce along larger dimension
if p.size(-2) >= p.size(-1):
state["second_momentum_buffer"] = torch.zeros_like(g[..., :1])
else:
state["second_momentum_buffer"] = torch.zeros_like(g[..., :1, :])
g = apply_variance_reduction(g, state["second_momentum_buffer"], group["beta2"])
# Parameter update with cautious weight decay
effective_lr = group["lr"] * (max(1.0, p.size(-2) / p.size(-1)) ** 0.5)
wd = group["weight_decay"]
if wd != 0:
mask = (g * p) >= 0
p.sub_(effective_lr * g + effective_lr * wd * p * mask)
else:
p.sub_(effective_lr * g)
# Replicate updated parameters to all ranks
ag_input = params[owner_idx] if owner_idx < len(params) else zero_buffer
ag_output = params[base_i:base_i + world_size]

View File

@ -8,6 +8,7 @@ dependencies = [
"datasets>=4.0.0",
"fastapi>=0.117.1",
"ipykernel>=7.1.0",
"kernels>=0.11.7",
"matplotlib>=3.10.8",
"psutil>=7.1.0",
"python-dotenv>=1.2.1",

View File

@ -42,6 +42,7 @@ parser.add_argument("--depth", type=int, default=20, help="depth of the Transfor
parser.add_argument("--aspect_ratio", type=int, default=64, help="model_dim = depth * aspect_ratio")
parser.add_argument("--head_dim", type=int, default=128, help="target head dimension for attention")
parser.add_argument("--max_seq_len", type=int, default=2048, help="max context length")
parser.add_argument("--window_pattern", type=str, default="SSSL", help="sliding window pattern tiled across layers: L=full, S=half context (e.g. 'SSL')")
# Training horizon (only one used, in order of precedence)
parser.add_argument("--num_iterations", type=int, default=-1, help="explicit number of optimization steps (-1 = disable)")
parser.add_argument("--target_flops", type=float, default=-1.0, help="calculate num_iterations to reach target_flops (-1 = disable)")
@ -51,8 +52,9 @@ parser.add_argument("--device_batch_size", type=int, default=32, help="per-devic
parser.add_argument("--total_batch_size", type=int, default=524288, help="total batch size in tokens")
parser.add_argument("--embedding_lr", type=float, default=0.3, help="learning rate for embedding parameters (Adam)")
parser.add_argument("--unembedding_lr", type=float, default=0.004, help="learning rate for unembedding parameters (Adam)")
parser.add_argument("--weight_decay", type=float, default=0.0, help="weight decay for embedding/unembedding parameters (Adam)")
parser.add_argument("--weight_decay", type=float, default=0.2, help="cautious weight decay for the Muon optimizer (for weights)")
parser.add_argument("--matrix_lr", type=float, default=0.02, help="learning rate for matrix parameters (Muon)")
parser.add_argument("--scalar_lr", type=float, default=0.5, help="learning rate for scalars (resid_lambdas, x0_lambdas)")
parser.add_argument("--adam_beta1", type=float, default=0.8, help="Adam beta1 for embedding/unembedding")
parser.add_argument("--adam_beta2", type=float, default=0.95, help="Adam beta2 for embedding/unembedding")
parser.add_argument("--warmup_ratio", type=float, default=0.0, help="ratio of iterations for LR warmup")
@ -129,11 +131,16 @@ if batch_ratio != 1.0:
batch_lr_scale = batch_ratio ** 0.5
print0(f"Scaling LRs by {batch_lr_scale:.4f} for batch size {args.total_batch_size:,} (reference: {reference_batch_size:,})")
# Weight decay is tuned at d12 and its scaling seems to be \propto 1/channels^2 (or equivalently, \propto 1/depth^2 due to constant aspect ratio)
weight_decay_scaled = args.weight_decay * (12 / args.depth)**2
if args.depth != 12:
print0(f"Scaling weight decay from {args.weight_decay:.6f} to {weight_decay_scaled:.6f} for depth {args.depth}")
# -----------------------------------------------------------------------------
# Initialize the Model
# Create a new model with random weights
model_config_kwargs = dict(sequence_len=args.max_seq_len, vocab_size=vocab_size, n_layer=num_layers, n_head=num_heads, n_kv_head=num_kv_heads, n_embd=model_dim)
model_config_kwargs = dict(sequence_len=args.max_seq_len, vocab_size=vocab_size, n_layer=num_layers, n_head=num_heads, n_kv_head=num_kv_heads, n_embd=model_dim, window_pattern=args.window_pattern)
with torch.device("meta"):
# All tensors are created as meta tensors (they have shape/dtype but no data)
model_config = GPTConfig(**model_config_kwargs)
@ -188,8 +195,9 @@ optimizers = model.setup_optimizers(
unembedding_lr=args.unembedding_lr * batch_lr_scale,
embedding_lr=args.embedding_lr * batch_lr_scale,
matrix_lr=args.matrix_lr * batch_lr_scale,
weight_decay=args.weight_decay,
weight_decay=weight_decay_scaled,
adam_betas=adam_betas,
scalar_lr=args.scalar_lr * batch_lr_scale,
)
adamw_optimizer, muon_optimizer = optimizers
@ -227,6 +235,10 @@ def get_muon_momentum(it):
momentum = (1 - frac) * 0.85 + frac * 0.95
return momentum
# Weight decay scheduler for Muon optimizer (linear to zero over the course of training)
def get_weight_decay(it):
return weight_decay_scaled * (1 - it / num_iterations)
# -----------------------------------------------------------------------------
# Loop state (variables updated by the training loop)
@ -257,7 +269,7 @@ while True:
eval_steps = args.eval_tokens // (args.device_batch_size * args.max_seq_len * ddp_world_size)
with autocast_ctx:
val_bpb = evaluate_bpb(model, val_loader, eval_steps, token_bytes)
print0(f"Step {step:05d} | Validation bpb: {val_bpb:.4f}")
print0(f"Step {step:05d} | Validation bpb: {val_bpb:.6f}")
if val_bpb < min_val_bpb:
min_val_bpb = val_bpb
wandb_run.log({
@ -351,8 +363,10 @@ while True:
for group in opt.param_groups:
group["lr"] = group["initial_lr"] * lrm
muon_momentum = get_muon_momentum(step)
muon_weight_decay = get_weight_decay(step)
for group in muon_optimizer.param_groups:
group["momentum"] = muon_momentum
group["weight_decay"] = muon_weight_decay
for opt in optimizers:
opt.step()
model.zero_grad(set_to_none=True)
@ -402,7 +416,7 @@ while True:
print0(f"Peak memory usage: {get_max_memory() / 1024 / 1024:.2f}MiB")
print0(f"Total training time: {total_training_time/60:.2f}m")
if val_bpb is not None:
print0(f"Minimum validation bpb: {min_val_bpb:.4f}")
print0(f"Minimum validation bpb: {min_val_bpb:.6f}")
# Log to report
from nanochat.report import get_report

View File

@ -39,13 +39,9 @@ class MockModel:
def forward(self, ids, kv_cache=None):
"""Return uniform logits so sampling is spread across vocab."""
B, T = ids.shape
# Simulate what a real transformer does: insert k,v into the cache for each layer
# With FA3, flash_attn_with_kvcache updates cache in-place and we advance position
if kv_cache is not None:
head_dim = self.config.n_embd // self.config.n_head
for layer_idx in range(self.config.n_layer):
k = torch.zeros(B, self.config.n_kv_head, T, head_dim)
v = torch.zeros(B, self.config.n_kv_head, T, head_dim)
kv_cache.insert_kv(layer_idx, k, v)
kv_cache.advance(T)
# Uniform logits -> equal probability for all tokens
logits = torch.zeros(B, T, self.vocab_size)
return logits
@ -85,16 +81,11 @@ class ByteTokenizer:
byte_tokens = [t for t in tokens if t < 256]
return bytes(byte_tokens).decode("utf-8", errors="replace")
def test_kv_cache_resize():
"""
The KV cache was not resized correctly, more information here:
https://github.com/karpathy/nanochat/pull/186
This test reproduces the issue and will be merged alongside the fix.
"""
def test_kv_cache_basic():
"""Test basic KVCache functionality for FA3."""
batch_size = 2
num_heads = 3
seq_len = 4
seq_len = 64
head_dim = 5
num_layers = 6
@ -103,45 +94,64 @@ def test_kv_cache_resize():
num_heads=num_heads,
seq_len=seq_len,
head_dim=head_dim,
num_layers=num_layers
num_layers=num_layers,
device="cpu",
)
# Insert a single token with a distinct fill value to all layers
def insert_token(token_idx):
for layer_idx in range(num_layers):
k = torch.full((batch_size, num_heads, 1, head_dim), fill_value=float(token_idx), dtype=torch.float32)
v = torch.full((batch_size, num_heads, 1, head_dim), fill_value=float(token_idx * 100), dtype=torch.float32)
kv_cache.insert_kv(layer_idx, k, v)
# Check initial state
assert kv_cache.get_pos() == 0
assert kv_cache.k_cache.shape == (num_layers, batch_size, seq_len, num_heads, head_dim)
assert kv_cache.v_cache.shape == (num_layers, batch_size, seq_len, num_heads, head_dim)
# Insert 4 tokens (fills the initial seq_len=4)
for i in range(4):
insert_token(i)
# Test advance
kv_cache.advance(10)
assert kv_cache.get_pos() == 10
# Record the original state of the cache
original_cache = kv_cache.kv_cache.clone()
original_seq_len = original_cache.shape[4]
kv_cache.advance(5)
assert kv_cache.get_pos() == 15
# Insert the 5th token, which will trigger a resize
insert_token(4)
# Verify that the cache actually resized
new_seq_len = kv_cache.kv_cache.shape[4]
assert new_seq_len > original_seq_len, f"Cache did not resize: original seq_len={original_seq_len}, new seq_len={new_seq_len}"
# Test reset
kv_cache.reset()
assert kv_cache.get_pos() == 0
# Verify that the original 4 tokens are still intact after resize
for layer_idx in range(num_layers):
for token_idx in range(4):
# Check that resized cache matches expected values
expected_k = float(token_idx)
expected_v = float(token_idx * 100)
actual_k = kv_cache.kv_cache[layer_idx, 0, :, :, token_idx, :]
actual_v = kv_cache.kv_cache[layer_idx, 1, :, :, token_idx, :]
assert (actual_k == expected_k).all(), f"Layer {layer_idx}, token {token_idx}: key corrupted, expected {expected_k}"
assert (actual_v == expected_v).all(), f"Layer {layer_idx}, token {token_idx}: value corrupted, expected {expected_v}"
# And that the original cache matches resized cache
original_k = original_cache[layer_idx, 0, :, :, token_idx, :]
original_v = original_cache[layer_idx, 1, :, :, token_idx, :]
assert (actual_k == original_k).all(), f"Layer {layer_idx}, token {token_idx}: key doesn't match original"
assert (actual_v == original_v).all(), f"Layer {layer_idx}, token {token_idx}: value doesn't match original"
# Test get_layer_cache returns correct views
k_layer0, v_layer0 = kv_cache.get_layer_cache(0)
assert k_layer0.shape == (batch_size, seq_len, num_heads, head_dim)
assert v_layer0.shape == (batch_size, seq_len, num_heads, head_dim)
def test_kv_cache_prefill():
"""Test KVCache.prefill() copies data correctly."""
batch_size = 1
num_heads = 4
head_dim = 8
num_layers = 2
# Create source cache and advance it
src_cache = KVCache(
batch_size=batch_size, num_heads=num_heads, seq_len=32,
head_dim=head_dim, num_layers=num_layers, device="cpu",
)
# Write some data to source cache
src_cache.k_cache[0, 0, :16, :, :] = 1.0
src_cache.v_cache[0, 0, :16, :, :] = 2.0
src_cache.advance(16)
# Create destination cache with larger seq_len
dst_cache = KVCache(
batch_size=batch_size, num_heads=num_heads, seq_len=64,
head_dim=head_dim, num_layers=num_layers, device="cpu",
)
# Prefill
dst_cache.prefill(src_cache)
# Check position was copied
assert dst_cache.get_pos() == 16
# Check data was copied
assert (dst_cache.k_cache[0, 0, :16, :, :] == 1.0).all()
assert (dst_cache.v_cache[0, 0, :16, :, :] == 2.0).all()
def test_multi_sample_first_token_diversity():

17
uv.lock
View File

@ -1089,6 +1089,21 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/e7/e7/80988e32bf6f73919a113473a604f5a8f09094de312b9d52b79c2df7612b/jupyter_core-5.9.1-py3-none-any.whl", hash = "sha256:ebf87fdc6073d142e114c72c9e29a9d7ca03fad818c5d300ce2adc1fb0743407", size = 29032, upload-time = "2025-10-16T19:19:16.783Z" },
]
[[package]]
name = "kernels"
version = "0.11.7"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "huggingface-hub" },
{ name = "packaging" },
{ name = "pyyaml" },
{ name = "tomli", marker = "python_full_version < '3.11' or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
]
sdist = { url = "https://files.pythonhosted.org/packages/d6/c8/2d4fea16366d34069af6d4c4f61218f55e5d0daea5d4c24d58849e9fd626/kernels-0.11.7.tar.gz", hash = "sha256:99c3aa518965518902f4dc26053d6051f06abc904ae33d9486c28674a2ea0fa5", size = 50282, upload-time = "2026-01-08T15:41:57.383Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/ab/49/e62183353374ec71306ef354781233ac8d12fdfd1cf3d47c875055a99603/kernels-0.11.7-py3-none-any.whl", hash = "sha256:1421791b1e501fcb0a7f0a4d763c5385591756d9d6ed12ed8baa1e0d71bcd21a", size = 46501, upload-time = "2026-01-08T15:41:55.784Z" },
]
[[package]]
name = "kiwisolver"
version = "1.4.9"
@ -1478,6 +1493,7 @@ dependencies = [
{ name = "datasets" },
{ name = "fastapi" },
{ name = "ipykernel" },
{ name = "kernels" },
{ name = "matplotlib" },
{ name = "psutil" },
{ name = "python-dotenv" },
@ -1518,6 +1534,7 @@ requires-dist = [
{ name = "datasets", specifier = ">=4.0.0" },
{ name = "fastapi", specifier = ">=0.117.1" },
{ name = "ipykernel", specifier = ">=7.1.0" },
{ name = "kernels", specifier = ">=0.11.7" },
{ name = "matplotlib", specifier = ">=3.10.8" },
{ name = "psutil", specifier = ">=7.1.0" },
{ name = "python-dotenv", specifier = ">=1.2.1" },