This commit is contained in:
Xingyu Dang 2026-03-03 18:33:27 -03:00 committed by GitHub
commit 91af7ac7e3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 618 additions and 58 deletions

1
.gitignore vendored
View File

@ -9,5 +9,6 @@ eval_bundle/
.env
# Local setup
cache
CLAUDE.md
wandb/

View File

@ -190,6 +190,82 @@ class _Float8Matmul(torch.autograd.Function):
return grad_input, grad_weight
@torch._dynamo.allow_in_graph
class _Float8MatmulND(torch.autograd.Function):
"""FP8 matmul that handles N-D input tensors.
Same as _Float8Matmul but accepts inputs of any shape (not just 2D).
Reshaping is done internally so torch.compile sees this as one opaque node,
preventing the reshaping overhead that occurs when reshapes are external.
This is specifically for reparam_linear where N-D tensors are common.
"""
@staticmethod
def forward(ctx, input, weight):
# Save original shape and flatten batch dimensions
orig_shape = input.shape
ctx.orig_shape = orig_shape
input_2d = input.reshape(-1, orig_shape[-1])
ctx.save_for_backward(input_2d, weight)
# Quantize and matmul (same as _Float8Matmul.forward)
input_fp8, input_inv = _to_fp8(input_2d, torch.float8_e4m3fn)
weight_fp8, weight_inv = _to_fp8(weight, torch.float8_e4m3fn)
output = torch._scaled_mm(
input_fp8,
weight_fp8.t(),
scale_a=input_inv,
scale_b=weight_inv,
out_dtype=input.dtype,
use_fast_accum=True,
)
# Reshape back to original batch dims
output = output.reshape(*orig_shape[:-1], output.shape[-1])
return output
@staticmethod
def backward(ctx, grad_output):
input_2d, weight = ctx.saved_tensors
orig_shape = ctx.orig_shape
# Flatten grad_output to match input_2d
grad_output_flat = grad_output.reshape(-1, grad_output.shape[-1])
# === GEMM 1: grad_input = grad_output @ weight ===
go_fp8, go_inv = _to_fp8(grad_output_flat, torch.float8_e5m2)
w_fp8, w_inv = _to_fp8(weight, torch.float8_e4m3fn)
w_col = _to_col_major(w_fp8)
grad_input_flat = torch._scaled_mm(
go_fp8,
w_col,
scale_a=go_inv,
scale_b=w_inv,
out_dtype=grad_output.dtype,
use_fast_accum=False,
)
# Reshape back to original input shape
grad_input = grad_input_flat.reshape(orig_shape)
# === GEMM 2: grad_weight = grad_output.T @ input ===
go_fp8_2, go_inv_2 = _to_fp8(grad_output_flat, torch.float8_e5m2)
in_fp8, in_inv = _to_fp8(input_2d, torch.float8_e4m3fn)
go_T = go_fp8_2.t().contiguous()
in_col = _to_col_major(in_fp8)
grad_weight = torch._scaled_mm(
go_T,
in_col,
scale_a=go_inv_2,
scale_b=in_inv,
out_dtype=grad_output.dtype,
use_fast_accum=False,
)
return grad_input, grad_weight
class Float8Linear(nn.Linear):
"""Drop-in nn.Linear replacement that does FP8 compute.

View File

@ -2,14 +2,16 @@
GPT model (rewrite, a lot simpler)
Notable features:
- rotary embeddings (and no positional embeddings)
- QK norm
- QK norm (functional RMSNorm, no learnable params)
- untied weights for token embedding and lm_head
- relu^2 activation in MLP
- norm after token embedding
- no learnable params in rmsnorm
- norm after token embedding (functional RMSNorm)
- parameterized RMSNorm in blocks (learnable gamma)
- per-block projection scalars for attention/MLP
- no bias in linear layers
- Group-Query Attention (GQA) support for more efficient inference
- Flash Attention 3 integration
- optional Hyperball optimizer for matrix params
"""
from functools import partial
@ -25,6 +27,13 @@ from nanochat.optim import MuonAdamW, DistMuonAdamW
# Our custom Flash Attention module that automatically uses FA3 on Hopper+ and SDPA fallback elsewhere
from nanochat.flash_attention import flash_attn
# FP8 imports (optional) - minimal custom implementation
try:
from nanochat.fp8 import Float8Linear, _Float8MatmulND
except ImportError:
Float8Linear = None
_Float8MatmulND = None
@dataclass
class GPTConfig:
sequence_len: int = 2048
@ -40,10 +49,37 @@ class GPTConfig:
def norm(x):
# Purely functional rmsnorm with no learnable params
# Purely functional RMSNorm with no learnable params
return F.rms_norm(x, (x.size(-1),))
def reparam_linear(module, x, gamma=None, scalar=None):
"""Linear with gamma/scalar folded into weight. Works with both nn.Linear and Float8Linear.
gamma: RMSNorm learnable weight, folded into input dim of W (w = w * gamma[None, :])
scalar: projection scalar, folded into output dim of W (w = scalar[:, None] * w)
For FP8, uses minimal custom _Float8MatmulND which handles N-D tensors internally.
"""
w = module.weight
if gamma is not None:
w = w * gamma[None, :]
if scalar is not None:
w = scalar[:, None] * w
# FP8 path: use custom _Float8MatmulND for efficient N-D tensor handling
# (reshaping is done internally, so torch.compile sees it as one opaque operation)
if Float8Linear is not None and isinstance(module, Float8Linear):
# Handle autocast (Float8Linear expects this)
if torch.is_autocast_enabled():
x = x.to(torch.get_autocast_gpu_dtype())
output = _Float8MatmulND.apply(x, w)
if module.bias is not None:
output = output + module.bias.to(output.dtype)
return output
# BF16 path
return F.linear(x, w)
def has_ve(layer_idx, n_layer):
"""Returns True if GPT layer should have Value Embedding (alternating, last layer always included)."""
return layer_idx % 2 == (n_layer - 1) % 2
@ -72,20 +108,22 @@ class CausalSelfAttention(nn.Module):
self.c_proj = nn.Linear(self.n_embd, self.n_embd, bias=False)
self.ve_gate_channels = 32
self.ve_gate = nn.Linear(self.ve_gate_channels, self.n_kv_head, bias=False) if has_ve(layer_idx, config.n_layer) else None
self.v_proj_scalar = nn.Parameter(torch.zeros(self.n_kv_head)) if has_ve(layer_idx, config.n_layer) else None
self.c_proj_scalar = nn.Parameter(torch.zeros(config.n_embd))
def forward(self, x, ve, cos_sin, window_size, kv_cache):
B, T, C = x.size()
# Project the input to get queries, keys, and values
# Project the input to get queries, keys, and values (gamma folded into weights)
# Shape: (B, T, H, D) - FA3's native layout, no transpose needed!
q = self.c_q(x).view(B, T, self.n_head, self.head_dim)
k = self.c_k(x).view(B, T, self.n_kv_head, self.head_dim)
v = self.c_v(x).view(B, T, self.n_kv_head, self.head_dim)
q = reparam_linear(self.c_q, x).view(B, T, self.n_head, self.head_dim)
k = reparam_linear(self.c_k, x).view(B, T, self.n_kv_head, self.head_dim)
v = reparam_linear(self.c_v, x).view(B, T, self.n_kv_head, self.head_dim)
# Value residual (ResFormer): mix in value embedding with input-dependent gate per head
if ve is not None:
ve = ve.view(B, T, self.n_kv_head, self.head_dim)
gate = 2 * torch.sigmoid(self.ve_gate(x[..., :self.ve_gate_channels])) # (B, T, n_kv_head), range (0, 2)
gate = 2 * torch.sigmoid(reparam_linear(self.ve_gate, x[..., :self.ve_gate_channels], scalar=self.v_proj_scalar)) # (B, T, n_kv_head), range (0, 2)
v = v + gate.unsqueeze(-1) * ve
# Apply Rotary Embeddings to queries and keys to get relative positional encoding
@ -112,9 +150,9 @@ class CausalSelfAttention(nn.Module):
if self.layer_idx == kv_cache.n_layers - 1:
kv_cache.advance(T)
# Re-assemble the heads and project back to residual stream
# Re-assemble the heads and project back to residual stream (scalar folded into weight)
y = y.contiguous().view(B, T, -1)
y = self.c_proj(y)
y = reparam_linear(self.c_proj, y, scalar=self.c_proj_scalar)
return y
@ -123,11 +161,12 @@ class MLP(nn.Module):
super().__init__()
self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=False)
self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=False)
self.c_proj_scalar = nn.Parameter(torch.zeros(config.n_embd))
def forward(self, x):
x = self.c_fc(x)
x = reparam_linear(self.c_fc, x)
x = F.relu(x).square()
x = self.c_proj(x)
x = reparam_linear(self.c_proj, x, scalar=self.c_proj_scalar)
return x
@ -196,9 +235,9 @@ class GPT(nn.Module):
attn.c_q: uniform, std=1/sqrt(n_embd)
attn.c_k: uniform, std=1/sqrt(n_embd)
attn.c_v: uniform, std=1/sqrt(n_embd)
attn.c_proj: zeros
attn.c_proj: uniform, std=1/sqrt(n_embd)
mlp.c_fc: uniform, std=1/sqrt(n_embd)
mlp.c_proj: zeros
mlp.c_proj: uniform, std=1/sqrt(n_embd)
"""
# Embedding and unembedding
@ -212,22 +251,35 @@ class GPT(nn.Module):
torch.nn.init.uniform_(block.attn.c_q.weight, -s, s) # weights use Uniform to avoid outliers
torch.nn.init.uniform_(block.attn.c_k.weight, -s, s)
torch.nn.init.uniform_(block.attn.c_v.weight, -s, s)
torch.nn.init.zeros_(block.attn.c_proj.weight) # projections are zero
torch.nn.init.uniform_(block.attn.c_proj.weight, -s, s)
torch.nn.init.uniform_(block.mlp.c_fc.weight, -s, s)
torch.nn.init.zeros_(block.mlp.c_proj.weight)
torch.nn.init.uniform_(block.mlp.c_proj.weight, -s, s)
# Per-layer scalars
self.resid_lambdas.fill_(1.0) # 1.0 => typical residual connections at init
self.x0_lambdas.fill_(0.1) # 0.1 => small initial weight for skip connection to input embedding
# Per-block projection scalars (zero-init for stable training start)
for block in self.transformer.h:
block.attn.c_proj_scalar.fill_(0.0)
block.mlp.c_proj_scalar.fill_(0.0)
if block.attn.v_proj_scalar is not None:
block.attn.v_proj_scalar.fill_(0.0)
if self.transformer.wte.weight.device.type == "cuda":
block.attn.c_proj_scalar.data = block.attn.c_proj_scalar.data.to(torch.bfloat16)
block.mlp.c_proj_scalar.data = block.mlp.c_proj_scalar.data.to(torch.bfloat16)
if block.attn.v_proj_scalar is not None:
block.attn.v_proj_scalar.data = block.attn.v_proj_scalar.data.to(torch.bfloat16)
# Value embeddings (init like c_v: uniform with same std)
for ve in self.value_embeds.values():
torch.nn.init.uniform_(ve.weight, -s, s)
# Gate weights init to zero so gates start at sigmoid(0) = 0.5, scaled by 2 -> 1.0 (neutral)
# Gate weights init to uniform (avoid zero-norm params under Hyperball, following mup)
s_ve_gate = 3 ** 0.5 * 32**-0.5
for block in self.transformer.h:
if block.attn.ve_gate is not None:
torch.nn.init.zeros_(block.attn.ve_gate.weight)
torch.nn.init.uniform_(block.attn.ve_gate.weight, -s_ve_gate, s_ve_gate)
# Rotary embeddings
head_dim = self.config.n_embd // self.config.n_head
@ -302,10 +354,11 @@ class GPT(nn.Module):
- Chinchilla counts exp/sum/divide in attention softmax as flops (a little sus and very tiny => we ignore)
"""
nparams = sum(p.numel() for p in self.parameters())
# Exclude non-matmul params: embeddings and per-layer scalars
# Exclude non-matmul params: embeddings, per-layer scalars, and 1D params in blocks
value_embeds_numel = sum(ve.weight.numel() for ve in self.value_embeds.values())
block_1d_params = sum(p.numel() for p in self.transformer.h.parameters() if p.ndim == 1)
nparams_exclude = (self.transformer.wte.weight.numel() + value_embeds_numel +
self.resid_lambdas.numel() + self.x0_lambdas.numel())
self.resid_lambdas.numel() + self.x0_lambdas.numel() + block_1d_params)
h, q, t = self.config.n_head, self.config.n_embd // self.config.n_head, self.config.sequence_len
# Sum attention FLOPs per layer, accounting for sliding window
attn_flops = 0
@ -332,31 +385,36 @@ class GPT(nn.Module):
wte = sum(p.numel() for p in self.transformer.wte.parameters())
value_embeds = sum(p.numel() for p in self.value_embeds.parameters())
lm_head = sum(p.numel() for p in self.lm_head.parameters())
transformer_matrices = sum(p.numel() for p in self.transformer.h.parameters())
block_matrix_params = sum(p.numel() for p in self.transformer.h.parameters() if p.ndim == 2)
block_1d_params = sum(p.numel() for p in self.transformer.h.parameters() if p.ndim == 1)
scalars = self.resid_lambdas.numel() + self.x0_lambdas.numel()
total = wte + value_embeds + lm_head + transformer_matrices + scalars
total = wte + value_embeds + lm_head + block_matrix_params + block_1d_params + scalars
assert total == sum(p.numel() for p in self.parameters()), "Parameter count mismatch"
return {
'wte': wte,
'value_embeds': value_embeds,
'lm_head': lm_head,
'transformer_matrices': transformer_matrices,
'transformer_matrices': block_matrix_params,
'norm_and_proj_scalars': block_1d_params,
'scalars': scalars,
'total': total,
}
def setup_optimizer(self, unembedding_lr=0.004, embedding_lr=0.2, matrix_lr=0.02, weight_decay=0.0, adam_betas=(0.8, 0.95), scalar_lr=0.5):
def setup_optimizer(self, unembedding_lr=0.004, embedding_lr=0.2, matrix_lr=0.02, weight_decay=0.0, adam_betas=(0.8, 0.95), scalar_lr=0.5, norm_lr=0.1, matrix_optimizer="muon"):
model_dim = self.config.n_embd
ddp, rank, local_rank, world_size = get_dist_info()
# Separate out all parameters into groups
matrix_params = list(self.transformer.h.parameters())
block_matrix_params = [p for p in self.transformer.h.parameters() if p.ndim == 2]
block_1d_params = [p for p in self.transformer.h.parameters() if p.ndim == 1]
value_embeds_params = list(self.value_embeds.parameters())
embedding_params = list(self.transformer.wte.parameters())
lm_head_params = list(self.lm_head.parameters())
resid_params = [self.resid_lambdas]
x0_params = [self.x0_lambdas]
assert len(list(self.parameters())) == len(matrix_params) + len(embedding_params) + len(lm_head_params) + len(value_embeds_params) + len(resid_params) + len(x0_params)
all_params_count = (len(block_matrix_params) + len(block_1d_params) + len(embedding_params) +
len(lm_head_params) + len(value_embeds_params) + len(resid_params) + len(x0_params))
assert len(list(self.parameters())) == all_params_count
# Scale the LR for the AdamW parameters by ∝1/√dmodel (tuned for 768 dim model)
dmodel_lr_scale = (model_dim / 768) ** -0.5
@ -370,14 +428,23 @@ class GPT(nn.Module):
dict(kind='adamw', params=value_embeds_params, lr=embedding_lr * dmodel_lr_scale, betas=adam_betas, eps=1e-10, weight_decay=0.0),
dict(kind='adamw', params=resid_params, lr=scalar_lr * 0.01, betas=adam_betas, eps=1e-10, weight_decay=0.0),
dict(kind='adamw', params=x0_params, lr=scalar_lr, betas=(0.96, 0.95), eps=1e-10, weight_decay=0.0), # higher beta1 for x0
dict(kind='adamw', params=block_1d_params, lr=norm_lr, betas=adam_betas, eps=1e-10, weight_decay=0.0),
]
# Muon groups (matrix params, grouped by shape for stacking)
for shape in sorted({p.shape for p in matrix_params}):
group_params = [p for p in matrix_params if p.shape == shape]
param_groups.append(dict(
kind='muon', params=group_params, lr=matrix_lr,
momentum=0.95, ns_steps=5, beta2=0.95, weight_decay=weight_decay,
))
# Matrix params (Muon or Hyperball), grouped by shape for stacking
if matrix_optimizer not in {"muon", "hyperball"}:
raise ValueError(f"Unknown matrix_optimizer: {matrix_optimizer}")
for shape in sorted({p.shape for p in block_matrix_params}):
group_params = [p for p in block_matrix_params if p.shape == shape]
if matrix_optimizer == "hyperball":
param_groups.append(dict(
kind='hyperball', params=group_params, lr=matrix_lr,
momentum=0.95, ns_steps=5, beta2=0.95,
))
else:
param_groups.append(dict(
kind='muon', params=group_params, lr=matrix_lr,
momentum=0.95, ns_steps=5, beta2=0.95, weight_decay=weight_decay,
))
Factory = DistMuonAdamW if ddp else MuonAdamW
optimizer = Factory(param_groups)

