diff --git a/nanochat/gpt.py b/nanochat/gpt.py index 208acd1..fc3d422 100644 --- a/nanochat/gpt.py +++ b/nanochat/gpt.py @@ -2,14 +2,16 @@ GPT model (rewrite, a lot simpler) Notable features: - rotary embeddings (and no positional embeddings) -- QK norm +- QK norm (functional RMSNorm, no learnable params) - untied weights for token embedding and lm_head - relu^2 activation in MLP -- norm after token embedding -- no learnable params in rmsnorm +- norm after token embedding (functional RMSNorm) +- parameterized RMSNorm in blocks (learnable gamma) +- per-block projection scalars for attention/MLP - no bias in linear layers - Group-Query Attention (GQA) support for more efficient inference - Flash Attention 3 integration +- optional Hyperball optimizer for matrix params """ from functools import partial @@ -40,7 +42,7 @@ class GPTConfig: def norm(x): - # Purely functional rmsnorm with no learnable params + # Purely functional RMSNorm with no learnable params return F.rms_norm(x, (x.size(-1),)) @@ -72,6 +74,7 @@ class CausalSelfAttention(nn.Module): self.c_proj = nn.Linear(self.n_embd, self.n_embd, bias=False) self.ve_gate_channels = 32 self.ve_gate = nn.Linear(self.ve_gate_channels, self.n_kv_head, bias=False) if has_ve(layer_idx, config.n_layer) else None + self.c_proj_scalar = nn.Parameter(torch.zeros(config.n_embd)) def forward(self, x, ve, cos_sin, window_size, kv_cache): B, T, C = x.size() @@ -115,6 +118,7 @@ class CausalSelfAttention(nn.Module): # Re-assemble the heads and project back to residual stream y = y.contiguous().view(B, T, -1) y = self.c_proj(y) + y = y * self.c_proj_scalar return y @@ -123,23 +127,27 @@ class MLP(nn.Module): super().__init__() self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=False) self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=False) + self.c_proj_scalar = nn.Parameter(torch.zeros(config.n_embd)) def forward(self, x): x = self.c_fc(x) x = F.relu(x).square() x = self.c_proj(x) + x = x * self.c_proj_scalar return x class Block(nn.Module): def __init__(self, config, layer_idx): super().__init__() + self.attn_norm = nn.RMSNorm(config.n_embd) self.attn = CausalSelfAttention(config, layer_idx) + self.mlp_norm = nn.RMSNorm(config.n_embd) self.mlp = MLP(config) def forward(self, x, ve, cos_sin, window_size, kv_cache): - x = x + self.attn(norm(x), ve, cos_sin, window_size, kv_cache) - x = x + self.mlp(norm(x)) + x = x + self.attn(self.attn_norm(x), ve, cos_sin, window_size, kv_cache) + x = x + self.mlp(self.mlp_norm(x)) return x @@ -196,15 +204,21 @@ class GPT(nn.Module): attn.c_q: uniform, std=1/sqrt(n_embd) attn.c_k: uniform, std=1/sqrt(n_embd) attn.c_v: uniform, std=1/sqrt(n_embd) - attn.c_proj: zeros + attn.c_proj: uniform, std=1/sqrt(n_embd) mlp.c_fc: uniform, std=1/sqrt(n_embd) - mlp.c_proj: zeros + mlp.c_proj: uniform, std=1/sqrt(n_embd) + nn.RMSNorm weight: ones (via explicit init below) """ # Embedding and unembedding torch.nn.init.normal_(self.transformer.wte.weight, mean=0.0, std=1.0) torch.nn.init.normal_(self.lm_head.weight, mean=0.0, std=0.001) + # nn.RMSNorm weight parameters: init to ones (must be explicit due to meta device) + for module in self.modules(): + if isinstance(module, nn.RMSNorm): + module.weight.fill_(1.0) + # Transformer blocks: uniform init with bound = sqrt(3) * std (same standard deviation as normal) n_embd = self.config.n_embd s = 3**0.5 * n_embd**-0.5 # sqrt(3) multiplier makes sure Uniform achieves the same std as Normal @@ -212,22 +226,38 @@ class GPT(nn.Module): torch.nn.init.uniform_(block.attn.c_q.weight, -s, s) # weights use Uniform to avoid outliers torch.nn.init.uniform_(block.attn.c_k.weight, -s, s) torch.nn.init.uniform_(block.attn.c_v.weight, -s, s) - torch.nn.init.zeros_(block.attn.c_proj.weight) # projections are zero + torch.nn.init.uniform_(block.attn.c_proj.weight, -s, s) torch.nn.init.uniform_(block.mlp.c_fc.weight, -s, s) - torch.nn.init.zeros_(block.mlp.c_proj.weight) + torch.nn.init.uniform_(block.mlp.c_proj.weight, -s, s) # Per-layer scalars self.resid_lambdas.fill_(1.0) # 1.0 => typical residual connections at init self.x0_lambdas.fill_(0.1) # 0.1 => small initial weight for skip connection to input embedding + # Per-block projection scalars (zero-init for stable training start) + for block in self.transformer.h: + block.attn.c_proj_scalar.fill_(0.0) + block.mlp.c_proj_scalar.fill_(0.0) + if self.transformer.wte.weight.device.type == "cuda": + block.attn.c_proj_scalar.data = block.attn.c_proj_scalar.data.to(torch.bfloat16) + block.mlp.c_proj_scalar.data = block.mlp.c_proj_scalar.data.to(torch.bfloat16) + + # Block RMSNorm weights (cast to bf16 for fused kernel) + for block in self.transformer.h: + block.attn_norm.weight.fill_(1.0) + block.mlp_norm.weight.fill_(1.0) + if self.transformer.wte.weight.device.type == "cuda": + block.attn_norm.to(dtype=torch.bfloat16) + block.mlp_norm.to(dtype=torch.bfloat16) + # Value embeddings (init like c_v: uniform with same std) for ve in self.value_embeds.values(): torch.nn.init.uniform_(ve.weight, -s, s) - # Gate weights init to zero so gates start at sigmoid(0) = 0.5, scaled by 2 -> 1.0 (neutral) + # Gate weights init to uniform (avoid zero-norm params under Hyperball) for block in self.transformer.h: if block.attn.ve_gate is not None: - torch.nn.init.zeros_(block.attn.ve_gate.weight) + torch.nn.init.uniform_(block.attn.ve_gate.weight, -s, s) # Rotary embeddings head_dim = self.config.n_embd // self.config.n_head @@ -302,10 +332,11 @@ class GPT(nn.Module): - Chinchilla counts exp/sum/divide in attention softmax as flops (a little sus and very tiny => we ignore) """ nparams = sum(p.numel() for p in self.parameters()) - # Exclude non-matmul params: embeddings and per-layer scalars + # Exclude non-matmul params: embeddings, per-layer scalars, and 1D params in blocks value_embeds_numel = sum(ve.weight.numel() for ve in self.value_embeds.values()) + block_1d_params = sum(p.numel() for p in self.transformer.h.parameters() if p.ndim == 1) nparams_exclude = (self.transformer.wte.weight.numel() + value_embeds_numel + - self.resid_lambdas.numel() + self.x0_lambdas.numel()) + self.resid_lambdas.numel() + self.x0_lambdas.numel() + block_1d_params) h, q, t = self.config.n_head, self.config.n_embd // self.config.n_head, self.config.sequence_len # Sum attention FLOPs per layer, accounting for sliding window attn_flops = 0 @@ -332,31 +363,36 @@ class GPT(nn.Module): wte = sum(p.numel() for p in self.transformer.wte.parameters()) value_embeds = sum(p.numel() for p in self.value_embeds.parameters()) lm_head = sum(p.numel() for p in self.lm_head.parameters()) - transformer_matrices = sum(p.numel() for p in self.transformer.h.parameters()) + block_matrix_params = sum(p.numel() for p in self.transformer.h.parameters() if p.ndim == 2) + block_1d_params = sum(p.numel() for p in self.transformer.h.parameters() if p.ndim == 1) scalars = self.resid_lambdas.numel() + self.x0_lambdas.numel() - total = wte + value_embeds + lm_head + transformer_matrices + scalars + total = wte + value_embeds + lm_head + block_matrix_params + block_1d_params + scalars assert total == sum(p.numel() for p in self.parameters()), "Parameter count mismatch" return { 'wte': wte, 'value_embeds': value_embeds, 'lm_head': lm_head, - 'transformer_matrices': transformer_matrices, + 'transformer_matrices': block_matrix_params, + 'norm_and_proj_scalars': block_1d_params, 'scalars': scalars, 'total': total, } - def setup_optimizer(self, unembedding_lr=0.004, embedding_lr=0.2, matrix_lr=0.02, weight_decay=0.0, adam_betas=(0.8, 0.95), scalar_lr=0.5): + def setup_optimizer(self, unembedding_lr=0.004, embedding_lr=0.2, matrix_lr=0.02, weight_decay=0.0, adam_betas=(0.8, 0.95), scalar_lr=0.5, matrix_optimizer="muon"): model_dim = self.config.n_embd ddp, rank, local_rank, world_size = get_dist_info() # Separate out all parameters into groups - matrix_params = list(self.transformer.h.parameters()) + block_matrix_params = [p for p in self.transformer.h.parameters() if p.ndim == 2] + block_1d_params = [p for p in self.transformer.h.parameters() if p.ndim == 1] value_embeds_params = list(self.value_embeds.parameters()) embedding_params = list(self.transformer.wte.parameters()) lm_head_params = list(self.lm_head.parameters()) resid_params = [self.resid_lambdas] x0_params = [self.x0_lambdas] - assert len(list(self.parameters())) == len(matrix_params) + len(embedding_params) + len(lm_head_params) + len(value_embeds_params) + len(resid_params) + len(x0_params) + all_params_count = (len(block_matrix_params) + len(block_1d_params) + len(embedding_params) + + len(lm_head_params) + len(value_embeds_params) + len(resid_params) + len(x0_params)) + assert len(list(self.parameters())) == all_params_count # Scale the LR for the AdamW parameters by ∝1/√dmodel (tuned for 768 dim model) dmodel_lr_scale = (model_dim / 768) ** -0.5 @@ -370,14 +406,23 @@ class GPT(nn.Module): dict(kind='adamw', params=value_embeds_params, lr=embedding_lr * dmodel_lr_scale, betas=adam_betas, eps=1e-10, weight_decay=0.0), dict(kind='adamw', params=resid_params, lr=scalar_lr * 0.01, betas=adam_betas, eps=1e-10, weight_decay=0.0), dict(kind='adamw', params=x0_params, lr=scalar_lr, betas=(0.96, 0.95), eps=1e-10, weight_decay=0.0), # higher beta1 for x0 + dict(kind='adamw', params=block_1d_params, lr=scalar_lr, betas=adam_betas, eps=1e-10, weight_decay=0.0), ] - # Muon groups (matrix params, grouped by shape for stacking) - for shape in sorted({p.shape for p in matrix_params}): - group_params = [p for p in matrix_params if p.shape == shape] - param_groups.append(dict( - kind='muon', params=group_params, lr=matrix_lr, - momentum=0.95, ns_steps=5, beta2=0.95, weight_decay=weight_decay, - )) + # Matrix params (Muon or Hyperball), grouped by shape for stacking + if matrix_optimizer not in {"muon", "hyperball"}: + raise ValueError(f"Unknown matrix_optimizer: {matrix_optimizer}") + for shape in sorted({p.shape for p in block_matrix_params}): + group_params = [p for p in block_matrix_params if p.shape == shape] + if matrix_optimizer == "hyperball": + param_groups.append(dict( + kind='hyperball', params=group_params, lr=matrix_lr, + momentum=0.95, ns_steps=5, beta2=0.95, + )) + else: + param_groups.append(dict( + kind='muon', params=group_params, lr=matrix_lr, + momentum=0.95, ns_steps=5, beta2=0.95, weight_decay=weight_decay, + )) Factory = DistMuonAdamW if ddp else MuonAdamW optimizer = Factory(param_groups) diff --git a/nanochat/optim.py b/nanochat/optim.py index 190a1ed..f93ba38 100644 --- a/nanochat/optim.py +++ b/nanochat/optim.py @@ -1,6 +1,6 @@ """ A nice and efficient mixed AdamW/Muon Combined Optimizer. -Usually the embeddings and scalars go into AdamW, and the matrix parameters go into Muon. +Usually the embeddings and scalars go into AdamW, and the matrix parameters go into Muon/Hyperball. Two versions are provided (MuonAdamW, DistMuonAdamW), for single GPU and distributed. Addapted from: https://github.com/KellerJordan/modded-nanogpt @@ -141,6 +141,80 @@ def muon_step_fused( mask = (g * stacked_params) >= 0 stacked_params.sub_(lr * g + lr * wd * stacked_params * mask) +# ----------------------------------------------------------------------------- +""" +Hyperball optimizer (MuonH): Muon with scale-invariant updates. +https://github.com/marin-community/marin/blob/main/lib/levanter/src/levanter/optim/muonh.py + +The key difference from Muon is that weights maintain constant Frobenius norm +throughout training via the following update rule: + + p_new_intermediate = p - lr * u * ||p|| / ||u|| + p_new = p_new_intermediate / ||p_new_intermediate|| * ||p|| + +This projects the update onto the hypersphere of constant norm, hence "Hyperball". +Uses variance reduction like Muon, but no cautious weight decay. +""" + +@torch.compile(dynamic=False, fullgraph=True) +def hyperball_step_fused( + stacked_grads: Tensor, # (K, M, N) - stacked gradients + stacked_params: Tensor, # (K, M, N) - stacked parameters + momentum_buffer: Tensor, # (K, M, N) - momentum buffer + second_momentum_buffer: Tensor, # (K, M, 1) or (K, 1, N) - factored second moment + p_norm: Tensor, # (K, 1, 1) - pre-computed Frobenius norm of params (constant) + momentum_t: Tensor, # () - 0-D CPU tensor, momentum coefficient + lr_t: Tensor, # () - 0-D CPU tensor, learning rate + beta2_t: Tensor, # () - 0-D CPU tensor, beta2 for second moment + ns_steps: int, # 5 - number of Newton-Schulz/Polar Express iterations + red_dim: int, # -1 or -2 - reduction dimension for variance +) -> None: + """ + Fused Hyperball step: momentum -> polar_express -> variance_reduction -> scale_invariant_update + All in one compiled graph. Weights maintain constant Frobenius norm. + """ + + # Nesterov momentum + momentum = momentum_t.to(stacked_grads.dtype) + momentum_buffer.lerp_(stacked_grads, 1 - momentum) + g = stacked_grads.lerp_(momentum_buffer, momentum) + + # Polar express orthogonalization + X = g.bfloat16() + X = X / (X.norm(dim=(-2, -1), keepdim=True) * 1.02 + 1e-6) + if g.size(-2) > g.size(-1): # Tall matrix + for a, b, c in polar_express_coeffs[:ns_steps]: + A = X.mT @ X + B = b * A + c * (A @ A) + X = a * X + X @ B + else: # Wide matrix + for a, b, c in polar_express_coeffs[:ns_steps]: + A = X @ X.mT + B = b * A + c * (A @ A) + X = a * X + B @ X + g = X + + # Variance reduction (note: this preserves ||g||_F, so ||u||_F == ||g||_F == v_norm) + beta2 = beta2_t.to(g.dtype) + v_mean = g.float().square().mean(dim=red_dim, keepdim=True) + red_dim_size = g.size(red_dim) + second_momentum_buffer.lerp_(v_mean.to(dtype=second_momentum_buffer.dtype), 1 - beta2) + step_size = second_momentum_buffer.clamp_min(1e-10).rsqrt() + scaled_sq_sum = (v_mean * red_dim_size) * step_size.float().square() + v_norm_new = scaled_sq_sum.sum(dim=(-2, -1), keepdim=True).sqrt() + p_norm = p_norm.to(v_norm_new.dtype) + final_scale = step_size * p_norm / v_norm_new.clamp_min(1e-10) + g = g * final_scale.to(g.dtype) + u = g.to(stacked_params.dtype) + + # Scale-invariant update: keeps ||p|| constant + lr = lr_t.to(stacked_params.dtype) + stacked_params.sub_(lr * u) + + # Project back to hypersphere: p = p * (||p_orig|| / ||p_new||) + p_new_norm = stacked_params.norm(dim=(-2, -1), keepdim=True).clamp_min(1e-10) + stacked_params.mul_(p_norm / p_new_norm) + # ----------------------------------------------------------------------------- # Single GPU version of the MuonAdamW optimizer. # Used mostly for reference, debugging and testing. @@ -167,9 +241,10 @@ class MuonAdamW(torch.optim.Optimizer): Arguments: param_groups: List of dicts, each containing: - 'params': List of parameters - - 'kind': 'adamw' or 'muon' + - 'kind': 'adamw', 'muon', or 'hyperball' - For AdamW groups: 'lr', 'betas', 'eps', 'weight_decay' - For Muon groups: 'lr', 'momentum', 'ns_steps', 'beta2', 'weight_decay' + - For Hyperball groups: 'lr', 'momentum', 'ns_steps', 'beta2' """ def __init__(self, param_groups: list[dict]): super().__init__(param_groups, defaults={}) @@ -186,6 +261,10 @@ class MuonAdamW(torch.optim.Optimizer): self._muon_lr_t = torch.tensor(0.0, dtype=torch.float32, device="cpu") self._muon_wd_t = torch.tensor(0.0, dtype=torch.float32, device="cpu") self._muon_beta2_t = torch.tensor(0.0, dtype=torch.float32, device="cpu") + # Hyperball tensors + self._hyperball_momentum_t = torch.tensor(0.0, dtype=torch.float32, device="cpu") + self._hyperball_lr_t = torch.tensor(0.0, dtype=torch.float32, device="cpu") + self._hyperball_beta2_t = torch.tensor(0.0, dtype=torch.float32, device="cpu") def _step_adamw(self, group: dict) -> None: """ @@ -276,6 +355,64 @@ class MuonAdamW(torch.optim.Optimizer): # Copy back to original params torch._foreach_copy_(params, list(stacked_params.unbind(0))) + def _step_hyperball(self, group: dict) -> None: + """ + Hyperball update for all params in the group (stacked for efficiency). + Like Muon, but uses scale-invariant updates that keep weight norms constant. + """ + params: list[Tensor] = group['params'] + if not params: + return + + # Get or create group-level buffers (stored in first param's state for convenience) + p = params[0] + state = self.state[p] + num_params = len(params) + shape, device, dtype = p.shape, p.device, p.dtype + + # Stack grads and params (NOTE: this assumes all params have the same shape) + stacked_grads = torch.stack([p.grad for p in params]) + stacked_params = torch.stack(params) + + # Momentum buffer for every individual parameter + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros(num_params, *shape, dtype=dtype, device=device) + momentum_buffer = state["momentum_buffer"] + + # Second momentum buffer is factored, either per-row or per-column + if "second_momentum_buffer" not in state: + state_shape = (num_params, shape[-2], 1) if shape[-2] >= shape[-1] else (num_params, 1, shape[-1]) + state["second_momentum_buffer"] = torch.zeros(state_shape, dtype=dtype, device=device) + second_momentum_buffer = state["second_momentum_buffer"] + red_dim = -1 if shape[-2] >= shape[-1] else -2 + + # Pre-compute and cache p_norm (Frobenius norm of each param, constant throughout training) + if "p_norm" not in state: + state["p_norm"] = stacked_params.norm(dim=(-2, -1), keepdim=True).clone() + p_norm = state["p_norm"] + + # Fill all the 0-D tensors with current values + self._hyperball_momentum_t.fill_(group["momentum"]) + self._hyperball_lr_t.fill_(group["lr"]) + self._hyperball_beta2_t.fill_(group["beta2"] if group["beta2"] is not None else 0.0) + + # Single fused kernel: momentum -> polar_express -> variance_reduction -> scale_invariant_update + hyperball_step_fused( + stacked_grads, + stacked_params, + momentum_buffer, + second_momentum_buffer, + p_norm, + self._hyperball_momentum_t, + self._hyperball_lr_t, + self._hyperball_beta2_t, + group["ns_steps"], + red_dim, + ) + + # Copy back to original params + torch._foreach_copy_(params, list(stacked_params.unbind(0))) + @torch.no_grad() def step(self): for group in self.param_groups: @@ -283,6 +420,8 @@ class MuonAdamW(torch.optim.Optimizer): self._step_adamw(group) elif group['kind'] == 'muon': self._step_muon(group) + elif group['kind'] == 'hyperball': + self._step_hyperball(group) else: raise ValueError(f"Unknown optimizer kind: {group['kind']}") @@ -344,9 +483,10 @@ class DistMuonAdamW(torch.optim.Optimizer): Arguments: param_groups: List of dicts, each containing: - 'params': List of parameters - - 'kind': 'adamw' or 'muon' + - 'kind': 'adamw', 'muon', or 'hyperball' - For AdamW groups: 'lr', 'betas', 'eps', 'weight_decay' - For Muon groups: 'lr', 'momentum', 'ns_steps', 'beta2', 'weight_decay' + - For Hyperball groups: 'lr', 'momentum', 'ns_steps', 'beta2' """ def __init__(self, param_groups: list[dict]): super().__init__(param_groups, defaults={}) @@ -361,6 +501,10 @@ class DistMuonAdamW(torch.optim.Optimizer): self._muon_lr_t = torch.tensor(0.0, dtype=torch.float32, device="cpu") self._muon_wd_t = torch.tensor(0.0, dtype=torch.float32, device="cpu") self._muon_beta2_t = torch.tensor(0.0, dtype=torch.float32, device="cpu") + # Hyperball tensors + self._hyperball_momentum_t = torch.tensor(0.0, dtype=torch.float32, device="cpu") + self._hyperball_lr_t = torch.tensor(0.0, dtype=torch.float32, device="cpu") + self._hyperball_beta2_t = torch.tensor(0.0, dtype=torch.float32, device="cpu") def _reduce_adamw(self, group: dict, world_size: int) -> dict: """Launch async reduce ops for AdamW group. Returns info dict with per-param infos.""" @@ -491,6 +635,85 @@ class DistMuonAdamW(torch.optim.Optimizer): future = dist.all_gather_into_tensor(stacked_params, updated_params, async_op=True).get_future() gather_list.append(dict(future=future, stacked_params=stacked_params, params=params)) + def _reduce_hyperball(self, group: dict, world_size: int) -> dict: + """Launch async reduce op for Hyperball group. Returns info dict.""" + params = group['params'] + chunk_size = (len(params) + world_size - 1) // world_size + padded_num_params = chunk_size * world_size + p = params[0] + shape, device, dtype = p.shape, p.device, p.dtype + + # Stack grads and zero-pad to padded_num_params + grad_stack = torch.stack([p.grad for p in params]) + stacked_grads = torch.empty(padded_num_params, *shape, dtype=dtype, device=device) + stacked_grads[:len(params)].copy_(grad_stack) + if len(params) < padded_num_params: + stacked_grads[len(params):].zero_() + + # Reduce_scatter to get this rank's chunk + grad_chunk = torch.empty(chunk_size, *shape, dtype=dtype, device=device) + future = dist.reduce_scatter_tensor(grad_chunk, stacked_grads, op=dist.ReduceOp.AVG, async_op=True).get_future() + + return dict(future=future, grad_chunk=grad_chunk, stacked_grads=stacked_grads, chunk_size=chunk_size) + + def _compute_hyperball(self, group: dict, info: dict, gather_list: list, rank: int) -> None: + """Wait for reduce, compute Hyperball updates, launch gather.""" + info['future'].wait() + params = group['params'] + chunk_size = info['chunk_size'] + grad_chunk = info['grad_chunk'] + p = params[0] + shape, device, dtype = p.shape, p.device, p.dtype + + # How many params does this rank own? + start_idx = rank * chunk_size + num_owned = min(chunk_size, max(0, len(params) - start_idx)) + + # Get or create group-level state + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros(chunk_size, *shape, dtype=dtype, device=device) + if "second_momentum_buffer" not in state: + state_shape = (chunk_size, shape[-2], 1) if shape[-2] >= shape[-1] else (chunk_size, 1, shape[-1]) + state["second_momentum_buffer"] = torch.zeros(state_shape, dtype=dtype, device=device) + red_dim = -1 if shape[-2] >= shape[-1] else -2 + + # Build output buffer for all_gather + updated_params = torch.empty(chunk_size, *shape, dtype=dtype, device=device) + + if num_owned > 0: + owned_params = [params[start_idx + i] for i in range(num_owned)] + stacked_owned = torch.stack(owned_params) + + # Pre-compute and cache p_norm for owned params (constant throughout training) + if "p_norm" not in state: + state["p_norm"] = stacked_owned.norm(dim=(-2, -1), keepdim=True).clone() + p_norm = state["p_norm"] + + # Fill 0-D tensors and run fused kernel + self._hyperball_momentum_t.fill_(group["momentum"]) + self._hyperball_lr_t.fill_(group["lr"]) + self._hyperball_beta2_t.fill_(group["beta2"] if group["beta2"] is not None else 0.0) + hyperball_step_fused( + grad_chunk[:num_owned], stacked_owned, + state["momentum_buffer"][:num_owned], + state["second_momentum_buffer"][:num_owned], + p_norm, + self._hyperball_momentum_t, self._hyperball_lr_t, + self._hyperball_beta2_t, + group["ns_steps"], + red_dim, + ) + updated_params[:num_owned].copy_(stacked_owned) + + if num_owned < chunk_size: + updated_params[num_owned:].zero_() + + # Reuse stacked_grads buffer for all_gather output + stacked_params = info["stacked_grads"] + future = dist.all_gather_into_tensor(stacked_params, updated_params, async_op=True).get_future() + gather_list.append(dict(future=future, stacked_params=stacked_params, params=params)) + def _finish_gathers(self, gather_list: list) -> None: """Wait for all gathers and copy Muon params back.""" for info in gather_list: @@ -511,6 +734,8 @@ class DistMuonAdamW(torch.optim.Optimizer): reduce_infos.append(self._reduce_adamw(group, world_size)) elif group['kind'] == 'muon': reduce_infos.append(self._reduce_muon(group, world_size)) + elif group['kind'] == 'hyperball': + reduce_infos.append(self._reduce_hyperball(group, world_size)) else: raise ValueError(f"Unknown optimizer kind: {group['kind']}") @@ -521,6 +746,8 @@ class DistMuonAdamW(torch.optim.Optimizer): self._compute_adamw(group, info, gather_list, rank, world_size) elif group['kind'] == 'muon': self._compute_muon(group, info, gather_list, rank) + elif group['kind'] == 'hyperball': + self._compute_hyperball(group, info, gather_list, rank) else: raise ValueError(f"Unknown optimizer kind: {group['kind']}") diff --git a/runs/quickrun_gamma_muonh_d24.sh b/runs/quickrun_gamma_muonh_d24.sh new file mode 100755 index 0000000..63ba8cf --- /dev/null +++ b/runs/quickrun_gamma_muonh_d24.sh @@ -0,0 +1,156 @@ +#!/bin/bash + +# Quickrun: GPT-Gamma + MuonH (Hyperball), depth=24 +# - Parameterized RMSNorm (learnable gamma) +# - Per-block projection scalars +# - Hyperball or Muon for matrix params +# +# Examples: +# bash runs/quickrun_gamma_muonh_d24.sh +# WANDB_RUN=exp1 bash runs/quickrun_gamma_muonh_d24.sh +# FP8=1 FP8_RECIPE=tensorwise bash runs/quickrun_gamma_muonh_d24.sh + +set -e + +# ----------------------------------------------------------------------------- +# Config + +DEPTH="${DEPTH:-24}" +NUM_SHARDS="${NUM_SHARDS:-370}" # ~10B tokens for d24 @ ratio~11 +TARGET_RATIO="${TARGET_RATIO:-11}" +WINDOW_PATTERN="${WINDOW_PATTERN:-SSSL}" +DEVICE_BATCH_SIZE="${DEVICE_BATCH_SIZE:-16}" +TOTAL_BATCH_SIZE="${TOTAL_BATCH_SIZE:-524288}" + +NPROC_PER_NODE="${NPROC_PER_NODE:-$(nvidia-smi -L 2>/dev/null | wc -l || echo 1)}" +if [ "$NPROC_PER_NODE" -eq 0 ]; then + NPROC_PER_NODE=1 +fi + +# Optimizer +MATRIX_OPTIMIZER="${MATRIX_OPTIMIZER:-hyperball}" +SCALAR_LR="${SCALAR_LR:-0.5}" +MATRIX_LR="$SCALAR_LR" # share with scalar LR +WARMDOWN_RATIO="${WARMDOWN_RATIO:-0.3}" + +# AdamW +EMBEDDING_LR="${EMBEDDING_LR:-0.3}" +UNEMBEDDING_LR="${UNEMBEDDING_LR:-0.004}" + +# Wandb +export WANDB_MODE=offline +WANDB_RUN="${WANDB_RUN:-muonh_d${DEPTH}_ratio${TARGET_RATIO}}" +MODEL_TAG="${MODEL_TAG:-d${DEPTH}_gamma_muonh}" + +# FP8 +FP8_ARGS="" +if [ "${FP8:-0}" -eq 1 ]; then + FP8_RECIPE="${FP8_RECIPE:-tensorwise}" + FP8_ARGS="--fp8 --fp8-recipe=${FP8_RECIPE}" +fi + +# NCCL (single node) +export NCCL_P2P_LEVEL=NVL +export NCCL_NVLS_ENABLE=1 +export NCCL_IB_DISABLE=1 + +# ----------------------------------------------------------------------------- +# Paths and cache + +export OMP_NUM_THREADS=1 +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +PROJECT_ROOT="$(dirname "$SCRIPT_DIR")" +export NANOCHAT_BASE_DIR="$PROJECT_ROOT/cache" +export TORCHINDUCTOR_CACHE_DIR="$NANOCHAT_BASE_DIR/torch_inductor" +export TRITON_CACHE_DIR="$NANOCHAT_BASE_DIR/triton" +export TMPDIR="$NANOCHAT_BASE_DIR/tmp" +mkdir -p "$NANOCHAT_BASE_DIR" "$TORCHINDUCTOR_CACHE_DIR" "$TRITON_CACHE_DIR" "$TMPDIR" + +# ----------------------------------------------------------------------------- +# Print summary + +echo "==============================================" +echo "Quickrun (GPT-Gamma + MuonH D24)" +echo "==============================================" +echo "Project root: $PROJECT_ROOT" +echo "Cache dir: $NANOCHAT_BASE_DIR" +echo "Depth: $DEPTH" +echo "Num shards: $NUM_SHARDS" +echo "Target ratio: $TARGET_RATIO" +echo "Window pattern: $WINDOW_PATTERN" +echo "Num GPUs: $NPROC_PER_NODE" +echo "Device batch size: $DEVICE_BATCH_SIZE" +echo "Total batch size: $TOTAL_BATCH_SIZE" +echo "Matrix optimizer: $MATRIX_OPTIMIZER" +echo "Matrix LR: $MATRIX_LR (shared with scalar)" +echo "Adam LRs: embedding=$EMBEDDING_LR, unembedding=$UNEMBEDDING_LR, scalar=$SCALAR_LR" +echo "Warmdown ratio: $WARMDOWN_RATIO" +echo "Wandb run: $WANDB_RUN" +echo "Model tag: $MODEL_TAG" +if [ "${FP8:-0}" -eq 1 ]; then + echo "FP8: enabled ($FP8_RECIPE)" +fi +echo "==============================================" + +cd "$PROJECT_ROOT" + +# ----------------------------------------------------------------------------- +# Python venv + +if [ ! -d ".venv" ]; then + echo "Setting up Python environment..." + command -v uv &> /dev/null || curl -LsSf https://astral.sh/uv/install.sh | sh + uv venv + uv sync --extra gpu +fi +source .venv/bin/activate + +# ----------------------------------------------------------------------------- +# Data + tokenizer + +echo "" +echo "Downloading $NUM_SHARDS data shards..." +python -m nanochat.dataset -n "$NUM_SHARDS" + +echo "" +echo "Checking tokenizer..." +python -m scripts.tok_train --max-chars=500000000 --vocab-size=32768 + +# ----------------------------------------------------------------------------- +# Train + +echo "" +echo "Starting pretraining (depth=$DEPTH)..." + +TRAIN_ARGS=( + --depth=$DEPTH + --run=$WANDB_RUN + --model-tag=$MODEL_TAG + --window-pattern=$WINDOW_PATTERN + --target-param-data-ratio=$TARGET_RATIO + --device-batch-size=$DEVICE_BATCH_SIZE + --total-batch-size=$TOTAL_BATCH_SIZE + --matrix-optimizer=$MATRIX_OPTIMIZER + --matrix-lr=$MATRIX_LR + --warmdown-ratio=$WARMDOWN_RATIO + --embedding-lr=$EMBEDDING_LR + --unembedding-lr=$UNEMBEDDING_LR + --scalar-lr=$SCALAR_LR + --core-metric-every=2000 + --sample-every=-1 + --save-every=-1 +) + +if [ "$NPROC_PER_NODE" -gt 1 ]; then + torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.base_train -- \ + "${TRAIN_ARGS[@]}" $FP8_ARGS +else + python -m scripts.base_train \ + "${TRAIN_ARGS[@]}" $FP8_ARGS +fi + +echo "" +echo "==============================================" +echo "Training complete!" +echo "==============================================" +echo "Checkpoint saved to: $NANOCHAT_BASE_DIR/base_checkpoints/${MODEL_TAG}/" diff --git a/scripts/base_train.py b/scripts/base_train.py index fa05b60..a4de906 100644 --- a/scripts/base_train.py +++ b/scripts/base_train.py @@ -58,7 +58,8 @@ parser.add_argument("--total-batch-size", type=int, default=524288, help="total parser.add_argument("--embedding-lr", type=float, default=0.3, help="learning rate for embedding parameters (Adam)") parser.add_argument("--unembedding-lr", type=float, default=0.004, help="learning rate for unembedding parameters (Adam)") parser.add_argument("--weight-decay", type=float, default=0.2, help="cautious weight decay for the Muon optimizer (for weights)") -parser.add_argument("--matrix-lr", type=float, default=0.02, help="learning rate for matrix parameters (Muon)") +parser.add_argument("--matrix-lr", type=float, default=0.02, help="learning rate for matrix parameters (Muon/Hyperball)") +parser.add_argument("--matrix-optimizer", type=str, default="muon", choices=["muon", "hyperball"], help="optimizer for matrix parameters") parser.add_argument("--scalar-lr", type=float, default=0.5, help="learning rate for scalars (resid_lambdas, x0_lambdas)") parser.add_argument("--adam-beta1", type=float, default=0.8, help="Adam beta1 for embedding/unembedding") parser.add_argument("--adam-beta2", type=float, default=0.95, help="Adam beta2 for embedding/unembedding") @@ -303,6 +304,7 @@ optimizer = model.setup_optimizer( weight_decay=weight_decay_scaled, adam_betas=adam_betas, scalar_lr=args.scalar_lr * batch_lr_scale, + matrix_optimizer=args.matrix_optimizer, ) if resuming: @@ -331,7 +333,7 @@ def get_lr_multiplier(it): progress = (num_iterations - it) / warmdown_iters return progress * 1.0 + (1 - progress) * args.final_lr_frac -# Momentum scheduler for Muon optimizer +# Momentum scheduler for matrix optimizer (Muon/Hyperball) def get_muon_momentum(it): frac = min(it / 300, 1) momentum = (1 - frac) * 0.85 + frac * 0.95 @@ -466,8 +468,9 @@ while True: muon_weight_decay = get_weight_decay(step) for group in optimizer.param_groups: group["lr"] = group["initial_lr"] * lrm - if group['kind'] == 'muon': + if group['kind'] in {'muon', 'hyperball'}: group["momentum"] = muon_momentum + if group['kind'] == 'muon': group["weight_decay"] = muon_weight_decay optimizer.step() model.zero_grad(set_to_none=True)