mirror of
https://github.com/karpathy/nanochat.git
synced 2026-05-10 01:40:17 +00:00
Add minimal bigram speedrun recipe
This commit is contained in:
parent
dc54a1a307
commit
aab331dfd4
66
dev/bigram_speedrun_results.md
Normal file
66
dev/bigram_speedrun_results.md
Normal file
|
|
@ -0,0 +1,66 @@
|
|||
# Bigram Speedrun Verification Notes
|
||||
|
||||
This branch is based on upstream nanochat master at `dc54a1a` and keeps the
|
||||
submission implementation focused on the winning recipe:
|
||||
|
||||
- per-layer hashed bigram residual embeddings
|
||||
- Muon+ post-orthogonalization normalization
|
||||
- row equilibration before Muon orthogonalization
|
||||
- lower scalar LR (`--scalar-lr=0.3`)
|
||||
- batched training logging (`--train-log-every=50`)
|
||||
- `torch.compile(..., mode="max-autotune-no-cudagraphs")` for the speedrun script
|
||||
|
||||
It intentionally excludes the experimental branches that were not part of the
|
||||
final candidate: sparse layers, MoE/TOP losses, train-time logit bias losses,
|
||||
post-hoc fitting, NorMuon, and checkpoint merging.
|
||||
|
||||
## Reproduction Sanity Check
|
||||
|
||||
Minimal branch d4/20 matched the prior experimental branch:
|
||||
|
||||
| Run | Step 0 BPB | Step 10 BPB | Final BPB |
|
||||
| --- | ---: | ---: | ---: |
|
||||
| Prior candidate branch | `3.237224` | `3.234722` | `3.223259` |
|
||||
| Minimal PR branch | `3.237224` | `3.234722` | `3.223286` |
|
||||
|
||||
The final difference is `0.000027` BPB on a tiny run, consistent with small
|
||||
compile/graph differences after removing unused experimental code.
|
||||
|
||||
## Full d16 Verification
|
||||
|
||||
Both runs used d16, FP8, target param/data ratio 8, total batch `524288`, and
|
||||
device batch `32` on the same machine.
|
||||
|
||||
| Run | Final BPB | Train time | Avg logged tok/s, excluding first | Avg logged step time, excluding first |
|
||||
| --- | ---: | ---: | ---: | ---: |
|
||||
| Upstream master dense | `0.800673` | `94.64m` | `329,904` | `1589.232ms` |
|
||||
| Bigram/Muon+ candidate | `0.798000` | `93.61m` | `333,507` | `1572.058ms` |
|
||||
|
||||
Candidate delta versus upstream master dense:
|
||||
|
||||
- BPB: `-0.002673`
|
||||
- train time: `-1.03m` (`1.09%` faster)
|
||||
- logged throughput: `+3,603 tok/s` (`1.09%` higher)
|
||||
|
||||
Important caveat: this is a full recipe comparison, not an architecture-only
|
||||
comparison. The candidate also uses `--train-log-every=50` and
|
||||
`--compile-mode=max-autotune-no-cudagraphs`, while upstream master logs every
|
||||
step and uses the default compile mode.
|
||||
|
||||
## Compile Mode Probe
|
||||
|
||||
Short d16/40 throughput probes on the minimal branch:
|
||||
|
||||
| Compile mode | Avg logged tok/s, excluding first | Avg logged step time, excluding first | Total time |
|
||||
| --- | ---: | ---: | ---: |
|
||||
| default `torch.compile` | `324,995` | `1613.250ms` | `0.78m` |
|
||||
| `max-autotune-no-cudagraphs` | `333,261` | `1573.250ms` | `0.76m` |
|
||||
|
||||
On this d16 probe, `max-autotune-no-cudagraphs` was about `2.5%` faster than
|
||||
the default compile mode. The speedrun script keeps this compile mode for that
|
||||
reason.
|
||||
|
||||
## Test Status
|
||||
|
||||
- `python -m pytest tests/test_engine.py -q`: `9 passed`
|
||||
- `python -m py_compile nanochat/gpt.py nanochat/optim.py scripts/base_train.py nanochat/engine.py`: passed
|
||||
|
|
@ -102,11 +102,14 @@ class KVCache:
|
|||
self.cache_seqlens = torch.zeros(batch_size, dtype=torch.int32, device=device)
|
||||
# Previous token's normalized embedding for smear (set by model forward pass)
|
||||
self.prev_embedding = None
|
||||
# Previous token id for hashed bigram embeddings (set by model forward pass)
|
||||
self.prev_token = None
|
||||
|
||||
def reset(self):
|
||||
"""Reset cache to empty state."""
|
||||
self.cache_seqlens.zero_()
|
||||
self.prev_embedding = None
|
||||
self.prev_token = None
|
||||
|
||||
def get_pos(self):
|
||||
"""Get current position (assumes all batch elements at same position)."""
|
||||
|
|
@ -135,6 +138,8 @@ class KVCache:
|
|||
# Copy smear state: expand batch=1 prev_embedding to num_samples
|
||||
if other.prev_embedding is not None:
|
||||
self.prev_embedding = other.prev_embedding.expand(self.batch_size, -1, -1).clone()
|
||||
if other.prev_token is not None:
|
||||
self.prev_token = other.prev_token.expand(self.batch_size, -1).clone()
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@torch.inference_mode()
|
||||
|
|
|
|||
|
|
@ -37,6 +37,8 @@ class GPTConfig:
|
|||
# Characters: L=long (full context), S=short (quarter context)
|
||||
# Examples: "L"=all full context, "SL"=alternating, "SSL"=two short then one long
|
||||
window_pattern: str = "SSSL"
|
||||
bigram_embed_factor: int = 0
|
||||
bigram_lambda_init: float = 0.05
|
||||
|
||||
|
||||
def norm(x):
|
||||
|
|
@ -172,6 +174,8 @@ class GPT(nn.Module):
|
|||
"wte": nn.Embedding(padded_vocab_size, config.n_embd),
|
||||
"h": nn.ModuleList([Block(config, layer_idx) for layer_idx in range(config.n_layer)]),
|
||||
})
|
||||
self.bigram_vocab_size = int(config.vocab_size * max(0, int(config.bigram_embed_factor)))
|
||||
self.bigram_embed = nn.Embedding(self.bigram_vocab_size, config.n_embd) if self.bigram_vocab_size > 0 else None
|
||||
self.lm_head = Linear(config.n_embd, padded_vocab_size, bias=False)
|
||||
# Per-layer learnable scalars (inspired by modded-nanogpt)
|
||||
# resid_lambdas: scales the residual stream at each layer (init 1.0 = neutral)
|
||||
|
|
@ -179,6 +183,10 @@ class GPT(nn.Module):
|
|||
# Separate parameters so they can have different optimizer treatment
|
||||
self.resid_lambdas = nn.Parameter(torch.ones(config.n_layer)) # fake init, real init in init_weights()
|
||||
self.x0_lambdas = nn.Parameter(torch.zeros(config.n_layer)) # fake init, real init in init_weights()
|
||||
if self.bigram_embed is not None:
|
||||
self.bigram_lambdas = nn.Parameter(torch.zeros(config.n_layer))
|
||||
else:
|
||||
self.register_buffer("bigram_lambdas", torch.zeros(0), persistent=False)
|
||||
# Smear: mix previous token's embedding into current token (cheap bigram-like info)
|
||||
self.smear_gate = Linear(24, 1, bias=False)
|
||||
self.smear_lambda = nn.Parameter(torch.zeros(1))
|
||||
|
|
@ -216,6 +224,8 @@ class GPT(nn.Module):
|
|||
|
||||
# Embedding and unembedding
|
||||
torch.nn.init.normal_(self.transformer.wte.weight, mean=0.0, std=0.8)
|
||||
if self.bigram_embed is not None:
|
||||
torch.nn.init.zeros_(self.bigram_embed.weight)
|
||||
torch.nn.init.normal_(self.lm_head.weight, mean=0.0, std=0.001)
|
||||
|
||||
# Transformer blocks: uniform init with bound = sqrt(3) * std (same standard deviation as normal)
|
||||
|
|
@ -237,6 +247,8 @@ class GPT(nn.Module):
|
|||
# Decaying x0 init: earlier layers get more input embedding blending
|
||||
for i in range(n_layer):
|
||||
self.x0_lambdas.data[i] = 0.20 - (0.15 * i / max(n_layer - 1, 1))
|
||||
if self.bigram_embed is not None:
|
||||
torch.nn.init.constant_(self.bigram_lambdas, self.config.bigram_lambda_init)
|
||||
|
||||
# Smear/backout scalars and smear gate must be explicitly initialized
|
||||
torch.nn.init.zeros_(self.smear_lambda)
|
||||
|
|
@ -262,9 +274,25 @@ class GPT(nn.Module):
|
|||
# because GradScaler cannot unscale fp16 gradients.
|
||||
if COMPUTE_DTYPE != torch.float16:
|
||||
self.transformer.wte.to(dtype=COMPUTE_DTYPE)
|
||||
if self.bigram_embed is not None:
|
||||
self.bigram_embed.to(dtype=COMPUTE_DTYPE)
|
||||
for ve in self.value_embeds.values():
|
||||
ve.to(dtype=COMPUTE_DTYPE)
|
||||
|
||||
def _bigram_hash(self, idx, prev_idx=None):
|
||||
mod = self.bigram_vocab_size - 1
|
||||
if mod <= 0:
|
||||
raise RuntimeError("bigram hash requested with disabled bigram embedding")
|
||||
idx_i32 = idx.to(torch.int32)
|
||||
out = torch.empty_like(idx_i32)
|
||||
if prev_idx is None:
|
||||
out[:, :1].fill_(mod)
|
||||
out[:, 1:] = torch.bitwise_xor(36313 * idx_i32[:, 1:], 27191 * idx_i32[:, :-1]) % mod
|
||||
else:
|
||||
prev_i32 = prev_idx.to(torch.int32)
|
||||
out[:] = torch.bitwise_xor(36313 * idx_i32, 27191 * prev_i32) % mod
|
||||
return out.to(torch.long)
|
||||
|
||||
def _precompute_rotary_embeddings(self, seq_len, head_dim, base=100000, device=None):
|
||||
# TODO: bump base theta more? e.g. 100K is more common more recently
|
||||
# autodetect the device from model embeddings
|
||||
|
|
@ -329,8 +357,9 @@ class GPT(nn.Module):
|
|||
nparams = sum(p.numel() for p in self.parameters())
|
||||
# Exclude non-matmul params: embeddings and per-layer scalars
|
||||
value_embeds_numel = sum(ve.weight.numel() for ve in self.value_embeds.values())
|
||||
nparams_exclude = (self.transformer.wte.weight.numel() + value_embeds_numel +
|
||||
self.resid_lambdas.numel() + self.x0_lambdas.numel() +
|
||||
bigram_embed_numel = self.bigram_embed.weight.numel() if self.bigram_embed is not None else 0
|
||||
nparams_exclude = (self.transformer.wte.weight.numel() + value_embeds_numel + bigram_embed_numel +
|
||||
self.resid_lambdas.numel() + self.x0_lambdas.numel() + self.bigram_lambdas.numel() +
|
||||
self.smear_gate.weight.numel() + self.smear_lambda.numel() + self.backout_lambda.numel())
|
||||
h, q, t = self.config.n_head, self.config.n_embd // self.config.n_head, self.config.sequence_len
|
||||
# Sum attention FLOPs per layer, accounting for sliding window
|
||||
|
|
@ -356,14 +385,17 @@ class GPT(nn.Module):
|
|||
"""
|
||||
# Count each group separately (mirrors the grouping in setup_optimizers)
|
||||
wte = sum(p.numel() for p in self.transformer.wte.parameters())
|
||||
bigram_embed = self.bigram_embed.weight.numel() if self.bigram_embed is not None else 0
|
||||
value_embeds = sum(p.numel() for p in self.value_embeds.parameters())
|
||||
lm_head = sum(p.numel() for p in self.lm_head.parameters())
|
||||
transformer_matrices = sum(p.numel() for p in self.transformer.h.parameters())
|
||||
scalars = self.resid_lambdas.numel() + self.x0_lambdas.numel() + self.smear_gate.weight.numel() + self.smear_lambda.numel() + self.backout_lambda.numel()
|
||||
total = wte + value_embeds + lm_head + transformer_matrices + scalars
|
||||
bigram_lambdas = self.bigram_lambdas.numel() if isinstance(self.bigram_lambdas, nn.Parameter) else 0
|
||||
scalars = self.resid_lambdas.numel() + self.x0_lambdas.numel() + bigram_lambdas + self.smear_gate.weight.numel() + self.smear_lambda.numel() + self.backout_lambda.numel()
|
||||
total = wte + bigram_embed + value_embeds + lm_head + transformer_matrices + scalars
|
||||
assert total == sum(p.numel() for p in self.parameters()), "Parameter count mismatch"
|
||||
return {
|
||||
'wte': wte,
|
||||
'bigram_embed': bigram_embed,
|
||||
'value_embeds': value_embeds,
|
||||
'lm_head': lm_head,
|
||||
'transformer_matrices': transformer_matrices,
|
||||
|
|
@ -371,40 +403,60 @@ class GPT(nn.Module):
|
|||
'total': total,
|
||||
}
|
||||
|
||||
def setup_optimizer(self, unembedding_lr=0.004, embedding_lr=0.2, matrix_lr=0.02, weight_decay=0.0, scalar_lr=0.5):
|
||||
def setup_optimizer(
|
||||
self,
|
||||
unembedding_lr=0.004,
|
||||
embedding_lr=0.2,
|
||||
bigram_embedding_lr_mult=1.0,
|
||||
bigram_lambda_lr=0.004,
|
||||
matrix_lr=0.02,
|
||||
weight_decay=0.0,
|
||||
scalar_lr=0.5,
|
||||
muon_plus=False,
|
||||
muon_eq_axis=0,
|
||||
):
|
||||
model_dim = self.config.n_embd
|
||||
ddp, rank, local_rank, world_size = get_dist_info()
|
||||
|
||||
# Separate out all parameters into groups
|
||||
matrix_params = list(self.transformer.h.parameters())
|
||||
value_embeds_params = list(self.value_embeds.parameters())
|
||||
bigram_embed_params = list(self.bigram_embed.parameters()) if self.bigram_embed is not None else []
|
||||
embedding_params = list(self.transformer.wte.parameters())
|
||||
lm_head_params = list(self.lm_head.parameters())
|
||||
resid_params = [self.resid_lambdas]
|
||||
x0_params = [self.x0_lambdas]
|
||||
bigram_lambda_params = [self.bigram_lambdas] if isinstance(self.bigram_lambdas, nn.Parameter) else []
|
||||
smear_params = [self.smear_gate.weight, self.smear_lambda, self.backout_lambda]
|
||||
assert len(list(self.parameters())) == len(matrix_params) + len(embedding_params) + len(lm_head_params) + len(value_embeds_params) + len(resid_params) + len(x0_params) + len(smear_params)
|
||||
assert len(list(self.parameters())) == len(matrix_params) + len(embedding_params) + len(lm_head_params) + len(value_embeds_params) + len(bigram_embed_params) + len(resid_params) + len(x0_params) + len(bigram_lambda_params) + len(smear_params)
|
||||
|
||||
# Scale the LR for the AdamW parameters by ∝1/√dmodel (tuned for 768 dim model)
|
||||
dmodel_lr_scale = (model_dim / 768) ** -0.5
|
||||
print0(f"Scaling the LR for the AdamW parameters ∝1/√({model_dim}/768) = {dmodel_lr_scale:.6f}")
|
||||
|
||||
# Build param_groups with all required fields explicit
|
||||
# AdamW groups (embeddings, lm_head, scalars)
|
||||
param_groups = [
|
||||
# AdamW groups (embeddings, lm_head, scalars)
|
||||
dict(kind='adamw', params=lm_head_params, lr=unembedding_lr * dmodel_lr_scale, betas=(0.8, 0.96), eps=1e-10, weight_decay=0.01),
|
||||
dict(kind='adamw', params=embedding_params, lr=embedding_lr * dmodel_lr_scale, betas=(0.8, 0.995), eps=1e-10, weight_decay=0.001),
|
||||
dict(kind='adamw', params=value_embeds_params, lr=embedding_lr * dmodel_lr_scale * 0.5, betas=(0.8, 0.995), eps=1e-10, weight_decay=0.01),
|
||||
]
|
||||
if bigram_embed_params:
|
||||
param_groups.append(dict(kind='adamw', params=bigram_embed_params, lr=embedding_lr * dmodel_lr_scale * bigram_embedding_lr_mult, betas=(0.75, 0.95), eps=1e-10, weight_decay=0.01))
|
||||
param_groups.extend([
|
||||
dict(kind='adamw', params=resid_params, lr=scalar_lr * 0.01, betas=(0.8, 0.95), eps=1e-10, weight_decay=0.05),
|
||||
dict(kind='adamw', params=x0_params, lr=scalar_lr, betas=(0.96, 0.95), eps=1e-10, weight_decay=0.0), # higher beta1 for x0
|
||||
dict(kind='adamw', params=smear_params, lr=0.2, betas=(0.8, 0.95), eps=1e-10, weight_decay=0.0),
|
||||
]
|
||||
])
|
||||
if bigram_embed_params:
|
||||
param_groups.append(dict(kind='adamw', params=bigram_lambda_params, lr=bigram_lambda_lr * dmodel_lr_scale, betas=(0.9, 0.95), eps=1e-10, weight_decay=0.0))
|
||||
param_groups.append(dict(kind='adamw', params=smear_params, lr=0.2, betas=(0.8, 0.95), eps=1e-10, weight_decay=0.0))
|
||||
# Muon groups (matrix params, grouped by shape for stacking)
|
||||
for shape in sorted({p.shape for p in matrix_params}):
|
||||
group_params = [p for p in matrix_params if p.shape == shape]
|
||||
param_groups.append(dict(
|
||||
kind='muon', params=group_params, lr=matrix_lr,
|
||||
momentum=0.95, ns_steps=5, beta2=0.9, weight_decay=weight_decay,
|
||||
muon_plus=muon_plus, muon_eq_axis=muon_eq_axis,
|
||||
))
|
||||
|
||||
Factory = DistMuonAdamW if ddp else MuonAdamW
|
||||
|
|
@ -448,6 +500,19 @@ class GPT(nn.Module):
|
|||
gate = self.smear_lambda.to(x.dtype) * torch.sigmoid(self.smear_gate(x[:, :, :24]))
|
||||
x = x + gate * x_pre_smear
|
||||
|
||||
# Optional hashed bigram embedding residual. During KV-cache decoding we need the
|
||||
# previous token id because the sequence length is one.
|
||||
if self.bigram_embed is not None:
|
||||
if kv_cache is None or T > 1:
|
||||
bigram_idx = self._bigram_hash(idx)
|
||||
else:
|
||||
bigram_idx = self._bigram_hash(idx, kv_cache.prev_token)
|
||||
x0_bigram = self.bigram_embed(bigram_idx).to(x.dtype)
|
||||
else:
|
||||
x0_bigram = None
|
||||
if kv_cache is not None:
|
||||
kv_cache.prev_token = idx[:, -1:].clone()
|
||||
|
||||
# Forward the trunk of the Transformer
|
||||
x0 = x # save initial normalized embedding for x0 residual
|
||||
n_layer = self.config.n_layer
|
||||
|
|
@ -455,6 +520,8 @@ class GPT(nn.Module):
|
|||
x_backout = None
|
||||
for i, block in enumerate(self.transformer.h):
|
||||
x = self.resid_lambdas[i] * x + self.x0_lambdas[i] * x0
|
||||
if x0_bigram is not None:
|
||||
x = x + self.bigram_lambdas[i].to(x.dtype) * x0_bigram
|
||||
ve = self.value_embeds[str(i)](idx).to(x.dtype) if str(i) in self.value_embeds else None
|
||||
x = block(x, ve, cos_sin, self.window_sizes[i], kv_cache)
|
||||
if i == backout_layer:
|
||||
|
|
|
|||
|
|
@ -100,6 +100,8 @@ def muon_step_fused(
|
|||
beta2_t: Tensor, # () - 0-D CPU tensor, beta2 for second moment
|
||||
ns_steps: int, # 5 - number of Newton-Schulz/Polar Express iterations
|
||||
red_dim: int, # -1 or -2 - reduction dimension for variance
|
||||
muon_plus: bool, # add one Frobenius renormalization after orthogonalization
|
||||
muon_eq_axis: int, # 0 none, 1 row, 2 column equilibration before orthogonalization
|
||||
) -> None:
|
||||
"""
|
||||
Fused Muon step: momentum -> polar_express -> variance_reduction -> cautious_update
|
||||
|
|
@ -115,6 +117,14 @@ def muon_step_fused(
|
|||
# Polar express
|
||||
# Cast to bf16 for speed when available; skip cast otherwise (fp16 is unstable here due to limited exponent range)
|
||||
X = g.bfloat16() if COMPUTE_DTYPE == torch.bfloat16 else g
|
||||
if muon_eq_axis == 1:
|
||||
target = X.float().norm(dim=(-2, -1), keepdim=True) / (X.size(-2) ** 0.5)
|
||||
row_norm = X.float().norm(dim=-1, keepdim=True).clamp_min(1e-6)
|
||||
X = X * (target / row_norm).to(X.dtype)
|
||||
elif muon_eq_axis == 2:
|
||||
target = X.float().norm(dim=(-2, -1), keepdim=True) / (X.size(-1) ** 0.5)
|
||||
col_norm = X.float().norm(dim=-2, keepdim=True).clamp_min(1e-6)
|
||||
X = X * (target / col_norm).to(X.dtype)
|
||||
X = X / (X.norm(dim=(-2, -1), keepdim=True) * 1.01 + 1e-6)
|
||||
if g.size(-2) > g.size(-1): # Tall matrix
|
||||
for a, b, c in polar_express_coeffs[:ns_steps]:
|
||||
|
|
@ -127,6 +137,10 @@ def muon_step_fused(
|
|||
B = b * A + c * (A @ A)
|
||||
X = a * X + B @ X
|
||||
g = X
|
||||
if muon_plus:
|
||||
target_norm = min(g.size(-2), g.size(-1)) ** 0.5
|
||||
current_norm = g.float().norm(dim=(-2, -1), keepdim=True).clamp_min(1e-6)
|
||||
g = g * (target_norm / current_norm).to(g.dtype)
|
||||
|
||||
# Variance reduction
|
||||
beta2 = beta2_t.to(g.dtype)
|
||||
|
|
@ -277,6 +291,8 @@ class MuonAdamW(torch.optim.Optimizer):
|
|||
self._muon_beta2_t,
|
||||
group["ns_steps"],
|
||||
red_dim,
|
||||
group.get("muon_plus", False),
|
||||
group.get("muon_eq_axis", 0),
|
||||
)
|
||||
|
||||
# Copy back to original params
|
||||
|
|
@ -486,7 +502,7 @@ class DistMuonAdamW(torch.optim.Optimizer):
|
|||
grad_chunk[:num_owned], stacked_owned,
|
||||
state["momentum_buffer"][:num_owned], state["second_momentum_buffer"][:num_owned],
|
||||
self._muon_momentum_t, self._muon_lr_t, self._muon_wd_t, self._muon_beta2_t,
|
||||
group["ns_steps"], red_dim,
|
||||
group["ns_steps"], red_dim, group.get("muon_plus", False), group.get("muon_eq_axis", 0),
|
||||
)
|
||||
updated_params[:num_owned].copy_(stacked_owned)
|
||||
|
||||
|
|
|
|||
|
|
@ -70,7 +70,19 @@ echo "Waiting for dataset download to complete..."
|
|||
wait $DATASET_DOWNLOAD_PID
|
||||
|
||||
# d24 model (slightly undertrained to beat GPT-2 => decrease data:params ratio from compute optimal 10.5 (default) to 8)
|
||||
torchrun --standalone --nproc_per_node=8 -m scripts.base_train -- --depth=24 --target-param-data-ratio=8 --device-batch-size=16 --fp8 --run=$WANDB_RUN
|
||||
torchrun --standalone --nproc_per_node=8 -m scripts.base_train -- \
|
||||
--depth=24 \
|
||||
--target-param-data-ratio=8 \
|
||||
--device-batch-size=16 \
|
||||
--total-batch-size=1048576 \
|
||||
--fp8 \
|
||||
--compile-mode=max-autotune-no-cudagraphs \
|
||||
--muon-plus \
|
||||
--muon-eq=row \
|
||||
--bigram-embed-factor=5 \
|
||||
--scalar-lr=0.3 \
|
||||
--train-log-every=50 \
|
||||
--run=$WANDB_RUN
|
||||
# evaluate the model: CORE metric, BPB on train/val, and draw samples
|
||||
torchrun --standalone --nproc_per_node=8 -m scripts.base_eval -- --device-batch-size=16
|
||||
|
||||
|
|
|
|||
|
|
@ -41,17 +41,23 @@ print_banner()
|
|||
parser = argparse.ArgumentParser(description="Pretrain base model")
|
||||
# Logging
|
||||
parser.add_argument("--run", type=str, default="dummy", help="wandb run name ('dummy' disables wandb logging)")
|
||||
parser.add_argument("--train-log-every", type=int, default=1, help="print training metrics every N steps; values >1 avoid per-step CPU/GPU sync")
|
||||
# Runtime
|
||||
parser.add_argument("--device-type", type=str, default="", help="cuda|cpu|mps (empty = autodetect)")
|
||||
# FP8 training
|
||||
parser.add_argument("--fp8", action="store_true", help="enable FP8 training (requires H100+ GPU and torchao)")
|
||||
parser.add_argument("--fp8-recipe", type=str, default="tensorwise", choices=["rowwise", "tensorwise"], help="FP8 scaling recipe: tensorwise (faster, recommended) or rowwise (more accurate but slower)")
|
||||
parser.add_argument("--compile-mode", type=str, default="", choices=["", "default", "reduce-overhead", "max-autotune", "max-autotune-no-cudagraphs"], help="optional torch.compile mode")
|
||||
# Model architecture
|
||||
parser.add_argument("--depth", type=int, default=20, help="depth of the Transformer model")
|
||||
parser.add_argument("--aspect-ratio", type=int, default=64, help="model_dim = depth * aspect_ratio")
|
||||
parser.add_argument("--head-dim", type=int, default=128, help="target head dimension for attention")
|
||||
parser.add_argument("--max-seq-len", type=int, default=2048, help="max context length")
|
||||
parser.add_argument("--window-pattern", type=str, default="SSSL", help="sliding window pattern tiled across layers: L=full, S=half context (e.g. 'SSL')")
|
||||
parser.add_argument("--bigram-embed-factor", type=int, default=0, help="if >0, add a hashed bigram embedding residual")
|
||||
parser.add_argument("--bigram-lambda-init", type=float, default=0.05, help="initial layer residual scale for --bigram-embed-factor")
|
||||
parser.add_argument("--bigram-embedding-lr-mult", type=float, default=1.0, help="bigram embedding LR multiplier relative to --embedding-lr")
|
||||
parser.add_argument("--bigram-lambda-lr", type=float, default=0.004, help="AdamW LR for bigram layer lambdas before dmodel schedule scaling")
|
||||
# Training horizon (only one used, in order of precedence)
|
||||
parser.add_argument("--num-iterations", type=int, default=-1, help="explicit number of optimization steps (-1 = disable)")
|
||||
parser.add_argument("--target-flops", type=float, default=-1.0, help="calculate num_iterations to reach target_flops (-1 = disable)")
|
||||
|
|
@ -64,6 +70,8 @@ parser.add_argument("--unembedding-lr", type=float, default=0.008, help="learnin
|
|||
parser.add_argument("--weight-decay", type=float, default=0.28, help="cautious weight decay for the Muon optimizer (for weights)")
|
||||
parser.add_argument("--matrix-lr", type=float, default=0.02, help="learning rate for matrix parameters (Muon)")
|
||||
parser.add_argument("--scalar-lr", type=float, default=0.5, help="learning rate for scalars (resid_lambdas, x0_lambdas)")
|
||||
parser.add_argument("--muon-plus", action="store_true", help="apply Muon+ style post-orthogonalization Frobenius renormalization")
|
||||
parser.add_argument("--muon-eq", type=str, default="none", choices=["none", "row", "col"], help="apply MuonEq-style row/column equilibration before orthogonalization")
|
||||
parser.add_argument("--warmup-steps", type=int, default=40, help="number of steps for LR warmup")
|
||||
parser.add_argument("--warmdown-ratio", type=float, default=0.65, help="ratio of iterations for LR warmdown")
|
||||
parser.add_argument("--final-lr-frac", type=float, default=0.05, help="final LR as fraction of initial LR")
|
||||
|
|
@ -71,6 +79,7 @@ parser.add_argument("--resume-from-step", type=int, default=-1, help="resume tra
|
|||
# Evaluation
|
||||
parser.add_argument("--eval-every", type=int, default=250, help="evaluate val bpb every N steps (-1 = disable)")
|
||||
parser.add_argument("--eval-tokens", type=int, default=80*524288, help="number of tokens to evaluate val loss on")
|
||||
parser.add_argument("--skip-initial-eval", action="store_true", help="skip the step 0 validation pass; final validation still runs")
|
||||
parser.add_argument("--core-metric-every", type=int, default=2000, help="evaluate CORE metric every N steps (-1 = disable)")
|
||||
parser.add_argument("--core-metric-max-per-task", type=int, default=500, help="examples per task for CORE metric")
|
||||
parser.add_argument("--sample-every", type=int, default=2000, help="sample from model every N steps (-1 = disable)")
|
||||
|
|
@ -79,6 +88,14 @@ parser.add_argument("--save-every", type=int, default=-1, help="save checkpoints
|
|||
parser.add_argument("--model-tag", type=str, default=None, help="override model tag for checkpoint directory name")
|
||||
args = parser.parse_args()
|
||||
user_config = vars(args).copy() # for logging
|
||||
if args.train_log_every <= 0:
|
||||
parser.error("--train-log-every must be positive")
|
||||
if args.bigram_embed_factor < 0:
|
||||
parser.error("--bigram-embed-factor must be non-negative")
|
||||
if args.bigram_lambda_lr < 0:
|
||||
parser.error("--bigram-lambda-lr must be non-negative")
|
||||
if args.bigram_embedding_lr_mult <= 0:
|
||||
parser.error("--bigram-embedding-lr-mult must be positive")
|
||||
# -----------------------------------------------------------------------------
|
||||
# Compute init and wandb logging
|
||||
|
||||
|
|
@ -137,6 +154,8 @@ def build_model_meta(depth):
|
|||
sequence_len=args.max_seq_len, vocab_size=vocab_size,
|
||||
n_layer=depth, n_head=num_heads, n_kv_head=num_heads, n_embd=model_dim,
|
||||
window_pattern=args.window_pattern,
|
||||
bigram_embed_factor=args.bigram_embed_factor,
|
||||
bigram_lambda_init=args.bigram_lambda_init,
|
||||
)
|
||||
with torch.device("meta"):
|
||||
model_meta = GPT(config)
|
||||
|
|
@ -243,7 +262,10 @@ def disable_fp8(model):
|
|||
# Compile the model
|
||||
|
||||
orig_model = model # original, uncompiled model, for saving raw model state_dict and for inference/evaluation (because the shapes may change shape)
|
||||
model = torch.compile(model, dynamic=False) # the inputs to model will never change shape so dynamic=False is safe
|
||||
compile_kwargs = {"dynamic": False}
|
||||
if args.compile_mode:
|
||||
compile_kwargs["mode"] = args.compile_mode
|
||||
model = torch.compile(model, **compile_kwargs) # the inputs to model will never change shape so dynamic=False is safe
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Scaling laws and muP extrapolations to determine the optimal training horizon, batch size, learning rates, weight decay.
|
||||
|
|
@ -305,14 +327,20 @@ if weight_decay_scaled != args.weight_decay:
|
|||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Initialize the Optimizer (combined MuonAdamW: Muon for matrix params, AdamW for rest)
|
||||
muon_eq_axis = {"none": 0, "row": 1, "col": 2}[args.muon_eq]
|
||||
print0(f"Muon options: muon_plus={args.muon_plus}, muon_eq={args.muon_eq}")
|
||||
optimizer = model.setup_optimizer(
|
||||
# AdamW hyperparameters
|
||||
unembedding_lr=args.unembedding_lr * batch_lr_scale,
|
||||
embedding_lr=args.embedding_lr * batch_lr_scale,
|
||||
bigram_embedding_lr_mult=args.bigram_embedding_lr_mult,
|
||||
bigram_lambda_lr=args.bigram_lambda_lr * batch_lr_scale,
|
||||
scalar_lr=args.scalar_lr * batch_lr_scale,
|
||||
# Muon hyperparameters
|
||||
matrix_lr=args.matrix_lr * batch_lr_scale,
|
||||
weight_decay=weight_decay_scaled,
|
||||
muon_plus=args.muon_plus,
|
||||
muon_eq_axis=muon_eq_axis,
|
||||
)
|
||||
|
||||
if resuming:
|
||||
|
|
@ -411,6 +439,11 @@ grad_accum_steps = total_batch_size // world_tokens_per_fwdbwd
|
|||
print0(f"Tokens / micro-batch / rank: {args.device_batch_size} x {args.max_seq_len} = {tokens_per_fwdbwd:,}")
|
||||
print0(f"Tokens / micro-batch: {world_tokens_per_fwdbwd:,}")
|
||||
print0(f"Total batch size {total_batch_size:,} => gradient accumulation steps: {grad_accum_steps}")
|
||||
train_log_every = args.train_log_every
|
||||
batched_train_timing = train_log_every > 1
|
||||
train_timing_interval_start = None
|
||||
train_timing_interval_first_step = step
|
||||
train_log_count = 0
|
||||
|
||||
# Go!
|
||||
while True:
|
||||
|
|
@ -418,7 +451,7 @@ while True:
|
|||
flops_so_far = num_flops_per_token * total_batch_size * step
|
||||
|
||||
# once in a while: evaluate the val bpb (all ranks participate)
|
||||
if args.eval_every > 0 and (last_step or step % args.eval_every == 0):
|
||||
if args.eval_every > 0 and (last_step or (step % args.eval_every == 0 and (step > 0 or not args.skip_initial_eval))):
|
||||
model.eval()
|
||||
val_loader = build_val_loader()
|
||||
eval_steps = args.eval_tokens // (args.device_batch_size * args.max_seq_len * ddp_world_size)
|
||||
|
|
@ -505,8 +538,14 @@ while True:
|
|||
# -------------------------------------------------------------------------
|
||||
# single training step
|
||||
# evaluate the gradient
|
||||
synchronize()
|
||||
t0 = time.time()
|
||||
if batched_train_timing:
|
||||
if train_timing_interval_start is None:
|
||||
synchronize()
|
||||
train_timing_interval_start = time.time()
|
||||
train_timing_interval_first_step = step
|
||||
else:
|
||||
synchronize()
|
||||
t0 = time.time()
|
||||
for micro_step in range(grad_accum_steps):
|
||||
loss = model(x, y)
|
||||
train_loss = loss.detach() # for logging
|
||||
|
|
@ -538,46 +577,66 @@ while True:
|
|||
else:
|
||||
optimizer.step()
|
||||
model.zero_grad(set_to_none=True)
|
||||
train_loss_f = train_loss.item() # .item() is a CPU-GPU sync point
|
||||
synchronize()
|
||||
t1 = time.time()
|
||||
dt = t1 - t0
|
||||
should_log_train = step == 0 or (step + 1) % train_log_every == 0 or (step + 1) == num_iterations
|
||||
if batched_train_timing:
|
||||
if should_log_train:
|
||||
synchronize()
|
||||
t1 = time.time()
|
||||
interval_steps = step - train_timing_interval_first_step + 1
|
||||
interval_dt = t1 - train_timing_interval_start
|
||||
dt = interval_dt / interval_steps
|
||||
counted_start = max(train_timing_interval_first_step, 11)
|
||||
counted_steps = max(0, step - counted_start + 1)
|
||||
if counted_steps > 0:
|
||||
total_training_time += interval_dt * counted_steps / interval_steps
|
||||
train_loss_f = train_loss.item()
|
||||
train_timing_interval_start = None
|
||||
else:
|
||||
dt = None
|
||||
train_loss_f = None
|
||||
else:
|
||||
train_loss_f = train_loss.item() # .item() is a CPU-GPU sync point
|
||||
synchronize()
|
||||
t1 = time.time()
|
||||
dt = t1 - t0
|
||||
if step > 10:
|
||||
total_training_time += dt # only count the time after the first 10 steps
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
# logging (CPU action only)
|
||||
ema_beta = 0.9 # EMA decay factor for some smoothing just for nicer logging
|
||||
smooth_train_loss = ema_beta * smooth_train_loss + (1 - ema_beta) * train_loss_f # EMA the training loss
|
||||
debiased_smooth_loss = smooth_train_loss / (1 - ema_beta**(step + 1)) # debias the EMA
|
||||
pct_done = 100 * step / num_iterations
|
||||
tok_per_sec = int(total_batch_size / dt)
|
||||
flops_per_sec = num_flops_per_token * total_batch_size / dt
|
||||
mfu = 100 * flops_per_sec / (gpu_peak_flops * ddp_world_size)
|
||||
if step > 10:
|
||||
total_training_time += dt # only count the time after the first 10 steps
|
||||
# Calculate ETA based on average time per step (excluding first 10 steps)
|
||||
steps_done = step - 10
|
||||
if steps_done > 0:
|
||||
avg_time_per_step = total_training_time / steps_done
|
||||
remaining_steps = num_iterations - step
|
||||
eta_seconds = remaining_steps * avg_time_per_step
|
||||
eta_str = f" | eta: {eta_seconds/60:.1f}m"
|
||||
else:
|
||||
eta_str = ""
|
||||
epoch = f"{dataloader_state_dict['epoch']} pq: {dataloader_state_dict['pq_idx']} rg: {dataloader_state_dict['rg_idx']}"
|
||||
print0(f"step {step:05d}/{num_iterations:05d} ({pct_done:.2f}%) | loss: {debiased_smooth_loss:.6f} | lrm: {lrm:.2f} | dt: {dt * 1000:.2f}ms | tok/sec: {tok_per_sec:,} | bf16_mfu: {mfu:.2f} | epoch: {epoch} | total time: {total_training_time/60:.2f}m{eta_str}")
|
||||
if step % 100 == 0:
|
||||
log_data = {
|
||||
"step": step,
|
||||
"total_training_flops": flops_so_far,
|
||||
"total_training_time": total_training_time,
|
||||
"train/loss": debiased_smooth_loss,
|
||||
"train/lrm": lrm,
|
||||
"train/dt": dt,
|
||||
"train/tok_per_sec": tok_per_sec,
|
||||
"train/mfu": mfu,
|
||||
"train/epoch": epoch,
|
||||
}
|
||||
wandb_run.log(log_data)
|
||||
if should_log_train:
|
||||
ema_beta = 0.9 # EMA decay factor for some smoothing just for nicer logging
|
||||
smooth_train_loss = ema_beta * smooth_train_loss + (1 - ema_beta) * train_loss_f # EMA the training loss
|
||||
train_log_count += 1
|
||||
debiased_smooth_loss = smooth_train_loss / (1 - ema_beta**train_log_count) # debias the EMA
|
||||
pct_done = 100 * step / num_iterations
|
||||
tok_per_sec = int(total_batch_size / dt)
|
||||
flops_per_sec = num_flops_per_token * total_batch_size / dt
|
||||
mfu = 100 * flops_per_sec / (gpu_peak_flops * ddp_world_size)
|
||||
# Calculate ETA based on average time per step (excluding first 10 steps)
|
||||
steps_done = step - 10
|
||||
if steps_done > 0:
|
||||
avg_time_per_step = total_training_time / steps_done
|
||||
remaining_steps = num_iterations - step
|
||||
eta_seconds = remaining_steps * avg_time_per_step
|
||||
eta_str = f" | eta: {eta_seconds/60:.1f}m"
|
||||
else:
|
||||
eta_str = ""
|
||||
epoch = f"{dataloader_state_dict['epoch']} pq: {dataloader_state_dict['pq_idx']} rg: {dataloader_state_dict['rg_idx']}"
|
||||
print0(f"step {step:05d}/{num_iterations:05d} ({pct_done:.2f}%) | loss: {debiased_smooth_loss:.6f} | lrm: {lrm:.2f} | dt: {dt * 1000:.2f}ms | tok/sec: {tok_per_sec:,} | bf16_mfu: {mfu:.2f} | epoch: {epoch} | total time: {total_training_time/60:.2f}m{eta_str}")
|
||||
if step % 100 == 0 or (step + 1) % 100 == 0:
|
||||
log_data = {
|
||||
"step": step,
|
||||
"total_training_flops": flops_so_far,
|
||||
"total_training_time": total_training_time,
|
||||
"train/loss": debiased_smooth_loss,
|
||||
"train/lrm": lrm,
|
||||
"train/dt": dt,
|
||||
"train/tok_per_sec": tok_per_sec,
|
||||
"train/mfu": mfu,
|
||||
"train/epoch": epoch,
|
||||
}
|
||||
wandb_run.log(log_data)
|
||||
|
||||
# state update
|
||||
first_step_of_run = (step == 0) or (resuming and step == args.resume_from_step)
|
||||
|
|
|
|||
|
|
@ -47,6 +47,25 @@ class MockModel:
|
|||
return logits
|
||||
|
||||
|
||||
class BigramStateModel(MockModel):
|
||||
"""Mock model whose greedy next token depends on current and previous token ids."""
|
||||
def forward(self, ids, kv_cache=None):
|
||||
B, T = ids.shape
|
||||
if kv_cache is None:
|
||||
prev = torch.cat([torch.zeros(B, 1, dtype=ids.dtype), ids[:, :-1]], dim=1)
|
||||
else:
|
||||
if T > 1 or kv_cache.prev_token is None:
|
||||
prev = torch.cat([torch.zeros(B, 1, dtype=ids.dtype), ids[:, :-1]], dim=1)
|
||||
else:
|
||||
prev = kv_cache.prev_token
|
||||
kv_cache.prev_token = ids[:, -1:].clone()
|
||||
kv_cache.advance(T)
|
||||
next_token = ((ids + prev + 1) % 256).long()
|
||||
logits = torch.full((B, T, self.vocab_size), -1000.0)
|
||||
logits.scatter_(2, next_token.unsqueeze(-1), 1000.0)
|
||||
return logits
|
||||
|
||||
|
||||
class ByteTokenizer:
|
||||
"""
|
||||
Simple byte-level tokenizer for testing.
|
||||
|
|
@ -114,6 +133,7 @@ def test_kv_cache_basic():
|
|||
# Test reset
|
||||
kv_cache.reset()
|
||||
assert kv_cache.get_pos() == 0
|
||||
assert kv_cache.prev_token is None
|
||||
|
||||
# Test get_layer_cache returns correct views
|
||||
k_layer0, v_layer0 = kv_cache.get_layer_cache(0)
|
||||
|
|
@ -136,6 +156,7 @@ def test_kv_cache_prefill():
|
|||
# Write some data to source cache
|
||||
src_cache.k_cache[0, 0, :16, :, :] = 1.0
|
||||
src_cache.v_cache[0, 0, :16, :, :] = 2.0
|
||||
src_cache.prev_token = torch.tensor([[123]])
|
||||
src_cache.advance(16)
|
||||
|
||||
# Create destination cache with larger seq_len
|
||||
|
|
@ -153,6 +174,29 @@ def test_kv_cache_prefill():
|
|||
# Check data was copied
|
||||
assert (dst_cache.k_cache[0, 0, :16, :, :] == 1.0).all()
|
||||
assert (dst_cache.v_cache[0, 0, :16, :, :] == 2.0).all()
|
||||
assert dst_cache.prev_token.tolist() == [[123]]
|
||||
|
||||
|
||||
def test_engine_preserves_bigram_prev_token_state():
|
||||
"""Engine KV-cache generation should match naive generation for previous-token state."""
|
||||
model = BigramStateModel()
|
||||
tokenizer = ByteTokenizer()
|
||||
engine = Engine(model, tokenizer)
|
||||
prompt = [261, 17, 23, 42]
|
||||
max_tokens = 8
|
||||
|
||||
def naive_generate(tokens):
|
||||
ids = torch.tensor([tokens], dtype=torch.long)
|
||||
out = []
|
||||
for _ in range(max_tokens):
|
||||
logits = model.forward(ids)
|
||||
next_id = int(logits[:, -1, :].argmax(dim=-1).item())
|
||||
out.append(next_id)
|
||||
ids = torch.cat([ids, torch.tensor([[next_id]], dtype=torch.long)], dim=1)
|
||||
return tokens + out
|
||||
|
||||
results, _ = engine.generate_batch(prompt, temperature=0.0, max_tokens=max_tokens)
|
||||
assert results[0] == naive_generate(prompt)
|
||||
|
||||
|
||||
def test_multi_sample_first_token_diversity():
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user