View File

@ -1,6 +1,6 @@
"""
A nice and efficient mixed AdamW/Muon Combined Optimizer.
Usually the embeddings and scalars go into AdamW, and the matrix parameters go into Muon.
Usually the embeddings and scalars go into AdamW, and the matrix parameters go into Muon/Hyperball.
Two versions are provided (MuonAdamW, DistMuonAdamW), for single GPU and distributed.
Addapted from: https://github.com/KellerJordan/modded-nanogpt
@ -145,6 +145,80 @@ def muon_step_fused(
mask = (g * stacked_params) >= 0
stacked_params.sub_(lr * g + lr * wd * stacked_params * mask)
# -----------------------------------------------------------------------------
"""
Hyperball optimizer (MuonH): Muon with scale-invariant updates.
https://github.com/marin-community/marin/blob/main/lib/levanter/src/levanter/optim/muonh.py
The key difference from Muon is that weights maintain constant Frobenius norm
throughout training via the following update rule:
p_new_intermediate = p - lr * u * ||p|| / ||u||
p_new = p_new_intermediate / ||p_new_intermediate|| * ||p||
This projects the update onto the hypersphere of constant norm, hence "Hyperball".
Uses variance reduction like Muon, but no cautious weight decay.
"""
@torch.compile(dynamic=False, fullgraph=True)
def hyperball_step_fused(
stacked_grads: Tensor, # (K, M, N) - stacked gradients
stacked_params: Tensor, # (K, M, N) - stacked parameters
momentum_buffer: Tensor, # (K, M, N) - momentum buffer
second_momentum_buffer: Tensor, # (K, M, 1) or (K, 1, N) - factored second moment
p_norm: Tensor, # (K, 1, 1) - pre-computed Frobenius norm of params (constant)
momentum_t: Tensor, # () - 0-D CPU tensor, momentum coefficient
lr_t: Tensor, # () - 0-D CPU tensor, learning rate
beta2_t: Tensor, # () - 0-D CPU tensor, beta2 for second moment
ns_steps: int, # 5 - number of Newton-Schulz/Polar Express iterations
red_dim: int, # -1 or -2 - reduction dimension for variance
) -> None:
"""
Fused Hyperball step: momentum -> polar_express -> variance_reduction -> scale_invariant_update
All in one compiled graph. Weights maintain constant Frobenius norm.
"""
# Nesterov momentum
momentum = momentum_t.to(stacked_grads.dtype)
momentum_buffer.lerp_(stacked_grads, 1 - momentum)
g = stacked_grads.lerp_(momentum_buffer, momentum)
# Polar express orthogonalization
X = g.bfloat16()
X = X / (X.norm(dim=(-2, -1), keepdim=True) * 1.02 + 1e-6)
if g.size(-2) > g.size(-1): # Tall matrix
for a, b, c in polar_express_coeffs[:ns_steps]:
A = X.mT @ X
B = b * A + c * (A @ A)
X = a * X + X @ B
else: # Wide matrix
for a, b, c in polar_express_coeffs[:ns_steps]:
A = X @ X.mT
B = b * A + c * (A @ A)
X = a * X + B @ X
g = X
# Variance reduction (note: this preserves ||g||_F, so ||u||_F == ||g||_F == v_norm)
beta2 = beta2_t.to(g.dtype)
v_mean = g.float().square().mean(dim=red_dim, keepdim=True)
red_dim_size = g.size(red_dim)
second_momentum_buffer.lerp_(v_mean.to(dtype=second_momentum_buffer.dtype), 1 - beta2)
step_size = second_momentum_buffer.clamp_min(1e-10).rsqrt()
scaled_sq_sum = (v_mean * red_dim_size) * step_size.float().square()
v_norm_new = scaled_sq_sum.sum(dim=(-2, -1), keepdim=True).sqrt()
p_norm = p_norm.to(v_norm_new.dtype)
final_scale = step_size * p_norm / v_norm_new.clamp_min(1e-10)
g = g * final_scale.to(g.dtype)
u = g.to(stacked_params.dtype)
# Scale-invariant update: keeps ||p|| constant
lr = lr_t.to(stacked_params.dtype)
stacked_params.sub_(lr * u)
# Project back to hypersphere: p = p * (||p_orig|| / ||p_new||)
p_new_norm = stacked_params.norm(dim=(-2, -1), keepdim=True).clamp_min(1e-10)
stacked_params.mul_(p_norm / p_new_norm)
# -----------------------------------------------------------------------------
# Single GPU version of the MuonAdamW optimizer.
# Used mostly for reference, debugging and testing.
@ -171,9 +245,10 @@ class MuonAdamW(torch.optim.Optimizer):
Arguments:
param_groups: List of dicts, each containing:
- 'params': List of parameters
- 'kind': 'adamw' or 'muon'
- 'kind': 'adamw', 'muon', or 'hyperball'
- For AdamW groups: 'lr', 'betas', 'eps', 'weight_decay'
- For Muon groups: 'lr', 'momentum', 'ns_steps', 'beta2', 'weight_decay'
- For Hyperball groups: 'lr', 'momentum', 'ns_steps', 'beta2'
"""
def __init__(self, param_groups: list[dict]):
super().__init__(param_groups, defaults={})
@ -190,6 +265,10 @@ class MuonAdamW(torch.optim.Optimizer):
self._muon_lr_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
self._muon_wd_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
self._muon_beta2_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
# Hyperball tensors
self._hyperball_momentum_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
self._hyperball_lr_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
self._hyperball_beta2_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
def _step_adamw(self, group: dict) -> None:
"""
@ -280,6 +359,64 @@ class MuonAdamW(torch.optim.Optimizer):
# Copy back to original params
torch._foreach_copy_(params, list(stacked_params.unbind(0)))
def _step_hyperball(self, group: dict) -> None:
"""
Hyperball update for all params in the group (stacked for efficiency).
Like Muon, but uses scale-invariant updates that keep weight norms constant.
"""
params: list[Tensor] = group['params']
if not params:
return
# Get or create group-level buffers (stored in first param's state for convenience)
p = params[0]
state = self.state[p]
num_params = len(params)
shape, device, dtype = p.shape, p.device, p.dtype
# Stack grads and params (NOTE: this assumes all params have the same shape)
stacked_grads = torch.stack([p.grad for p in params])
stacked_params = torch.stack(params)
# Momentum buffer for every individual parameter
if "momentum_buffer" not in state:
state["momentum_buffer"] = torch.zeros(num_params, *shape, dtype=dtype, device=device)
momentum_buffer = state["momentum_buffer"]
# Second momentum buffer is factored, either per-row or per-column
if "second_momentum_buffer" not in state:
state_shape = (num_params, shape[-2], 1) if shape[-2] >= shape[-1] else (num_params, 1, shape[-1])
state["second_momentum_buffer"] = torch.zeros(state_shape, dtype=dtype, device=device)
second_momentum_buffer = state["second_momentum_buffer"]
red_dim = -1 if shape[-2] >= shape[-1] else -2
# Pre-compute and cache p_norm (Frobenius norm of each param, constant throughout training)
if "p_norm" not in state:
state["p_norm"] = stacked_params.norm(dim=(-2, -1), keepdim=True).clone()
p_norm = state["p_norm"]
# Fill all the 0-D tensors with current values
self._hyperball_momentum_t.fill_(group["momentum"])
self._hyperball_lr_t.fill_(group["lr"])
self._hyperball_beta2_t.fill_(group["beta2"] if group["beta2"] is not None else 0.0)
# Single fused kernel: momentum -> polar_express -> variance_reduction -> scale_invariant_update
hyperball_step_fused(
stacked_grads,
stacked_params,
momentum_buffer,
second_momentum_buffer,
p_norm,
self._hyperball_momentum_t,
self._hyperball_lr_t,
self._hyperball_beta2_t,
group["ns_steps"],
red_dim,
)
# Copy back to original params
torch._foreach_copy_(params, list(stacked_params.unbind(0)))
@torch.no_grad()
def step(self):
for group in self.param_groups:
@ -287,6 +424,8 @@ class MuonAdamW(torch.optim.Optimizer):
self._step_adamw(group)
elif group['kind'] == 'muon':
self._step_muon(group)
elif group['kind'] == 'hyperball':
self._step_hyperball(group)
else:
raise ValueError(f"Unknown optimizer kind: {group['kind']}")
@ -348,9 +487,10 @@ class DistMuonAdamW(torch.optim.Optimizer):
Arguments:
param_groups: List of dicts, each containing:
- 'params': List of parameters
- 'kind': 'adamw' or 'muon'
- 'kind': 'adamw', 'muon', or 'hyperball'
- For AdamW groups: 'lr', 'betas', 'eps', 'weight_decay'
- For Muon groups: 'lr', 'momentum', 'ns_steps', 'beta2', 'weight_decay'
- For Hyperball groups: 'lr', 'momentum', 'ns_steps', 'beta2'
"""
def __init__(self, param_groups: list[dict]):
super().__init__(param_groups, defaults={})
@ -365,6 +505,10 @@ class DistMuonAdamW(torch.optim.Optimizer):
self._muon_lr_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
self._muon_wd_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
self._muon_beta2_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
# Hyperball tensors
self._hyperball_momentum_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
self._hyperball_lr_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
self._hyperball_beta2_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
def _reduce_adamw(self, group: dict, world_size: int) -> dict:
"""Launch async reduce ops for AdamW group. Returns info dict with per-param infos."""
@ -496,6 +640,85 @@ class DistMuonAdamW(torch.optim.Optimizer):
future = dist.all_gather_into_tensor(stacked_params, updated_params, async_op=True).get_future()
gather_list.append(dict(future=future, stacked_params=stacked_params, params=params))
def _reduce_hyperball(self, group: dict, world_size: int) -> dict:
"""Launch async reduce op for Hyperball group. Returns info dict."""
params = group['params']
chunk_size = (len(params) + world_size - 1) // world_size
padded_num_params = chunk_size * world_size
p = params[0]
shape, device, dtype = p.shape, p.device, p.dtype
# Stack grads and zero-pad to padded_num_params
grad_stack = torch.stack([p.grad for p in params])
stacked_grads = torch.empty(padded_num_params, *shape, dtype=dtype, device=device)
stacked_grads[:len(params)].copy_(grad_stack)
if len(params) < padded_num_params:
stacked_grads[len(params):].zero_()
# Reduce_scatter to get this rank's chunk
grad_chunk = torch.empty(chunk_size, *shape, dtype=dtype, device=device)
future = dist.reduce_scatter_tensor(grad_chunk, stacked_grads, op=dist.ReduceOp.AVG, async_op=True).get_future()
return dict(future=future, grad_chunk=grad_chunk, stacked_grads=stacked_grads, chunk_size=chunk_size)
def _compute_hyperball(self, group: dict, info: dict, gather_list: list, rank: int) -> None:
"""Wait for reduce, compute Hyperball updates, launch gather."""
info['future'].wait()
params = group['params']
chunk_size = info['chunk_size']
grad_chunk = info['grad_chunk']
p = params[0]
shape, device, dtype = p.shape, p.device, p.dtype
# How many params does this rank own?
start_idx = rank * chunk_size
num_owned = min(chunk_size, max(0, len(params) - start_idx))
# Get or create group-level state
state = self.state[p]
if "momentum_buffer" not in state:
state["momentum_buffer"] = torch.zeros(chunk_size, *shape, dtype=dtype, device=device)
if "second_momentum_buffer" not in state:
state_shape = (chunk_size, shape[-2], 1) if shape[-2] >= shape[-1] else (chunk_size, 1, shape[-1])
state["second_momentum_buffer"] = torch.zeros(state_shape, dtype=dtype, device=device)
red_dim = -1 if shape[-2] >= shape[-1] else -2
# Build output buffer for all_gather
updated_params = torch.empty(chunk_size, *shape, dtype=dtype, device=device)
if num_owned > 0:
owned_params = [params[start_idx + i] for i in range(num_owned)]
stacked_owned = torch.stack(owned_params)
# Pre-compute and cache p_norm for owned params (constant throughout training)
if "p_norm" not in state:
state["p_norm"] = stacked_owned.norm(dim=(-2, -1), keepdim=True).clone()
p_norm = state["p_norm"]
# Fill 0-D tensors and run fused kernel
self._hyperball_momentum_t.fill_(group["momentum"])
self._hyperball_lr_t.fill_(group["lr"])
self._hyperball_beta2_t.fill_(group["beta2"] if group["beta2"] is not None else 0.0)
hyperball_step_fused(
grad_chunk[:num_owned], stacked_owned,
state["momentum_buffer"][:num_owned],
state["second_momentum_buffer"][:num_owned],
p_norm,
self._hyperball_momentum_t, self._hyperball_lr_t,
self._hyperball_beta2_t,
group["ns_steps"],
red_dim,
)
updated_params[:num_owned].copy_(stacked_owned)
if num_owned < chunk_size:
updated_params[num_owned:].zero_()
# Reuse stacked_grads buffer for all_gather output
stacked_params = info["stacked_grads"]
future = dist.all_gather_into_tensor(stacked_params, updated_params, async_op=True).get_future()
gather_list.append(dict(future=future, stacked_params=stacked_params, params=params))
def _finish_gathers(self, gather_list: list) -> None:
"""Wait for all gathers and copy Muon params back."""
for info in gather_list:
@ -516,6 +739,8 @@ class DistMuonAdamW(torch.optim.Optimizer):
reduce_infos.append(self._reduce_adamw(group, world_size))
elif group['kind'] == 'muon':
reduce_infos.append(self._reduce_muon(group, world_size))
elif group['kind'] == 'hyperball':
reduce_infos.append(self._reduce_hyperball(group, world_size))
else:
raise ValueError(f"Unknown optimizer kind: {group['kind']}")
@ -526,6 +751,8 @@ class DistMuonAdamW(torch.optim.Optimizer):
self._compute_adamw(group, info, gather_list, rank, world_size)
elif group['kind'] == 'muon':
self._compute_muon(group, info, gather_list, rank)
elif group['kind'] == 'hyperball':
self._compute_hyperball(group, info, gather_list, rank)
else:
raise ValueError(f"Unknown optimizer kind: {group['kind']}")

