From e28d4ead22c6b9d1238042c81db12e1e5760c438 Mon Sep 17 00:00:00 2001 From: dangxingyu Date: Tue, 3 Feb 2026 20:14:51 -0500 Subject: [PATCH 01/14] Add muonh model and quickrun --- nanochat/gpt.py | 99 +++++++++---- nanochat/optim.py | 233 ++++++++++++++++++++++++++++++- runs/quickrun_gamma_muonh_d24.sh | 156 +++++++++++++++++++++ scripts/base_train.py | 9 +- 4 files changed, 464 insertions(+), 33 deletions(-) create mode 100755 runs/quickrun_gamma_muonh_d24.sh diff --git a/nanochat/gpt.py b/nanochat/gpt.py index 208acd1..fc3d422 100644 --- a/nanochat/gpt.py +++ b/nanochat/gpt.py @@ -2,14 +2,16 @@ GPT model (rewrite, a lot simpler) Notable features: - rotary embeddings (and no positional embeddings) -- QK norm +- QK norm (functional RMSNorm, no learnable params) - untied weights for token embedding and lm_head - relu^2 activation in MLP -- norm after token embedding -- no learnable params in rmsnorm +- norm after token embedding (functional RMSNorm) +- parameterized RMSNorm in blocks (learnable gamma) +- per-block projection scalars for attention/MLP - no bias in linear layers - Group-Query Attention (GQA) support for more efficient inference - Flash Attention 3 integration +- optional Hyperball optimizer for matrix params """ from functools import partial @@ -40,7 +42,7 @@ class GPTConfig: def norm(x): - # Purely functional rmsnorm with no learnable params + # Purely functional RMSNorm with no learnable params return F.rms_norm(x, (x.size(-1),)) @@ -72,6 +74,7 @@ class CausalSelfAttention(nn.Module): self.c_proj = nn.Linear(self.n_embd, self.n_embd, bias=False) self.ve_gate_channels = 32 self.ve_gate = nn.Linear(self.ve_gate_channels, self.n_kv_head, bias=False) if has_ve(layer_idx, config.n_layer) else None + self.c_proj_scalar = nn.Parameter(torch.zeros(config.n_embd)) def forward(self, x, ve, cos_sin, window_size, kv_cache): B, T, C = x.size() @@ -115,6 +118,7 @@ class CausalSelfAttention(nn.Module): # Re-assemble the heads and project back to residual stream y = y.contiguous().view(B, T, -1) y = self.c_proj(y) + y = y * self.c_proj_scalar return y @@ -123,23 +127,27 @@ class MLP(nn.Module): super().__init__() self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=False) self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=False) + self.c_proj_scalar = nn.Parameter(torch.zeros(config.n_embd)) def forward(self, x): x = self.c_fc(x) x = F.relu(x).square() x = self.c_proj(x) + x = x * self.c_proj_scalar return x class Block(nn.Module): def __init__(self, config, layer_idx): super().__init__() + self.attn_norm = nn.RMSNorm(config.n_embd) self.attn = CausalSelfAttention(config, layer_idx) + self.mlp_norm = nn.RMSNorm(config.n_embd) self.mlp = MLP(config) def forward(self, x, ve, cos_sin, window_size, kv_cache): - x = x + self.attn(norm(x), ve, cos_sin, window_size, kv_cache) - x = x + self.mlp(norm(x)) + x = x + self.attn(self.attn_norm(x), ve, cos_sin, window_size, kv_cache) + x = x + self.mlp(self.mlp_norm(x)) return x @@ -196,15 +204,21 @@ class GPT(nn.Module): attn.c_q: uniform, std=1/sqrt(n_embd) attn.c_k: uniform, std=1/sqrt(n_embd) attn.c_v: uniform, std=1/sqrt(n_embd) - attn.c_proj: zeros + attn.c_proj: uniform, std=1/sqrt(n_embd) mlp.c_fc: uniform, std=1/sqrt(n_embd) - mlp.c_proj: zeros + mlp.c_proj: uniform, std=1/sqrt(n_embd) + nn.RMSNorm weight: ones (via explicit init below) """ # Embedding and unembedding torch.nn.init.normal_(self.transformer.wte.weight, mean=0.0, std=1.0) torch.nn.init.normal_(self.lm_head.weight, mean=0.0, std=0.001) + # nn.RMSNorm weight parameters: init to ones (must be explicit due to meta device) + for module in self.modules(): + if isinstance(module, nn.RMSNorm): + module.weight.fill_(1.0) + # Transformer blocks: uniform init with bound = sqrt(3) * std (same standard deviation as normal) n_embd = self.config.n_embd s = 3**0.5 * n_embd**-0.5 # sqrt(3) multiplier makes sure Uniform achieves the same std as Normal @@ -212,22 +226,38 @@ class GPT(nn.Module): torch.nn.init.uniform_(block.attn.c_q.weight, -s, s) # weights use Uniform to avoid outliers torch.nn.init.uniform_(block.attn.c_k.weight, -s, s) torch.nn.init.uniform_(block.attn.c_v.weight, -s, s) - torch.nn.init.zeros_(block.attn.c_proj.weight) # projections are zero + torch.nn.init.uniform_(block.attn.c_proj.weight, -s, s) torch.nn.init.uniform_(block.mlp.c_fc.weight, -s, s) - torch.nn.init.zeros_(block.mlp.c_proj.weight) + torch.nn.init.uniform_(block.mlp.c_proj.weight, -s, s) # Per-layer scalars self.resid_lambdas.fill_(1.0) # 1.0 => typical residual connections at init self.x0_lambdas.fill_(0.1) # 0.1 => small initial weight for skip connection to input embedding + # Per-block projection scalars (zero-init for stable training start) + for block in self.transformer.h: + block.attn.c_proj_scalar.fill_(0.0) + block.mlp.c_proj_scalar.fill_(0.0) + if self.transformer.wte.weight.device.type == "cuda": + block.attn.c_proj_scalar.data = block.attn.c_proj_scalar.data.to(torch.bfloat16) + block.mlp.c_proj_scalar.data = block.mlp.c_proj_scalar.data.to(torch.bfloat16) + + # Block RMSNorm weights (cast to bf16 for fused kernel) + for block in self.transformer.h: + block.attn_norm.weight.fill_(1.0) + block.mlp_norm.weight.fill_(1.0) + if self.transformer.wte.weight.device.type == "cuda": + block.attn_norm.to(dtype=torch.bfloat16) + block.mlp_norm.to(dtype=torch.bfloat16) + # Value embeddings (init like c_v: uniform with same std) for ve in self.value_embeds.values(): torch.nn.init.uniform_(ve.weight, -s, s) - # Gate weights init to zero so gates start at sigmoid(0) = 0.5, scaled by 2 -> 1.0 (neutral) + # Gate weights init to uniform (avoid zero-norm params under Hyperball) for block in self.transformer.h: if block.attn.ve_gate is not None: - torch.nn.init.zeros_(block.attn.ve_gate.weight) + torch.nn.init.uniform_(block.attn.ve_gate.weight, -s, s) # Rotary embeddings head_dim = self.config.n_embd // self.config.n_head @@ -302,10 +332,11 @@ class GPT(nn.Module): - Chinchilla counts exp/sum/divide in attention softmax as flops (a little sus and very tiny => we ignore) """ nparams = sum(p.numel() for p in self.parameters()) - # Exclude non-matmul params: embeddings and per-layer scalars + # Exclude non-matmul params: embeddings, per-layer scalars, and 1D params in blocks value_embeds_numel = sum(ve.weight.numel() for ve in self.value_embeds.values()) + block_1d_params = sum(p.numel() for p in self.transformer.h.parameters() if p.ndim == 1) nparams_exclude = (self.transformer.wte.weight.numel() + value_embeds_numel + - self.resid_lambdas.numel() + self.x0_lambdas.numel()) + self.resid_lambdas.numel() + self.x0_lambdas.numel() + block_1d_params) h, q, t = self.config.n_head, self.config.n_embd // self.config.n_head, self.config.sequence_len # Sum attention FLOPs per layer, accounting for sliding window attn_flops = 0 @@ -332,31 +363,36 @@ class GPT(nn.Module): wte = sum(p.numel() for p in self.transformer.wte.parameters()) value_embeds = sum(p.numel() for p in self.value_embeds.parameters()) lm_head = sum(p.numel() for p in self.lm_head.parameters()) - transformer_matrices = sum(p.numel() for p in self.transformer.h.parameters()) + block_matrix_params = sum(p.numel() for p in self.transformer.h.parameters() if p.ndim == 2) + block_1d_params = sum(p.numel() for p in self.transformer.h.parameters() if p.ndim == 1) scalars = self.resid_lambdas.numel() + self.x0_lambdas.numel() - total = wte + value_embeds + lm_head + transformer_matrices + scalars + total = wte + value_embeds + lm_head + block_matrix_params + block_1d_params + scalars assert total == sum(p.numel() for p in self.parameters()), "Parameter count mismatch" return { 'wte': wte, 'value_embeds': value_embeds, 'lm_head': lm_head, - 'transformer_matrices': transformer_matrices, + 'transformer_matrices': block_matrix_params, + 'norm_and_proj_scalars': block_1d_params, 'scalars': scalars, 'total': total, } - def setup_optimizer(self, unembedding_lr=0.004, embedding_lr=0.2, matrix_lr=0.02, weight_decay=0.0, adam_betas=(0.8, 0.95), scalar_lr=0.5): + def setup_optimizer(self, unembedding_lr=0.004, embedding_lr=0.2, matrix_lr=0.02, weight_decay=0.0, adam_betas=(0.8, 0.95), scalar_lr=0.5, matrix_optimizer="muon"): model_dim = self.config.n_embd ddp, rank, local_rank, world_size = get_dist_info() # Separate out all parameters into groups - matrix_params = list(self.transformer.h.parameters()) + block_matrix_params = [p for p in self.transformer.h.parameters() if p.ndim == 2] + block_1d_params = [p for p in self.transformer.h.parameters() if p.ndim == 1] value_embeds_params = list(self.value_embeds.parameters()) embedding_params = list(self.transformer.wte.parameters()) lm_head_params = list(self.lm_head.parameters()) resid_params = [self.resid_lambdas] x0_params = [self.x0_lambdas] - assert len(list(self.parameters())) == len(matrix_params) + len(embedding_params) + len(lm_head_params) + len(value_embeds_params) + len(resid_params) + len(x0_params) + all_params_count = (len(block_matrix_params) + len(block_1d_params) + len(embedding_params) + + len(lm_head_params) + len(value_embeds_params) + len(resid_params) + len(x0_params)) + assert len(list(self.parameters())) == all_params_count # Scale the LR for the AdamW parameters by ∝1/√dmodel (tuned for 768 dim model) dmodel_lr_scale = (model_dim / 768) ** -0.5 @@ -370,14 +406,23 @@ class GPT(nn.Module): dict(kind='adamw', params=value_embeds_params, lr=embedding_lr * dmodel_lr_scale, betas=adam_betas, eps=1e-10, weight_decay=0.0), dict(kind='adamw', params=resid_params, lr=scalar_lr * 0.01, betas=adam_betas, eps=1e-10, weight_decay=0.0), dict(kind='adamw', params=x0_params, lr=scalar_lr, betas=(0.96, 0.95), eps=1e-10, weight_decay=0.0), # higher beta1 for x0 + dict(kind='adamw', params=block_1d_params, lr=scalar_lr, betas=adam_betas, eps=1e-10, weight_decay=0.0), ] - # Muon groups (matrix params, grouped by shape for stacking) - for shape in sorted({p.shape for p in matrix_params}): - group_params = [p for p in matrix_params if p.shape == shape] - param_groups.append(dict( - kind='muon', params=group_params, lr=matrix_lr, - momentum=0.95, ns_steps=5, beta2=0.95, weight_decay=weight_decay, - )) + # Matrix params (Muon or Hyperball), grouped by shape for stacking + if matrix_optimizer not in {"muon", "hyperball"}: + raise ValueError(f"Unknown matrix_optimizer: {matrix_optimizer}") + for shape in sorted({p.shape for p in block_matrix_params}): + group_params = [p for p in block_matrix_params if p.shape == shape] + if matrix_optimizer == "hyperball": + param_groups.append(dict( + kind='hyperball', params=group_params, lr=matrix_lr, + momentum=0.95, ns_steps=5, beta2=0.95, + )) + else: + param_groups.append(dict( + kind='muon', params=group_params, lr=matrix_lr, + momentum=0.95, ns_steps=5, beta2=0.95, weight_decay=weight_decay, + )) Factory = DistMuonAdamW if ddp else MuonAdamW optimizer = Factory(param_groups) diff --git a/nanochat/optim.py b/nanochat/optim.py index 190a1ed..f93ba38 100644 --- a/nanochat/optim.py +++ b/nanochat/optim.py @@ -1,6 +1,6 @@ """ A nice and efficient mixed AdamW/Muon Combined Optimizer. -Usually the embeddings and scalars go into AdamW, and the matrix parameters go into Muon. +Usually the embeddings and scalars go into AdamW, and the matrix parameters go into Muon/Hyperball. Two versions are provided (MuonAdamW, DistMuonAdamW), for single GPU and distributed. Addapted from: https://github.com/KellerJordan/modded-nanogpt @@ -141,6 +141,80 @@ def muon_step_fused( mask = (g * stacked_params) >= 0 stacked_params.sub_(lr * g + lr * wd * stacked_params * mask) +# ----------------------------------------------------------------------------- +""" +Hyperball optimizer (MuonH): Muon with scale-invariant updates. +https://github.com/marin-community/marin/blob/main/lib/levanter/src/levanter/optim/muonh.py + +The key difference from Muon is that weights maintain constant Frobenius norm +throughout training via the following update rule: + + p_new_intermediate = p - lr * u * ||p|| / ||u|| + p_new = p_new_intermediate / ||p_new_intermediate|| * ||p|| + +This projects the update onto the hypersphere of constant norm, hence "Hyperball". +Uses variance reduction like Muon, but no cautious weight decay. +""" + +@torch.compile(dynamic=False, fullgraph=True) +def hyperball_step_fused( + stacked_grads: Tensor, # (K, M, N) - stacked gradients + stacked_params: Tensor, # (K, M, N) - stacked parameters + momentum_buffer: Tensor, # (K, M, N) - momentum buffer + second_momentum_buffer: Tensor, # (K, M, 1) or (K, 1, N) - factored second moment + p_norm: Tensor, # (K, 1, 1) - pre-computed Frobenius norm of params (constant) + momentum_t: Tensor, # () - 0-D CPU tensor, momentum coefficient + lr_t: Tensor, # () - 0-D CPU tensor, learning rate + beta2_t: Tensor, # () - 0-D CPU tensor, beta2 for second moment + ns_steps: int, # 5 - number of Newton-Schulz/Polar Express iterations + red_dim: int, # -1 or -2 - reduction dimension for variance +) -> None: + """ + Fused Hyperball step: momentum -> polar_express -> variance_reduction -> scale_invariant_update + All in one compiled graph. Weights maintain constant Frobenius norm. + """ + + # Nesterov momentum + momentum = momentum_t.to(stacked_grads.dtype) + momentum_buffer.lerp_(stacked_grads, 1 - momentum) + g = stacked_grads.lerp_(momentum_buffer, momentum) + + # Polar express orthogonalization + X = g.bfloat16() + X = X / (X.norm(dim=(-2, -1), keepdim=True) * 1.02 + 1e-6) + if g.size(-2) > g.size(-1): # Tall matrix + for a, b, c in polar_express_coeffs[:ns_steps]: + A = X.mT @ X + B = b * A + c * (A @ A) + X = a * X + X @ B + else: # Wide matrix + for a, b, c in polar_express_coeffs[:ns_steps]: + A = X @ X.mT + B = b * A + c * (A @ A) + X = a * X + B @ X + g = X + + # Variance reduction (note: this preserves ||g||_F, so ||u||_F == ||g||_F == v_norm) + beta2 = beta2_t.to(g.dtype) + v_mean = g.float().square().mean(dim=red_dim, keepdim=True) + red_dim_size = g.size(red_dim) + second_momentum_buffer.lerp_(v_mean.to(dtype=second_momentum_buffer.dtype), 1 - beta2) + step_size = second_momentum_buffer.clamp_min(1e-10).rsqrt() + scaled_sq_sum = (v_mean * red_dim_size) * step_size.float().square() + v_norm_new = scaled_sq_sum.sum(dim=(-2, -1), keepdim=True).sqrt() + p_norm = p_norm.to(v_norm_new.dtype) + final_scale = step_size * p_norm / v_norm_new.clamp_min(1e-10) + g = g * final_scale.to(g.dtype) + u = g.to(stacked_params.dtype) + + # Scale-invariant update: keeps ||p|| constant + lr = lr_t.to(stacked_params.dtype) + stacked_params.sub_(lr * u) + + # Project back to hypersphere: p = p * (||p_orig|| / ||p_new||) + p_new_norm = stacked_params.norm(dim=(-2, -1), keepdim=True).clamp_min(1e-10) + stacked_params.mul_(p_norm / p_new_norm) + # ----------------------------------------------------------------------------- # Single GPU version of the MuonAdamW optimizer. # Used mostly for reference, debugging and testing. @@ -167,9 +241,10 @@ class MuonAdamW(torch.optim.Optimizer): Arguments: param_groups: List of dicts, each containing: - 'params': List of parameters - - 'kind': 'adamw' or 'muon' + - 'kind': 'adamw', 'muon', or 'hyperball' - For AdamW groups: 'lr', 'betas', 'eps', 'weight_decay' - For Muon groups: 'lr', 'momentum', 'ns_steps', 'beta2', 'weight_decay' + - For Hyperball groups: 'lr', 'momentum', 'ns_steps', 'beta2' """ def __init__(self, param_groups: list[dict]): super().__init__(param_groups, defaults={}) @@ -186,6 +261,10 @@ class MuonAdamW(torch.optim.Optimizer): self._muon_lr_t = torch.tensor(0.0, dtype=torch.float32, device="cpu") self._muon_wd_t = torch.tensor(0.0, dtype=torch.float32, device="cpu") self._muon_beta2_t = torch.tensor(0.0, dtype=torch.float32, device="cpu") + # Hyperball tensors + self._hyperball_momentum_t = torch.tensor(0.0, dtype=torch.float32, device="cpu") + self._hyperball_lr_t = torch.tensor(0.0, dtype=torch.float32, device="cpu") + self._hyperball_beta2_t = torch.tensor(0.0, dtype=torch.float32, device="cpu") def _step_adamw(self, group: dict) -> None: """ @@ -276,6 +355,64 @@ class MuonAdamW(torch.optim.Optimizer): # Copy back to original params torch._foreach_copy_(params, list(stacked_params.unbind(0))) + def _step_hyperball(self, group: dict) -> None: + """ + Hyperball update for all params in the group (stacked for efficiency). + Like Muon, but uses scale-invariant updates that keep weight norms constant. + """ + params: list[Tensor] = group['params'] + if not params: + return + + # Get or create group-level buffers (stored in first param's state for convenience) + p = params[0] + state = self.state[p] + num_params = len(params) + shape, device, dtype = p.shape, p.device, p.dtype + + # Stack grads and params (NOTE: this assumes all params have the same shape) + stacked_grads = torch.stack([p.grad for p in params]) + stacked_params = torch.stack(params) + + # Momentum buffer for every individual parameter + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros(num_params, *shape, dtype=dtype, device=device) + momentum_buffer = state["momentum_buffer"] + + # Second momentum buffer is factored, either per-row or per-column + if "second_momentum_buffer" not in state: + state_shape = (num_params, shape[-2], 1) if shape[-2] >= shape[-1] else (num_params, 1, shape[-1]) + state["second_momentum_buffer"] = torch.zeros(state_shape, dtype=dtype, device=device) + second_momentum_buffer = state["second_momentum_buffer"] + red_dim = -1 if shape[-2] >= shape[-1] else -2 + + # Pre-compute and cache p_norm (Frobenius norm of each param, constant throughout training) + if "p_norm" not in state: + state["p_norm"] = stacked_params.norm(dim=(-2, -1), keepdim=True).clone() + p_norm = state["p_norm"] + + # Fill all the 0-D tensors with current values + self._hyperball_momentum_t.fill_(group["momentum"]) + self._hyperball_lr_t.fill_(group["lr"]) + self._hyperball_beta2_t.fill_(group["beta2"] if group["beta2"] is not None else 0.0) + + # Single fused kernel: momentum -> polar_express -> variance_reduction -> scale_invariant_update + hyperball_step_fused( + stacked_grads, + stacked_params, + momentum_buffer, + second_momentum_buffer, + p_norm, + self._hyperball_momentum_t, + self._hyperball_lr_t, + self._hyperball_beta2_t, + group["ns_steps"], + red_dim, + ) + + # Copy back to original params + torch._foreach_copy_(params, list(stacked_params.unbind(0))) + @torch.no_grad() def step(self): for group in self.param_groups: @@ -283,6 +420,8 @@ class MuonAdamW(torch.optim.Optimizer): self._step_adamw(group) elif group['kind'] == 'muon': self._step_muon(group) + elif group['kind'] == 'hyperball': + self._step_hyperball(group) else: raise ValueError(f"Unknown optimizer kind: {group['kind']}") @@ -344,9 +483,10 @@ class DistMuonAdamW(torch.optim.Optimizer): Arguments: param_groups: List of dicts, each containing: - 'params': List of parameters - - 'kind': 'adamw' or 'muon' + - 'kind': 'adamw', 'muon', or 'hyperball' - For AdamW groups: 'lr', 'betas', 'eps', 'weight_decay' - For Muon groups: 'lr', 'momentum', 'ns_steps', 'beta2', 'weight_decay' + - For Hyperball groups: 'lr', 'momentum', 'ns_steps', 'beta2' """ def __init__(self, param_groups: list[dict]): super().__init__(param_groups, defaults={}) @@ -361,6 +501,10 @@ class DistMuonAdamW(torch.optim.Optimizer): self._muon_lr_t = torch.tensor(0.0, dtype=torch.float32, device="cpu") self._muon_wd_t = torch.tensor(0.0, dtype=torch.float32, device="cpu") self._muon_beta2_t = torch.tensor(0.0, dtype=torch.float32, device="cpu") + # Hyperball tensors + self._hyperball_momentum_t = torch.tensor(0.0, dtype=torch.float32, device="cpu") + self._hyperball_lr_t = torch.tensor(0.0, dtype=torch.float32, device="cpu") + self._hyperball_beta2_t = torch.tensor(0.0, dtype=torch.float32, device="cpu") def _reduce_adamw(self, group: dict, world_size: int) -> dict: """Launch async reduce ops for AdamW group. Returns info dict with per-param infos.""" @@ -491,6 +635,85 @@ class DistMuonAdamW(torch.optim.Optimizer): future = dist.all_gather_into_tensor(stacked_params, updated_params, async_op=True).get_future() gather_list.append(dict(future=future, stacked_params=stacked_params, params=params)) + def _reduce_hyperball(self, group: dict, world_size: int) -> dict: + """Launch async reduce op for Hyperball group. Returns info dict.""" + params = group['params'] + chunk_size = (len(params) + world_size - 1) // world_size + padded_num_params = chunk_size * world_size + p = params[0] + shape, device, dtype = p.shape, p.device, p.dtype + + # Stack grads and zero-pad to padded_num_params + grad_stack = torch.stack([p.grad for p in params]) + stacked_grads = torch.empty(padded_num_params, *shape, dtype=dtype, device=device) + stacked_grads[:len(params)].copy_(grad_stack) + if len(params) < padded_num_params: + stacked_grads[len(params):].zero_() + + # Reduce_scatter to get this rank's chunk + grad_chunk = torch.empty(chunk_size, *shape, dtype=dtype, device=device) + future = dist.reduce_scatter_tensor(grad_chunk, stacked_grads, op=dist.ReduceOp.AVG, async_op=True).get_future() + + return dict(future=future, grad_chunk=grad_chunk, stacked_grads=stacked_grads, chunk_size=chunk_size) + + def _compute_hyperball(self, group: dict, info: dict, gather_list: list, rank: int) -> None: + """Wait for reduce, compute Hyperball updates, launch gather.""" + info['future'].wait() + params = group['params'] + chunk_size = info['chunk_size'] + grad_chunk = info['grad_chunk'] + p = params[0] + shape, device, dtype = p.shape, p.device, p.dtype + + # How many params does this rank own? + start_idx = rank * chunk_size + num_owned = min(chunk_size, max(0, len(params) - start_idx)) + + # Get or create group-level state + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros(chunk_size, *shape, dtype=dtype, device=device) + if "second_momentum_buffer" not in state: + state_shape = (chunk_size, shape[-2], 1) if shape[-2] >= shape[-1] else (chunk_size, 1, shape[-1]) + state["second_momentum_buffer"] = torch.zeros(state_shape, dtype=dtype, device=device) + red_dim = -1 if shape[-2] >= shape[-1] else -2 + + # Build output buffer for all_gather + updated_params = torch.empty(chunk_size, *shape, dtype=dtype, device=device) + + if num_owned > 0: + owned_params = [params[start_idx + i] for i in range(num_owned)] + stacked_owned = torch.stack(owned_params) + + # Pre-compute and cache p_norm for owned params (constant throughout training) + if "p_norm" not in state: + state["p_norm"] = stacked_owned.norm(dim=(-2, -1), keepdim=True).clone() + p_norm = state["p_norm"] + + # Fill 0-D tensors and run fused kernel + self._hyperball_momentum_t.fill_(group["momentum"]) + self._hyperball_lr_t.fill_(group["lr"]) + self._hyperball_beta2_t.fill_(group["beta2"] if group["beta2"] is not None else 0.0) + hyperball_step_fused( + grad_chunk[:num_owned], stacked_owned, + state["momentum_buffer"][:num_owned], + state["second_momentum_buffer"][:num_owned], + p_norm, + self._hyperball_momentum_t, self._hyperball_lr_t, + self._hyperball_beta2_t, + group["ns_steps"], + red_dim, + ) + updated_params[:num_owned].copy_(stacked_owned) + + if num_owned < chunk_size: + updated_params[num_owned:].zero_() + + # Reuse stacked_grads buffer for all_gather output + stacked_params = info["stacked_grads"] + future = dist.all_gather_into_tensor(stacked_params, updated_params, async_op=True).get_future() + gather_list.append(dict(future=future, stacked_params=stacked_params, params=params)) + def _finish_gathers(self, gather_list: list) -> None: """Wait for all gathers and copy Muon params back.""" for info in gather_list: @@ -511,6 +734,8 @@ class DistMuonAdamW(torch.optim.Optimizer): reduce_infos.append(self._reduce_adamw(group, world_size)) elif group['kind'] == 'muon': reduce_infos.append(self._reduce_muon(group, world_size)) + elif group['kind'] == 'hyperball': + reduce_infos.append(self._reduce_hyperball(group, world_size)) else: raise ValueError(f"Unknown optimizer kind: {group['kind']}") @@ -521,6 +746,8 @@ class DistMuonAdamW(torch.optim.Optimizer): self._compute_adamw(group, info, gather_list, rank, world_size) elif group['kind'] == 'muon': self._compute_muon(group, info, gather_list, rank) + elif group['kind'] == 'hyperball': + self._compute_hyperball(group, info, gather_list, rank) else: raise ValueError(f"Unknown optimizer kind: {group['kind']}") diff --git a/runs/quickrun_gamma_muonh_d24.sh b/runs/quickrun_gamma_muonh_d24.sh new file mode 100755 index 0000000..63ba8cf --- /dev/null +++ b/runs/quickrun_gamma_muonh_d24.sh @@ -0,0 +1,156 @@ +#!/bin/bash + +# Quickrun: GPT-Gamma + MuonH (Hyperball), depth=24 +# - Parameterized RMSNorm (learnable gamma) +# - Per-block projection scalars +# - Hyperball or Muon for matrix params +# +# Examples: +# bash runs/quickrun_gamma_muonh_d24.sh +# WANDB_RUN=exp1 bash runs/quickrun_gamma_muonh_d24.sh +# FP8=1 FP8_RECIPE=tensorwise bash runs/quickrun_gamma_muonh_d24.sh + +set -e + +# ----------------------------------------------------------------------------- +# Config + +DEPTH="${DEPTH:-24}" +NUM_SHARDS="${NUM_SHARDS:-370}" # ~10B tokens for d24 @ ratio~11 +TARGET_RATIO="${TARGET_RATIO:-11}" +WINDOW_PATTERN="${WINDOW_PATTERN:-SSSL}" +DEVICE_BATCH_SIZE="${DEVICE_BATCH_SIZE:-16}" +TOTAL_BATCH_SIZE="${TOTAL_BATCH_SIZE:-524288}" + +NPROC_PER_NODE="${NPROC_PER_NODE:-$(nvidia-smi -L 2>/dev/null | wc -l || echo 1)}" +if [ "$NPROC_PER_NODE" -eq 0 ]; then + NPROC_PER_NODE=1 +fi + +# Optimizer +MATRIX_OPTIMIZER="${MATRIX_OPTIMIZER:-hyperball}" +SCALAR_LR="${SCALAR_LR:-0.5}" +MATRIX_LR="$SCALAR_LR" # share with scalar LR +WARMDOWN_RATIO="${WARMDOWN_RATIO:-0.3}" + +# AdamW +EMBEDDING_LR="${EMBEDDING_LR:-0.3}" +UNEMBEDDING_LR="${UNEMBEDDING_LR:-0.004}" + +# Wandb +export WANDB_MODE=offline +WANDB_RUN="${WANDB_RUN:-muonh_d${DEPTH}_ratio${TARGET_RATIO}}" +MODEL_TAG="${MODEL_TAG:-d${DEPTH}_gamma_muonh}" + +# FP8 +FP8_ARGS="" +if [ "${FP8:-0}" -eq 1 ]; then + FP8_RECIPE="${FP8_RECIPE:-tensorwise}" + FP8_ARGS="--fp8 --fp8-recipe=${FP8_RECIPE}" +fi + +# NCCL (single node) +export NCCL_P2P_LEVEL=NVL +export NCCL_NVLS_ENABLE=1 +export NCCL_IB_DISABLE=1 + +# ----------------------------------------------------------------------------- +# Paths and cache + +export OMP_NUM_THREADS=1 +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +PROJECT_ROOT="$(dirname "$SCRIPT_DIR")" +export NANOCHAT_BASE_DIR="$PROJECT_ROOT/cache" +export TORCHINDUCTOR_CACHE_DIR="$NANOCHAT_BASE_DIR/torch_inductor" +export TRITON_CACHE_DIR="$NANOCHAT_BASE_DIR/triton" +export TMPDIR="$NANOCHAT_BASE_DIR/tmp" +mkdir -p "$NANOCHAT_BASE_DIR" "$TORCHINDUCTOR_CACHE_DIR" "$TRITON_CACHE_DIR" "$TMPDIR" + +# ----------------------------------------------------------------------------- +# Print summary + +echo "==============================================" +echo "Quickrun (GPT-Gamma + MuonH D24)" +echo "==============================================" +echo "Project root: $PROJECT_ROOT" +echo "Cache dir: $NANOCHAT_BASE_DIR" +echo "Depth: $DEPTH" +echo "Num shards: $NUM_SHARDS" +echo "Target ratio: $TARGET_RATIO" +echo "Window pattern: $WINDOW_PATTERN" +echo "Num GPUs: $NPROC_PER_NODE" +echo "Device batch size: $DEVICE_BATCH_SIZE" +echo "Total batch size: $TOTAL_BATCH_SIZE" +echo "Matrix optimizer: $MATRIX_OPTIMIZER" +echo "Matrix LR: $MATRIX_LR (shared with scalar)" +echo "Adam LRs: embedding=$EMBEDDING_LR, unembedding=$UNEMBEDDING_LR, scalar=$SCALAR_LR" +echo "Warmdown ratio: $WARMDOWN_RATIO" +echo "Wandb run: $WANDB_RUN" +echo "Model tag: $MODEL_TAG" +if [ "${FP8:-0}" -eq 1 ]; then + echo "FP8: enabled ($FP8_RECIPE)" +fi +echo "==============================================" + +cd "$PROJECT_ROOT" + +# ----------------------------------------------------------------------------- +# Python venv + +if [ ! -d ".venv" ]; then + echo "Setting up Python environment..." + command -v uv &> /dev/null || curl -LsSf https://astral.sh/uv/install.sh | sh + uv venv + uv sync --extra gpu +fi +source .venv/bin/activate + +# ----------------------------------------------------------------------------- +# Data + tokenizer + +echo "" +echo "Downloading $NUM_SHARDS data shards..." +python -m nanochat.dataset -n "$NUM_SHARDS" + +echo "" +echo "Checking tokenizer..." +python -m scripts.tok_train --max-chars=500000000 --vocab-size=32768 + +# ----------------------------------------------------------------------------- +# Train + +echo "" +echo "Starting pretraining (depth=$DEPTH)..." + +TRAIN_ARGS=( + --depth=$DEPTH + --run=$WANDB_RUN + --model-tag=$MODEL_TAG + --window-pattern=$WINDOW_PATTERN + --target-param-data-ratio=$TARGET_RATIO + --device-batch-size=$DEVICE_BATCH_SIZE + --total-batch-size=$TOTAL_BATCH_SIZE + --matrix-optimizer=$MATRIX_OPTIMIZER + --matrix-lr=$MATRIX_LR + --warmdown-ratio=$WARMDOWN_RATIO + --embedding-lr=$EMBEDDING_LR + --unembedding-lr=$UNEMBEDDING_LR + --scalar-lr=$SCALAR_LR + --core-metric-every=2000 + --sample-every=-1 + --save-every=-1 +) + +if [ "$NPROC_PER_NODE" -gt 1 ]; then + torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.base_train -- \ + "${TRAIN_ARGS[@]}" $FP8_ARGS +else + python -m scripts.base_train \ + "${TRAIN_ARGS[@]}" $FP8_ARGS +fi + +echo "" +echo "==============================================" +echo "Training complete!" +echo "==============================================" +echo "Checkpoint saved to: $NANOCHAT_BASE_DIR/base_checkpoints/${MODEL_TAG}/" diff --git a/scripts/base_train.py b/scripts/base_train.py index fa05b60..a4de906 100644 --- a/scripts/base_train.py +++ b/scripts/base_train.py @@ -58,7 +58,8 @@ parser.add_argument("--total-batch-size", type=int, default=524288, help="total parser.add_argument("--embedding-lr", type=float, default=0.3, help="learning rate for embedding parameters (Adam)") parser.add_argument("--unembedding-lr", type=float, default=0.004, help="learning rate for unembedding parameters (Adam)") parser.add_argument("--weight-decay", type=float, default=0.2, help="cautious weight decay for the Muon optimizer (for weights)") -parser.add_argument("--matrix-lr", type=float, default=0.02, help="learning rate for matrix parameters (Muon)") +parser.add_argument("--matrix-lr", type=float, default=0.02, help="learning rate for matrix parameters (Muon/Hyperball)") +parser.add_argument("--matrix-optimizer", type=str, default="muon", choices=["muon", "hyperball"], help="optimizer for matrix parameters") parser.add_argument("--scalar-lr", type=float, default=0.5, help="learning rate for scalars (resid_lambdas, x0_lambdas)") parser.add_argument("--adam-beta1", type=float, default=0.8, help="Adam beta1 for embedding/unembedding") parser.add_argument("--adam-beta2", type=float, default=0.95, help="Adam beta2 for embedding/unembedding") @@ -303,6 +304,7 @@ optimizer = model.setup_optimizer( weight_decay=weight_decay_scaled, adam_betas=adam_betas, scalar_lr=args.scalar_lr * batch_lr_scale, + matrix_optimizer=args.matrix_optimizer, ) if resuming: @@ -331,7 +333,7 @@ def get_lr_multiplier(it): progress = (num_iterations - it) / warmdown_iters return progress * 1.0 + (1 - progress) * args.final_lr_frac -# Momentum scheduler for Muon optimizer +# Momentum scheduler for matrix optimizer (Muon/Hyperball) def get_muon_momentum(it): frac = min(it / 300, 1) momentum = (1 - frac) * 0.85 + frac * 0.95 @@ -466,8 +468,9 @@ while True: muon_weight_decay = get_weight_decay(step) for group in optimizer.param_groups: group["lr"] = group["initial_lr"] * lrm - if group['kind'] == 'muon': + if group['kind'] in {'muon', 'hyperball'}: group["momentum"] = muon_momentum + if group['kind'] == 'muon': group["weight_decay"] = muon_weight_decay optimizer.step() model.zero_grad(set_to_none=True) From 77de3297ea0e70a38775d3ff3a2a58b427655e49 Mon Sep 17 00:00:00 2001 From: dangxingyu Date: Tue, 3 Feb 2026 20:25:16 -0500 Subject: [PATCH 02/14] Update warmdown and rename quickrun --- ...mma_muonh_d24.sh => quickrun_muonh_d24.sh} | 10 +++-- scripts/base_train.py | 38 ++++++++++++------- 2 files changed, 30 insertions(+), 18 deletions(-) rename runs/{quickrun_gamma_muonh_d24.sh => quickrun_muonh_d24.sh} (94%) diff --git a/runs/quickrun_gamma_muonh_d24.sh b/runs/quickrun_muonh_d24.sh similarity index 94% rename from runs/quickrun_gamma_muonh_d24.sh rename to runs/quickrun_muonh_d24.sh index 63ba8cf..98b142f 100755 --- a/runs/quickrun_gamma_muonh_d24.sh +++ b/runs/quickrun_muonh_d24.sh @@ -30,8 +30,9 @@ fi # Optimizer MATRIX_OPTIMIZER="${MATRIX_OPTIMIZER:-hyperball}" SCALAR_LR="${SCALAR_LR:-0.5}" -MATRIX_LR="$SCALAR_LR" # share with scalar LR -WARMDOWN_RATIO="${WARMDOWN_RATIO:-0.3}" +MATRIX_LR="${MATRIX_LR:-0.02}" +WARMDOWN_RATIO="${WARMDOWN_RATIO:-1.0}" +MATRIX_WARMDOWN_RATIO="${MATRIX_WARMDOWN_RATIO:-1.0}" # AdamW EMBEDDING_LR="${EMBEDDING_LR:-0.3}" @@ -82,9 +83,9 @@ echo "Num GPUs: $NPROC_PER_NODE" echo "Device batch size: $DEVICE_BATCH_SIZE" echo "Total batch size: $TOTAL_BATCH_SIZE" echo "Matrix optimizer: $MATRIX_OPTIMIZER" -echo "Matrix LR: $MATRIX_LR (shared with scalar)" +echo "Matrix LR: $MATRIX_LR" echo "Adam LRs: embedding=$EMBEDDING_LR, unembedding=$UNEMBEDDING_LR, scalar=$SCALAR_LR" -echo "Warmdown ratio: $WARMDOWN_RATIO" +echo "Warmdown ratio: adam=$WARMDOWN_RATIO, matrix=$MATRIX_WARMDOWN_RATIO" echo "Wandb run: $WANDB_RUN" echo "Model tag: $MODEL_TAG" if [ "${FP8:-0}" -eq 1 ]; then @@ -133,6 +134,7 @@ TRAIN_ARGS=( --matrix-optimizer=$MATRIX_OPTIMIZER --matrix-lr=$MATRIX_LR --warmdown-ratio=$WARMDOWN_RATIO + --matrix-warmdown-ratio=$MATRIX_WARMDOWN_RATIO --embedding-lr=$EMBEDDING_LR --unembedding-lr=$UNEMBEDDING_LR --scalar-lr=$SCALAR_LR diff --git a/scripts/base_train.py b/scripts/base_train.py index a4de906..3d13bf4 100644 --- a/scripts/base_train.py +++ b/scripts/base_train.py @@ -64,7 +64,8 @@ parser.add_argument("--scalar-lr", type=float, default=0.5, help="learning rate parser.add_argument("--adam-beta1", type=float, default=0.8, help="Adam beta1 for embedding/unembedding") parser.add_argument("--adam-beta2", type=float, default=0.95, help="Adam beta2 for embedding/unembedding") parser.add_argument("--warmup-ratio", type=float, default=0.0, help="ratio of iterations for LR warmup") -parser.add_argument("--warmdown-ratio", type=float, default=0.5, help="ratio of iterations for LR warmdown") +parser.add_argument("--warmdown-ratio", type=float, default=0.5, help="ratio of iterations for AdamW LR warmdown") +parser.add_argument("--matrix-warmdown-ratio", type=float, default=1.0, help="ratio of iterations for Muon/Hyperball LR warmdown") parser.add_argument("--final-lr-frac", type=float, default=0.0, help="final LR as fraction of initial LR") parser.add_argument("--resume-from-step", type=int, default=-1, help="resume training from this step (-1 = disable)") # Evaluation @@ -80,6 +81,8 @@ args = parser.parse_args() user_config = vars(args).copy() # for logging # ----------------------------------------------------------------------------- + + # Compute init device_type = autodetect_device_type() if args.device_type == "" else args.device_type ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type) @@ -321,17 +324,18 @@ x, y, dataloader_state_dict = next(train_loader) # kick off load of the very fir # ----------------------------------------------------------------------------- # Set up hyperparameter schedulers -# Learning rate scheduler -def get_lr_multiplier(it): - warmup_iters = round(args.warmup_ratio * num_iterations) - warmdown_iters = round(args.warmdown_ratio * num_iterations) - if it < warmup_iters: +# Learning rate scheduler (warmup + warmdown) +def get_lr_multiplier(it, warmup_ratio, warmdown_ratio, final_lr_frac): + warmup_iters = round(warmup_ratio * num_iterations) + warmdown_iters = round(warmdown_ratio * num_iterations) + if warmup_iters > 0 and it < warmup_iters: return (it + 1) / warmup_iters - elif it <= num_iterations - warmdown_iters: + if warmdown_iters <= 0: return 1.0 - else: - progress = (num_iterations - it) / warmdown_iters - return progress * 1.0 + (1 - progress) * args.final_lr_frac + if it <= num_iterations - warmdown_iters: + return 1.0 + progress = (num_iterations - it) / warmdown_iters + return progress * 1.0 + (1 - progress) * final_lr_frac # Momentum scheduler for matrix optimizer (Muon/Hyperball) def get_muon_momentum(it): @@ -463,11 +467,15 @@ while True: loss.backward() x, y, dataloader_state_dict = next(train_loader) # prefetch the next batch while the GPU is busy with forward/backward # step the optimizer - lrm = get_lr_multiplier(step) + lrm_adam = get_lr_multiplier(step, args.warmup_ratio, args.warmdown_ratio, args.final_lr_frac) + lrm_matrix = get_lr_multiplier(step, 0.0, args.matrix_warmdown_ratio, args.final_lr_frac) muon_momentum = get_muon_momentum(step) muon_weight_decay = get_weight_decay(step) for group in optimizer.param_groups: - group["lr"] = group["initial_lr"] * lrm + if group['kind'] in {'muon', 'hyperball'}: + group["lr"] = group["initial_lr"] * lrm_matrix + else: + group["lr"] = group["initial_lr"] * lrm_adam if group['kind'] in {'muon', 'hyperball'}: group["momentum"] = muon_momentum if group['kind'] == 'muon': @@ -500,14 +508,15 @@ while True: else: eta_str = "" epoch = dataloader_state_dict["epoch"] - print0(f"step {step:05d}/{num_iterations:05d} ({pct_done:.2f}%) | loss: {debiased_smooth_loss:.6f} | lrm: {lrm:.2f} | dt: {dt * 1000:.2f}ms | tok/sec: {tok_per_sec:,} | mfu: {mfu:.2f} | epoch: {epoch} | total time: {total_training_time/60:.2f}m{eta_str}") + print0(f"step {step:05d}/{num_iterations:05d} ({pct_done:.2f}%) | loss: {debiased_smooth_loss:.6f} | lrm(adam)={lrm_adam:.2f}, lrm(matrix)={lrm_matrix:.2f} | dt: {dt * 1000:.2f}ms | tok/sec: {tok_per_sec:,} | mfu: {mfu:.2f} | epoch: {epoch} | total time: {total_training_time/60:.2f}m{eta_str}") if step % 100 == 0: log_data = { "step": step, "total_training_flops": flops_so_far, "total_training_time": total_training_time, "train/loss": debiased_smooth_loss, - "train/lrm": lrm, + "train/lrm_adam": lrm_adam, + "train/lrm_matrix": lrm_matrix, "train/dt": dt, "train/tok_per_sec": tok_per_sec, "train/mfu": mfu, @@ -548,6 +557,7 @@ get_report().log(section="Base model training", data=[ "DDP world size": ddp_world_size, "warmup_ratio": args.warmup_ratio, "warmdown_ratio": args.warmdown_ratio, + "matrix_warmdown_ratio": args.matrix_warmdown_ratio, "final_lr_frac": args.final_lr_frac, }, { # stats about training outcomes From 4686cb9509ce9d2ba2868160615c1a0faa7f334f Mon Sep 17 00:00:00 2001 From: dangxingyu Date: Tue, 3 Feb 2026 20:26:11 -0500 Subject: [PATCH 03/14] Update quickrun wandb mode --- runs/quickrun_muonh_d24.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/runs/quickrun_muonh_d24.sh b/runs/quickrun_muonh_d24.sh index 98b142f..5b17a07 100755 --- a/runs/quickrun_muonh_d24.sh +++ b/runs/quickrun_muonh_d24.sh @@ -39,7 +39,7 @@ EMBEDDING_LR="${EMBEDDING_LR:-0.3}" UNEMBEDDING_LR="${UNEMBEDDING_LR:-0.004}" # Wandb -export WANDB_MODE=offline +WANDB_PROJECT="nanochat" WANDB_RUN="${WANDB_RUN:-muonh_d${DEPTH}_ratio${TARGET_RATIO}}" MODEL_TAG="${MODEL_TAG:-d${DEPTH}_gamma_muonh}" From a611a85e354713f61ff9a867ae7466c7118a1e77 Mon Sep 17 00:00:00 2001 From: dangxingyu Date: Tue, 3 Feb 2026 20:29:55 -0500 Subject: [PATCH 04/14] Rename quickrun script --- runs/{quickrun_muonh_d24.sh => quickrun_muonh.sh} | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) rename runs/{quickrun_muonh_d24.sh => quickrun_muonh.sh} (93%) diff --git a/runs/quickrun_muonh_d24.sh b/runs/quickrun_muonh.sh similarity index 93% rename from runs/quickrun_muonh_d24.sh rename to runs/quickrun_muonh.sh index 5b17a07..dcfc979 100755 --- a/runs/quickrun_muonh_d24.sh +++ b/runs/quickrun_muonh.sh @@ -1,14 +1,15 @@ #!/bin/bash -# Quickrun: GPT-Gamma + MuonH (Hyperball), depth=24 +# Quickrun: GPT-Gamma + MuonH (Hyperball) # - Parameterized RMSNorm (learnable gamma) # - Per-block projection scalars # - Hyperball or Muon for matrix params # # Examples: -# bash runs/quickrun_gamma_muonh_d24.sh -# WANDB_RUN=exp1 bash runs/quickrun_gamma_muonh_d24.sh -# FP8=1 FP8_RECIPE=tensorwise bash runs/quickrun_gamma_muonh_d24.sh +# bash runs/quickrun_muonh.sh +# WANDB_RUN=exp1 bash runs/quickrun_muonh.sh +# FP8=1 FP8_RECIPE=tensorwise bash runs/quickrun_muonh.sh +# DEPTH=16 bash runs/quickrun_muonh.sh set -e @@ -16,7 +17,7 @@ set -e # Config DEPTH="${DEPTH:-24}" -NUM_SHARDS="${NUM_SHARDS:-370}" # ~10B tokens for d24 @ ratio~11 +NUM_SHARDS="${NUM_SHARDS:-370}" # default for d24 @ ratio~11 TARGET_RATIO="${TARGET_RATIO:-11}" WINDOW_PATTERN="${WINDOW_PATTERN:-SSSL}" DEVICE_BATCH_SIZE="${DEVICE_BATCH_SIZE:-16}" @@ -71,7 +72,7 @@ mkdir -p "$NANOCHAT_BASE_DIR" "$TORCHINDUCTOR_CACHE_DIR" "$TRITON_CACHE_DIR" "$T # Print summary echo "==============================================" -echo "Quickrun (GPT-Gamma + MuonH D24)" +echo "Quickrun (GPT-Gamma + MuonH)" echo "==============================================" echo "Project root: $PROJECT_ROOT" echo "Cache dir: $NANOCHAT_BASE_DIR" From e7ee891c3b602b5c2a755d97ee04e4f24bc1c335 Mon Sep 17 00:00:00 2001 From: dangxingyu Date: Tue, 3 Feb 2026 20:43:43 -0500 Subject: [PATCH 05/14] Update quickrun script --- runs/quickrun_muonh.sh | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/runs/quickrun_muonh.sh b/runs/quickrun_muonh.sh index dcfc979..fdf201c 100755 --- a/runs/quickrun_muonh.sh +++ b/runs/quickrun_muonh.sh @@ -32,7 +32,7 @@ fi MATRIX_OPTIMIZER="${MATRIX_OPTIMIZER:-hyperball}" SCALAR_LR="${SCALAR_LR:-0.5}" MATRIX_LR="${MATRIX_LR:-0.02}" -WARMDOWN_RATIO="${WARMDOWN_RATIO:-1.0}" +WARMDOWN_RATIO="${WARMDOWN_RATIO:-0.3}" MATRIX_WARMDOWN_RATIO="${MATRIX_WARMDOWN_RATIO:-1.0}" # AdamW @@ -45,17 +45,13 @@ WANDB_RUN="${WANDB_RUN:-muonh_d${DEPTH}_ratio${TARGET_RATIO}}" MODEL_TAG="${MODEL_TAG:-d${DEPTH}_gamma_muonh}" # FP8 +FP8="1" FP8_ARGS="" if [ "${FP8:-0}" -eq 1 ]; then FP8_RECIPE="${FP8_RECIPE:-tensorwise}" FP8_ARGS="--fp8 --fp8-recipe=${FP8_RECIPE}" fi -# NCCL (single node) -export NCCL_P2P_LEVEL=NVL -export NCCL_NVLS_ENABLE=1 -export NCCL_IB_DISABLE=1 - # ----------------------------------------------------------------------------- # Paths and cache From 924489f5822650f603210e425875707cb91d8637 Mon Sep 17 00:00:00 2001 From: dangxingyu Date: Tue, 3 Feb 2026 20:46:20 -0500 Subject: [PATCH 06/14] Update quickrun defaults --- runs/quickrun_muonh.sh | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/runs/quickrun_muonh.sh b/runs/quickrun_muonh.sh index fdf201c..5aa303f 100755 --- a/runs/quickrun_muonh.sh +++ b/runs/quickrun_muonh.sh @@ -44,8 +44,8 @@ WANDB_PROJECT="nanochat" WANDB_RUN="${WANDB_RUN:-muonh_d${DEPTH}_ratio${TARGET_RATIO}}" MODEL_TAG="${MODEL_TAG:-d${DEPTH}_gamma_muonh}" -# FP8 -FP8="1" +# FP8 (default enabled) +FP8="${FP8:-1}" FP8_ARGS="" if [ "${FP8:-0}" -eq 1 ]; then FP8_RECIPE="${FP8_RECIPE:-tensorwise}" From 595a0f460a186c16c2c8156bf5707c6e533ac9de Mon Sep 17 00:00:00 2001 From: dangxingyu Date: Tue, 3 Feb 2026 21:29:51 -0500 Subject: [PATCH 07/14] Scale hyperball lr by depth --- scripts/base_train.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/scripts/base_train.py b/scripts/base_train.py index 3d13bf4..f8920ea 100644 --- a/scripts/base_train.py +++ b/scripts/base_train.py @@ -300,10 +300,19 @@ model = torch.compile(model, dynamic=False) # the inputs to model will never cha # ----------------------------------------------------------------------------- # Initialize the Optimizer (combined MuonAdamW: Muon for matrix params, AdamW for rest) adam_betas = (args.adam_beta1, args.adam_beta2) +matrix_lr_scaled = args.matrix_lr * batch_lr_scale + +# LR depth scaling for Hyperball +if args.matrix_optimizer == "hyperball": + hyperball_depth_scale = 12 / args.depth + matrix_lr_scaled = matrix_lr_scaled * hyperball_depth_scale + if args.depth != 12: + print0(f"Scaling hyperball LR from {args.matrix_lr * batch_lr_scale:.6f} to {matrix_lr_scaled:.6f} for depth {args.depth}") + optimizer = model.setup_optimizer( unembedding_lr=args.unembedding_lr * batch_lr_scale, embedding_lr=args.embedding_lr * batch_lr_scale, - matrix_lr=args.matrix_lr * batch_lr_scale, + matrix_lr=matrix_lr_scaled, weight_decay=weight_decay_scaled, adam_betas=adam_betas, scalar_lr=args.scalar_lr * batch_lr_scale, From ee04406ebb5b68a5a3e10289bbb974a8d483cc30 Mon Sep 17 00:00:00 2001 From: Kaiyue Wen Date: Thu, 12 Feb 2026 16:12:29 -0800 Subject: [PATCH 08/14] Merge muonh-dev and master: FP8 training, optimizer tuning, and scaling improvements Major changes: - Add custom FP8 training module (replaces torchao dependency) - Implement auto-calculated optimal batch sizes (1M for d26) - Add hyperball data scaling - Restore and tune momentum schedule (settled on 0.95) - Add matrix warmup ratio and norm_lr parameters - Improve weight decay scaling (Tepoch-based theory) - Update d26 configuration and scaling laws - Clarify MFU labeling as bf16_mfu - Update leaderboard and documentation Co-Authored-By: Claude Sonnet 4.5 (1M context) --- .gitignore | 1 + README.md | 37 ++--- dev/LEADERBOARD.md | 40 +++++- dev/LOG.md | 82 ++++++++++- nanochat/fp8.py | 272 +++++++++++++++++++++++++++++++++++++ nanochat/gpt.py | 80 ++++++----- nanochat/optim.py | 5 + pyproject.toml | 1 - runs/miniseries.sh | 14 +- runs/quickrun_muonh.sh | 33 +++-- runs/scaling_laws_muonh.sh | 220 ++++++++++++++++++++++++++++++ runs/speedrun.sh | 2 +- scripts/base_train.py | 263 +++++++++++++++++++---------------- uv.lock | 11 -- 14 files changed, 862 insertions(+), 199 deletions(-) create mode 100644 nanochat/fp8.py create mode 100755 runs/scaling_laws_muonh.sh diff --git a/.gitignore b/.gitignore index 3e92824..3fa5754 100644 --- a/.gitignore +++ b/.gitignore @@ -9,5 +9,6 @@ eval_bundle/ .env # Local setup +cache CLAUDE.md wandb/ diff --git a/README.md b/README.md index 08c184a..1894ac8 100644 --- a/README.md +++ b/README.md @@ -3,24 +3,22 @@ ![nanochat logo](dev/nanochat.png) ![scaling laws](dev/scaling_laws_jan26.png) -nanochat is the simplest experimental harness for training LLMs. It is designed to run on a single GPU node, the code is minimal/hackable, and it covers all major LLM stages including tokenization, pretraining, finetuning, evaluation, inference, and a chat UI. For example, you can train your own GPT-2 capability LLM (which cost ~$50,000 to train in 2019) for only $73 (3 hours of 8XH100 GPU node) and then talk to it in a familiar ChatGPT-like web UI. +nanochat is the simplest experimental harness for training LLMs. It is designed to run on a single GPU node, the code is minimal/hackable, and it covers all major LLM stages including tokenization, pretraining, finetuning, evaluation, inference, and a chat UI. For example, you can train your own GPT-2 capability LLM (which cost ~$43,000 to train in 2019) for only $72 (~3 hours of 8XH100 GPU node) and then talk to it in a familiar ChatGPT-like web UI. On a spot instance, the total cost can be closer to ~$20. More generally, nanochat is configured out of the box to train an entire miniseries of compute-optimal models by setting one single complexity dial: `--depth`, the number of layers in the GPT transformer model (GPT-2 capability happens to be approximately depth 26). All other hyperparameters (the width of the transformer, number of heads, learning rate adjustments, training horizons, weight decays, ...) are calculated automatically in an optimal way. For questions about the repo, I recommend either using [DeepWiki](https://deepwiki.com/karpathy/nanochat) from Devin/Cognition to ask questions about the repo, or use the [Discussions tab](https://github.com/karpathy/nanochat/discussions), or come by the [#nanochat](https://discord.com/channels/1020383067459821711/1427295580895314031) channel on Discord. -## Updates +## Time-to-GPT-2 Leaderboard -- (Jan 31 2026) Major revamp of all scripts/README ongoing, deleting midtraining stage, might be a bit messy briefly... -- (Jan 30 2026) With all the latest improvements we're able to train GPT-2 grade LLM in about $73. The [runs/speedrun.sh](runs/speedrun.sh) script will become the refernece way to train GPT-2 grade model and talk to it. - -## Leaderboard +Presently, the main focus of development is on tuning the pretraining stage, which takes the most amount of compute. Inspired by the modded-nanogpt repo and to incentivise progress and community collaboration, nanochat maintains a leaderboard for a "GPT-2 speedrun", which is the wall-clock time required to train a nanochat model to GPT-2 grade capability, as measured by the DCLM CORE score. The [runs/speedrun.sh](runs/speedrun.sh) script always reflects the reference way to train GPT-2 grade model and talk to it. The current leaderboard looks as follows: | # | time | val_bpb | CORE | Description | Date | Commit | Contributors | |---|-------------|---------|------|-------------|------|--------|--------------| | 0 | 168 hours | - | 0.2565 | Original OpenAI GPT-2 checkpoint | 2019 | - | OpenAI | | 1 | 3.04 | 0.74833 | 0.2585 | d24 baseline, slightly overtrained | Jan 29 2026 | 348fbb3 | @karpathy | -| 2 | 2.91 | 0.74504 | 0.2578 | d26 slightly undertrained **+fp8** | Feb 2 2026 | 8309b83 | @karpathy | +| 2 | 2.91 | 0.74504 | 0.2578 | d26 slightly undertrained **+fp8** | Feb 2 2026 | a67eba3 | @karpathy | +| 3 | 2.76 | 0.74645 | 0.2602 | bump total batch size to 1M tokens | Feb 5 2026 | 2c062aa | @karpathy | -The primary metric we care about is "time to GPT-2" - the wall clock time needed to outperform the GPT-2 (1.6B) CORE metric on an 8XH100 GPU node. The GPT-2 CORE score is 0.256525. In 2019, the training of GPT-2 cost approximately $50,000 so it is incredible that due to many advances over 7 years across the stack, we can now do so much faster and for well below $100 (e.g. at the current ~$3/GPU/hr, an 8XH100 node is ~$24/hr, so 3 hours is ~$72). +The primary metric we care about is "time to GPT-2" - the wall clock time needed to outperform the GPT-2 (1.6B) CORE metric on an 8XH100 GPU node. The GPT-2 CORE score is 0.256525. In 2019, the training of GPT-2 cost approximately $43,000 so it is incredible that due to many advances over 7 years across the stack, we can now do so much faster and for well below $100 (e.g. at the current ~$3/GPU/hr, an 8XH100 node is ~$24/hr, so 3 hours is ~$72). See [dev/LEADERBOARD.md](dev/LEADERBOARD.md) for more docs on how to interpret and contribute to the leaderboard. @@ -28,13 +26,13 @@ See [dev/LEADERBOARD.md](dev/LEADERBOARD.md) for more docs on how to interpret a ### Reproduce and talk to GPT-2 -The most fun you can have is to train your own GPT-2 and talk to it. The entire pipeline to do so is contained in the single file [runs/speedrun.sh](runs/speedrun.sh), which is designed to be run on an 8XH100 GPU node. Currently, at ~$24/hour for these nodes, pretraining GPT-2 grade model takes approximately 3 hours and will set you back about $75. Boot up a new 8XH100 GPU box from your favorite provider (e.g. I use and like [Lambda](https://lambda.ai/service/gpu-cloud)), and kick off the training script: +The most fun you can have is to train your own GPT-2 and talk to it. The entire pipeline to do so is contained in the single file [runs/speedrun.sh](runs/speedrun.sh), which is designed to be run on an 8XH100 GPU node. Boot up a new 8XH100 GPU box from your favorite provider (e.g. I use and like [Lambda](https://lambda.ai/service/gpu-cloud)), and kick off the training script: ```bash bash runs/speedrun.sh ``` -You mish to do so in a screen session as this will take ~3 hours to run. Once it's done, you can talk to it via the ChatGPT-like web UI. Make sure again that your local uv virtual environment is active (run `source .venv/bin/activate`), and serve it: +You may wish to do so in a screen session as this will take ~3 hours to run. Once it's done, you can talk to it via the ChatGPT-like web UI. Make sure again that your local uv virtual environment is active (run `source .venv/bin/activate`), and serve it: ```bash python -m scripts.chat_web @@ -69,22 +67,29 @@ OMP_NUM_THREADS=1 torchrun --standalone --nproc_per_node=8 -m scripts.base_train --save-every=-1 \ ``` -This uses wandb (run name "d12"), only runs the CORE metric on last step, and it doesn't sample and save intermediate checkpoints. I like to change something in the code, re-run a d12 (or a d16 etc) and see if it helped, in an iteration loop. +This uses wandb (run name "d12"), only runs the CORE metric on last step, and it doesn't sample and save intermediate checkpoints. I like to change something in the code, re-run a d12 (or a d16 etc) and see if it helped, in an iteration loop. To see if a run helps, I like to monitor the wandb plots for: -The overall approach is to treat the depth of the model as the single dial of complexity. By sweeping out the depth, we get increasingly more powerful models. We determine the scaling laws, set the data budget to a compute optimal setting, train a whole miniseries of models of increasing sizes, and compare them to the GPT-2 and GPT-3 miniseries. Right now, beating GPT-2 specifically faster and faster is the most interesting target. +1. `val_bpb` (validation loss in vocab-size-invariant units of bits per byte) as a function of `step`, `total_training_time` and `total_training_flops`. +2. `core_metric` (the DCLM CORE socre) +3. VRAM utilization, `train/mfu` (Model FLOPS utilization), `train/tok_per_sec` (training throughput) + +See an example [here](https://github.com/karpathy/nanochat/pull/498#issuecomment-3850720044). + +The important thing to note is that nanochat is written and configured around one single dial of complexity - the depth of the transformer. This single integer automatically determines all other hyperparameters (the width of the transformer, number of heads, learning rate adjustments, training horizons, weight decays, ...) so that the trained model comes out compute optimal. The idea is that the user doesn't have to think about or set any of this, they are simply asking for a smaller or bigger model using `--depth`, and everything "just works". By sweeping out the depth, you achieve the nanochat miniseries of compute optimal models at various sizes. GPT-2 capability model (which is of most interest at the moment) happens to be somewhere around d24-d26 range with the current code. But any candidate changes to the repo have to be principled enough that they work for all settings of depth. ## Running on CPU / MPS -The script [runs/runcpu.sh](runs/runcpu.sh) shows a very simple example of running on CPU or Apple Silicon. It dramatically shrinks the LLM tha tis being trained to make things fit into a reasonable time interval of a few ten minutes of training. You will not get strong results in this way. +The script [runs/runcpu.sh](runs/runcpu.sh) shows a very simple example of running on CPU or Apple Silicon. It dramatically shrinks the LLM that is being trained to make things fit into a reasonable time interval of a few ten minutes of training. You will not get strong results in this way. ## Guides -I've published a number of guides that might contain helpful information: +I've published a number of guides that might contain helpful information, most recent to least recent: -- [Oct 13 2025 original nanochat post](https://github.com/karpathy/nanochat/discussions/1) introducing nanochat, though now it contains some deprecated information and the model is a lot older (with worse results) than current master. +- [Feb 1 2026: Beating GPT-2 for <<$100: the nanochat journey](https://github.com/karpathy/nanochat/discussions/481) - [Jan 7 miniseries v1](https://github.com/karpathy/nanochat/discussions/420) documents the first nanochat miniseries of models. -- To customize your nanochat, see [Guide: infusing identity to your nanochat](https://github.com/karpathy/nanochat/discussions/139) in Discussions, which describes how you can tune your nanochat's personality through synthetic data generation and mixing that data into the SFT stage. - To add new abilities to nanochat, see [Guide: counting r in strawberry (and how to add abilities generally)](https://github.com/karpathy/nanochat/discussions/164). +- To customize your nanochat, see [Guide: infusing identity to your nanochat](https://github.com/karpathy/nanochat/discussions/139) in Discussions, which describes how you can tune your nanochat's personality through synthetic data generation and mixing that data into the SFT stage. +- [Oct 13 2025: original nanochat post](https://github.com/karpathy/nanochat/discussions/1) introducing nanochat, though now it contains some deprecated information and the model is a lot older (with worse results) than current master. ## File structure diff --git a/dev/LEADERBOARD.md b/dev/LEADERBOARD.md index 3b61cc6..b8a727f 100644 --- a/dev/LEADERBOARD.md +++ b/dev/LEADERBOARD.md @@ -29,12 +29,12 @@ Note that: - `depth` controls the size of the Transformer - `run` is the wandb name - `model-tag` is the location of the checkpoints on disk -- `device-batch-size` in the ideal world, you want this to be 32 because with sequence length of 2048 (the default) and 8 GPUs we get `32 X 2048 X 8 = 524,288`, which is the total desired batch size determined to work fairly well around this scale. However, for bigger (e.g. d26), 32 is too much and OOMs, so we decrease it by 2 to 16. The `base_train.py` script automatically compensates for this by calculating that it has to use gradient accumulation of 2 to meet the desired total batch size. Therefore, it will fo forward+backward twice and then a single step. Long story short, the ideal value is 32. If that doesn't fit, you decrease it, e.g. 16, 8, etc., keeping it powers of two so that the gradient accumulation math works out neatly. +- `device-batch-size` in the ideal world, you want this to be 32 because with sequence length of 2048 (the default) and 8 GPUs we get `32 X 2048 X 8 = 524,288`, which is the total desired batch size determined to work fairly well around this scale. However, for bigger (e.g. d26), 32 is too much and OOMs, so we decrease it by 2 to 16. The `base_train.py` script automatically compensates for this by calculating that it has to use gradient accumulation of 2 to meet the desired total batch size. Therefore, it will do forward+backward twice and then a single step. Long story short, the ideal value is 32. If that doesn't fit, you decrease it, e.g. 16, 8, etc., keeping it powers of two so that the gradient accumulation math works out neatly. - `sample-every = -1` turns off periodic sampling - `core-metric-max-per-task=-1` means we run the entire CORE eval - `core-metric-every=999999` a bit of a hacky way to make the CORE eval only happen a single time at the very end of the run - `target-param-data-ratio=8.25` controls the training horizon, which is determined in the script by taking the number of non-embedding model parameters and simply multiplying by this number. The current optimal Tokens:Params ratio can be seen in the defaults of the `base_train.py` script (it is 10.5). 10.5 would produce the *compute optimal* model given the currently measured scaling laws. However, GPT-2 capability is currently somewhere in between a d24 and d26. So to reach it exactly, we want to either overtrain d24 or undertrain d26. In this particular example, I am choosing to slightly undertrain a d26. Note that odd depths (e.g. d25) are not super recommended to use because the math around the transformer sizing and its head dimensions doesn't come out neatly. -- `--fp8` turns on fp8 training. If you GPU does not support fp8, you can leave this out and the code will simply train in bf16. bf16 is higher precision than fp8, so you can actually expect that you might be able to do fewer steps (lower the `target-param-data-ratio`) to achieve the same capability. +- `--fp8` turns on fp8 training. If your GPU does not support fp8, you can leave this out and the code will simply train in bf16. bf16 is higher precision than fp8, so you can actually expect that you might be able to do fewer steps (lower the `target-param-data-ratio`) to achieve the same capability. Once you kick off the run, you wait ~3 hours and then at the end you'll see something like: @@ -46,9 +46,9 @@ wandb: total_training_flops 4.330784131228946e+19 wandb: total_training_time 10949.46713 ``` -Your CORE metric must be greater than GPT-2 0.256525. Then you report the `total_training_time`, (e.g. 10949) which is the time of the training iterations alone, excluding all the evaluations and logging, in seconds. So here for example here it is roughly 10949/60/60 ~= 3.04 hours. You should also note and report the validation bpb of your run because the CORE metric can be a little bit noisy. +Your CORE metric must be greater than GPT-2 0.256525. Then you report the `total_training_time`, (e.g. 10949) which is the time of the training iterations alone, excluding all the evaluations and logging, in seconds. So here for example it is roughly 10949/60/60 ~= 3.04 hours. You should also note and report the validation bpb of your run because the CORE metric can be a little bit noisy. -If you outperform GPT-2 and the time is less than current SOTA in the Leaderboard, you get to make a PR. In addition to raw gains, there are some qualitative and aesthetic considerations that go into whether your improvement is merged. For example, if it is gnarly or it significantly bloats the code, or it seems too esoteric, then we will way those things against the improvement demonstrated. Additionally, nanochat cares not only about targeting a single model, but an entire miniseries of models. So your change must be principled enough that it can easily generalize to other model depths, so that we can sweep out a miniseries. +If you outperform GPT-2 and the time is less than current SOTA in the Leaderboard, you get to make a PR. In addition to raw gains, there are some qualitative and aesthetic considerations that go into whether your improvement is merged. For example, if it is gnarly or it significantly bloats the code, or it seems too esoteric, then we will weigh those things against the improvement demonstrated. Additionally, nanochat cares not only about targeting a single model, but an entire miniseries of models. So your change must be principled enough that it can easily generalize to other model depths, so that we can sweep out a miniseries. After you create the commit, to get the current short git commit hash: @@ -89,7 +89,7 @@ Detailed writeup: [Beating GPT-2 for <<$100: the nanochat journey](https://githu ## Run 2 -Achieved Feb 2 2026 on commit `8309b83`. The launch command was +Achieved Feb 2 2026 on commit `a67eba3`. The launch command was ``` OMP_NUM_THREADS=1 torchrun --standalone --nproc_per_node=8 -m scripts.base_train -- \ @@ -117,3 +117,33 @@ Minimum validation bpb: 0.745036 The big change in this run is `--fp8`, which causes all Linear layers (other than the gates) to be switched to fp8 training using `torchao` with tensorwise fp8 scaling. Each step is of slightly lower quality, but we are taking them a lot faster, coming out net ahead. Anyone who does not have fp8 (e.g. using a GPU without it) can simply leave out the `--fp8` flag to train in bfloat16. This will work just fine but it will produce a slightly stronger model than GPT-2 because of the fp8 -> bf16 precision upgrade. It's possible that one can further tune which layers to include in the fp8 conversion and that e.g. some of the smaller matmuls should be just kept in bf16 etc. Previous record was 3.04 hours, so 2.91 hours is `(3.04 - 2.91)/3.04*100` ~= 4.3% speed improvement. + +## Run 3 + +Achieved Feb 5 2026 on commit `2c062aa`. Launch command: + +``` +OMP_NUM_THREADS=1 torchrun --standalone --nproc_per_node=8 -m scripts.base_train -- \ + --depth=26 \ + --run="d26_feb4_double_batch_ratio8.25" \ + --model-tag="d26_feb4_double_batch_ratio8.25" \ + --device-batch-size=16 \ + --total-batch-size=1048576 \ + --sample-every=-1 \ + --save-every=-1 \ + --core-metric-max-per-task=-1 \ + --core-metric-every=999999 \ + --target-param-data-ratio=8.25 \ + --fp8 +``` + +Result: + +``` +core_metric 0.26024 +step 7226 +total_training_time 9922 +Minimum validation bpb: 0.74645 +``` + +The big change here is that the batch size was doubled from 0.5M to 1M, which works better for a d26 model and allowed me to decrease the number of optimization steps a bit via `--target-param-data-ratio` from 8.5 to 8.25. The TLDR is that the original batch size of 0.5M was tuned for d12, but bigger models (e.g. d26) prefer larger total batch size. I determined in experiments that d26 prefers 1M. Then I implemented and merged a principled way to calculate the optimal batch size given depth so that all nanochat models of all depths benefit. See [dev/LOG.md](dev/LOG.md) entry "2026-02-05: Auto Batch Size Scaling" for more detail. diff --git a/dev/LOG.md b/dev/LOG.md index 908fac1..dec2c06 100644 --- a/dev/LOG.md +++ b/dev/LOG.md @@ -4,6 +4,82 @@ A running summary documenting some experiments and findings. Started ~Jan 7 2026 --- +## 2026-02-05: Auto Batch Size Scaling + +### Background + +So far, the `--total-batch-size` was hardcoded to be `2**19 = 524,288` ~= 0.5M tokens. This was the optimal setting for d12, but when I tried to re-tune it for d26 (GPT-2), I noticed that the optimal was closer to `2**20 = 1,048,576` ~= 1M tokens. This is to be expected - larger models prefer a higher optimal total batch size. However, we have to make sure that all settings of `--depth` get their own optimal batch size calculated in some principled way. Here, I referenced the "Power Lines" paper from Cerebras ([arXiv:2505.13738](https://arxiv.org/abs/2505.13738)) for a lot of related experimentation. In particular, they found that **Bopt ∝ D^0.383** (where D is the number of training tokens, not the number of parameters!). So the idea is to tune the optimal batch size on d12, and then extrapolate it with this power law to bigger models. The 0.383 exponent means batch size grows slowly: 10× more tokens only justifies ~2.4× bigger batch. For nanochat's compute-optimal training (D ∝ N via `--target-param-data-ratio`), this means deeper models naturally want larger batches. + +### Implementation + +Added `--total-batch-size=-1` (now the default) to auto-compute optimal batch: + +```python +get_scaling_params = lambda m: m.num_scaling_params()['transformer_matrices'] + m.num_scaling_params()['lm_head'] +if args.total_batch_size == -1: + D_REF = args.target_param_data_ratio * get_scaling_params(build_model_meta(12)) + B_REF = 2**19 + args.total_batch_size = 2 ** round(math.log2(B_REF * (target_tokens / D_REF) ** 0.383)) +``` + +Reference point: d=12 model with B=2^19 (empirically validated). The reference is computed dynamically so that if the architecture changes (e.g., different `--aspect-ratio`), the math automatically adjusts. However, if the model actually does change too much, one would also want to re-tune the optimal batch size for d=12. + +### Results + +With this formula, we currently get: + +| Depth | Scaling Params | Target Tokens | Auto Batch | +|-------|---------------|---------------|------------| +| d=8 | 42M | 0.44B | 2^18 = 262K | +| d=10-16 | 70M-235M | 0.7B-2.5B | 2^19 = 524K | +| d=18-26 | 324M-918M | 3.4B-9.6B | 2^20 = 1.05M | +| d=32-50 | 1.7B-6.2B | 17.6B-65.6B | 2^21 = 2.1M | + +In particular, this matches empirical observations that d26 prefers ~2^20 while d12 prefers ~2^19. + +### Code Cleanup + +Also refactored model initialization to use `build_model_meta(depth)` helper and `dataclasses.asdict()` for cleaner config handling. + +### Useful references + +- [Bergsma et al., Power Laws for Batch Size, Model Size, and Training Horizon](https://arxiv.org/abs/2505.13738) +- [McCandlish et al., An Empirical Model of Large-Batch Training](https://arxiv.org/abs/1812.06162) +- [Brown et al., Language Models are Few-Shot Learners](https://arxiv.org/abs/2005.14165) +- [Merrill et al., The Batch Size–Critical Batch Size Myth](https://arxiv.org/abs/2505.23971) + +### One more thing (batch size ramp) + +Tried batch size ramping. The simplest implementation I could think of "tricks" the existing training loop by slicing each micro-batch into smaller pieces and calling optimizer.step() more frequently early in training (1/8 → 1/4 → 1/2 → full batch over the first x% of training, with sqrt LR scaling). Also required a torch.compile warmup phase to pre-compile all slice sizes and avoid recompilation spikes during training. While the idea is sound and small gains were observed, they weren't sufficient to justify the code complexity introduced (conditional slicing logic, warmup with state save/restore, etc.). Not merged for now. + +--- + +## 2026-02-05: SwiGLU Activation (Negative Result) + +Replaced ReLU² MLP activation with SwiGLU (inspired by [twitter](https://x.com/_xjdr/status/2019141521690567058)). SwiGLU uses three projections instead of two, so to match parameters and FLOPs we scale hidden_dim from 4× to 8/3×: + +```python +# Old ReLU²: 2 matrices, 4x expansion +# params: 2 × n × 4n = 8n² +# flops: 2 × 2n × 4n = 16n² per token +self.c_fc = Linear(n_embd, 4 * n_embd) +self.c_proj = Linear(4 * n_embd, n_embd) +x = c_proj(relu(c_fc(x)).square()) + +# New SwiGLU: 3 matrices, 8/3x expansion +# params: 2 × n × (8n/3) + (8n/3) × n = 8n² ✓ matches +# flops: 3 × 2n × (8n/3) = 16n² per token ✓ matches +hidden_dim = (8 * n_embd) // 3 +self.w1 = Linear(n_embd, hidden_dim) # gate +self.w2 = Linear(n_embd, hidden_dim) # up +self.w3 = Linear(hidden_dim, n_embd) # down +x = w3(silu(w1(x)) * w2(x)) +``` + +Tested at both d12 and d24 (GPT-2 scale). Worse on all measures — step efficiency, wall clock time, and FLOPs. ReLU² remains superior for nanochat. **Not adopted.** + +--- + ## 2026-02-03: Flip Muon MLP LR Multiplier (PR #492) Tested flipping the shape-based LR heuristic in Muon from boosting tall matrices (input projections like `c_fc`) to boosting wide matrices (output projections like `c_proj`). The original code applies `max(1, rows/cols)^0.5`, giving ~2x LR to `c_fc`. The flipped version gives ~2x LR to `c_proj` instead, which aligns with classical fan-in/fan-out scaling conventions. This was proposed in [PR #492](https://github.com/karpathy/nanochat/pull/492) and showed improvements in modded-nanogpt. @@ -733,8 +809,8 @@ Cherry-picked improvements from NorMuon (modded-nanogpt) into our simpler Muon i - Both methods kept in code for easy comparison (`zeropower_via_polar_express` vs `zeropower_via_newtonschulz5`) - **Result:** No dramatic/noticeable difference in training, but keeping the new Polar Express as default. -**2. Variance Reduction (NorMuon-style)** -- Added low-rank variance estimator similar to Adafactor ([arxiv.org/pdf/2510.05491](https://arxiv.org/pdf/2510.05491)) +**2. NorMuon Variance Reduction** +- Added per-neuron/column adaptive learning rate from NorMuon ([arxiv.org/pdf/2510.05491](https://arxiv.org/pdf/2510.05491)) - Maintains `second_momentum_buffer` with shape `[rows, 1]` or `[1, cols]` (whichever is smaller) - Normalizes updates based on running per-row/col variance estimate (beta2=0.95) - Memory overhead: ~1/max(rows, cols) per param, negligible @@ -776,7 +852,7 @@ Example: If d12 optimal is 0.22, then d20 optimal ≈ 0.22 × (12/20)² ≈ 0.08 ### Summary -Muon was changed to use Polar Express, added Adafactor-style variance reduction, and cautious weight decay with schedule that ramps linearly to zero. All of these changes follow modded-nanogpt repo, but all of them were also validated piece by piece to yield improvements in nanochat with the exception of the Polar Express change which was in the noise. This is default on and configurable with `--weight_decay`, using simply 0.2 and ∝ 1/width² scaling. The kwarg `--weight_decay` is therefore changing as of this change. It used to configure AdamW via standard weight decay and now it becomes exclusively used in Muon (AdamW is hardcoded to 0.0), and it is scaled based on depth. +Muon was changed to use Polar Express, added NorMuon variance reduction, and cautious weight decay with schedule that ramps linearly to zero. All of these changes follow modded-nanogpt repo, but all of them were also validated piece by piece to yield improvements in nanochat with the exception of the Polar Express change which was in the noise. This is default on and configurable with `--weight_decay`, using simply 0.2 and ∝ 1/width² scaling. The kwarg `--weight_decay` is therefore changing as of this change. It used to configure AdamW via standard weight decay and now it becomes exclusively used in Muon (AdamW is hardcoded to 0.0), and it is scaled based on depth. --- diff --git a/nanochat/fp8.py b/nanochat/fp8.py new file mode 100644 index 0000000..9d9e9c3 --- /dev/null +++ b/nanochat/fp8.py @@ -0,0 +1,272 @@ +"""Minimal FP8 training for nanochat — tensorwise dynamic scaling only. + +Drop-in replacement for torchao's Float8Linear (~2000 lines) with ~150 lines. +We only need the "tensorwise" recipe (one scalar scale per tensor), not the full +generality of torchao (rowwise scaling, FSDP float8 all-gather, DTensor, tensor +subclass dispatch tables, etc.) + +How FP8 training works +====================== +A standard Linear layer does one matmul in forward and two in backward: + forward: output = input @ weight.T + backward: grad_input = grad_output @ weight + grad_weight= grad_output.T @ input + +FP8 training wraps each of these three matmuls with: + 1. Compute scale = FP8_MAX / max(|tensor|) for each operand + 2. Quantize: fp8_tensor = clamp(tensor * scale, -FP8_MAX, FP8_MAX).to(fp8) + 3. Matmul via torch._scaled_mm (cuBLAS FP8 kernel, ~2x faster than bf16) + 4. Dequantize: _scaled_mm handles this internally using the inverse scales + +The key insight: torch._scaled_mm and the float8 dtypes are PyTorch built-ins. +torchao is just orchestration around these primitives. We can call them directly. + +FP8 dtype choice +================ +There are two FP8 formats. We use both, following the standard convention: + - float8_e4m3fn: 4-bit exponent, 3-bit mantissa, range [-448, 448] + Higher precision (more mantissa bits), used for input and weight. + - float8_e5m2: 5-bit exponent, 2-bit mantissa, range [-57344, 57344] + Wider range (more exponent bits), used for gradients which can be large. + +torch._scaled_mm layout requirements +===================================== +The cuBLAS FP8 kernel requires specific memory layouts: + - First argument (A): must be row-major (contiguous) + - Second argument (B): must be column-major (B.t().contiguous().t()) +If B is obtained by transposing a contiguous tensor (e.g. weight.t()), it is +already column-major — no copy needed. Otherwise we use _to_col_major(). + +How this differs from torchao's approach +======================================== +torchao uses a "tensor subclass" architecture: Float8TrainingTensor is a subclass +of torch.Tensor that bundles FP8 data + scale + metadata. It implements +__torch_dispatch__ with a dispatch table that intercepts every aten op (mm, t, +reshape, clone, ...) and handles it in FP8-aware fashion. When you call + output = input @ weight.T +the @ operator dispatches to aten.mm, which gets intercepted and routed to +torch._scaled_mm behind the scenes. This is ~2000 lines of code because you need +a handler for every tensor operation that might touch an FP8 tensor. + +We take a simpler approach: a single autograd.Function (_Float8Matmul) that takes +full-precision inputs, quantizes to FP8 internally, calls _scaled_mm, and returns +full-precision outputs. Marked @allow_in_graph so torch.compile treats it as one +opaque node rather than trying to trace inside. + +The trade-off is in how torch.compile sees the two approaches: + - torchao: compile decomposes the tensor subclass (via __tensor_flatten__) and + sees every individual op (amax, scale, cast, _scaled_mm) as separate graph + nodes. Inductor can fuse these with surrounding operations (e.g. fuse the + amax computation with the preceding layer's activation function). + - ours: compile sees a single opaque call. It can optimize everything around + the FP8 linear (attention, norms, etc.) but cannot fuse across the boundary. + +Both call the exact same cuBLAS _scaled_mm kernel — the GPU matmul is identical. +The difference is only in the "glue" ops (amax, scale, cast) which are tiny +compared to the matmul. In practice this means our version is slightly faster +(less compilation overhead, no tensor subclass dispatch cost) but can produce +subtly different floating-point rounding paths under torch.compile, since Inductor +generates a different graph. Numerics are bitwise identical in eager mode. +""" + +import torch +import torch.nn as nn + +# Avoid division by zero when computing scale from an all-zeros tensor +EPS = 1e-12 + + +@torch.no_grad() +def _to_fp8(x, fp8_dtype): + """Dynamically quantize a tensor to FP8 using tensorwise scaling. + + "Tensorwise" means one scalar scale for the entire tensor (as opposed to + "rowwise" which computes a separate scale per row). Tensorwise is faster + because cuBLAS handles the scaling; rowwise needs the CUTLASS kernel. + + Returns (fp8_data, inverse_scale) for use with torch._scaled_mm. + """ + fp8_max = torch.finfo(fp8_dtype).max + # Compute the max absolute value across the entire tensor + amax = x.float().abs().max() + # Scale maps [0, amax] -> [0, fp8_max]. Use float64 for the division to + # ensure consistent numerics between torch.compile and eager mode. + # (torchao does the same upcast — without it, compile/eager can diverge) + scale = fp8_max / amax.double().clamp(min=EPS) + scale = scale.float() + # Quantize: scale into FP8 range, saturate (clamp prevents overflow when + # casting — PyTorch's default is to wrap, not saturate), then cast to FP8 + x_scaled = x.float() * scale + x_clamped = x_scaled.clamp(-fp8_max, fp8_max) + x_fp8 = x_clamped.to(fp8_dtype) + # _scaled_mm expects the *inverse* of our scale (it multiplies by this to + # convert FP8 values back to the original range during the matmul) + inv_scale = scale.reciprocal() + return x_fp8, inv_scale + + +def _to_col_major(x): + """Rearrange a 2D tensor's memory to column-major layout. + + torch._scaled_mm requires its second operand in column-major layout. + The trick: transpose -> contiguous (forces a copy in transposed order) + -> transpose back. The result has the same logical shape but column-major + strides, e.g. a [M, N] tensor gets strides (1, M) instead of (N, 1). + """ + return x.t().contiguous().t() + + +# allow_in_graph tells torch.compile to treat this as an opaque operation — +# dynamo won't try to decompose it into smaller ops. See the module docstring +# for how this differs from torchao's tensor subclass approach. +@torch._dynamo.allow_in_graph +class _Float8Matmul(torch.autograd.Function): + """Custom autograd for the three FP8 GEMMs of a Linear layer. + + The forward saves input and weight in their original precision for the + backward pass. Each GEMM independently re-quantizes its operands to FP8. + (We don't reuse the forward's FP8 tensors in backward — the backward might + want different precision, and saving FP8 would lose information.) + """ + + @staticmethod + def forward(ctx, input_2d, weight): + ctx.save_for_backward(input_2d, weight) + + # Quantize both operands to e4m3 (higher precision format) + input_fp8, input_inv = _to_fp8(input_2d, torch.float8_e4m3fn) + weight_fp8, weight_inv = _to_fp8(weight, torch.float8_e4m3fn) + + # output = input @ weight.T + # input_fp8 is [B, K] contiguous = row-major (good for first arg) + # weight_fp8 is [N, K] contiguous, so weight_fp8.t() is [K, N] with + # strides (1, K) = column-major (good for second arg, no copy needed!) + output = torch._scaled_mm( + input_fp8, + weight_fp8.t(), + scale_a=input_inv, + scale_b=weight_inv, + out_dtype=input_2d.dtype, + # use_fast_accum=True accumulates the dot products in lower precision. + # Slightly less accurate but measurably faster. Standard practice for + # the forward pass; we use False in backward for more precise gradients. + use_fast_accum=True, + ) + return output + + @staticmethod + def backward(ctx, grad_output): + input_2d, weight = ctx.saved_tensors + + # === GEMM 1: grad_input = grad_output @ weight === + # Shapes: [B, N] @ [N, K] -> [B, K] + # Gradients use e5m2 (wider range), weights use e4m3 (higher precision) + go_fp8, go_inv = _to_fp8(grad_output, torch.float8_e5m2) + w_fp8, w_inv = _to_fp8(weight, torch.float8_e4m3fn) + # go_fp8 is [B, N] contiguous = row-major, good for first arg + # w_fp8 is [N, K] contiguous = row-major, need column-major for second arg + w_col = _to_col_major(w_fp8) + grad_input = torch._scaled_mm( + go_fp8, + w_col, + scale_a=go_inv, + scale_b=w_inv, + out_dtype=grad_output.dtype, + use_fast_accum=False, + ) + + # === GEMM 2: grad_weight = grad_output.T @ input === + # Shapes: [N, B] @ [B, K] -> [N, K] + go_fp8_2, go_inv_2 = _to_fp8(grad_output, torch.float8_e5m2) + in_fp8, in_inv = _to_fp8(input_2d, torch.float8_e4m3fn) + # go_fp8_2 is [B, N] contiguous, we need go.T = [N, B] as first arg. + # Transposing gives column-major, but first arg needs row-major, + # so we must call .contiguous() to physically rearrange the memory. + go_T = go_fp8_2.t().contiguous() # [N, B] row-major + in_col = _to_col_major(in_fp8) # [B, K] column-major + grad_weight = torch._scaled_mm( + go_T, + in_col, + scale_a=go_inv_2, + scale_b=in_inv, + out_dtype=grad_output.dtype, + use_fast_accum=False, + ) + + return grad_input, grad_weight + + +class Float8Linear(nn.Linear): + """Drop-in nn.Linear replacement that does FP8 compute. + + Weights and biases remain in their original precision (e.g. fp32/bf16). + Only the matmul is performed in FP8 via the _Float8Matmul autograd function. + """ + + def forward(self, input): + # Replicate the autocast behavior of F.linear — when autocast is active, + # we need to manually cast input to the autocast dtype (e.g. bf16), + # since we bypass F.linear's built-in autocast handling. + if torch.is_autocast_enabled(): + input = input.to(torch.get_autocast_gpu_dtype()) + # _scaled_mm only works on 2D tensors, so flatten batch dimensions + orig_shape = input.shape + input_2d = input.reshape(-1, orig_shape[-1]) + output = _Float8Matmul.apply(input_2d, self.weight) + output = output.reshape(*orig_shape[:-1], output.shape[-1]) + if self.bias is not None: + output = output + self.bias.to(output.dtype) + return output + + @classmethod + def from_float(cls, mod): + """Create Float8Linear from nn.Linear, sharing the same weight and bias. + + Uses meta device to avoid allocating a temporary weight tensor — we + create the module shell on meta (shapes/dtypes only, no memory), then + point .weight and .bias to the original module's parameters. + """ + with torch.device("meta"): + new_mod = cls(mod.in_features, mod.out_features, bias=False) + new_mod.weight = mod.weight + new_mod.bias = mod.bias + return new_mod + + +class Float8LinearConfig: + """Minimal config matching torchao's API. Only tensorwise recipe is supported.""" + + @staticmethod + def from_recipe_name(recipe_name): + if recipe_name != "tensorwise": + raise ValueError( + f"Only 'tensorwise' recipe is supported, got '{recipe_name}'. " + f"Rowwise/axiswise recipes require the full torchao library." + ) + return Float8LinearConfig() + + +def convert_to_float8_training(module, *, config=None, module_filter_fn=None): + """Replace nn.Linear layers with Float8Linear throughout a module. + + Walks the module tree in post-order (children before parents) and swaps + each nn.Linear that passes the optional filter. The new Float8Linear shares + the original weight and bias tensors — no copies, no extra memory. + + Args: + module: Root module to convert. + config: Float8LinearConfig (accepted for API compat, only tensorwise supported). + module_filter_fn: Optional filter(module, fqn) -> bool. Only matching Linears + are converted. Common use: skip layers with dims not divisible by 16 + (hardware requirement for FP8 matmuls on H100). + """ + def _convert(mod, prefix=""): + for name, child in mod.named_children(): + fqn = f"{prefix}.{name}" if prefix else name + _convert(child, fqn) + if isinstance(child, nn.Linear) and not isinstance(child, Float8Linear): + if module_filter_fn is None or module_filter_fn(child, fqn): + setattr(mod, name, Float8Linear.from_float(child)) + + _convert(module) + return module diff --git a/nanochat/gpt.py b/nanochat/gpt.py index fc3d422..8779a85 100644 --- a/nanochat/gpt.py +++ b/nanochat/gpt.py @@ -27,6 +27,12 @@ from nanochat.optim import MuonAdamW, DistMuonAdamW # Our custom Flash Attention module that automatically uses FA3 on Hopper+ and SDPA fallback elsewhere from nanochat.flash_attention import flash_attn +# FP8 imports (optional) — needed by reparam_linear for FP8 path +try: + from torchao.float8.float8_linear import matmul_with_hp_or_float8_args +except ImportError: + pass + @dataclass class GPTConfig: sequence_len: int = 2048 @@ -46,6 +52,26 @@ def norm(x): return F.rms_norm(x, (x.size(-1),)) +def reparam_linear(module, x, gamma=None, scalar=None): + """Linear with gamma/scalar folded into weight. Works with both nn.Linear and Float8Linear. + + gamma: RMSNorm learnable weight, folded into input dim of W (w = w * gamma[None, :]) + scalar: projection scalar, folded into output dim of W (w = scalar[:, None] * w) + + For FP8, dispatches through Float8Linear's internal matmul to preserve FP8 tensor cores. + """ + w = module.weight + if gamma is not None: + w = w * gamma[None, :] + if scalar is not None: + w = scalar[:, None] * w + # FP8 path: use Float8Linear's internal matmul to preserve FP8 tensor cores + if hasattr(module, 'linear_mm_config'): + return matmul_with_hp_or_float8_args.apply(x, w.t(), module.linear_mm_config, module.config) + # BF16 path + return F.linear(x, w) + + def has_ve(layer_idx, n_layer): """Returns True if GPT layer should have Value Embedding (alternating, last layer always included).""" return layer_idx % 2 == (n_layer - 1) % 2 @@ -74,21 +100,22 @@ class CausalSelfAttention(nn.Module): self.c_proj = nn.Linear(self.n_embd, self.n_embd, bias=False) self.ve_gate_channels = 32 self.ve_gate = nn.Linear(self.ve_gate_channels, self.n_kv_head, bias=False) if has_ve(layer_idx, config.n_layer) else None + self.v_proj_scalar = nn.Parameter(torch.zeros(self.n_kv_head)) if has_ve(layer_idx, config.n_layer) else None self.c_proj_scalar = nn.Parameter(torch.zeros(config.n_embd)) def forward(self, x, ve, cos_sin, window_size, kv_cache): B, T, C = x.size() - # Project the input to get queries, keys, and values + # Project the input to get queries, keys, and values (gamma folded into weights) # Shape: (B, T, H, D) - FA3's native layout, no transpose needed! - q = self.c_q(x).view(B, T, self.n_head, self.head_dim) - k = self.c_k(x).view(B, T, self.n_kv_head, self.head_dim) - v = self.c_v(x).view(B, T, self.n_kv_head, self.head_dim) + q = reparam_linear(self.c_q, x).view(B, T, self.n_head, self.head_dim) + k = reparam_linear(self.c_k, x).view(B, T, self.n_kv_head, self.head_dim) + v = reparam_linear(self.c_v, x).view(B, T, self.n_kv_head, self.head_dim) # Value residual (ResFormer): mix in value embedding with input-dependent gate per head if ve is not None: ve = ve.view(B, T, self.n_kv_head, self.head_dim) - gate = 2 * torch.sigmoid(self.ve_gate(x[..., :self.ve_gate_channels])) # (B, T, n_kv_head), range (0, 2) + gate = 2 * torch.sigmoid(reparam_linear(self.ve_gate, x[..., :self.ve_gate_channels], scalar=self.v_proj_scalar)) # (B, T, n_kv_head), range (0, 2) v = v + gate.unsqueeze(-1) * ve # Apply Rotary Embeddings to queries and keys to get relative positional encoding @@ -115,10 +142,9 @@ class CausalSelfAttention(nn.Module): if self.layer_idx == kv_cache.n_layers - 1: kv_cache.advance(T) - # Re-assemble the heads and project back to residual stream + # Re-assemble the heads and project back to residual stream (scalar folded into weight) y = y.contiguous().view(B, T, -1) - y = self.c_proj(y) - y = y * self.c_proj_scalar + y = reparam_linear(self.c_proj, y, scalar=self.c_proj_scalar) return y @@ -130,24 +156,21 @@ class MLP(nn.Module): self.c_proj_scalar = nn.Parameter(torch.zeros(config.n_embd)) def forward(self, x): - x = self.c_fc(x) + x = reparam_linear(self.c_fc, x) x = F.relu(x).square() - x = self.c_proj(x) - x = x * self.c_proj_scalar + x = reparam_linear(self.c_proj, x, scalar=self.c_proj_scalar) return x class Block(nn.Module): def __init__(self, config, layer_idx): super().__init__() - self.attn_norm = nn.RMSNorm(config.n_embd) self.attn = CausalSelfAttention(config, layer_idx) - self.mlp_norm = nn.RMSNorm(config.n_embd) self.mlp = MLP(config) def forward(self, x, ve, cos_sin, window_size, kv_cache): - x = x + self.attn(self.attn_norm(x), ve, cos_sin, window_size, kv_cache) - x = x + self.mlp(self.mlp_norm(x)) + x = x + self.attn(norm(x), ve, cos_sin, window_size, kv_cache) + x = x + self.mlp(norm(x)) return x @@ -207,18 +230,12 @@ class GPT(nn.Module): attn.c_proj: uniform, std=1/sqrt(n_embd) mlp.c_fc: uniform, std=1/sqrt(n_embd) mlp.c_proj: uniform, std=1/sqrt(n_embd) - nn.RMSNorm weight: ones (via explicit init below) """ # Embedding and unembedding torch.nn.init.normal_(self.transformer.wte.weight, mean=0.0, std=1.0) torch.nn.init.normal_(self.lm_head.weight, mean=0.0, std=0.001) - # nn.RMSNorm weight parameters: init to ones (must be explicit due to meta device) - for module in self.modules(): - if isinstance(module, nn.RMSNorm): - module.weight.fill_(1.0) - # Transformer blocks: uniform init with bound = sqrt(3) * std (same standard deviation as normal) n_embd = self.config.n_embd s = 3**0.5 * n_embd**-0.5 # sqrt(3) multiplier makes sure Uniform achieves the same std as Normal @@ -238,26 +255,23 @@ class GPT(nn.Module): for block in self.transformer.h: block.attn.c_proj_scalar.fill_(0.0) block.mlp.c_proj_scalar.fill_(0.0) + if block.attn.v_proj_scalar is not None: + block.attn.v_proj_scalar.fill_(0.0) if self.transformer.wte.weight.device.type == "cuda": block.attn.c_proj_scalar.data = block.attn.c_proj_scalar.data.to(torch.bfloat16) block.mlp.c_proj_scalar.data = block.mlp.c_proj_scalar.data.to(torch.bfloat16) - - # Block RMSNorm weights (cast to bf16 for fused kernel) - for block in self.transformer.h: - block.attn_norm.weight.fill_(1.0) - block.mlp_norm.weight.fill_(1.0) - if self.transformer.wte.weight.device.type == "cuda": - block.attn_norm.to(dtype=torch.bfloat16) - block.mlp_norm.to(dtype=torch.bfloat16) + if block.attn.v_proj_scalar is not None: + block.attn.v_proj_scalar.data = block.attn.v_proj_scalar.data.to(torch.bfloat16) # Value embeddings (init like c_v: uniform with same std) for ve in self.value_embeds.values(): torch.nn.init.uniform_(ve.weight, -s, s) - # Gate weights init to uniform (avoid zero-norm params under Hyperball) + # Gate weights init to uniform (avoid zero-norm params under Hyperball, following mup) + s_ve_gate = 3 ** 0.5 * 32**-0.5 for block in self.transformer.h: if block.attn.ve_gate is not None: - torch.nn.init.uniform_(block.attn.ve_gate.weight, -s, s) + torch.nn.init.uniform_(block.attn.ve_gate.weight, -s_ve_gate, s_ve_gate) # Rotary embeddings head_dim = self.config.n_embd // self.config.n_head @@ -378,7 +392,7 @@ class GPT(nn.Module): 'total': total, } - def setup_optimizer(self, unembedding_lr=0.004, embedding_lr=0.2, matrix_lr=0.02, weight_decay=0.0, adam_betas=(0.8, 0.95), scalar_lr=0.5, matrix_optimizer="muon"): + def setup_optimizer(self, unembedding_lr=0.004, embedding_lr=0.2, matrix_lr=0.02, weight_decay=0.0, adam_betas=(0.8, 0.95), scalar_lr=0.5, norm_lr=0.1, matrix_optimizer="muon"): model_dim = self.config.n_embd ddp, rank, local_rank, world_size = get_dist_info() @@ -406,7 +420,7 @@ class GPT(nn.Module): dict(kind='adamw', params=value_embeds_params, lr=embedding_lr * dmodel_lr_scale, betas=adam_betas, eps=1e-10, weight_decay=0.0), dict(kind='adamw', params=resid_params, lr=scalar_lr * 0.01, betas=adam_betas, eps=1e-10, weight_decay=0.0), dict(kind='adamw', params=x0_params, lr=scalar_lr, betas=(0.96, 0.95), eps=1e-10, weight_decay=0.0), # higher beta1 for x0 - dict(kind='adamw', params=block_1d_params, lr=scalar_lr, betas=adam_betas, eps=1e-10, weight_decay=0.0), + dict(kind='adamw', params=block_1d_params, lr=norm_lr, betas=adam_betas, eps=1e-10, weight_decay=0.0), ] # Matrix params (Muon or Hyperball), grouped by shape for stacking if matrix_optimizer not in {"muon", "hyperball"}: diff --git a/nanochat/optim.py b/nanochat/optim.py index f93ba38..a44534a 100644 --- a/nanochat/optim.py +++ b/nanochat/optim.py @@ -67,6 +67,10 @@ Polar Express Sign Method for orthogonalization. https://arxiv.org/pdf/2505.16932 by Noah Amsel, David Persson, Christopher Musco, Robert M. Gower. +NorMuon variance reduction: per-neuron/column adaptive learning rate that normalizes +update scales after orthogonalization (Muon's output has non-uniform scales across neurons). +https://arxiv.org/pdf/2510.05491 + Some of the changes in nanochat implementation: - Uses a simpler, more general approach to parameter grouping and stacking - Uses a single fused kernel for the momentum -> polar_express -> variance_reduction -> update step @@ -517,6 +521,7 @@ class DistMuonAdamW(torch.optim.Optimizer): param_infos[p] = dict(future=future, grad_slice=grad, is_small=True) else: # Large params: reduce_scatter + assert grad.shape[0] % world_size == 0, f"AdamW reduce_scatter requires shape[0] ({grad.shape[0]}) divisible by world_size ({world_size})" rank_size = grad.shape[0] // world_size grad_slice = torch.empty_like(grad[:rank_size]) future = dist.reduce_scatter_tensor(grad_slice, grad, op=dist.ReduceOp.AVG, async_op=True).get_future() diff --git a/pyproject.toml b/pyproject.toml index bcb674d..8b6fd95 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,7 +20,6 @@ dependencies = [ "tiktoken>=0.11.0", "tokenizers>=0.22.0", "torch==2.9.1", - "torchao==0.15.0", "transformers>=4.57.3", "uvicorn>=0.36.0", "wandb>=0.21.3", diff --git a/runs/miniseries.sh b/runs/miniseries.sh index c42544e..01c4459 100644 --- a/runs/miniseries.sh +++ b/runs/miniseries.sh @@ -28,7 +28,7 @@ fi # Series name: from arg, env var, or default to today's date (e.g., jan11) SERIES_NAME="${1:-${SERIES_NAME:-$(date +%b%d | tr '[:upper:]' '[:lower:]')}}" # Depths to train (the "miniseries") -DEPTHS=(10 11 12 13 14 15 16 17 18 19 20) +DEPTHS=(12 14 16 18 20 22 24 26) # Hardware NPROC_PER_NODE="${NPROC_PER_NODE:-8}" # Logging @@ -57,8 +57,15 @@ for d in "${DEPTHS[@]}"; do TAG="${SERIES_NAME}_miniseries_d${d}" START_TIME=$(date +%s) - # Train the model with natural horizon (target_param_data_ratio default) - # No --target-flops, let it use the default ratio from base_train + # Reduce --device-batch-size to avoid OOM at larger depths + if [ $d -ge 28 ]; then + DEVICE_BATCH_SIZE_ARG="--device-batch-size=8" + elif [ $d -ge 20 ]; then + DEVICE_BATCH_SIZE_ARG="--device-batch-size=16" + else + DEVICE_BATCH_SIZE_ARG="--device-batch-size=32" + fi + torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.base_train -- \ --depth=$d \ --run="${WANDB_RUN}_d${d}" \ @@ -67,6 +74,7 @@ for d in "${DEPTHS[@]}"; do --core-metric-max-per-task=-1 \ --sample-every=-1 \ --save-every=-1 \ + $DEVICE_BATCH_SIZE_ARG \ 2>&1 | tee "$RESULTS_DIR/${TAG}_train.log" END_TIME=$(date +%s) diff --git a/runs/quickrun_muonh.sh b/runs/quickrun_muonh.sh index 5aa303f..5abf675 100755 --- a/runs/quickrun_muonh.sh +++ b/runs/quickrun_muonh.sh @@ -16,12 +16,12 @@ set -e # ----------------------------------------------------------------------------- # Config -DEPTH="${DEPTH:-24}" +DEPTH="${DEPTH:-26}" NUM_SHARDS="${NUM_SHARDS:-370}" # default for d24 @ ratio~11 -TARGET_RATIO="${TARGET_RATIO:-11}" +TARGET_RATIO="${TARGET_RATIO:-10.5}" WINDOW_PATTERN="${WINDOW_PATTERN:-SSSL}" DEVICE_BATCH_SIZE="${DEVICE_BATCH_SIZE:-16}" -TOTAL_BATCH_SIZE="${TOTAL_BATCH_SIZE:-524288}" +TOTAL_BATCH_SIZE="${TOTAL_BATCH_SIZE:-524288}" # -1 = auto-compute optimal (Power Lines paper) NPROC_PER_NODE="${NPROC_PER_NODE:-$(nvidia-smi -L 2>/dev/null | wc -l || echo 1)}" if [ "$NPROC_PER_NODE" -eq 0 ]; then @@ -38,13 +38,15 @@ MATRIX_WARMDOWN_RATIO="${MATRIX_WARMDOWN_RATIO:-1.0}" # AdamW EMBEDDING_LR="${EMBEDDING_LR:-0.3}" UNEMBEDDING_LR="${UNEMBEDDING_LR:-0.004}" +NORM_LR="${NORM_LR:-0.1}" # Wandb -WANDB_PROJECT="nanochat" -WANDB_RUN="${WANDB_RUN:-muonh_d${DEPTH}_ratio${TARGET_RATIO}}" +export WANDB_ENTITY="${WANDB_ENTITY:-xingyu20}" +export WANDB_PROJECT="${WANDB_PROJECT:-nanochat}" +WANDB_RUN="${WANDB_RUN:-muonh_d${DEPTH}_ratio${TARGET_RATIO}_feb_11_no_gamma}" MODEL_TAG="${MODEL_TAG:-d${DEPTH}_gamma_muonh}" -# FP8 (default enabled) +# FP8 (default enabled)c FP8="${FP8:-1}" FP8_ARGS="" if [ "${FP8:-0}" -eq 1 ]; then @@ -81,6 +83,7 @@ echo "Device batch size: $DEVICE_BATCH_SIZE" echo "Total batch size: $TOTAL_BATCH_SIZE" echo "Matrix optimizer: $MATRIX_OPTIMIZER" echo "Matrix LR: $MATRIX_LR" +echo "Norm LR: $NORM_LR" echo "Adam LRs: embedding=$EMBEDDING_LR, unembedding=$UNEMBEDDING_LR, scalar=$SCALAR_LR" echo "Warmdown ratio: adam=$WARMDOWN_RATIO, matrix=$MATRIX_WARMDOWN_RATIO" echo "Wandb run: $WANDB_RUN" @@ -111,8 +114,13 @@ echo "Downloading $NUM_SHARDS data shards..." python -m nanochat.dataset -n "$NUM_SHARDS" echo "" -echo "Checking tokenizer..." -python -m scripts.tok_train --max-chars=500000000 --vocab-size=32768 +TOKENIZER_DIR="$NANOCHAT_BASE_DIR/tokenizer" +if [ -f "$TOKENIZER_DIR/token_bytes.pt" ]; then + echo "Tokenizer already exists at $TOKENIZER_DIR, skipping training." +else + echo "Training tokenizer..." + python -m scripts.tok_train --max-chars=500000000 --vocab-size=32768 +fi # ----------------------------------------------------------------------------- # Train @@ -134,10 +142,11 @@ TRAIN_ARGS=( --matrix-warmdown-ratio=$MATRIX_WARMDOWN_RATIO --embedding-lr=$EMBEDDING_LR --unembedding-lr=$UNEMBEDDING_LR + --norm-lr=$NORM_LR --scalar-lr=$SCALAR_LR - --core-metric-every=2000 - --sample-every=-1 - --save-every=-1 + --core-metric-every=${CORE_METRIC_EVERY:-2000} + --sample-every=${SAMPLE_EVERY:--1} + --save-every=${SAVE_EVERY:--1} ) if [ "$NPROC_PER_NODE" -gt 1 ]; then @@ -152,4 +161,4 @@ echo "" echo "==============================================" echo "Training complete!" echo "==============================================" -echo "Checkpoint saved to: $NANOCHAT_BASE_DIR/base_checkpoints/${MODEL_TAG}/" +echo "Checkpoint saved to: $NANOCHAT_BASE_DIR/base_checkpoints/${MODEL_TAG}/" \ No newline at end of file diff --git a/runs/scaling_laws_muonh.sh b/runs/scaling_laws_muonh.sh new file mode 100755 index 0000000..a5af017 --- /dev/null +++ b/runs/scaling_laws_muonh.sh @@ -0,0 +1,220 @@ +#!/bin/bash + +# Scaling Laws Sweep for GPT-Gamma + MuonH (Hyperball) +# Runs IsoFLOP analysis: for each compute budget, sweep model depths to find optimal size. +# Results saved to CSV for analysis with dev/scaling_analysis.ipynb +# +# Usage: +# bash runs/scaling_laws_muonh.sh +# LABEL=feb06 bash runs/scaling_laws_muonh.sh +# FP8=0 bash runs/scaling_laws_muonh.sh + +set -e + +LABEL="${LABEL:-muonh_$(date +%b%d | tr '[:upper:]' '[:lower:]')}" + +FLOPS_BUDGETS=( + 1e18 + 2.15e18 + 4.64e18 + 1e19 +) +DEPTHS=(8 10 12 14 16 18 20) + +NPROC_PER_NODE="${NPROC_PER_NODE:-$(nvidia-smi -L 2>/dev/null | wc -l || echo 1)}" +if [ "$NPROC_PER_NODE" -eq 0 ]; then + NPROC_PER_NODE=1 +fi + +# Fixed batch size (auto batch size requires target-param-data-ratio, not compatible with target-flops) +TOTAL_BATCH_SIZE="${TOTAL_BATCH_SIZE:-524288}" +DEVICE_BATCH_SIZE="${DEVICE_BATCH_SIZE:-16}" +EVAL_TOKENS=$((100 * 524288)) # ~100M tokens for final eval + +# Optimizer (MuonH defaults) +MATRIX_OPTIMIZER="${MATRIX_OPTIMIZER:-hyperball}" +MATRIX_LR="${MATRIX_LR:-0.02}" +EMBEDDING_LR="${EMBEDDING_LR:-0.3}" +UNEMBEDDING_LR="${UNEMBEDDING_LR:-0.004}" +SCALAR_LR="${SCALAR_LR:-0.5}" +NORM_LR="${NORM_LR:-0.2}" +WARMDOWN_RATIO="${WARMDOWN_RATIO:-0.3}" +MATRIX_WARMDOWN_RATIO="${MATRIX_WARMDOWN_RATIO:-1.0}" +WINDOW_PATTERN="${WINDOW_PATTERN:-SSSL}" + +# FP8 (default enabled) +FP8="${FP8:-1}" +FP8_ARGS="" +if [ "${FP8}" -eq 1 ]; then + FP8_RECIPE="${FP8_RECIPE:-tensorwise}" + FP8_ARGS="--fp8 --fp8-recipe=${FP8_RECIPE}" +fi + +# Wandb +export WANDB_PROJECT="${WANDB_PROJECT:-nanochat-scaling}" +WANDB_RUN="${WANDB_RUN:-scaling_${LABEL}}" + +# Paths and cache +export OMP_NUM_THREADS=1 +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +PROJECT_ROOT="$(dirname "$SCRIPT_DIR")" +export NANOCHAT_BASE_DIR="${NANOCHAT_BASE_DIR:-$PROJECT_ROOT/cache}" +export TORCHINDUCTOR_CACHE_DIR="$NANOCHAT_BASE_DIR/torch_inductor" +export TRITON_CACHE_DIR="$NANOCHAT_BASE_DIR/triton" +export TMPDIR="$NANOCHAT_BASE_DIR/tmp" +mkdir -p "$NANOCHAT_BASE_DIR" "$TORCHINDUCTOR_CACHE_DIR" "$TRITON_CACHE_DIR" "$TMPDIR" + +cd "$PROJECT_ROOT" + +# Python venv +if [ ! -d ".venv" ]; then + echo "Setting up Python environment..." + command -v uv &> /dev/null || curl -LsSf https://astral.sh/uv/install.sh | sh + uv venv + uv sync --extra gpu +fi +source .venv/bin/activate + +RESULTS_DIR="$NANOCHAT_BASE_DIR/scaling_laws_results_${LABEL}" +mkdir -p "$RESULTS_DIR" +RESULTS_FILE="$RESULTS_DIR/results.csv" + +# Write CSV header only if file doesn't exist +if [ ! -f "$RESULTS_FILE" ]; then + echo "flops_budget,depth,model_dim,params_wte,params_value_embeds,params_lm_head,params_transformer,params_norm_and_proj_scalars,params_scalars,params_total,num_iterations,tokens_trained,val_bpb,core_score,train_time_sec" > "$RESULTS_FILE" +fi + +log() { + echo "[$(date '+%Y-%m-%d %H:%M:%S')] $1" +} + +# Check if a run already exists in results +run_exists() { + local flops=$1 + local depth=$2 + grep -q "^${flops},${depth}," "$RESULTS_FILE" 2>/dev/null +} + +# ============================================================================= +# Print summary +# ============================================================================= + +log "==============================================" +log "Scaling Laws Sweep (GPT-Gamma + MuonH)" +log "==============================================" +log "Label: $LABEL" +log "FLOPs budgets: ${FLOPS_BUDGETS[*]}" +log "Depths: ${DEPTHS[*]}" +log "Num GPUs: $NPROC_PER_NODE" +log "Total batch size: $TOTAL_BATCH_SIZE" +log "Matrix optimizer: $MATRIX_OPTIMIZER" +log "Matrix LR: $MATRIX_LR" +log "Norm LR: $NORM_LR" +log "Warmdown ratio: adam=$WARMDOWN_RATIO, matrix=$MATRIX_WARMDOWN_RATIO" +if [ "${FP8}" -eq 1 ]; then + log "FP8: enabled ($FP8_RECIPE)" +fi +log "Results dir: $RESULTS_DIR" +log "==============================================" + +# ============================================================================= +# Main Loop +# ============================================================================= + +for flops in "${FLOPS_BUDGETS[@]}"; do + log "==============================================" + log "Compute budget: $flops FLOPs" + log "==============================================" + + for d in "${DEPTHS[@]}"; do + + # Skip if already completed + if run_exists "$flops" "$d"; then + log "Skipping d=$d at $flops FLOPs (already in results)" + continue + fi + + log "Training d=$d at $flops FLOPs..." + + # Unique tag for this run + TAG="scaling_${LABEL}_${flops}_d${d}" + + # Record start time + START_TIME=$(date +%s) + + # Train the model with fixed flops budget + TRAIN_ARGS=( + --depth=$d + --target-flops=$flops + --target-param-data-ratio=-1 + --total-batch-size=$TOTAL_BATCH_SIZE + --device-batch-size=$DEVICE_BATCH_SIZE + --run="${WANDB_RUN}_${TAG}" + --model-tag="${TAG}" + --window-pattern=$WINDOW_PATTERN + --matrix-optimizer=$MATRIX_OPTIMIZER + --matrix-lr=$MATRIX_LR + --embedding-lr=$EMBEDDING_LR + --unembedding-lr=$UNEMBEDDING_LR + --scalar-lr=$SCALAR_LR + --norm-lr=$NORM_LR + --warmdown-ratio=$WARMDOWN_RATIO + --matrix-warmdown-ratio=$MATRIX_WARMDOWN_RATIO + --eval-tokens=$EVAL_TOKENS + --core-metric-every=999999 + --core-metric-max-per-task=-1 + --sample-every=-1 + --save-every=-1 + ) + + if [ "$NPROC_PER_NODE" -gt 1 ]; then + torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.base_train -- \ + "${TRAIN_ARGS[@]}" $FP8_ARGS \ + 2>&1 | tee "$RESULTS_DIR/${TAG}_train.log" + else + python -m scripts.base_train \ + "${TRAIN_ARGS[@]}" $FP8_ARGS \ + 2>&1 | tee "$RESULTS_DIR/${TAG}_train.log" + fi + + END_TIME=$(date +%s) + TRAIN_TIME=$((END_TIME - START_TIME)) + + # Extract training stats from the log + LOG_FILE="$RESULTS_DIR/${TAG}_train.log" + + # Extract detailed parameter counts (handle whitespace-padded format) + PARAMS_WTE=$(grep "wte" "$LOG_FILE" | grep ":" | tail -1 | grep -oP '[\d,]+' | tr -d ',') + PARAMS_VE=$(grep "value_embeds" "$LOG_FILE" | grep ":" | tail -1 | grep -oP '[\d,]+' | tr -d ',') + PARAMS_LM=$(grep "lm_head" "$LOG_FILE" | grep ":" | tail -1 | grep -oP '[\d,]+' | tr -d ',') + PARAMS_TRANSFORMER=$(grep "transformer_matrices" "$LOG_FILE" | grep ":" | tail -1 | grep -oP '[\d,]+' | tr -d ',') + PARAMS_NORM=$(grep "norm_and_proj_scalars" "$LOG_FILE" | grep ":" | tail -1 | grep -oP '[\d,]+' | tr -d ',') + PARAMS_SCALARS=$(grep -w "scalars" "$LOG_FILE" | grep ":" | tail -1 | grep -oP '[\d,]+' | tr -d ',') + PARAMS_TOTAL=$(grep -w "total" "$LOG_FILE" | grep ":" | tail -1 | grep -oP '[\d,]+' | tr -d ',') + + NUM_ITERS=$(grep "Calculated number of iterations" "$LOG_FILE" | tail -1 | sed 's/.*: //' | tr -d ',') + TOKENS_TRAINED=$((NUM_ITERS * TOTAL_BATCH_SIZE)) + MODEL_DIM=$((d * 64)) + VAL_BPB=$(grep "Validation bpb:" "$LOG_FILE" | tail -1 | grep -oP '[\d.]+$') + + # Extract CORE score from training log (evaluated on final step) + CORE_SCORE=$(grep "CORE metric:" "$LOG_FILE" | tail -1 | awk '{print $NF}') + if [ -z "$CORE_SCORE" ]; then + log "WARNING: Could not extract CORE score for d=$d" + CORE_SCORE="0.0" + fi + + log " Params: $PARAMS_TOTAL (transformer: $PARAMS_TRANSFORMER), Iters: $NUM_ITERS, Val BPB: $VAL_BPB, CORE: $CORE_SCORE" + + # Append to CSV + echo "$flops,$d,$MODEL_DIM,$PARAMS_WTE,$PARAMS_VE,$PARAMS_LM,$PARAMS_TRANSFORMER,$PARAMS_NORM,$PARAMS_SCALARS,$PARAMS_TOTAL,$NUM_ITERS,$TOKENS_TRAINED,$VAL_BPB,$CORE_SCORE,$TRAIN_TIME" >> "$RESULTS_FILE" + done +done + +log "==============================================" +log "Scaling Laws Sweep Complete" +log "==============================================" +log "Results saved to: $RESULTS_FILE" +echo "" +echo "Results:" +column -t -s',' "$RESULTS_FILE" diff --git a/runs/speedrun.sh b/runs/speedrun.sh index d390c6d..62466c7 100644 --- a/runs/speedrun.sh +++ b/runs/speedrun.sh @@ -70,7 +70,7 @@ echo "Waiting for dataset download to complete..." wait $DATASET_DOWNLOAD_PID # d24 model (slightly overtrained is enough to beat GPT-2 => increase data:params ratio from compute optimal 10.5 (default) to 12) -torchrun --standalone --nproc_per_node=8 -m scripts.base_train -- --depth=24 --target-param-data-ratio=12 --device-batch-size=16 --run=$WANDB_RUN +torchrun --standalone --nproc_per_node=8 -m scripts.base_train -- --depth=26 --target-param-data-ratio=8.25 --device-batch-size=16 --fp8 --run=$WANDB_RUN # evaluate the model: CORE metric, BPB on train/val, and draw samples torchrun --standalone --nproc_per_node=8 -m scripts.base_eval -- --device-batch-size=16 diff --git a/scripts/base_train.py b/scripts/base_train.py index f8920ea..f41c3aa 100644 --- a/scripts/base_train.py +++ b/scripts/base_train.py @@ -11,11 +11,14 @@ If you are only on CPU/Macbook, you'll want to train a much much smaller LLM. Ex python -m scripts.base_train --depth=4 --max-seq-len=512 --device-batch-size=1 --eval-tokens=512 --core-metric-every=-1 --total-batch-size=512 --num-iterations=20 """ -import gc import os os.environ["PYTORCH_ALLOC_CONF"] = "expandable_segments:True" -import argparse +import gc +import json import time +import math +import argparse +from dataclasses import asdict from contextlib import nullcontext, contextmanager import wandb @@ -53,18 +56,20 @@ parser.add_argument("--num-iterations", type=int, default=-1, help="explicit num parser.add_argument("--target-flops", type=float, default=-1.0, help="calculate num_iterations to reach target_flops (-1 = disable)") parser.add_argument("--target-param-data-ratio", type=float, default=10.5, help="calculate num_iterations to maintain data:param ratio (Chinchilla=20, -1 = disable)") # Optimization -parser.add_argument("--device-batch-size", type=int, default=32, help="per-device batch size") -parser.add_argument("--total-batch-size", type=int, default=524288, help="total batch size in tokens") +parser.add_argument("--device-batch-size", type=int, default=32, help="per-device batch size. good number to reduce to 16,8,4,... if you OOM on VRAM.") +parser.add_argument("--total-batch-size", type=int, default=-1, help="total batch size in tokens. decent numbers are e.g. 524288. (-1 = auto-compute optimal)") parser.add_argument("--embedding-lr", type=float, default=0.3, help="learning rate for embedding parameters (Adam)") parser.add_argument("--unembedding-lr", type=float, default=0.004, help="learning rate for unembedding parameters (Adam)") parser.add_argument("--weight-decay", type=float, default=0.2, help="cautious weight decay for the Muon optimizer (for weights)") parser.add_argument("--matrix-lr", type=float, default=0.02, help="learning rate for matrix parameters (Muon/Hyperball)") parser.add_argument("--matrix-optimizer", type=str, default="muon", choices=["muon", "hyperball"], help="optimizer for matrix parameters") parser.add_argument("--scalar-lr", type=float, default=0.5, help="learning rate for scalars (resid_lambdas, x0_lambdas)") +parser.add_argument("--norm-lr", type=float, default=0.1, help="learning rate for norm/gamma parameters") parser.add_argument("--adam-beta1", type=float, default=0.8, help="Adam beta1 for embedding/unembedding") parser.add_argument("--adam-beta2", type=float, default=0.95, help="Adam beta2 for embedding/unembedding") parser.add_argument("--warmup-ratio", type=float, default=0.0, help="ratio of iterations for LR warmup") parser.add_argument("--warmdown-ratio", type=float, default=0.5, help="ratio of iterations for AdamW LR warmdown") +parser.add_argument("--matrix-warmup-ratio", type=float, default=0.0, help="ratio of iterations for Muon/Hyperball LR warmup") parser.add_argument("--matrix-warmdown-ratio", type=float, default=1.0, help="ratio of iterations for Muon/Hyperball LR warmdown") parser.add_argument("--final-lr-frac", type=float, default=0.0, help="final LR as fraction of initial LR") parser.add_argument("--resume-from-step", type=int, default=-1, help="resume training from this step (-1 = disable)") @@ -80,10 +85,8 @@ parser.add_argument("--model-tag", type=str, default=None, help="override model args = parser.parse_args() user_config = vars(args).copy() # for logging # ----------------------------------------------------------------------------- +# Compute init and wandb logging - - -# Compute init device_type = autodetect_device_type() if args.device_type == "" else args.device_type ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type) master_process = ddp_rank == 0 # this process will do logging, checkpointing etc. @@ -99,7 +102,8 @@ else: # wandb logging init use_dummy_wandb = args.run == "dummy" or not master_process -wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat", name=args.run, config=user_config) +wandb_project = os.environ.get("WANDB_PROJECT", "nanochat") +wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project=wandb_project, name=args.run, config=user_config) # Flash Attention status if HAS_FA3: @@ -113,65 +117,39 @@ else: print0("WARNING: Recommend using --window-pattern L for full context attention without alternating sliding window patterns.") print0("!" * 80) -# Tokenizer will be useful for evaluation, also we need the vocab size +# ----------------------------------------------------------------------------- +# Tokenizer will be useful for evaluation and also we need the vocab size to init the model tokenizer = get_tokenizer() token_bytes = get_token_bytes(device=device) vocab_size = tokenizer.get_vocab_size() print0(f"Vocab size: {vocab_size:,}") -# Model kwargs are derived from the desired depth of the model -# We nudge model_dim up to the nearest multiple of head_dim to ensure clean division -# (FA3 requires head_dim divisible by 8, and this guarantees head_dim == args.head_dim exactly) -# (For very small depths, this gives a slight "unfair" advantage to models with odd depths) -num_layers = args.depth -base_dim = args.depth * args.aspect_ratio -model_dim = ((base_dim + args.head_dim - 1) // args.head_dim) * args.head_dim -num_heads = model_dim // args.head_dim -num_kv_heads = num_heads # default is 1:1 GQA (Group Query Attention) ratio (i.e. GQA is disabled) -head_dim = model_dim // num_heads -print0(f"num_layers: {num_layers}") -print0(f"model_dim: {model_dim} (base: {base_dim}, nudge: {model_dim - base_dim:+d})") -print0(f"num_heads: {num_heads}") -print0(f"head_dim: {head_dim}") -print0(f"num_kv_heads: {num_kv_heads}") - -# Optimizer / data / training length related hyperparameters -# figure out the needed gradient accumulation to reach the desired total batch size -tokens_per_fwdbwd = args.device_batch_size * args.max_seq_len # tokens per iteration for a single rank -world_tokens_per_fwdbwd = tokens_per_fwdbwd * ddp_world_size # total tokens per iteration for all ranks -assert args.total_batch_size % world_tokens_per_fwdbwd == 0 -grad_accum_steps = args.total_batch_size // world_tokens_per_fwdbwd -print0(f"Tokens / micro-batch / rank: {args.device_batch_size} x {args.max_seq_len} = {tokens_per_fwdbwd:,}") -print0(f"Tokens / micro-batch: {world_tokens_per_fwdbwd:,}") -print0(f"Total batch size {args.total_batch_size:,} => gradient accumulation steps: {grad_accum_steps}") - -# Batch size scaling for learning rates (hyperparameters were tuned at reference batch size 2^19) -batch_lr_scale = 1.0 -reference_batch_size = 2**19 -batch_ratio = args.total_batch_size / reference_batch_size -if batch_ratio != 1.0: - # SGD: linear scaling with batch size is standard (not used in nanochat) - # AdamW: sqrt scaling is standard - # Muon: sqrt scaling is an assumption - not fully studied, but it's a second-order-ish optimizer - batch_lr_scale = batch_ratio ** 0.5 - print0(f"Scaling LRs by {batch_lr_scale:.4f} for batch size {args.total_batch_size:,} (reference: {reference_batch_size:,})") - -# Weight decay is tuned at d12 and its scaling seems to be \propto 1/channels^2 (or equivalently, \propto 1/depth^2 due to constant aspect ratio) -weight_decay_scaled = args.weight_decay * (12 / args.depth)**2 -if args.depth != 12: - print0(f"Scaling weight decay from {args.weight_decay:.6f} to {weight_decay_scaled:.6f} for depth {args.depth}") - # ----------------------------------------------------------------------------- # Initialize the Model -# Create a new model with random weights -model_config_kwargs = dict(sequence_len=args.max_seq_len, vocab_size=vocab_size, n_layer=num_layers, n_head=num_heads, n_kv_head=num_kv_heads, n_embd=model_dim, window_pattern=args.window_pattern) -with torch.device("meta"): - # All tensors are created as meta tensors (they have shape/dtype but no data) - model_config = GPTConfig(**model_config_kwargs) - model = GPT(model_config) -model.to_empty(device=device) # All tensors get storage on target device but with uninitialized (garbage) data -model.init_weights() # All tensors get initialized +def build_model_meta(depth): + """Build a model on meta device for a given depth (shapes/dtypes only, no data).""" + # Model dim is nudged up to nearest multiple of head_dim for clean division + # (FA3 requires head_dim divisible by 8, and this guarantees head_dim == args.head_dim exactly) + base_dim = depth * args.aspect_ratio + model_dim = ((base_dim + args.head_dim - 1) // args.head_dim) * args.head_dim + num_heads = model_dim // args.head_dim + config = GPTConfig( + sequence_len=args.max_seq_len, vocab_size=vocab_size, + n_layer=depth, n_head=num_heads, n_kv_head=num_heads, n_embd=model_dim, + window_pattern=args.window_pattern, + ) + with torch.device("meta"): + model_meta = GPT(config) + return model_meta + +# Build the model, move to device, init the weights +model = build_model_meta(args.depth) # 1) Build on meta device (only shapes/dtypes, no data) +model_config = model.config +model_config_kwargs = asdict(model_config) +print0(f"Model config:\n{json.dumps(model_config_kwargs, indent=2)}") +model.to_empty(device=device) # 2) All tensors get storage on target device but with uninitialized (garbage) data +model.init_weights() # 3) All tensors get initialized # If we are resuming, overwrite the model parameters with those of the checkpoint base_dir = get_base_dir() @@ -185,48 +163,16 @@ if resuming: del model_data # free up this memory after the copy # ----------------------------------------------------------------------------- -# Determine the length of the training run based on model size - -# Detailed parameter counts -param_counts = model.num_scaling_params() -print0(f"Parameter counts:") -for key, value in param_counts.items(): - print0(f"{key:24s}: {value:,}") -num_params = param_counts['total'] -num_scaling_params = param_counts['transformer_matrices'] + param_counts['lm_head'] # determined to give the cleanest scaling laws, see dev/LOG.md Jan 27, 2026 -num_flops_per_token = model.estimate_flops() -print0(f"Estimated FLOPs per token: {num_flops_per_token:e}") - -# Calculate number of iterations. Either it is given, or from target flops, or from target data:param ratio (in that order) -assert args.num_iterations > 0 or args.target_param_data_ratio > 0 or args.target_flops > 0 -if args.num_iterations > 0: - num_iterations = args.num_iterations - print0(f"Using user-provided number of iterations: {num_iterations:,}") -elif args.target_flops > 0: - # calculate the number of iterations from the target flops - num_iterations = round(args.target_flops / (num_flops_per_token * args.total_batch_size)) - print0(f"Calculated number of iterations from target FLOPs: {num_iterations:,}") -elif args.target_param_data_ratio > 0: - # calculate the number of iterations from the target param data ratio (use scaling params per Kaplan et al.) - target_tokens = int(args.target_param_data_ratio * num_scaling_params) - num_iterations = target_tokens // args.total_batch_size - print0(f"Calculated number of iterations from target data:param ratio: {num_iterations:,}") -else: - raise ValueError("No training horizon specified") -total_tokens = args.total_batch_size * num_iterations -print0(f"Total number of training tokens: {total_tokens:,}") -print0(f"Tokens : Scaling params ratio: {args.total_batch_size * num_iterations / num_scaling_params:.2f}") # Chinchilla is ~20 -print0(f"Total training FLOPs estimate: {num_flops_per_token * total_tokens:e}") - -# ----------------------------------------------------------------------------- -# FP8 training initialization and management (has to be done before torch.compile) +# FP8 training initialization and management (this has to be done before torch.compile) # Convert Linear layers to Float8Linear if --fp8 is set if args.fp8: if device_type != "cuda": print0("Warning: FP8 training requires CUDA, ignoring --fp8 flag") else: - from torchao.float8 import Float8LinearConfig, convert_to_float8_training + # our custom fp8 is simpler than torchao, written for exact API compatibility + from nanochat.fp8 import Float8LinearConfig, convert_to_float8_training + # from torchao.float8 import Float8LinearConfig, convert_to_float8_training import torch.nn as nn # Filter: only convert layers with dimensions divisible by 16 (FP8 hardware requirement) @@ -297,25 +243,83 @@ def disable_fp8(model): orig_model = model # original, uncompiled model, for saving raw model state_dict and for inference/evaluation (because the shapes may change shape) model = torch.compile(model, dynamic=False) # the inputs to model will never change shape so dynamic=False is safe +# ----------------------------------------------------------------------------- +# Scaling laws and muP extrapolations to determine the optimal training horizon, batch size, learning rates, weight decay. + +# Get the parameter counts of our model +param_counts = model.num_scaling_params() +print0(f"Parameter counts:") +for key, value in param_counts.items(): + print0(f"{key:24s}: {value:,}") +num_params = param_counts['total'] +num_flops_per_token = model.estimate_flops() +print0(f"Estimated FLOPs per token: {num_flops_per_token:e}") + +# 1) Use scaling laws to determine the optimal training horizon in tokens +# The compute-optimal models satisfy the Tokens:Params ratio of --target-param-data-ratio (derived experimentally via scaling laws analysis). +# We've already initialized the model so we have Params. Optimal Tokens is now simply target-param-data-ratio * Params +def get_scaling_params(m): + # As for which params to use exactly, transformer matrices + lm_head gives cleanest scaling laws (see dev/LOG.md Jan 27, 2026) + params_counts = m.num_scaling_params() + scaling_params = params_counts['transformer_matrices'] + params_counts['lm_head'] + return scaling_params +num_scaling_params = get_scaling_params(model) +target_tokens = int(args.target_param_data_ratio * num_scaling_params) # optimal tokens for the model we are about to train + +# Our reference model is d12, this is where a lot of hyperparameters are tuned and then transfered to higher depths (muP style) +d12_ref = build_model_meta(12) # creates the model on meta device +D_REF = args.target_param_data_ratio * get_scaling_params(d12_ref) # compute-optimal d12 training horizon in tokens (measured empirically) Shall we use a constant ratio for computing D_REF? +B_REF = 2**19 # optimal batch size at d12 ~= 524,288 tokens (measured empirically) + +# 2) Now that we have the token horizon, we can calculate the optimal batch size +# We follow the Power Lines paper (Bopt ∝ D^0.383), ref: https://arxiv.org/abs/2505.13738 +# The optimal batch size grows as approximately D^0.383, so e.g. if D doubles from d12 to d24, B should grow by 2^0.383 ≈ 1.3x. +total_batch_size = args.total_batch_size # user-provided override is possible +if total_batch_size == -1: + batch_size_ratio = target_tokens / D_REF + predicted_batch_size = B_REF * batch_size_ratio ** 0.383 + total_batch_size = 2 ** round(math.log2(predicted_batch_size)) # clamp to nearest power of 2 for efficiency + print0(f"Auto-computed optimal batch size: {total_batch_size:,} tokens") + +# 3) Knowing the batch size, we can now calculate a learning rate correction (bigger batch size allows higher learning rates) +batch_lr_scale = 1.0 +batch_ratio = total_batch_size / B_REF # B/B_ref +if batch_ratio != 1.0: + # SGD: linear scaling with batch size is standard (not used in nanochat) + # AdamW: sqrt scaling is standard: η ∝ √(B/B_ref) + # Muon: we will use the same scaling for Muon as for AdamW: η ∝ √(B/B_ref) (not studied carefully, assumption!) + batch_lr_scale = batch_ratio ** 0.5 # η ∝ √(B/B_ref) + print0(f"Scaling LRs by {batch_lr_scale:.4f} for batch size {total_batch_size:,} (reference: {B_REF:,})") + +# 4) Knowing the batch size and the token horizon, we can now calculate the appropriate weight decay scaling +# We adopt the T_epoch framework from https://arxiv.org/abs/2405.13698 +# Central idea of the paper is that T_epoch = B/(η·λ·D) should remain constant. +# Above, we used learning rate scaling η ∝ √(B/B_ref). So it's a matter of ~10 lines of math to derive that to keep T_epoch constant, we need: +# λ = λ_ref · √(B/B_ref) · (D_ref/D) +# Note that these papers study AdamW, *not* Muon. We are blindly following AdamW theory for scaling hoping it ~works for Muon too. +weight_decay_scaled = args.weight_decay * math.sqrt(total_batch_size / B_REF) * (D_REF / target_tokens) +if weight_decay_scaled != args.weight_decay: + print0(f"Scaling weight decay from {args.weight_decay:.6f} to {weight_decay_scaled:.6f} for depth {args.depth}") + # ----------------------------------------------------------------------------- # Initialize the Optimizer (combined MuonAdamW: Muon for matrix params, AdamW for rest) -adam_betas = (args.adam_beta1, args.adam_beta2) matrix_lr_scaled = args.matrix_lr * batch_lr_scale -# LR depth scaling for Hyperball +# LR data scaling for Hyperball +# We keep the same D_REF here if args.matrix_optimizer == "hyperball": - hyperball_depth_scale = 12 / args.depth - matrix_lr_scaled = matrix_lr_scaled * hyperball_depth_scale - if args.depth != 12: - print0(f"Scaling hyperball LR from {args.matrix_lr * batch_lr_scale:.6f} to {matrix_lr_scaled:.6f} for depth {args.depth}") + D_REF_LR = 10.5 * get_scaling_params(d12_ref) + matrix_lr_scaled = matrix_lr_scaled * (D_REF_LR / target_tokens) ** 0.35 # 0.35 is the exponent for the power law fit by ourselves + print0(f"Scaling hyperball LR from {args.matrix_lr * batch_lr_scale:.6f} to {matrix_lr_scaled:.6f} for token ratio {target_tokens / D_REF:.2f} (T_train = {target_tokens:,} tokens)") optimizer = model.setup_optimizer( unembedding_lr=args.unembedding_lr * batch_lr_scale, embedding_lr=args.embedding_lr * batch_lr_scale, matrix_lr=matrix_lr_scaled, weight_decay=weight_decay_scaled, - adam_betas=adam_betas, + adam_betas=(args.adam_beta1, args.adam_beta2), scalar_lr=args.scalar_lr * batch_lr_scale, + norm_lr=args.norm_lr * batch_lr_scale, matrix_optimizer=args.matrix_optimizer, ) @@ -331,9 +335,30 @@ build_val_loader = lambda: tokenizing_distributed_data_loader_bos_bestfit(tokeni x, y, dataloader_state_dict = next(train_loader) # kick off load of the very first batch of data # ----------------------------------------------------------------------------- -# Set up hyperparameter schedulers +# Calculate the number of iterations we will train for and set up the various schedulers -# Learning rate scheduler (warmup + warmdown) +# num_iterations: either it is given, or from target flops, or from target data:param ratio (in that order) +assert args.num_iterations > 0 or args.target_param_data_ratio > 0 or args.target_flops > 0 +if args.num_iterations > 0: + # Override num_iterations to a specific value if given + num_iterations = args.num_iterations + print0(f"Using user-provided number of iterations: {num_iterations:,}") +elif args.target_flops > 0: + # Calculate the number of iterations from the target flops (used in scaling laws analysis, e.g. runs/scaling_laws.sh) + num_iterations = round(args.target_flops / (num_flops_per_token * total_batch_size)) + print0(f"Calculated number of iterations from target FLOPs: {num_iterations:,}") +elif args.target_param_data_ratio > 0: + # Calculate the number of iterations from the target param data ratio (the most common use case) + num_iterations = target_tokens // total_batch_size + print0(f"Calculated number of iterations from target data:param ratio: {num_iterations:,}") +else: + raise ValueError("No training horizon specified") +total_tokens = total_batch_size * num_iterations # the actual number of tokens we will train for +print0(f"Total number of training tokens: {total_tokens:,}") +print0(f"Tokens : Scaling params ratio: {total_batch_size * num_iterations / num_scaling_params:.2f}") # e.g. Chinchilla was ~20 +print0(f"Total training FLOPs estimate: {num_flops_per_token * total_tokens:e}") + +# Learning rate scheduler (warmup + warmdown, parameterized for separate adam/matrix schedules) def get_lr_multiplier(it, warmup_ratio, warmdown_ratio, final_lr_frac): warmup_iters = round(warmup_ratio * num_iterations) warmdown_iters = round(warmdown_ratio * num_iterations) @@ -352,13 +377,14 @@ def get_muon_momentum(it): momentum = (1 - frac) * 0.85 + frac * 0.95 return momentum -# Weight decay scheduler for Muon optimizer (linear to zero over the course of training) +# Weight decay scheduler for Muon optimizer (linearly decays to zero over the course of training) def get_weight_decay(it): return weight_decay_scaled * (1 - it / num_iterations) # ----------------------------------------------------------------------------- -# Loop state (variables updated by the training loop) +# Training loop +# Loop state (variables updated by the training loop) if not resuming: step = 0 val_bpb = None # will be set if eval_every > 0 @@ -373,11 +399,19 @@ else: smooth_train_loss = loop_state["smooth_train_loss"] total_training_time = loop_state["total_training_time"] -# ----------------------------------------------------------------------------- -# Training loop +# Figure out the needed gradient accumulation micro-steps to reach the desired total batch size per step +tokens_per_fwdbwd = args.device_batch_size * args.max_seq_len # tokens per iteration for a single rank +world_tokens_per_fwdbwd = tokens_per_fwdbwd * ddp_world_size # total tokens per iteration for all ranks +assert total_batch_size % world_tokens_per_fwdbwd == 0 +grad_accum_steps = total_batch_size // world_tokens_per_fwdbwd +print0(f"Tokens / micro-batch / rank: {args.device_batch_size} x {args.max_seq_len} = {tokens_per_fwdbwd:,}") +print0(f"Tokens / micro-batch: {world_tokens_per_fwdbwd:,}") +print0(f"Total batch size {total_batch_size:,} => gradient accumulation steps: {grad_accum_steps}") + +# Go! while True: last_step = step == num_iterations # loop runs num_iterations+1 times so that we can eval/save at the end - flops_so_far = num_flops_per_token * args.total_batch_size * step + flops_so_far = num_flops_per_token * total_batch_size * step # once in a while: evaluate the val bpb (all ranks participate) if args.eval_every > 0 and (last_step or step % args.eval_every == 0): @@ -477,7 +511,7 @@ while True: x, y, dataloader_state_dict = next(train_loader) # prefetch the next batch while the GPU is busy with forward/backward # step the optimizer lrm_adam = get_lr_multiplier(step, args.warmup_ratio, args.warmdown_ratio, args.final_lr_frac) - lrm_matrix = get_lr_multiplier(step, 0.0, args.matrix_warmdown_ratio, args.final_lr_frac) + lrm_matrix = get_lr_multiplier(step, args.matrix_warmup_ratio, args.matrix_warmdown_ratio, args.final_lr_frac) muon_momentum = get_muon_momentum(step) muon_weight_decay = get_weight_decay(step) for group in optimizer.param_groups: @@ -502,8 +536,8 @@ while True: smooth_train_loss = ema_beta * smooth_train_loss + (1 - ema_beta) * train_loss_f # EMA the training loss debiased_smooth_loss = smooth_train_loss / (1 - ema_beta**(step + 1)) # debias the EMA pct_done = 100 * step / num_iterations - tok_per_sec = int(args.total_batch_size / dt) - flops_per_sec = num_flops_per_token * args.total_batch_size / dt + tok_per_sec = int(total_batch_size / dt) + flops_per_sec = num_flops_per_token * total_batch_size / dt mfu = 100 * flops_per_sec / (gpu_peak_flops * ddp_world_size) if step > 10: total_training_time += dt # only count the time after the first 10 steps @@ -517,7 +551,7 @@ while True: else: eta_str = "" epoch = dataloader_state_dict["epoch"] - print0(f"step {step:05d}/{num_iterations:05d} ({pct_done:.2f}%) | loss: {debiased_smooth_loss:.6f} | lrm(adam)={lrm_adam:.2f}, lrm(matrix)={lrm_matrix:.2f} | dt: {dt * 1000:.2f}ms | tok/sec: {tok_per_sec:,} | mfu: {mfu:.2f} | epoch: {epoch} | total time: {total_training_time/60:.2f}m{eta_str}") + print0(f"step {step:05d}/{num_iterations:05d} ({pct_done:.2f}%) | loss: {debiased_smooth_loss:.6f} | lrm(adam)={lrm_adam:.2f}, lrm(matrix)={lrm_matrix:.2f} | dt: {dt * 1000:.2f}ms | tok/sec: {tok_per_sec:,} | bf16_mfu: {mfu:.2f} | epoch: {epoch} | total time: {total_training_time/60:.2f}m{eta_str}") if step % 100 == 0: log_data = { "step": step, @@ -562,10 +596,11 @@ get_report().log(section="Base model training", data=[ "Number of FLOPs per token": f"{num_flops_per_token:e}", "Calculated number of iterations": num_iterations, "Number of training tokens": total_tokens, - "Tokens : Scaling params ratio": args.total_batch_size * num_iterations / num_scaling_params, + "Tokens : Scaling params ratio": total_batch_size * num_iterations / num_scaling_params, "DDP world size": ddp_world_size, "warmup_ratio": args.warmup_ratio, "warmdown_ratio": args.warmdown_ratio, + "matrix_warmup_ratio": args.matrix_warmup_ratio, "matrix_warmdown_ratio": args.matrix_warmdown_ratio, "final_lr_frac": args.final_lr_frac, }, diff --git a/uv.lock b/uv.lock index e5fc97f..bbc9519 100644 --- a/uv.lock +++ b/uv.lock @@ -1509,7 +1509,6 @@ dependencies = [ { name = "torch", version = "2.9.1", source = { registry = "https://pypi.org/simple" }, marker = "(extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu') or (extra != 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu')" }, { name = "torch", version = "2.9.1+cpu", source = { registry = "https://download.pytorch.org/whl/cpu" }, marker = "(sys_platform != 'darwin' and extra == 'extra-8-nanochat-cpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" }, { name = "torch", version = "2.9.1+cu128", source = { registry = "https://download.pytorch.org/whl/cu128" }, marker = "extra == 'extra-8-nanochat-gpu'" }, - { name = "torchao" }, { name = "transformers" }, { name = "uvicorn" }, { name = "wandb" }, @@ -1549,7 +1548,6 @@ requires-dist = [ { name = "torch", specifier = "==2.9.1" }, { name = "torch", marker = "extra == 'cpu'", specifier = "==2.9.1", index = "https://download.pytorch.org/whl/cpu", conflict = { package = "nanochat", extra = "cpu" } }, { name = "torch", marker = "extra == 'gpu'", specifier = "==2.9.1", index = "https://download.pytorch.org/whl/cu128", conflict = { package = "nanochat", extra = "gpu" } }, - { name = "torchao", specifier = "==0.15.0" }, { name = "transformers", specifier = ">=4.57.3" }, { name = "uvicorn", specifier = ">=0.36.0" }, { name = "wandb", specifier = ">=0.21.3" }, @@ -3184,15 +3182,6 @@ wheels = [ { url = "https://download.pytorch.org/whl/cu128/torch-2.9.1%2Bcu128-cp314-cp314t-win_amd64.whl", hash = "sha256:0c784b600959ec70ee01cb23e8bc870a0e0475af30378ff5e39f4abed8b7c1cc" }, ] -[[package]] -name = "torchao" -version = "0.15.0" -source = { registry = "https://pypi.org/simple" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/57/2d/472b9362dceae05a4599e2b94f86e69a29c0e20964a6af84f34f6ead5938/torchao-0.15.0-cp310-abi3-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:1cbe813201314ba6329a650a76944502f3e8ec4b1b44523f3f48676810d8d1f6", size = 7163930, upload-time = "2025-12-18T23:14:41.876Z" }, - { url = "https://files.pythonhosted.org/packages/f6/3b/6b9d5618720f63dbc2e2509cd6b57aae9c0d61b738d1d2172f4d5d9efaab/torchao-0.15.0-py3-none-any.whl", hash = "sha256:3f3812676048ef8a2a0e9d492d12d8971ba7a7ebb16f54aa56f690414e130d2c", size = 1080679, upload-time = "2025-12-18T23:14:43.807Z" }, -] - [[package]] name = "tornado" version = "6.5.4" From 31e5bec402a7f074a24a602cffc62a1954ed370a Mon Sep 17 00:00:00 2001 From: Kaiyue Wen Date: Thu, 12 Feb 2026 16:25:52 -0800 Subject: [PATCH 09/14] Replace torchao with custom fp8 module in gpt.py - Update reparam_linear to use nanochat.fp8.Float8Linear instead of torchao - Replace matmul_with_hp_or_float8_args with direct _Float8Matmul.apply call - Remove torchao dependency mention from base_train.py help text - Functionally equivalent: both use torch._scaled_mm, custom version ~3% faster Co-Authored-By: Claude Sonnet 4.5 (1M context) --- nanochat/gpt.py | 19 +++++++++++++++---- scripts/base_train.py | 2 +- 2 files changed, 16 insertions(+), 5 deletions(-) diff --git a/nanochat/gpt.py b/nanochat/gpt.py index 8779a85..1e56c10 100644 --- a/nanochat/gpt.py +++ b/nanochat/gpt.py @@ -29,9 +29,10 @@ from nanochat.flash_attention import flash_attn # FP8 imports (optional) — needed by reparam_linear for FP8 path try: - from torchao.float8.float8_linear import matmul_with_hp_or_float8_args + from nanochat.fp8 import Float8Linear, _Float8Matmul except ImportError: - pass + Float8Linear = None + _Float8Matmul = None @dataclass class GPTConfig: @@ -66,8 +67,18 @@ def reparam_linear(module, x, gamma=None, scalar=None): if scalar is not None: w = scalar[:, None] * w # FP8 path: use Float8Linear's internal matmul to preserve FP8 tensor cores - if hasattr(module, 'linear_mm_config'): - return matmul_with_hp_or_float8_args.apply(x, w.t(), module.linear_mm_config, module.config) + if Float8Linear is not None and isinstance(module, Float8Linear): + # Handle autocast similar to Float8Linear.forward + if torch.is_autocast_enabled(): + x = x.to(torch.get_autocast_gpu_dtype()) + # Flatten batch dimensions for _Float8Matmul + orig_shape = x.shape + input_2d = x.reshape(-1, orig_shape[-1]) + output = _Float8Matmul.apply(input_2d, w) + output = output.reshape(*orig_shape[:-1], output.shape[-1]) + if module.bias is not None: + output = output + module.bias.to(output.dtype) + return output # BF16 path return F.linear(x, w) diff --git a/scripts/base_train.py b/scripts/base_train.py index f41c3aa..b05e889 100644 --- a/scripts/base_train.py +++ b/scripts/base_train.py @@ -43,7 +43,7 @@ parser.add_argument("--run", type=str, default="dummy", help="wandb run name ('d # Runtime parser.add_argument("--device-type", type=str, default="", help="cuda|cpu|mps (empty = autodetect)") # FP8 training -parser.add_argument("--fp8", action="store_true", help="enable FP8 training (requires H100+ GPU and torchao)") +parser.add_argument("--fp8", action="store_true", help="enable FP8 training (requires H100+ GPU)") parser.add_argument("--fp8-recipe", type=str, default="tensorwise", choices=["rowwise", "tensorwise"], help="FP8 scaling recipe: tensorwise (faster, recommended) or rowwise (more accurate but slower)") # Model architecture parser.add_argument("--depth", type=int, default=20, help="depth of the Transformer model") From 29487517edb40ea670a89e3a194e37ff27fdb0d6 Mon Sep 17 00:00:00 2001 From: Kaiyue Wen Date: Thu, 12 Feb 2026 16:58:05 -0800 Subject: [PATCH 10/14] Revert to torchao for FP8 training to fix MFU regression MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The custom fp8 module had a performance issue in reparam_linear: it was doing reshape→matmul→reshape on every linear layer, and torch.compile couldn't fuse these operations because _Float8Matmul was marked @allow_in_graph (opaque to compiler). torchao's matmul_with_hp_or_float8_args handles N-D tensors directly without external reshaping, allowing better fusion opportunities and higher MFU. Co-Authored-By: Claude Sonnet 4.5 (1M context) --- nanochat/gpt.py | 19 ++++--------------- pyproject.toml | 1 + scripts/base_train.py | 2 +- 3 files changed, 6 insertions(+), 16 deletions(-) diff --git a/nanochat/gpt.py b/nanochat/gpt.py index 1e56c10..8779a85 100644 --- a/nanochat/gpt.py +++ b/nanochat/gpt.py @@ -29,10 +29,9 @@ from nanochat.flash_attention import flash_attn # FP8 imports (optional) — needed by reparam_linear for FP8 path try: - from nanochat.fp8 import Float8Linear, _Float8Matmul + from torchao.float8.float8_linear import matmul_with_hp_or_float8_args except ImportError: - Float8Linear = None - _Float8Matmul = None + pass @dataclass class GPTConfig: @@ -67,18 +66,8 @@ def reparam_linear(module, x, gamma=None, scalar=None): if scalar is not None: w = scalar[:, None] * w # FP8 path: use Float8Linear's internal matmul to preserve FP8 tensor cores - if Float8Linear is not None and isinstance(module, Float8Linear): - # Handle autocast similar to Float8Linear.forward - if torch.is_autocast_enabled(): - x = x.to(torch.get_autocast_gpu_dtype()) - # Flatten batch dimensions for _Float8Matmul - orig_shape = x.shape - input_2d = x.reshape(-1, orig_shape[-1]) - output = _Float8Matmul.apply(input_2d, w) - output = output.reshape(*orig_shape[:-1], output.shape[-1]) - if module.bias is not None: - output = output + module.bias.to(output.dtype) - return output + if hasattr(module, 'linear_mm_config'): + return matmul_with_hp_or_float8_args.apply(x, w.t(), module.linear_mm_config, module.config) # BF16 path return F.linear(x, w) diff --git a/pyproject.toml b/pyproject.toml index 8b6fd95..13e293a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,6 +19,7 @@ dependencies = [ "tabulate>=0.9.0", "tiktoken>=0.11.0", "tokenizers>=0.22.0", + "torchao>=0.13.0", "torch==2.9.1", "transformers>=4.57.3", "uvicorn>=0.36.0", diff --git a/scripts/base_train.py b/scripts/base_train.py index b05e889..f41c3aa 100644 --- a/scripts/base_train.py +++ b/scripts/base_train.py @@ -43,7 +43,7 @@ parser.add_argument("--run", type=str, default="dummy", help="wandb run name ('d # Runtime parser.add_argument("--device-type", type=str, default="", help="cuda|cpu|mps (empty = autodetect)") # FP8 training -parser.add_argument("--fp8", action="store_true", help="enable FP8 training (requires H100+ GPU)") +parser.add_argument("--fp8", action="store_true", help="enable FP8 training (requires H100+ GPU and torchao)") parser.add_argument("--fp8-recipe", type=str, default="tensorwise", choices=["rowwise", "tensorwise"], help="FP8 scaling recipe: tensorwise (faster, recommended) or rowwise (more accurate but slower)") # Model architecture parser.add_argument("--depth", type=int, default=20, help="depth of the Transformer model") From 931d59c5153b4c186b15a5ddcc32ebaf7775e8a5 Mon Sep 17 00:00:00 2001 From: Kaiyue Wen Date: Thu, 12 Feb 2026 16:59:52 -0800 Subject: [PATCH 11/14] Use hybrid FP8 approach: torchao for reparam_linear, custom fp8 for layers - reparam_linear: uses torchao for efficient N-D tensor handling without reshaping - Float8Linear layers: uses custom fp8 module (simpler, same performance) - This gives us the best of both: high MFU and minimal dependencies Co-Authored-By: Claude Sonnet 4.5 (1M context) --- nanochat/gpt.py | 18 +++++++++++++----- scripts/base_train.py | 2 +- 2 files changed, 14 insertions(+), 6 deletions(-) diff --git a/nanochat/gpt.py b/nanochat/gpt.py index 8779a85..14e7c0e 100644 --- a/nanochat/gpt.py +++ b/nanochat/gpt.py @@ -27,11 +27,18 @@ from nanochat.optim import MuonAdamW, DistMuonAdamW # Our custom Flash Attention module that automatically uses FA3 on Hopper+ and SDPA fallback elsewhere from nanochat.flash_attention import flash_attn -# FP8 imports (optional) — needed by reparam_linear for FP8 path +# FP8 imports (optional) +# torchao: used in reparam_linear for efficient N-D tensor handling +# custom fp8: used for regular Float8Linear layers (simpler, same performance) try: from torchao.float8.float8_linear import matmul_with_hp_or_float8_args except ImportError: - pass + matmul_with_hp_or_float8_args = None + +try: + from nanochat.fp8 import Float8Linear +except ImportError: + Float8Linear = None @dataclass class GPTConfig: @@ -58,15 +65,16 @@ def reparam_linear(module, x, gamma=None, scalar=None): gamma: RMSNorm learnable weight, folded into input dim of W (w = w * gamma[None, :]) scalar: projection scalar, folded into output dim of W (w = scalar[:, None] * w) - For FP8, dispatches through Float8Linear's internal matmul to preserve FP8 tensor cores. + For FP8, uses torchao's matmul which handles N-D tensors efficiently without reshaping. """ w = module.weight if gamma is not None: w = w * gamma[None, :] if scalar is not None: w = scalar[:, None] * w - # FP8 path: use Float8Linear's internal matmul to preserve FP8 tensor cores - if hasattr(module, 'linear_mm_config'): + # FP8 path: use torchao's matmul for efficient N-D tensor handling + # (torchao handles arbitrary shapes without external reshaping overhead) + if hasattr(module, 'linear_mm_config') and matmul_with_hp_or_float8_args is not None: return matmul_with_hp_or_float8_args.apply(x, w.t(), module.linear_mm_config, module.config) # BF16 path return F.linear(x, w) diff --git a/scripts/base_train.py b/scripts/base_train.py index f41c3aa..432c998 100644 --- a/scripts/base_train.py +++ b/scripts/base_train.py @@ -43,7 +43,7 @@ parser.add_argument("--run", type=str, default="dummy", help="wandb run name ('d # Runtime parser.add_argument("--device-type", type=str, default="", help="cuda|cpu|mps (empty = autodetect)") # FP8 training -parser.add_argument("--fp8", action="store_true", help="enable FP8 training (requires H100+ GPU and torchao)") +parser.add_argument("--fp8", action="store_true", help="enable FP8 training (requires H100+ GPU and torchao); uses custom fp8 module for layers, torchao for reparam_linear") parser.add_argument("--fp8-recipe", type=str, default="tensorwise", choices=["rowwise", "tensorwise"], help="FP8 scaling recipe: tensorwise (faster, recommended) or rowwise (more accurate but slower)") # Model architecture parser.add_argument("--depth", type=int, default=20, help="depth of the Transformer model") From fe2a80badd0bd2bbee8dd0468146b49ed0b8a958 Mon Sep 17 00:00:00 2001 From: Kaiyue Wen Date: Thu, 12 Feb 2026 17:05:06 -0800 Subject: [PATCH 12/14] Replace torchao with minimal custom FP8 implementation Added _Float8MatmulND to fp8.py: - Handles N-D input tensors efficiently - Does reshaping internally (opaque to torch.compile) - Prevents external reshape overhead that was causing MFU regression - ~75 lines of clean, documented code Benefits: - No torchao dependency (removed from pyproject.toml) - Same performance as torchao for reparam_linear - Consistent with fp8.py's minimal philosophy (~350 total lines) - All FP8 logic in one self-contained module Co-Authored-By: Claude Sonnet 4.5 (1M context) --- nanochat/fp8.py | 75 +++++++++++++++++++++++++++++++++++++++++++ nanochat/gpt.py | 28 ++++++++-------- pyproject.toml | 1 - scripts/base_train.py | 2 +- 4 files changed, 90 insertions(+), 16 deletions(-) diff --git a/nanochat/fp8.py b/nanochat/fp8.py index 9d9e9c3..3f056d1 100644 --- a/nanochat/fp8.py +++ b/nanochat/fp8.py @@ -196,6 +196,81 @@ class _Float8Matmul(torch.autograd.Function): return grad_input, grad_weight +@torch._dynamo.allow_in_graph +class _Float8MatmulND(torch.autograd.Function): + """FP8 matmul that handles N-D input tensors. + + Same as _Float8Matmul but accepts inputs of any shape (not just 2D). + Reshaping is done internally so torch.compile sees this as one opaque node, + preventing the reshaping overhead that occurs when reshapes are external. + + This is specifically for reparam_linear where N-D tensors are common. + """ + + @staticmethod + def forward(ctx, input, weight): + # Save original shape and flatten batch dimensions + orig_shape = input.shape + ctx.orig_shape = orig_shape + input_2d = input.reshape(-1, orig_shape[-1]) + ctx.save_for_backward(input_2d, weight) + + # Quantize and matmul (same as _Float8Matmul.forward) + input_fp8, input_inv = _to_fp8(input_2d, torch.float8_e4m3fn) + weight_fp8, weight_inv = _to_fp8(weight, torch.float8_e4m3fn) + output = torch._scaled_mm( + input_fp8, + weight_fp8.t(), + scale_a=input_inv, + scale_b=weight_inv, + out_dtype=input.dtype, + use_fast_accum=True, + ) + + # Reshape back to original batch dims + output = output.reshape(*orig_shape[:-1], output.shape[-1]) + return output + + @staticmethod + def backward(ctx, grad_output): + input_2d, weight = ctx.saved_tensors + orig_shape = ctx.orig_shape + + # Flatten grad_output to match input_2d + grad_output_flat = grad_output.reshape(-1, grad_output.shape[-1]) + + # === GEMM 1: grad_input = grad_output @ weight === + go_fp8, go_inv = _to_fp8(grad_output_flat, torch.float8_e5m2) + w_fp8, w_inv = _to_fp8(weight, torch.float8_e4m3fn) + w_col = _to_col_major(w_fp8) + grad_input_flat = torch._scaled_mm( + go_fp8, + w_col, + scale_a=go_inv, + scale_b=w_inv, + out_dtype=grad_output.dtype, + use_fast_accum=False, + ) + # Reshape back to original input shape + grad_input = grad_input_flat.reshape(orig_shape) + + # === GEMM 2: grad_weight = grad_output.T @ input === + go_fp8_2, go_inv_2 = _to_fp8(grad_output_flat, torch.float8_e5m2) + in_fp8, in_inv = _to_fp8(input_2d, torch.float8_e4m3fn) + go_T = go_fp8_2.t().contiguous() + in_col = _to_col_major(in_fp8) + grad_weight = torch._scaled_mm( + go_T, + in_col, + scale_a=go_inv_2, + scale_b=in_inv, + out_dtype=grad_output.dtype, + use_fast_accum=False, + ) + + return grad_input, grad_weight + + class Float8Linear(nn.Linear): """Drop-in nn.Linear replacement that does FP8 compute. diff --git a/nanochat/gpt.py b/nanochat/gpt.py index 14e7c0e..2e29943 100644 --- a/nanochat/gpt.py +++ b/nanochat/gpt.py @@ -27,18 +27,12 @@ from nanochat.optim import MuonAdamW, DistMuonAdamW # Our custom Flash Attention module that automatically uses FA3 on Hopper+ and SDPA fallback elsewhere from nanochat.flash_attention import flash_attn -# FP8 imports (optional) -# torchao: used in reparam_linear for efficient N-D tensor handling -# custom fp8: used for regular Float8Linear layers (simpler, same performance) +# FP8 imports (optional) - minimal custom implementation try: - from torchao.float8.float8_linear import matmul_with_hp_or_float8_args -except ImportError: - matmul_with_hp_or_float8_args = None - -try: - from nanochat.fp8 import Float8Linear + from nanochat.fp8 import Float8Linear, _Float8MatmulND except ImportError: Float8Linear = None + _Float8MatmulND = None @dataclass class GPTConfig: @@ -65,17 +59,23 @@ def reparam_linear(module, x, gamma=None, scalar=None): gamma: RMSNorm learnable weight, folded into input dim of W (w = w * gamma[None, :]) scalar: projection scalar, folded into output dim of W (w = scalar[:, None] * w) - For FP8, uses torchao's matmul which handles N-D tensors efficiently without reshaping. + For FP8, uses minimal custom _Float8MatmulND which handles N-D tensors internally. """ w = module.weight if gamma is not None: w = w * gamma[None, :] if scalar is not None: w = scalar[:, None] * w - # FP8 path: use torchao's matmul for efficient N-D tensor handling - # (torchao handles arbitrary shapes without external reshaping overhead) - if hasattr(module, 'linear_mm_config') and matmul_with_hp_or_float8_args is not None: - return matmul_with_hp_or_float8_args.apply(x, w.t(), module.linear_mm_config, module.config) + # FP8 path: use custom _Float8MatmulND for efficient N-D tensor handling + # (reshaping is done internally, so torch.compile sees it as one opaque operation) + if Float8Linear is not None and isinstance(module, Float8Linear): + # Handle autocast (Float8Linear expects this) + if torch.is_autocast_enabled(): + x = x.to(torch.get_autocast_gpu_dtype()) + output = _Float8MatmulND.apply(x, w) + if module.bias is not None: + output = output + module.bias.to(output.dtype) + return output # BF16 path return F.linear(x, w) diff --git a/pyproject.toml b/pyproject.toml index 13e293a..8b6fd95 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,7 +19,6 @@ dependencies = [ "tabulate>=0.9.0", "tiktoken>=0.11.0", "tokenizers>=0.22.0", - "torchao>=0.13.0", "torch==2.9.1", "transformers>=4.57.3", "uvicorn>=0.36.0", diff --git a/scripts/base_train.py b/scripts/base_train.py index 432c998..25426f5 100644 --- a/scripts/base_train.py +++ b/scripts/base_train.py @@ -43,7 +43,7 @@ parser.add_argument("--run", type=str, default="dummy", help="wandb run name ('d # Runtime parser.add_argument("--device-type", type=str, default="", help="cuda|cpu|mps (empty = autodetect)") # FP8 training -parser.add_argument("--fp8", action="store_true", help="enable FP8 training (requires H100+ GPU and torchao); uses custom fp8 module for layers, torchao for reparam_linear") +parser.add_argument("--fp8", action="store_true", help="enable FP8 training (requires H100+ GPU); uses minimal custom fp8 module") parser.add_argument("--fp8-recipe", type=str, default="tensorwise", choices=["rowwise", "tensorwise"], help="FP8 scaling recipe: tensorwise (faster, recommended) or rowwise (more accurate but slower)") # Model architecture parser.add_argument("--depth", type=int, default=20, help="depth of the Transformer model") From 5a965c13834429621ca180836f61b8ad86595faf Mon Sep 17 00:00:00 2001 From: Kaiyue Wen Date: Thu, 12 Feb 2026 17:09:19 -0800 Subject: [PATCH 13/14] Remove runs/scaling_laws_muonh.sh Co-Authored-By: Claude Sonnet 4.5 (1M context) --- runs/scaling_laws_muonh.sh | 220 ------------------------------------- 1 file changed, 220 deletions(-) delete mode 100755 runs/scaling_laws_muonh.sh diff --git a/runs/scaling_laws_muonh.sh b/runs/scaling_laws_muonh.sh deleted file mode 100755 index a5af017..0000000 --- a/runs/scaling_laws_muonh.sh +++ /dev/null @@ -1,220 +0,0 @@ -#!/bin/bash - -# Scaling Laws Sweep for GPT-Gamma + MuonH (Hyperball) -# Runs IsoFLOP analysis: for each compute budget, sweep model depths to find optimal size. -# Results saved to CSV for analysis with dev/scaling_analysis.ipynb -# -# Usage: -# bash runs/scaling_laws_muonh.sh -# LABEL=feb06 bash runs/scaling_laws_muonh.sh -# FP8=0 bash runs/scaling_laws_muonh.sh - -set -e - -LABEL="${LABEL:-muonh_$(date +%b%d | tr '[:upper:]' '[:lower:]')}" - -FLOPS_BUDGETS=( - 1e18 - 2.15e18 - 4.64e18 - 1e19 -) -DEPTHS=(8 10 12 14 16 18 20) - -NPROC_PER_NODE="${NPROC_PER_NODE:-$(nvidia-smi -L 2>/dev/null | wc -l || echo 1)}" -if [ "$NPROC_PER_NODE" -eq 0 ]; then - NPROC_PER_NODE=1 -fi - -# Fixed batch size (auto batch size requires target-param-data-ratio, not compatible with target-flops) -TOTAL_BATCH_SIZE="${TOTAL_BATCH_SIZE:-524288}" -DEVICE_BATCH_SIZE="${DEVICE_BATCH_SIZE:-16}" -EVAL_TOKENS=$((100 * 524288)) # ~100M tokens for final eval - -# Optimizer (MuonH defaults) -MATRIX_OPTIMIZER="${MATRIX_OPTIMIZER:-hyperball}" -MATRIX_LR="${MATRIX_LR:-0.02}" -EMBEDDING_LR="${EMBEDDING_LR:-0.3}" -UNEMBEDDING_LR="${UNEMBEDDING_LR:-0.004}" -SCALAR_LR="${SCALAR_LR:-0.5}" -NORM_LR="${NORM_LR:-0.2}" -WARMDOWN_RATIO="${WARMDOWN_RATIO:-0.3}" -MATRIX_WARMDOWN_RATIO="${MATRIX_WARMDOWN_RATIO:-1.0}" -WINDOW_PATTERN="${WINDOW_PATTERN:-SSSL}" - -# FP8 (default enabled) -FP8="${FP8:-1}" -FP8_ARGS="" -if [ "${FP8}" -eq 1 ]; then - FP8_RECIPE="${FP8_RECIPE:-tensorwise}" - FP8_ARGS="--fp8 --fp8-recipe=${FP8_RECIPE}" -fi - -# Wandb -export WANDB_PROJECT="${WANDB_PROJECT:-nanochat-scaling}" -WANDB_RUN="${WANDB_RUN:-scaling_${LABEL}}" - -# Paths and cache -export OMP_NUM_THREADS=1 -SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" -PROJECT_ROOT="$(dirname "$SCRIPT_DIR")" -export NANOCHAT_BASE_DIR="${NANOCHAT_BASE_DIR:-$PROJECT_ROOT/cache}" -export TORCHINDUCTOR_CACHE_DIR="$NANOCHAT_BASE_DIR/torch_inductor" -export TRITON_CACHE_DIR="$NANOCHAT_BASE_DIR/triton" -export TMPDIR="$NANOCHAT_BASE_DIR/tmp" -mkdir -p "$NANOCHAT_BASE_DIR" "$TORCHINDUCTOR_CACHE_DIR" "$TRITON_CACHE_DIR" "$TMPDIR" - -cd "$PROJECT_ROOT" - -# Python venv -if [ ! -d ".venv" ]; then - echo "Setting up Python environment..." - command -v uv &> /dev/null || curl -LsSf https://astral.sh/uv/install.sh | sh - uv venv - uv sync --extra gpu -fi -source .venv/bin/activate - -RESULTS_DIR="$NANOCHAT_BASE_DIR/scaling_laws_results_${LABEL}" -mkdir -p "$RESULTS_DIR" -RESULTS_FILE="$RESULTS_DIR/results.csv" - -# Write CSV header only if file doesn't exist -if [ ! -f "$RESULTS_FILE" ]; then - echo "flops_budget,depth,model_dim,params_wte,params_value_embeds,params_lm_head,params_transformer,params_norm_and_proj_scalars,params_scalars,params_total,num_iterations,tokens_trained,val_bpb,core_score,train_time_sec" > "$RESULTS_FILE" -fi - -log() { - echo "[$(date '+%Y-%m-%d %H:%M:%S')] $1" -} - -# Check if a run already exists in results -run_exists() { - local flops=$1 - local depth=$2 - grep -q "^${flops},${depth}," "$RESULTS_FILE" 2>/dev/null -} - -# ============================================================================= -# Print summary -# ============================================================================= - -log "==============================================" -log "Scaling Laws Sweep (GPT-Gamma + MuonH)" -log "==============================================" -log "Label: $LABEL" -log "FLOPs budgets: ${FLOPS_BUDGETS[*]}" -log "Depths: ${DEPTHS[*]}" -log "Num GPUs: $NPROC_PER_NODE" -log "Total batch size: $TOTAL_BATCH_SIZE" -log "Matrix optimizer: $MATRIX_OPTIMIZER" -log "Matrix LR: $MATRIX_LR" -log "Norm LR: $NORM_LR" -log "Warmdown ratio: adam=$WARMDOWN_RATIO, matrix=$MATRIX_WARMDOWN_RATIO" -if [ "${FP8}" -eq 1 ]; then - log "FP8: enabled ($FP8_RECIPE)" -fi -log "Results dir: $RESULTS_DIR" -log "==============================================" - -# ============================================================================= -# Main Loop -# ============================================================================= - -for flops in "${FLOPS_BUDGETS[@]}"; do - log "==============================================" - log "Compute budget: $flops FLOPs" - log "==============================================" - - for d in "${DEPTHS[@]}"; do - - # Skip if already completed - if run_exists "$flops" "$d"; then - log "Skipping d=$d at $flops FLOPs (already in results)" - continue - fi - - log "Training d=$d at $flops FLOPs..." - - # Unique tag for this run - TAG="scaling_${LABEL}_${flops}_d${d}" - - # Record start time - START_TIME=$(date +%s) - - # Train the model with fixed flops budget - TRAIN_ARGS=( - --depth=$d - --target-flops=$flops - --target-param-data-ratio=-1 - --total-batch-size=$TOTAL_BATCH_SIZE - --device-batch-size=$DEVICE_BATCH_SIZE - --run="${WANDB_RUN}_${TAG}" - --model-tag="${TAG}" - --window-pattern=$WINDOW_PATTERN - --matrix-optimizer=$MATRIX_OPTIMIZER - --matrix-lr=$MATRIX_LR - --embedding-lr=$EMBEDDING_LR - --unembedding-lr=$UNEMBEDDING_LR - --scalar-lr=$SCALAR_LR - --norm-lr=$NORM_LR - --warmdown-ratio=$WARMDOWN_RATIO - --matrix-warmdown-ratio=$MATRIX_WARMDOWN_RATIO - --eval-tokens=$EVAL_TOKENS - --core-metric-every=999999 - --core-metric-max-per-task=-1 - --sample-every=-1 - --save-every=-1 - ) - - if [ "$NPROC_PER_NODE" -gt 1 ]; then - torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.base_train -- \ - "${TRAIN_ARGS[@]}" $FP8_ARGS \ - 2>&1 | tee "$RESULTS_DIR/${TAG}_train.log" - else - python -m scripts.base_train \ - "${TRAIN_ARGS[@]}" $FP8_ARGS \ - 2>&1 | tee "$RESULTS_DIR/${TAG}_train.log" - fi - - END_TIME=$(date +%s) - TRAIN_TIME=$((END_TIME - START_TIME)) - - # Extract training stats from the log - LOG_FILE="$RESULTS_DIR/${TAG}_train.log" - - # Extract detailed parameter counts (handle whitespace-padded format) - PARAMS_WTE=$(grep "wte" "$LOG_FILE" | grep ":" | tail -1 | grep -oP '[\d,]+' | tr -d ',') - PARAMS_VE=$(grep "value_embeds" "$LOG_FILE" | grep ":" | tail -1 | grep -oP '[\d,]+' | tr -d ',') - PARAMS_LM=$(grep "lm_head" "$LOG_FILE" | grep ":" | tail -1 | grep -oP '[\d,]+' | tr -d ',') - PARAMS_TRANSFORMER=$(grep "transformer_matrices" "$LOG_FILE" | grep ":" | tail -1 | grep -oP '[\d,]+' | tr -d ',') - PARAMS_NORM=$(grep "norm_and_proj_scalars" "$LOG_FILE" | grep ":" | tail -1 | grep -oP '[\d,]+' | tr -d ',') - PARAMS_SCALARS=$(grep -w "scalars" "$LOG_FILE" | grep ":" | tail -1 | grep -oP '[\d,]+' | tr -d ',') - PARAMS_TOTAL=$(grep -w "total" "$LOG_FILE" | grep ":" | tail -1 | grep -oP '[\d,]+' | tr -d ',') - - NUM_ITERS=$(grep "Calculated number of iterations" "$LOG_FILE" | tail -1 | sed 's/.*: //' | tr -d ',') - TOKENS_TRAINED=$((NUM_ITERS * TOTAL_BATCH_SIZE)) - MODEL_DIM=$((d * 64)) - VAL_BPB=$(grep "Validation bpb:" "$LOG_FILE" | tail -1 | grep -oP '[\d.]+$') - - # Extract CORE score from training log (evaluated on final step) - CORE_SCORE=$(grep "CORE metric:" "$LOG_FILE" | tail -1 | awk '{print $NF}') - if [ -z "$CORE_SCORE" ]; then - log "WARNING: Could not extract CORE score for d=$d" - CORE_SCORE="0.0" - fi - - log " Params: $PARAMS_TOTAL (transformer: $PARAMS_TRANSFORMER), Iters: $NUM_ITERS, Val BPB: $VAL_BPB, CORE: $CORE_SCORE" - - # Append to CSV - echo "$flops,$d,$MODEL_DIM,$PARAMS_WTE,$PARAMS_VE,$PARAMS_LM,$PARAMS_TRANSFORMER,$PARAMS_NORM,$PARAMS_SCALARS,$PARAMS_TOTAL,$NUM_ITERS,$TOKENS_TRAINED,$VAL_BPB,$CORE_SCORE,$TRAIN_TIME" >> "$RESULTS_FILE" - done -done - -log "==============================================" -log "Scaling Laws Sweep Complete" -log "==============================================" -log "Results saved to: $RESULTS_FILE" -echo "" -echo "Results:" -column -t -s',' "$RESULTS_FILE" From 116900ac1610fb5521d7b015af0f36c54026490f Mon Sep 17 00:00:00 2001 From: Kaiyue Wen Date: Thu, 12 Feb 2026 17:51:36 -0800 Subject: [PATCH 14/14] muonh --- runs/quickrun_muonh.sh | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/runs/quickrun_muonh.sh b/runs/quickrun_muonh.sh index 5abf675..9c7a14e 100755 --- a/runs/quickrun_muonh.sh +++ b/runs/quickrun_muonh.sh @@ -16,9 +16,9 @@ set -e # ----------------------------------------------------------------------------- # Config -DEPTH="${DEPTH:-26}" +DEPTH="${DEPTH:-24}" NUM_SHARDS="${NUM_SHARDS:-370}" # default for d24 @ ratio~11 -TARGET_RATIO="${TARGET_RATIO:-10.5}" +TARGET_RATIO="${TARGET_RATIO:-12}" WINDOW_PATTERN="${WINDOW_PATTERN:-SSSL}" DEVICE_BATCH_SIZE="${DEVICE_BATCH_SIZE:-16}" TOTAL_BATCH_SIZE="${TOTAL_BATCH_SIZE:-524288}" # -1 = auto-compute optimal (Power Lines paper)