diff --git a/.gitignore b/.gitignore index 3e92824..3fa5754 100644 --- a/.gitignore +++ b/.gitignore @@ -9,5 +9,6 @@ eval_bundle/ .env # Local setup +cache CLAUDE.md wandb/ diff --git a/nanochat/fp8.py b/nanochat/fp8.py index 3e88285..713711a 100644 --- a/nanochat/fp8.py +++ b/nanochat/fp8.py @@ -190,6 +190,82 @@ class _Float8Matmul(torch.autograd.Function): return grad_input, grad_weight +@torch._dynamo.allow_in_graph +class _Float8MatmulND(torch.autograd.Function): + """FP8 matmul that handles N-D input tensors. + + Same as _Float8Matmul but accepts inputs of any shape (not just 2D). + Reshaping is done internally so torch.compile sees this as one opaque node, + preventing the reshaping overhead that occurs when reshapes are external. + + This is specifically for reparam_linear where N-D tensors are common. + """ + + @staticmethod + def forward(ctx, input, weight): + # Save original shape and flatten batch dimensions + orig_shape = input.shape + ctx.orig_shape = orig_shape + input_2d = input.reshape(-1, orig_shape[-1]) + ctx.save_for_backward(input_2d, weight) + + # Quantize and matmul (same as _Float8Matmul.forward) + input_fp8, input_inv = _to_fp8(input_2d, torch.float8_e4m3fn) + weight_fp8, weight_inv = _to_fp8(weight, torch.float8_e4m3fn) + output = torch._scaled_mm( + input_fp8, + weight_fp8.t(), + scale_a=input_inv, + scale_b=weight_inv, + out_dtype=input.dtype, + use_fast_accum=True, + ) + + # Reshape back to original batch dims + output = output.reshape(*orig_shape[:-1], output.shape[-1]) + return output + + @staticmethod + def backward(ctx, grad_output): + input_2d, weight = ctx.saved_tensors + orig_shape = ctx.orig_shape + + # Flatten grad_output to match input_2d + grad_output_flat = grad_output.reshape(-1, grad_output.shape[-1]) + + # === GEMM 1: grad_input = grad_output @ weight === + go_fp8, go_inv = _to_fp8(grad_output_flat, torch.float8_e5m2) + w_fp8, w_inv = _to_fp8(weight, torch.float8_e4m3fn) + w_col = _to_col_major(w_fp8) + grad_input_flat = torch._scaled_mm( + go_fp8, + w_col, + scale_a=go_inv, + scale_b=w_inv, + out_dtype=grad_output.dtype, + use_fast_accum=False, + ) + # Reshape back to original input shape + grad_input = grad_input_flat.reshape(orig_shape) + + # === GEMM 2: grad_weight = grad_output.T @ input === + go_fp8_2, go_inv_2 = _to_fp8(grad_output_flat, torch.float8_e5m2) + in_fp8, in_inv = _to_fp8(input_2d, torch.float8_e4m3fn) + go_T = go_fp8_2.t().contiguous() + in_col = _to_col_major(in_fp8) + grad_weight = torch._scaled_mm( + go_T, + in_col, + scale_a=go_inv_2, + scale_b=in_inv, + out_dtype=grad_output.dtype, + use_fast_accum=False, + ) + + return grad_input, grad_weight + + + class Float8Linear(nn.Linear): """Drop-in nn.Linear replacement that does FP8 compute. diff --git a/nanochat/gpt.py b/nanochat/gpt.py index 74e39fd..faf76f3 100644 --- a/nanochat/gpt.py +++ b/nanochat/gpt.py @@ -2,14 +2,16 @@ GPT model (rewrite, a lot simpler) Notable features: - rotary embeddings (and no positional embeddings) -- QK norm +- QK norm (functional RMSNorm, no learnable params) - untied weights for token embedding and lm_head - relu^2 activation in MLP -- norm after token embedding -- no learnable params in rmsnorm +- norm after token embedding (functional RMSNorm) +- parameterized RMSNorm in blocks (learnable gamma) +- per-block projection scalars for attention/MLP - no bias in linear layers - Group-Query Attention (GQA) support for more efficient inference - Flash Attention 3 integration +- optional Hyperball optimizer for matrix params """ from functools import partial @@ -25,6 +27,13 @@ from nanochat.optim import MuonAdamW, DistMuonAdamW # Our custom Flash Attention module that automatically uses FA3 on Hopper+ and SDPA fallback elsewhere from nanochat.flash_attention import flash_attn +# FP8 imports (optional) - minimal custom implementation +try: + from nanochat.fp8 import Float8Linear, _Float8MatmulND +except ImportError: + Float8Linear = None + _Float8MatmulND = None + @dataclass class GPTConfig: sequence_len: int = 2048 @@ -40,10 +49,37 @@ class GPTConfig: def norm(x): - # Purely functional rmsnorm with no learnable params + # Purely functional RMSNorm with no learnable params return F.rms_norm(x, (x.size(-1),)) +def reparam_linear(module, x, gamma=None, scalar=None): + """Linear with gamma/scalar folded into weight. Works with both nn.Linear and Float8Linear. + + gamma: RMSNorm learnable weight, folded into input dim of W (w = w * gamma[None, :]) + scalar: projection scalar, folded into output dim of W (w = scalar[:, None] * w) + + For FP8, uses minimal custom _Float8MatmulND which handles N-D tensors internally. + """ + w = module.weight + if gamma is not None: + w = w * gamma[None, :] + if scalar is not None: + w = scalar[:, None] * w + # FP8 path: use custom _Float8MatmulND for efficient N-D tensor handling + # (reshaping is done internally, so torch.compile sees it as one opaque operation) + if Float8Linear is not None and isinstance(module, Float8Linear): + # Handle autocast (Float8Linear expects this) + if torch.is_autocast_enabled(): + x = x.to(torch.get_autocast_gpu_dtype()) + output = _Float8MatmulND.apply(x, w) + if module.bias is not None: + output = output + module.bias.to(output.dtype) + return output + # BF16 path + return F.linear(x, w) + + def has_ve(layer_idx, n_layer): """Returns True if GPT layer should have Value Embedding (alternating, last layer always included).""" return layer_idx % 2 == (n_layer - 1) % 2 @@ -72,20 +108,22 @@ class CausalSelfAttention(nn.Module): self.c_proj = nn.Linear(self.n_embd, self.n_embd, bias=False) self.ve_gate_channels = 32 self.ve_gate = nn.Linear(self.ve_gate_channels, self.n_kv_head, bias=False) if has_ve(layer_idx, config.n_layer) else None + self.v_proj_scalar = nn.Parameter(torch.zeros(self.n_kv_head)) if has_ve(layer_idx, config.n_layer) else None + self.c_proj_scalar = nn.Parameter(torch.zeros(config.n_embd)) def forward(self, x, ve, cos_sin, window_size, kv_cache): B, T, C = x.size() - # Project the input to get queries, keys, and values + # Project the input to get queries, keys, and values (gamma folded into weights) # Shape: (B, T, H, D) - FA3's native layout, no transpose needed! - q = self.c_q(x).view(B, T, self.n_head, self.head_dim) - k = self.c_k(x).view(B, T, self.n_kv_head, self.head_dim) - v = self.c_v(x).view(B, T, self.n_kv_head, self.head_dim) + q = reparam_linear(self.c_q, x).view(B, T, self.n_head, self.head_dim) + k = reparam_linear(self.c_k, x).view(B, T, self.n_kv_head, self.head_dim) + v = reparam_linear(self.c_v, x).view(B, T, self.n_kv_head, self.head_dim) # Value residual (ResFormer): mix in value embedding with input-dependent gate per head if ve is not None: ve = ve.view(B, T, self.n_kv_head, self.head_dim) - gate = 2 * torch.sigmoid(self.ve_gate(x[..., :self.ve_gate_channels])) # (B, T, n_kv_head), range (0, 2) + gate = 2 * torch.sigmoid(reparam_linear(self.ve_gate, x[..., :self.ve_gate_channels], scalar=self.v_proj_scalar)) # (B, T, n_kv_head), range (0, 2) v = v + gate.unsqueeze(-1) * ve # Apply Rotary Embeddings to queries and keys to get relative positional encoding @@ -112,9 +150,9 @@ class CausalSelfAttention(nn.Module): if self.layer_idx == kv_cache.n_layers - 1: kv_cache.advance(T) - # Re-assemble the heads and project back to residual stream + # Re-assemble the heads and project back to residual stream (scalar folded into weight) y = y.contiguous().view(B, T, -1) - y = self.c_proj(y) + y = reparam_linear(self.c_proj, y, scalar=self.c_proj_scalar) return y @@ -123,11 +161,12 @@ class MLP(nn.Module): super().__init__() self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=False) self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=False) + self.c_proj_scalar = nn.Parameter(torch.zeros(config.n_embd)) def forward(self, x): - x = self.c_fc(x) + x = reparam_linear(self.c_fc, x) x = F.relu(x).square() - x = self.c_proj(x) + x = reparam_linear(self.c_proj, x, scalar=self.c_proj_scalar) return x @@ -196,9 +235,9 @@ class GPT(nn.Module): attn.c_q: uniform, std=1/sqrt(n_embd) attn.c_k: uniform, std=1/sqrt(n_embd) attn.c_v: uniform, std=1/sqrt(n_embd) - attn.c_proj: zeros + attn.c_proj: uniform, std=1/sqrt(n_embd) mlp.c_fc: uniform, std=1/sqrt(n_embd) - mlp.c_proj: zeros + mlp.c_proj: uniform, std=1/sqrt(n_embd) """ # Embedding and unembedding @@ -212,22 +251,35 @@ class GPT(nn.Module): torch.nn.init.uniform_(block.attn.c_q.weight, -s, s) # weights use Uniform to avoid outliers torch.nn.init.uniform_(block.attn.c_k.weight, -s, s) torch.nn.init.uniform_(block.attn.c_v.weight, -s, s) - torch.nn.init.zeros_(block.attn.c_proj.weight) # projections are zero + torch.nn.init.uniform_(block.attn.c_proj.weight, -s, s) torch.nn.init.uniform_(block.mlp.c_fc.weight, -s, s) - torch.nn.init.zeros_(block.mlp.c_proj.weight) + torch.nn.init.uniform_(block.mlp.c_proj.weight, -s, s) # Per-layer scalars self.resid_lambdas.fill_(1.0) # 1.0 => typical residual connections at init self.x0_lambdas.fill_(0.1) # 0.1 => small initial weight for skip connection to input embedding + # Per-block projection scalars (zero-init for stable training start) + for block in self.transformer.h: + block.attn.c_proj_scalar.fill_(0.0) + block.mlp.c_proj_scalar.fill_(0.0) + if block.attn.v_proj_scalar is not None: + block.attn.v_proj_scalar.fill_(0.0) + if self.transformer.wte.weight.device.type == "cuda": + block.attn.c_proj_scalar.data = block.attn.c_proj_scalar.data.to(torch.bfloat16) + block.mlp.c_proj_scalar.data = block.mlp.c_proj_scalar.data.to(torch.bfloat16) + if block.attn.v_proj_scalar is not None: + block.attn.v_proj_scalar.data = block.attn.v_proj_scalar.data.to(torch.bfloat16) + # Value embeddings (init like c_v: uniform with same std) for ve in self.value_embeds.values(): torch.nn.init.uniform_(ve.weight, -s, s) - # Gate weights init to zero so gates start at sigmoid(0) = 0.5, scaled by 2 -> 1.0 (neutral) + # Gate weights init to uniform (avoid zero-norm params under Hyperball, following mup) + s_ve_gate = 3 ** 0.5 * 32**-0.5 for block in self.transformer.h: if block.attn.ve_gate is not None: - torch.nn.init.zeros_(block.attn.ve_gate.weight) + torch.nn.init.uniform_(block.attn.ve_gate.weight, -s_ve_gate, s_ve_gate) # Rotary embeddings head_dim = self.config.n_embd // self.config.n_head @@ -302,10 +354,11 @@ class GPT(nn.Module): - Chinchilla counts exp/sum/divide in attention softmax as flops (a little sus and very tiny => we ignore) """ nparams = sum(p.numel() for p in self.parameters()) - # Exclude non-matmul params: embeddings and per-layer scalars + # Exclude non-matmul params: embeddings, per-layer scalars, and 1D params in blocks value_embeds_numel = sum(ve.weight.numel() for ve in self.value_embeds.values()) + block_1d_params = sum(p.numel() for p in self.transformer.h.parameters() if p.ndim == 1) nparams_exclude = (self.transformer.wte.weight.numel() + value_embeds_numel + - self.resid_lambdas.numel() + self.x0_lambdas.numel()) + self.resid_lambdas.numel() + self.x0_lambdas.numel() + block_1d_params) h, q, t = self.config.n_head, self.config.n_embd // self.config.n_head, self.config.sequence_len # Sum attention FLOPs per layer, accounting for sliding window attn_flops = 0 @@ -332,31 +385,36 @@ class GPT(nn.Module): wte = sum(p.numel() for p in self.transformer.wte.parameters()) value_embeds = sum(p.numel() for p in self.value_embeds.parameters()) lm_head = sum(p.numel() for p in self.lm_head.parameters()) - transformer_matrices = sum(p.numel() for p in self.transformer.h.parameters()) + block_matrix_params = sum(p.numel() for p in self.transformer.h.parameters() if p.ndim == 2) + block_1d_params = sum(p.numel() for p in self.transformer.h.parameters() if p.ndim == 1) scalars = self.resid_lambdas.numel() + self.x0_lambdas.numel() - total = wte + value_embeds + lm_head + transformer_matrices + scalars + total = wte + value_embeds + lm_head + block_matrix_params + block_1d_params + scalars assert total == sum(p.numel() for p in self.parameters()), "Parameter count mismatch" return { 'wte': wte, 'value_embeds': value_embeds, 'lm_head': lm_head, - 'transformer_matrices': transformer_matrices, + 'transformer_matrices': block_matrix_params, + 'norm_and_proj_scalars': block_1d_params, 'scalars': scalars, 'total': total, } - def setup_optimizer(self, unembedding_lr=0.004, embedding_lr=0.2, matrix_lr=0.02, weight_decay=0.0, adam_betas=(0.8, 0.95), scalar_lr=0.5): + def setup_optimizer(self, unembedding_lr=0.004, embedding_lr=0.2, matrix_lr=0.02, weight_decay=0.0, adam_betas=(0.8, 0.95), scalar_lr=0.5, norm_lr=0.1, matrix_optimizer="muon"): model_dim = self.config.n_embd ddp, rank, local_rank, world_size = get_dist_info() # Separate out all parameters into groups - matrix_params = list(self.transformer.h.parameters()) + block_matrix_params = [p for p in self.transformer.h.parameters() if p.ndim == 2] + block_1d_params = [p for p in self.transformer.h.parameters() if p.ndim == 1] value_embeds_params = list(self.value_embeds.parameters()) embedding_params = list(self.transformer.wte.parameters()) lm_head_params = list(self.lm_head.parameters()) resid_params = [self.resid_lambdas] x0_params = [self.x0_lambdas] - assert len(list(self.parameters())) == len(matrix_params) + len(embedding_params) + len(lm_head_params) + len(value_embeds_params) + len(resid_params) + len(x0_params) + all_params_count = (len(block_matrix_params) + len(block_1d_params) + len(embedding_params) + + len(lm_head_params) + len(value_embeds_params) + len(resid_params) + len(x0_params)) + assert len(list(self.parameters())) == all_params_count # Scale the LR for the AdamW parameters by ∝1/√dmodel (tuned for 768 dim model) dmodel_lr_scale = (model_dim / 768) ** -0.5 @@ -370,14 +428,23 @@ class GPT(nn.Module): dict(kind='adamw', params=value_embeds_params, lr=embedding_lr * dmodel_lr_scale, betas=adam_betas, eps=1e-10, weight_decay=0.0), dict(kind='adamw', params=resid_params, lr=scalar_lr * 0.01, betas=adam_betas, eps=1e-10, weight_decay=0.0), dict(kind='adamw', params=x0_params, lr=scalar_lr, betas=(0.96, 0.95), eps=1e-10, weight_decay=0.0), # higher beta1 for x0 + dict(kind='adamw', params=block_1d_params, lr=norm_lr, betas=adam_betas, eps=1e-10, weight_decay=0.0), ] - # Muon groups (matrix params, grouped by shape for stacking) - for shape in sorted({p.shape for p in matrix_params}): - group_params = [p for p in matrix_params if p.shape == shape] - param_groups.append(dict( - kind='muon', params=group_params, lr=matrix_lr, - momentum=0.95, ns_steps=5, beta2=0.95, weight_decay=weight_decay, - )) + # Matrix params (Muon or Hyperball), grouped by shape for stacking + if matrix_optimizer not in {"muon", "hyperball"}: + raise ValueError(f"Unknown matrix_optimizer: {matrix_optimizer}") + for shape in sorted({p.shape for p in block_matrix_params}): + group_params = [p for p in block_matrix_params if p.shape == shape] + if matrix_optimizer == "hyperball": + param_groups.append(dict( + kind='hyperball', params=group_params, lr=matrix_lr, + momentum=0.95, ns_steps=5, beta2=0.95, + )) + else: + param_groups.append(dict( + kind='muon', params=group_params, lr=matrix_lr, + momentum=0.95, ns_steps=5, beta2=0.95, weight_decay=weight_decay, + )) Factory = DistMuonAdamW if ddp else MuonAdamW optimizer = Factory(param_groups) diff --git a/nanochat/optim.py b/nanochat/optim.py index 42d862b..a44534a 100644 --- a/nanochat/optim.py +++ b/nanochat/optim.py @@ -1,6 +1,6 @@ """ A nice and efficient mixed AdamW/Muon Combined Optimizer. -Usually the embeddings and scalars go into AdamW, and the matrix parameters go into Muon. +Usually the embeddings and scalars go into AdamW, and the matrix parameters go into Muon/Hyperball. Two versions are provided (MuonAdamW, DistMuonAdamW), for single GPU and distributed. Addapted from: https://github.com/KellerJordan/modded-nanogpt @@ -145,6 +145,80 @@ def muon_step_fused( mask = (g * stacked_params) >= 0 stacked_params.sub_(lr * g + lr * wd * stacked_params * mask) +# ----------------------------------------------------------------------------- +""" +Hyperball optimizer (MuonH): Muon with scale-invariant updates. +https://github.com/marin-community/marin/blob/main/lib/levanter/src/levanter/optim/muonh.py + +The key difference from Muon is that weights maintain constant Frobenius norm +throughout training via the following update rule: + + p_new_intermediate = p - lr * u * ||p|| / ||u|| + p_new = p_new_intermediate / ||p_new_intermediate|| * ||p|| + +This projects the update onto the hypersphere of constant norm, hence "Hyperball". +Uses variance reduction like Muon, but no cautious weight decay. +""" + +@torch.compile(dynamic=False, fullgraph=True) +def hyperball_step_fused( + stacked_grads: Tensor, # (K, M, N) - stacked gradients + stacked_params: Tensor, # (K, M, N) - stacked parameters + momentum_buffer: Tensor, # (K, M, N) - momentum buffer + second_momentum_buffer: Tensor, # (K, M, 1) or (K, 1, N) - factored second moment + p_norm: Tensor, # (K, 1, 1) - pre-computed Frobenius norm of params (constant) + momentum_t: Tensor, # () - 0-D CPU tensor, momentum coefficient + lr_t: Tensor, # () - 0-D CPU tensor, learning rate + beta2_t: Tensor, # () - 0-D CPU tensor, beta2 for second moment + ns_steps: int, # 5 - number of Newton-Schulz/Polar Express iterations + red_dim: int, # -1 or -2 - reduction dimension for variance +) -> None: + """ + Fused Hyperball step: momentum -> polar_express -> variance_reduction -> scale_invariant_update + All in one compiled graph. Weights maintain constant Frobenius norm. + """ + + # Nesterov momentum + momentum = momentum_t.to(stacked_grads.dtype) + momentum_buffer.lerp_(stacked_grads, 1 - momentum) + g = stacked_grads.lerp_(momentum_buffer, momentum) + + # Polar express orthogonalization + X = g.bfloat16() + X = X / (X.norm(dim=(-2, -1), keepdim=True) * 1.02 + 1e-6) + if g.size(-2) > g.size(-1): # Tall matrix + for a, b, c in polar_express_coeffs[:ns_steps]: + A = X.mT @ X + B = b * A + c * (A @ A) + X = a * X + X @ B + else: # Wide matrix + for a, b, c in polar_express_coeffs[:ns_steps]: + A = X @ X.mT + B = b * A + c * (A @ A) + X = a * X + B @ X + g = X + + # Variance reduction (note: this preserves ||g||_F, so ||u||_F == ||g||_F == v_norm) + beta2 = beta2_t.to(g.dtype) + v_mean = g.float().square().mean(dim=red_dim, keepdim=True) + red_dim_size = g.size(red_dim) + second_momentum_buffer.lerp_(v_mean.to(dtype=second_momentum_buffer.dtype), 1 - beta2) + step_size = second_momentum_buffer.clamp_min(1e-10).rsqrt() + scaled_sq_sum = (v_mean * red_dim_size) * step_size.float().square() + v_norm_new = scaled_sq_sum.sum(dim=(-2, -1), keepdim=True).sqrt() + p_norm = p_norm.to(v_norm_new.dtype) + final_scale = step_size * p_norm / v_norm_new.clamp_min(1e-10) + g = g * final_scale.to(g.dtype) + u = g.to(stacked_params.dtype) + + # Scale-invariant update: keeps ||p|| constant + lr = lr_t.to(stacked_params.dtype) + stacked_params.sub_(lr * u) + + # Project back to hypersphere: p = p * (||p_orig|| / ||p_new||) + p_new_norm = stacked_params.norm(dim=(-2, -1), keepdim=True).clamp_min(1e-10) + stacked_params.mul_(p_norm / p_new_norm) + # ----------------------------------------------------------------------------- # Single GPU version of the MuonAdamW optimizer. # Used mostly for reference, debugging and testing. @@ -171,9 +245,10 @@ class MuonAdamW(torch.optim.Optimizer): Arguments: param_groups: List of dicts, each containing: - 'params': List of parameters - - 'kind': 'adamw' or 'muon' + - 'kind': 'adamw', 'muon', or 'hyperball' - For AdamW groups: 'lr', 'betas', 'eps', 'weight_decay' - For Muon groups: 'lr', 'momentum', 'ns_steps', 'beta2', 'weight_decay' + - For Hyperball groups: 'lr', 'momentum', 'ns_steps', 'beta2' """ def __init__(self, param_groups: list[dict]): super().__init__(param_groups, defaults={}) @@ -190,6 +265,10 @@ class MuonAdamW(torch.optim.Optimizer): self._muon_lr_t = torch.tensor(0.0, dtype=torch.float32, device="cpu") self._muon_wd_t = torch.tensor(0.0, dtype=torch.float32, device="cpu") self._muon_beta2_t = torch.tensor(0.0, dtype=torch.float32, device="cpu") + # Hyperball tensors + self._hyperball_momentum_t = torch.tensor(0.0, dtype=torch.float32, device="cpu") + self._hyperball_lr_t = torch.tensor(0.0, dtype=torch.float32, device="cpu") + self._hyperball_beta2_t = torch.tensor(0.0, dtype=torch.float32, device="cpu") def _step_adamw(self, group: dict) -> None: """ @@ -280,6 +359,64 @@ class MuonAdamW(torch.optim.Optimizer): # Copy back to original params torch._foreach_copy_(params, list(stacked_params.unbind(0))) + def _step_hyperball(self, group: dict) -> None: + """ + Hyperball update for all params in the group (stacked for efficiency). + Like Muon, but uses scale-invariant updates that keep weight norms constant. + """ + params: list[Tensor] = group['params'] + if not params: + return + + # Get or create group-level buffers (stored in first param's state for convenience) + p = params[0] + state = self.state[p] + num_params = len(params) + shape, device, dtype = p.shape, p.device, p.dtype + + # Stack grads and params (NOTE: this assumes all params have the same shape) + stacked_grads = torch.stack([p.grad for p in params]) + stacked_params = torch.stack(params) + + # Momentum buffer for every individual parameter + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros(num_params, *shape, dtype=dtype, device=device) + momentum_buffer = state["momentum_buffer"] + + # Second momentum buffer is factored, either per-row or per-column + if "second_momentum_buffer" not in state: + state_shape = (num_params, shape[-2], 1) if shape[-2] >= shape[-1] else (num_params, 1, shape[-1]) + state["second_momentum_buffer"] = torch.zeros(state_shape, dtype=dtype, device=device) + second_momentum_buffer = state["second_momentum_buffer"] + red_dim = -1 if shape[-2] >= shape[-1] else -2 + + # Pre-compute and cache p_norm (Frobenius norm of each param, constant throughout training) + if "p_norm" not in state: + state["p_norm"] = stacked_params.norm(dim=(-2, -1), keepdim=True).clone() + p_norm = state["p_norm"] + + # Fill all the 0-D tensors with current values + self._hyperball_momentum_t.fill_(group["momentum"]) + self._hyperball_lr_t.fill_(group["lr"]) + self._hyperball_beta2_t.fill_(group["beta2"] if group["beta2"] is not None else 0.0) + + # Single fused kernel: momentum -> polar_express -> variance_reduction -> scale_invariant_update + hyperball_step_fused( + stacked_grads, + stacked_params, + momentum_buffer, + second_momentum_buffer, + p_norm, + self._hyperball_momentum_t, + self._hyperball_lr_t, + self._hyperball_beta2_t, + group["ns_steps"], + red_dim, + ) + + # Copy back to original params + torch._foreach_copy_(params, list(stacked_params.unbind(0))) + @torch.no_grad() def step(self): for group in self.param_groups: @@ -287,6 +424,8 @@ class MuonAdamW(torch.optim.Optimizer): self._step_adamw(group) elif group['kind'] == 'muon': self._step_muon(group) + elif group['kind'] == 'hyperball': + self._step_hyperball(group) else: raise ValueError(f"Unknown optimizer kind: {group['kind']}") @@ -348,9 +487,10 @@ class DistMuonAdamW(torch.optim.Optimizer): Arguments: param_groups: List of dicts, each containing: - 'params': List of parameters - - 'kind': 'adamw' or 'muon' + - 'kind': 'adamw', 'muon', or 'hyperball' - For AdamW groups: 'lr', 'betas', 'eps', 'weight_decay' - For Muon groups: 'lr', 'momentum', 'ns_steps', 'beta2', 'weight_decay' + - For Hyperball groups: 'lr', 'momentum', 'ns_steps', 'beta2' """ def __init__(self, param_groups: list[dict]): super().__init__(param_groups, defaults={}) @@ -365,6 +505,10 @@ class DistMuonAdamW(torch.optim.Optimizer): self._muon_lr_t = torch.tensor(0.0, dtype=torch.float32, device="cpu") self._muon_wd_t = torch.tensor(0.0, dtype=torch.float32, device="cpu") self._muon_beta2_t = torch.tensor(0.0, dtype=torch.float32, device="cpu") + # Hyperball tensors + self._hyperball_momentum_t = torch.tensor(0.0, dtype=torch.float32, device="cpu") + self._hyperball_lr_t = torch.tensor(0.0, dtype=torch.float32, device="cpu") + self._hyperball_beta2_t = torch.tensor(0.0, dtype=torch.float32, device="cpu") def _reduce_adamw(self, group: dict, world_size: int) -> dict: """Launch async reduce ops for AdamW group. Returns info dict with per-param infos.""" @@ -496,6 +640,85 @@ class DistMuonAdamW(torch.optim.Optimizer): future = dist.all_gather_into_tensor(stacked_params, updated_params, async_op=True).get_future() gather_list.append(dict(future=future, stacked_params=stacked_params, params=params)) + def _reduce_hyperball(self, group: dict, world_size: int) -> dict: + """Launch async reduce op for Hyperball group. Returns info dict.""" + params = group['params'] + chunk_size = (len(params) + world_size - 1) // world_size + padded_num_params = chunk_size * world_size + p = params[0] + shape, device, dtype = p.shape, p.device, p.dtype + + # Stack grads and zero-pad to padded_num_params + grad_stack = torch.stack([p.grad for p in params]) + stacked_grads = torch.empty(padded_num_params, *shape, dtype=dtype, device=device) + stacked_grads[:len(params)].copy_(grad_stack) + if len(params) < padded_num_params: + stacked_grads[len(params):].zero_() + + # Reduce_scatter to get this rank's chunk + grad_chunk = torch.empty(chunk_size, *shape, dtype=dtype, device=device) + future = dist.reduce_scatter_tensor(grad_chunk, stacked_grads, op=dist.ReduceOp.AVG, async_op=True).get_future() + + return dict(future=future, grad_chunk=grad_chunk, stacked_grads=stacked_grads, chunk_size=chunk_size) + + def _compute_hyperball(self, group: dict, info: dict, gather_list: list, rank: int) -> None: + """Wait for reduce, compute Hyperball updates, launch gather.""" + info['future'].wait() + params = group['params'] + chunk_size = info['chunk_size'] + grad_chunk = info['grad_chunk'] + p = params[0] + shape, device, dtype = p.shape, p.device, p.dtype + + # How many params does this rank own? + start_idx = rank * chunk_size + num_owned = min(chunk_size, max(0, len(params) - start_idx)) + + # Get or create group-level state + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros(chunk_size, *shape, dtype=dtype, device=device) + if "second_momentum_buffer" not in state: + state_shape = (chunk_size, shape[-2], 1) if shape[-2] >= shape[-1] else (chunk_size, 1, shape[-1]) + state["second_momentum_buffer"] = torch.zeros(state_shape, dtype=dtype, device=device) + red_dim = -1 if shape[-2] >= shape[-1] else -2 + + # Build output buffer for all_gather + updated_params = torch.empty(chunk_size, *shape, dtype=dtype, device=device) + + if num_owned > 0: + owned_params = [params[start_idx + i] for i in range(num_owned)] + stacked_owned = torch.stack(owned_params) + + # Pre-compute and cache p_norm for owned params (constant throughout training) + if "p_norm" not in state: + state["p_norm"] = stacked_owned.norm(dim=(-2, -1), keepdim=True).clone() + p_norm = state["p_norm"] + + # Fill 0-D tensors and run fused kernel + self._hyperball_momentum_t.fill_(group["momentum"]) + self._hyperball_lr_t.fill_(group["lr"]) + self._hyperball_beta2_t.fill_(group["beta2"] if group["beta2"] is not None else 0.0) + hyperball_step_fused( + grad_chunk[:num_owned], stacked_owned, + state["momentum_buffer"][:num_owned], + state["second_momentum_buffer"][:num_owned], + p_norm, + self._hyperball_momentum_t, self._hyperball_lr_t, + self._hyperball_beta2_t, + group["ns_steps"], + red_dim, + ) + updated_params[:num_owned].copy_(stacked_owned) + + if num_owned < chunk_size: + updated_params[num_owned:].zero_() + + # Reuse stacked_grads buffer for all_gather output + stacked_params = info["stacked_grads"] + future = dist.all_gather_into_tensor(stacked_params, updated_params, async_op=True).get_future() + gather_list.append(dict(future=future, stacked_params=stacked_params, params=params)) + def _finish_gathers(self, gather_list: list) -> None: """Wait for all gathers and copy Muon params back.""" for info in gather_list: @@ -516,6 +739,8 @@ class DistMuonAdamW(torch.optim.Optimizer): reduce_infos.append(self._reduce_adamw(group, world_size)) elif group['kind'] == 'muon': reduce_infos.append(self._reduce_muon(group, world_size)) + elif group['kind'] == 'hyperball': + reduce_infos.append(self._reduce_hyperball(group, world_size)) else: raise ValueError(f"Unknown optimizer kind: {group['kind']}") @@ -526,6 +751,8 @@ class DistMuonAdamW(torch.optim.Optimizer): self._compute_adamw(group, info, gather_list, rank, world_size) elif group['kind'] == 'muon': self._compute_muon(group, info, gather_list, rank) + elif group['kind'] == 'hyperball': + self._compute_hyperball(group, info, gather_list, rank) else: raise ValueError(f"Unknown optimizer kind: {group['kind']}") diff --git a/runs/quickrun_muonh.sh b/runs/quickrun_muonh.sh new file mode 100755 index 0000000..9c7a14e --- /dev/null +++ b/runs/quickrun_muonh.sh @@ -0,0 +1,164 @@ +#!/bin/bash + +# Quickrun: GPT-Gamma + MuonH (Hyperball) +# - Parameterized RMSNorm (learnable gamma) +# - Per-block projection scalars +# - Hyperball or Muon for matrix params +# +# Examples: +# bash runs/quickrun_muonh.sh +# WANDB_RUN=exp1 bash runs/quickrun_muonh.sh +# FP8=1 FP8_RECIPE=tensorwise bash runs/quickrun_muonh.sh +# DEPTH=16 bash runs/quickrun_muonh.sh + +set -e + +# ----------------------------------------------------------------------------- +# Config + +DEPTH="${DEPTH:-24}" +NUM_SHARDS="${NUM_SHARDS:-370}" # default for d24 @ ratio~11 +TARGET_RATIO="${TARGET_RATIO:-12}" +WINDOW_PATTERN="${WINDOW_PATTERN:-SSSL}" +DEVICE_BATCH_SIZE="${DEVICE_BATCH_SIZE:-16}" +TOTAL_BATCH_SIZE="${TOTAL_BATCH_SIZE:-524288}" # -1 = auto-compute optimal (Power Lines paper) + +NPROC_PER_NODE="${NPROC_PER_NODE:-$(nvidia-smi -L 2>/dev/null | wc -l || echo 1)}" +if [ "$NPROC_PER_NODE" -eq 0 ]; then + NPROC_PER_NODE=1 +fi + +# Optimizer +MATRIX_OPTIMIZER="${MATRIX_OPTIMIZER:-hyperball}" +SCALAR_LR="${SCALAR_LR:-0.5}" +MATRIX_LR="${MATRIX_LR:-0.02}" +WARMDOWN_RATIO="${WARMDOWN_RATIO:-0.3}" +MATRIX_WARMDOWN_RATIO="${MATRIX_WARMDOWN_RATIO:-1.0}" + +# AdamW +EMBEDDING_LR="${EMBEDDING_LR:-0.3}" +UNEMBEDDING_LR="${UNEMBEDDING_LR:-0.004}" +NORM_LR="${NORM_LR:-0.1}" + +# Wandb +export WANDB_ENTITY="${WANDB_ENTITY:-xingyu20}" +export WANDB_PROJECT="${WANDB_PROJECT:-nanochat}" +WANDB_RUN="${WANDB_RUN:-muonh_d${DEPTH}_ratio${TARGET_RATIO}_feb_11_no_gamma}" +MODEL_TAG="${MODEL_TAG:-d${DEPTH}_gamma_muonh}" + +# FP8 (default enabled)c +FP8="${FP8:-1}" +FP8_ARGS="" +if [ "${FP8:-0}" -eq 1 ]; then + FP8_RECIPE="${FP8_RECIPE:-tensorwise}" + FP8_ARGS="--fp8 --fp8-recipe=${FP8_RECIPE}" +fi + +# ----------------------------------------------------------------------------- +# Paths and cache + +export OMP_NUM_THREADS=1 +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +PROJECT_ROOT="$(dirname "$SCRIPT_DIR")" +export NANOCHAT_BASE_DIR="$PROJECT_ROOT/cache" +export TORCHINDUCTOR_CACHE_DIR="$NANOCHAT_BASE_DIR/torch_inductor" +export TRITON_CACHE_DIR="$NANOCHAT_BASE_DIR/triton" +export TMPDIR="$NANOCHAT_BASE_DIR/tmp" +mkdir -p "$NANOCHAT_BASE_DIR" "$TORCHINDUCTOR_CACHE_DIR" "$TRITON_CACHE_DIR" "$TMPDIR" + +# ----------------------------------------------------------------------------- +# Print summary + +echo "==============================================" +echo "Quickrun (GPT-Gamma + MuonH)" +echo "==============================================" +echo "Project root: $PROJECT_ROOT" +echo "Cache dir: $NANOCHAT_BASE_DIR" +echo "Depth: $DEPTH" +echo "Num shards: $NUM_SHARDS" +echo "Target ratio: $TARGET_RATIO" +echo "Window pattern: $WINDOW_PATTERN" +echo "Num GPUs: $NPROC_PER_NODE" +echo "Device batch size: $DEVICE_BATCH_SIZE" +echo "Total batch size: $TOTAL_BATCH_SIZE" +echo "Matrix optimizer: $MATRIX_OPTIMIZER" +echo "Matrix LR: $MATRIX_LR" +echo "Norm LR: $NORM_LR" +echo "Adam LRs: embedding=$EMBEDDING_LR, unembedding=$UNEMBEDDING_LR, scalar=$SCALAR_LR" +echo "Warmdown ratio: adam=$WARMDOWN_RATIO, matrix=$MATRIX_WARMDOWN_RATIO" +echo "Wandb run: $WANDB_RUN" +echo "Model tag: $MODEL_TAG" +if [ "${FP8:-0}" -eq 1 ]; then + echo "FP8: enabled ($FP8_RECIPE)" +fi +echo "==============================================" + +cd "$PROJECT_ROOT" + +# ----------------------------------------------------------------------------- +# Python venv + +if [ ! -d ".venv" ]; then + echo "Setting up Python environment..." + command -v uv &> /dev/null || curl -LsSf https://astral.sh/uv/install.sh | sh + uv venv + uv sync --extra gpu +fi +source .venv/bin/activate + +# ----------------------------------------------------------------------------- +# Data + tokenizer + +echo "" +echo "Downloading $NUM_SHARDS data shards..." +python -m nanochat.dataset -n "$NUM_SHARDS" + +echo "" +TOKENIZER_DIR="$NANOCHAT_BASE_DIR/tokenizer" +if [ -f "$TOKENIZER_DIR/token_bytes.pt" ]; then + echo "Tokenizer already exists at $TOKENIZER_DIR, skipping training." +else + echo "Training tokenizer..." + python -m scripts.tok_train --max-chars=500000000 --vocab-size=32768 +fi + +# ----------------------------------------------------------------------------- +# Train + +echo "" +echo "Starting pretraining (depth=$DEPTH)..." + +TRAIN_ARGS=( + --depth=$DEPTH + --run=$WANDB_RUN + --model-tag=$MODEL_TAG + --window-pattern=$WINDOW_PATTERN + --target-param-data-ratio=$TARGET_RATIO + --device-batch-size=$DEVICE_BATCH_SIZE + --total-batch-size=$TOTAL_BATCH_SIZE + --matrix-optimizer=$MATRIX_OPTIMIZER + --matrix-lr=$MATRIX_LR + --warmdown-ratio=$WARMDOWN_RATIO + --matrix-warmdown-ratio=$MATRIX_WARMDOWN_RATIO + --embedding-lr=$EMBEDDING_LR + --unembedding-lr=$UNEMBEDDING_LR + --norm-lr=$NORM_LR + --scalar-lr=$SCALAR_LR + --core-metric-every=${CORE_METRIC_EVERY:-2000} + --sample-every=${SAMPLE_EVERY:--1} + --save-every=${SAVE_EVERY:--1} +) + +if [ "$NPROC_PER_NODE" -gt 1 ]; then + torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.base_train -- \ + "${TRAIN_ARGS[@]}" $FP8_ARGS +else + python -m scripts.base_train \ + "${TRAIN_ARGS[@]}" $FP8_ARGS +fi + +echo "" +echo "==============================================" +echo "Training complete!" +echo "==============================================" +echo "Checkpoint saved to: $NANOCHAT_BASE_DIR/base_checkpoints/${MODEL_TAG}/" \ No newline at end of file diff --git a/scripts/base_train.py b/scripts/base_train.py index 24091b6..06d0925 100644 --- a/scripts/base_train.py +++ b/scripts/base_train.py @@ -43,7 +43,7 @@ parser.add_argument("--run", type=str, default="dummy", help="wandb run name ('d # Runtime parser.add_argument("--device-type", type=str, default="", help="cuda|cpu|mps (empty = autodetect)") # FP8 training -parser.add_argument("--fp8", action="store_true", help="enable FP8 training (requires H100+ GPU and torchao)") +parser.add_argument("--fp8", action="store_true", help="enable FP8 training (requires H100+ GPU); uses minimal custom fp8 module") parser.add_argument("--fp8-recipe", type=str, default="tensorwise", choices=["rowwise", "tensorwise"], help="FP8 scaling recipe: tensorwise (faster, recommended) or rowwise (more accurate but slower)") # Model architecture parser.add_argument("--depth", type=int, default=20, help="depth of the Transformer model") @@ -61,12 +61,16 @@ parser.add_argument("--total-batch-size", type=int, default=-1, help="total batc parser.add_argument("--embedding-lr", type=float, default=0.3, help="learning rate for embedding parameters (Adam)") parser.add_argument("--unembedding-lr", type=float, default=0.004, help="learning rate for unembedding parameters (Adam)") parser.add_argument("--weight-decay", type=float, default=0.2, help="cautious weight decay for the Muon optimizer (for weights)") -parser.add_argument("--matrix-lr", type=float, default=0.02, help="learning rate for matrix parameters (Muon)") +parser.add_argument("--matrix-lr", type=float, default=0.02, help="learning rate for matrix parameters (Muon/Hyperball)") +parser.add_argument("--matrix-optimizer", type=str, default="muon", choices=["muon", "hyperball"], help="optimizer for matrix parameters") parser.add_argument("--scalar-lr", type=float, default=0.5, help="learning rate for scalars (resid_lambdas, x0_lambdas)") +parser.add_argument("--norm-lr", type=float, default=0.1, help="learning rate for norm/gamma parameters") parser.add_argument("--adam-beta1", type=float, default=0.8, help="Adam beta1 for embedding/unembedding") parser.add_argument("--adam-beta2", type=float, default=0.95, help="Adam beta2 for embedding/unembedding") parser.add_argument("--warmup-ratio", type=float, default=0.0, help="ratio of iterations for LR warmup") -parser.add_argument("--warmdown-ratio", type=float, default=0.5, help="ratio of iterations for LR warmdown") +parser.add_argument("--warmdown-ratio", type=float, default=0.5, help="ratio of iterations for AdamW LR warmdown") +parser.add_argument("--matrix-warmup-ratio", type=float, default=0.0, help="ratio of iterations for Muon/Hyperball LR warmup") +parser.add_argument("--matrix-warmdown-ratio", type=float, default=1.0, help="ratio of iterations for Muon/Hyperball LR warmdown") parser.add_argument("--final-lr-frac", type=float, default=0.0, help="final LR as fraction of initial LR") parser.add_argument("--resume-from-step", type=int, default=-1, help="resume training from this step (-1 = disable)") # Evaluation @@ -98,7 +102,8 @@ else: # wandb logging init use_dummy_wandb = args.run == "dummy" or not master_process -wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat", name=args.run, config=user_config) +wandb_project = os.environ.get("WANDB_PROJECT", "nanochat") +wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project=wandb_project, name=args.run, config=user_config) # Flash Attention status if HAS_FA3: @@ -300,15 +305,26 @@ if weight_decay_scaled != args.weight_decay: # ----------------------------------------------------------------------------- # Initialize the Optimizer (combined MuonAdamW: Muon for matrix params, AdamW for rest) +matrix_lr_scaled = args.matrix_lr * batch_lr_scale + +# LR data scaling for Hyperball +# We keep the same D_REF here +if args.matrix_optimizer == "hyperball": + D_REF_LR = 10.5 * get_scaling_params(d12_ref) + matrix_lr_scaled = matrix_lr_scaled * (D_REF_LR / target_tokens) ** 0.35 # 0.35 is the exponent for the power law fit by ourselves + print0(f"Scaling hyperball LR from {args.matrix_lr * batch_lr_scale:.6f} to {matrix_lr_scaled:.6f} for token ratio {target_tokens / D_REF:.2f} (T_train = {target_tokens:,} tokens)") + optimizer = model.setup_optimizer( # AdamW hyperparameters unembedding_lr=args.unembedding_lr * batch_lr_scale, embedding_lr=args.embedding_lr * batch_lr_scale, scalar_lr=args.scalar_lr * batch_lr_scale, adam_betas=(args.adam_beta1, args.adam_beta2), - # Muon hyperparameters - matrix_lr=args.matrix_lr * batch_lr_scale, + norm_lr=args.norm_lr * batch_lr_scale, + # Muon/Hyperball hyperparameters + matrix_lr=matrix_lr_scaled, weight_decay=weight_decay_scaled, + matrix_optimizer=args.matrix_optimizer, ) if resuming: @@ -346,19 +362,20 @@ print0(f"Total number of training tokens: {total_tokens:,}") print0(f"Tokens : Scaling params ratio: {total_batch_size * num_iterations / num_scaling_params:.2f}") # e.g. Chinchilla was ~20 print0(f"Total training FLOPs estimate: {num_flops_per_token * total_tokens:e}") -# Learning rate schedule (linear warmup, constant, linear warmdown) -def get_lr_multiplier(it): - warmup_iters = round(args.warmup_ratio * num_iterations) - warmdown_iters = round(args.warmdown_ratio * num_iterations) - if it < warmup_iters: +# Learning rate scheduler (warmup + warmdown, parameterized for separate adam/matrix schedules) +def get_lr_multiplier(it, warmup_ratio, warmdown_ratio, final_lr_frac): + warmup_iters = round(warmup_ratio * num_iterations) + warmdown_iters = round(warmdown_ratio * num_iterations) + if warmup_iters > 0 and it < warmup_iters: return (it + 1) / warmup_iters - elif it <= num_iterations - warmdown_iters: + if warmdown_iters <= 0: return 1.0 - else: - progress = (num_iterations - it) / warmdown_iters - return progress * 1.0 + (1 - progress) * args.final_lr_frac + if it <= num_iterations - warmdown_iters: + return 1.0 + progress = (num_iterations - it) / warmdown_iters + return progress * 1.0 + (1 - progress) * final_lr_frac -# Momentum scheduler for Muon optimizer (warms up to 0.95 over the first 300 steps) +# Momentum scheduler for matrix optimizer (Muon/Hyperball) def get_muon_momentum(it): frac = min(it / 300, 1) momentum = (1 - frac) * 0.85 + frac * 0.95 @@ -498,13 +515,18 @@ while True: loss.backward() x, y, dataloader_state_dict = next(train_loader) # prefetch the next batch while the GPU is busy with forward/backward # step the optimizer - lrm = get_lr_multiplier(step) + lrm_adam = get_lr_multiplier(step, args.warmup_ratio, args.warmdown_ratio, args.final_lr_frac) + lrm_matrix = get_lr_multiplier(step, args.matrix_warmup_ratio, args.matrix_warmdown_ratio, args.final_lr_frac) muon_momentum = get_muon_momentum(step) muon_weight_decay = get_weight_decay(step) for group in optimizer.param_groups: - group["lr"] = group["initial_lr"] * lrm - if group['kind'] == 'muon': + if group['kind'] in {'muon', 'hyperball'}: + group["lr"] = group["initial_lr"] * lrm_matrix + else: + group["lr"] = group["initial_lr"] * lrm_adam + if group['kind'] in {'muon', 'hyperball'}: group["momentum"] = muon_momentum + if group['kind'] == 'muon': group["weight_decay"] = muon_weight_decay optimizer.step() model.zero_grad(set_to_none=True) @@ -534,14 +556,15 @@ while True: else: eta_str = "" epoch = dataloader_state_dict["epoch"] - print0(f"step {step:05d}/{num_iterations:05d} ({pct_done:.2f}%) | loss: {debiased_smooth_loss:.6f} | lrm: {lrm:.2f} | dt: {dt * 1000:.2f}ms | tok/sec: {tok_per_sec:,} | bf16_mfu: {mfu:.2f} | epoch: {epoch} | total time: {total_training_time/60:.2f}m{eta_str}") + print0(f"step {step:05d}/{num_iterations:05d} ({pct_done:.2f}%) | loss: {debiased_smooth_loss:.6f} | lrm(adam)={lrm_adam:.2f}, lrm(matrix)={lrm_matrix:.2f} | dt: {dt * 1000:.2f}ms | tok/sec: {tok_per_sec:,} | bf16_mfu: {mfu:.2f} | epoch: {epoch} | total time: {total_training_time/60:.2f}m{eta_str}") if step % 100 == 0: log_data = { "step": step, "total_training_flops": flops_so_far, "total_training_time": total_training_time, "train/loss": debiased_smooth_loss, - "train/lrm": lrm, + "train/lrm_adam": lrm_adam, + "train/lrm_matrix": lrm_matrix, "train/dt": dt, "train/tok_per_sec": tok_per_sec, "train/mfu": mfu, @@ -582,6 +605,8 @@ get_report().log(section="Base model training", data=[ "DDP world size": ddp_world_size, "warmup_ratio": args.warmup_ratio, "warmdown_ratio": args.warmdown_ratio, + "matrix_warmup_ratio": args.matrix_warmup_ratio, + "matrix_warmdown_ratio": args.matrix_warmdown_ratio, "final_lr_frac": args.final_lr_frac, }, { # stats about training outcomes