164
runs/quickrun_muonh.sh Executable file
View File

@ -0,0 +1,164 @@
#!/bin/bash
# Quickrun: GPT-Gamma + MuonH (Hyperball)
# - Parameterized RMSNorm (learnable gamma)
# - Per-block projection scalars
# - Hyperball or Muon for matrix params
#
# Examples:
# bash runs/quickrun_muonh.sh
# WANDB_RUN=exp1 bash runs/quickrun_muonh.sh
# FP8=1 FP8_RECIPE=tensorwise bash runs/quickrun_muonh.sh
# DEPTH=16 bash runs/quickrun_muonh.sh
set -e
# -----------------------------------------------------------------------------
# Config
DEPTH="${DEPTH:-24}"
NUM_SHARDS="${NUM_SHARDS:-370}" # default for d24 @ ratio~11
TARGET_RATIO="${TARGET_RATIO:-12}"
WINDOW_PATTERN="${WINDOW_PATTERN:-SSSL}"
DEVICE_BATCH_SIZE="${DEVICE_BATCH_SIZE:-16}"
TOTAL_BATCH_SIZE="${TOTAL_BATCH_SIZE:-524288}" # -1 = auto-compute optimal (Power Lines paper)
NPROC_PER_NODE="${NPROC_PER_NODE:-$(nvidia-smi -L 2>/dev/null | wc -l || echo 1)}"
if [ "$NPROC_PER_NODE" -eq 0 ]; then
NPROC_PER_NODE=1
fi
# Optimizer
MATRIX_OPTIMIZER="${MATRIX_OPTIMIZER:-hyperball}"
SCALAR_LR="${SCALAR_LR:-0.5}"
MATRIX_LR="${MATRIX_LR:-0.02}"
WARMDOWN_RATIO="${WARMDOWN_RATIO:-0.3}"
MATRIX_WARMDOWN_RATIO="${MATRIX_WARMDOWN_RATIO:-1.0}"
# AdamW
EMBEDDING_LR="${EMBEDDING_LR:-0.3}"
UNEMBEDDING_LR="${UNEMBEDDING_LR:-0.004}"
NORM_LR="${NORM_LR:-0.1}"
# Wandb
export WANDB_ENTITY="${WANDB_ENTITY:-xingyu20}"
export WANDB_PROJECT="${WANDB_PROJECT:-nanochat}"
WANDB_RUN="${WANDB_RUN:-muonh_d${DEPTH}_ratio${TARGET_RATIO}_feb_11_no_gamma}"
MODEL_TAG="${MODEL_TAG:-d${DEPTH}_gamma_muonh}"
# FP8 (default enabled)c
FP8="${FP8:-1}"
FP8_ARGS=""
if [ "${FP8:-0}" -eq 1 ]; then
FP8_RECIPE="${FP8_RECIPE:-tensorwise}"
FP8_ARGS="--fp8 --fp8-recipe=${FP8_RECIPE}"
fi
# -----------------------------------------------------------------------------
# Paths and cache
export OMP_NUM_THREADS=1
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
PROJECT_ROOT="$(dirname "$SCRIPT_DIR")"
export NANOCHAT_BASE_DIR="$PROJECT_ROOT/cache"
export TORCHINDUCTOR_CACHE_DIR="$NANOCHAT_BASE_DIR/torch_inductor"
export TRITON_CACHE_DIR="$NANOCHAT_BASE_DIR/triton"
export TMPDIR="$NANOCHAT_BASE_DIR/tmp"
mkdir -p "$NANOCHAT_BASE_DIR" "$TORCHINDUCTOR_CACHE_DIR" "$TRITON_CACHE_DIR" "$TMPDIR"
# -----------------------------------------------------------------------------
# Print summary
echo "=============================================="
echo "Quickrun (GPT-Gamma + MuonH)"
echo "=============================================="
echo "Project root: $PROJECT_ROOT"
echo "Cache dir: $NANOCHAT_BASE_DIR"
echo "Depth: $DEPTH"
echo "Num shards: $NUM_SHARDS"
echo "Target ratio: $TARGET_RATIO"
echo "Window pattern: $WINDOW_PATTERN"
echo "Num GPUs: $NPROC_PER_NODE"
echo "Device batch size: $DEVICE_BATCH_SIZE"
echo "Total batch size: $TOTAL_BATCH_SIZE"
echo "Matrix optimizer: $MATRIX_OPTIMIZER"
echo "Matrix LR: $MATRIX_LR"
echo "Norm LR: $NORM_LR"
echo "Adam LRs: embedding=$EMBEDDING_LR, unembedding=$UNEMBEDDING_LR, scalar=$SCALAR_LR"
echo "Warmdown ratio: adam=$WARMDOWN_RATIO, matrix=$MATRIX_WARMDOWN_RATIO"
echo "Wandb run: $WANDB_RUN"
echo "Model tag: $MODEL_TAG"
if [ "${FP8:-0}" -eq 1 ]; then
echo "FP8: enabled ($FP8_RECIPE)"
fi
echo "=============================================="
cd "$PROJECT_ROOT"
# -----------------------------------------------------------------------------
# Python venv
if [ ! -d ".venv" ]; then
echo "Setting up Python environment..."
command -v uv &> /dev/null || curl -LsSf https://astral.sh/uv/install.sh | sh
uv venv
uv sync --extra gpu
fi
source .venv/bin/activate
# -----------------------------------------------------------------------------
# Data + tokenizer
echo ""
echo "Downloading $NUM_SHARDS data shards..."
python -m nanochat.dataset -n "$NUM_SHARDS"
echo ""
TOKENIZER_DIR="$NANOCHAT_BASE_DIR/tokenizer"
if [ -f "$TOKENIZER_DIR/token_bytes.pt" ]; then
echo "Tokenizer already exists at $TOKENIZER_DIR, skipping training."
else
echo "Training tokenizer..."
python -m scripts.tok_train --max-chars=500000000 --vocab-size=32768
fi
# -----------------------------------------------------------------------------
# Train
echo ""
echo "Starting pretraining (depth=$DEPTH)..."
TRAIN_ARGS=(
--depth=$DEPTH
--run=$WANDB_RUN
--model-tag=$MODEL_TAG
--window-pattern=$WINDOW_PATTERN
--target-param-data-ratio=$TARGET_RATIO
--device-batch-size=$DEVICE_BATCH_SIZE
--total-batch-size=$TOTAL_BATCH_SIZE
--matrix-optimizer=$MATRIX_OPTIMIZER
--matrix-lr=$MATRIX_LR
--warmdown-ratio=$WARMDOWN_RATIO
--matrix-warmdown-ratio=$MATRIX_WARMDOWN_RATIO
--embedding-lr=$EMBEDDING_LR
--unembedding-lr=$UNEMBEDDING_LR
--norm-lr=$NORM_LR
--scalar-lr=$SCALAR_LR
--core-metric-every=${CORE_METRIC_EVERY:-2000}
--sample-every=${SAMPLE_EVERY:--1}
--save-every=${SAVE_EVERY:--1}
)
if [ "$NPROC_PER_NODE" -gt 1 ]; then
torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.base_train -- \
"${TRAIN_ARGS[@]}" $FP8_ARGS
else
python -m scripts.base_train \
"${TRAIN_ARGS[@]}" $FP8_ARGS
fi
echo ""
echo "=============================================="
echo "Training complete!"
echo "=============================================="
echo "Checkpoint saved to: $NANOCHAT_BASE_DIR/base_checkpoints/${MODEL_TAG}/"

View File

@ -43,7 +43,7 @@ parser.add_argument("--run", type=str, default="dummy", help="wandb run name ('d
# Runtime
parser.add_argument("--device-type", type=str, default="", help="cuda|cpu|mps (empty = autodetect)")
# FP8 training
parser.add_argument("--fp8", action="store_true", help="enable FP8 training (requires H100+ GPU and torchao)")
parser.add_argument("--fp8", action="store_true", help="enable FP8 training (requires H100+ GPU); uses minimal custom fp8 module")
parser.add_argument("--fp8-recipe", type=str, default="tensorwise", choices=["rowwise", "tensorwise"], help="FP8 scaling recipe: tensorwise (faster, recommended) or rowwise (more accurate but slower)")
# Model architecture
parser.add_argument("--depth", type=int, default=20, help="depth of the Transformer model")
@ -61,12 +61,16 @@ parser.add_argument("--total-batch-size", type=int, default=-1, help="total batc
parser.add_argument("--embedding-lr", type=float, default=0.3, help="learning rate for embedding parameters (Adam)")
parser.add_argument("--unembedding-lr", type=float, default=0.004, help="learning rate for unembedding parameters (Adam)")
parser.add_argument("--weight-decay", type=float, default=0.2, help="cautious weight decay for the Muon optimizer (for weights)")
parser.add_argument("--matrix-lr", type=float, default=0.02, help="learning rate for matrix parameters (Muon)")
parser.add_argument("--matrix-lr", type=float, default=0.02, help="learning rate for matrix parameters (Muon/Hyperball)")
parser.add_argument("--matrix-optimizer", type=str, default="muon", choices=["muon", "hyperball"], help="optimizer for matrix parameters")
parser.add_argument("--scalar-lr", type=float, default=0.5, help="learning rate for scalars (resid_lambdas, x0_lambdas)")
parser.add_argument("--norm-lr", type=float, default=0.1, help="learning rate for norm/gamma parameters")
parser.add_argument("--adam-beta1", type=float, default=0.8, help="Adam beta1 for embedding/unembedding")
parser.add_argument("--adam-beta2", type=float, default=0.95, help="Adam beta2 for embedding/unembedding")
parser.add_argument("--warmup-ratio", type=float, default=0.0, help="ratio of iterations for LR warmup")
parser.add_argument("--warmdown-ratio", type=float, default=0.5, help="ratio of iterations for LR warmdown")
parser.add_argument("--warmdown-ratio", type=float, default=0.5, help="ratio of iterations for AdamW LR warmdown")
parser.add_argument("--matrix-warmup-ratio", type=float, default=0.0, help="ratio of iterations for Muon/Hyperball LR warmup")
parser.add_argument("--matrix-warmdown-ratio", type=float, default=1.0, help="ratio of iterations for Muon/Hyperball LR warmdown")
parser.add_argument("--final-lr-frac", type=float, default=0.0, help="final LR as fraction of initial LR")
parser.add_argument("--resume-from-step", type=int, default=-1, help="resume training from this step (-1 = disable)")
# Evaluation
@ -98,7 +102,8 @@ else:
# wandb logging init
use_dummy_wandb = args.run == "dummy" or not master_process
wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat", name=args.run, config=user_config)
wandb_project = os.environ.get("WANDB_PROJECT", "nanochat")
wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project=wandb_project, name=args.run, config=user_config)
# Flash Attention status
if HAS_FA3:
@ -300,15 +305,26 @@ if weight_decay_scaled != args.weight_decay:
# -----------------------------------------------------------------------------
# Initialize the Optimizer (combined MuonAdamW: Muon for matrix params, AdamW for rest)
matrix_lr_scaled = args.matrix_lr * batch_lr_scale
# LR data scaling for Hyperball
# We keep the same D_REF here
if args.matrix_optimizer == "hyperball":
D_REF_LR = 10.5 * get_scaling_params(d12_ref)
matrix_lr_scaled = matrix_lr_scaled * (D_REF_LR / target_tokens) ** 0.35 # 0.35 is the exponent for the power law fit by ourselves
print0(f"Scaling hyperball LR from {args.matrix_lr * batch_lr_scale:.6f} to {matrix_lr_scaled:.6f} for token ratio {target_tokens / D_REF:.2f} (T_train = {target_tokens:,} tokens)")
optimizer = model.setup_optimizer(
# AdamW hyperparameters
unembedding_lr=args.unembedding_lr * batch_lr_scale,
embedding_lr=args.embedding_lr * batch_lr_scale,
scalar_lr=args.scalar_lr * batch_lr_scale,
adam_betas=(args.adam_beta1, args.adam_beta2),
# Muon hyperparameters
matrix_lr=args.matrix_lr * batch_lr_scale,
norm_lr=args.norm_lr * batch_lr_scale,
# Muon/Hyperball hyperparameters
matrix_lr=matrix_lr_scaled,
weight_decay=weight_decay_scaled,
matrix_optimizer=args.matrix_optimizer,
)
if resuming:
@ -346,19 +362,20 @@ print0(f"Total number of training tokens: {total_tokens:,}")
print0(f"Tokens : Scaling params ratio: {total_batch_size * num_iterations / num_scaling_params:.2f}") # e.g. Chinchilla was ~20
print0(f"Total training FLOPs estimate: {num_flops_per_token * total_tokens:e}")
# Learning rate schedule (linear warmup, constant, linear warmdown)
def get_lr_multiplier(it):
warmup_iters = round(args.warmup_ratio * num_iterations)
warmdown_iters = round(args.warmdown_ratio * num_iterations)
if it < warmup_iters:
# Learning rate scheduler (warmup + warmdown, parameterized for separate adam/matrix schedules)
def get_lr_multiplier(it, warmup_ratio, warmdown_ratio, final_lr_frac):
warmup_iters = round(warmup_ratio * num_iterations)
warmdown_iters = round(warmdown_ratio * num_iterations)
if warmup_iters > 0 and it < warmup_iters:
return (it + 1) / warmup_iters
elif it <= num_iterations - warmdown_iters:
if warmdown_iters <= 0:
return 1.0
else:
progress = (num_iterations - it) / warmdown_iters
return progress * 1.0 + (1 - progress) * args.final_lr_frac
if it <= num_iterations - warmdown_iters:
return 1.0
progress = (num_iterations - it) / warmdown_iters
return progress * 1.0 + (1 - progress) * final_lr_frac
# Momentum scheduler for Muon optimizer (warms up to 0.95 over the first 300 steps)
# Momentum scheduler for matrix optimizer (Muon/Hyperball)
def get_muon_momentum(it):
frac = min(it / 300, 1)
momentum = (1 - frac) * 0.85 + frac * 0.95
@ -498,13 +515,18 @@ while True:
loss.backward()
x, y, dataloader_state_dict = next(train_loader) # prefetch the next batch while the GPU is busy with forward/backward
# step the optimizer
lrm = get_lr_multiplier(step)
lrm_adam = get_lr_multiplier(step, args.warmup_ratio, args.warmdown_ratio, args.final_lr_frac)
lrm_matrix = get_lr_multiplier(step, args.matrix_warmup_ratio, args.matrix_warmdown_ratio, args.final_lr_frac)
muon_momentum = get_muon_momentum(step)
muon_weight_decay = get_weight_decay(step)
for group in optimizer.param_groups:
group["lr"] = group["initial_lr"] * lrm
if group['kind'] == 'muon':
if group['kind'] in {'muon', 'hyperball'}:
group["lr"] = group["initial_lr"] * lrm_matrix
else:
group["lr"] = group["initial_lr"] * lrm_adam
if group['kind'] in {'muon', 'hyperball'}:
group["momentum"] = muon_momentum
if group['kind'] == 'muon':
group["weight_decay"] = muon_weight_decay
optimizer.step()
model.zero_grad(set_to_none=True)
@ -534,14 +556,15 @@ while True:
else:
eta_str = ""
epoch = dataloader_state_dict["epoch"]
print0(f"step {step:05d}/{num_iterations:05d} ({pct_done:.2f}%) | loss: {debiased_smooth_loss:.6f} | lrm: {lrm:.2f} | dt: {dt * 1000:.2f}ms | tok/sec: {tok_per_sec:,} | bf16_mfu: {mfu:.2f} | epoch: {epoch} | total time: {total_training_time/60:.2f}m{eta_str}")
print0(f"step {step:05d}/{num_iterations:05d} ({pct_done:.2f}%) | loss: {debiased_smooth_loss:.6f} | lrm(adam)={lrm_adam:.2f}, lrm(matrix)={lrm_matrix:.2f} | dt: {dt * 1000:.2f}ms | tok/sec: {tok_per_sec:,} | bf16_mfu: {mfu:.2f} | epoch: {epoch} | total time: {total_training_time/60:.2f}m{eta_str}")
if step % 100 == 0:
log_data = {
"step": step,
"total_training_flops": flops_so_far,
"total_training_time": total_training_time,
"train/loss": debiased_smooth_loss,
"train/lrm": lrm,
"train/lrm_adam": lrm_adam,
"train/lrm_matrix": lrm_matrix,
"train/dt": dt,
"train/tok_per_sec": tok_per_sec,
"train/mfu": mfu,
@ -582,6 +605,8 @@ get_report().log(section="Base model training", data=[
"DDP world size": ddp_world_size,
"warmup_ratio": args.warmup_ratio,
"warmdown_ratio": args.warmdown_ratio,
"matrix_warmup_ratio": args.matrix_warmup_ratio,
"matrix_warmdown_ratio": args.matrix_warmdown_ratio,
"final_lr_frac": args.final_lr_frac,
},
{ # stats about training outcomes