Add muonh model and quickrun

This commit is contained in:
dangxingyu 2026-02-03 20:14:51 -05:00
parent d510b1385b
commit e28d4ead22
4 changed files with 464 additions and 33 deletions

View File

@ -2,14 +2,16 @@
GPT model (rewrite, a lot simpler)
Notable features:
- rotary embeddings (and no positional embeddings)
- QK norm
- QK norm (functional RMSNorm, no learnable params)
- untied weights for token embedding and lm_head
- relu^2 activation in MLP
- norm after token embedding
- no learnable params in rmsnorm
- norm after token embedding (functional RMSNorm)
- parameterized RMSNorm in blocks (learnable gamma)
- per-block projection scalars for attention/MLP
- no bias in linear layers
- Group-Query Attention (GQA) support for more efficient inference
- Flash Attention 3 integration
- optional Hyperball optimizer for matrix params
"""
from functools import partial
@ -40,7 +42,7 @@ class GPTConfig:
def norm(x):
# Purely functional rmsnorm with no learnable params
# Purely functional RMSNorm with no learnable params
return F.rms_norm(x, (x.size(-1),))
@ -72,6 +74,7 @@ class CausalSelfAttention(nn.Module):
self.c_proj = nn.Linear(self.n_embd, self.n_embd, bias=False)
self.ve_gate_channels = 32
self.ve_gate = nn.Linear(self.ve_gate_channels, self.n_kv_head, bias=False) if has_ve(layer_idx, config.n_layer) else None
self.c_proj_scalar = nn.Parameter(torch.zeros(config.n_embd))
def forward(self, x, ve, cos_sin, window_size, kv_cache):
B, T, C = x.size()
@ -115,6 +118,7 @@ class CausalSelfAttention(nn.Module):
# Re-assemble the heads and project back to residual stream
y = y.contiguous().view(B, T, -1)
y = self.c_proj(y)
y = y * self.c_proj_scalar
return y
@ -123,23 +127,27 @@ class MLP(nn.Module):
super().__init__()
self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=False)
self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=False)
self.c_proj_scalar = nn.Parameter(torch.zeros(config.n_embd))
def forward(self, x):
x = self.c_fc(x)
x = F.relu(x).square()
x = self.c_proj(x)
x = x * self.c_proj_scalar
return x
class Block(nn.Module):
def __init__(self, config, layer_idx):
super().__init__()
self.attn_norm = nn.RMSNorm(config.n_embd)
self.attn = CausalSelfAttention(config, layer_idx)
self.mlp_norm = nn.RMSNorm(config.n_embd)
self.mlp = MLP(config)
def forward(self, x, ve, cos_sin, window_size, kv_cache):
x = x + self.attn(norm(x), ve, cos_sin, window_size, kv_cache)
x = x + self.mlp(norm(x))
x = x + self.attn(self.attn_norm(x), ve, cos_sin, window_size, kv_cache)
x = x + self.mlp(self.mlp_norm(x))
return x
@ -196,15 +204,21 @@ class GPT(nn.Module):
attn.c_q: uniform, std=1/sqrt(n_embd)
attn.c_k: uniform, std=1/sqrt(n_embd)
attn.c_v: uniform, std=1/sqrt(n_embd)
attn.c_proj: zeros
attn.c_proj: uniform, std=1/sqrt(n_embd)
mlp.c_fc: uniform, std=1/sqrt(n_embd)
mlp.c_proj: zeros
mlp.c_proj: uniform, std=1/sqrt(n_embd)
nn.RMSNorm weight: ones (via explicit init below)
"""
# Embedding and unembedding
torch.nn.init.normal_(self.transformer.wte.weight, mean=0.0, std=1.0)
torch.nn.init.normal_(self.lm_head.weight, mean=0.0, std=0.001)
# nn.RMSNorm weight parameters: init to ones (must be explicit due to meta device)
for module in self.modules():
if isinstance(module, nn.RMSNorm):
module.weight.fill_(1.0)
# Transformer blocks: uniform init with bound = sqrt(3) * std (same standard deviation as normal)
n_embd = self.config.n_embd
s = 3**0.5 * n_embd**-0.5 # sqrt(3) multiplier makes sure Uniform achieves the same std as Normal
@ -212,22 +226,38 @@ class GPT(nn.Module):
torch.nn.init.uniform_(block.attn.c_q.weight, -s, s) # weights use Uniform to avoid outliers
torch.nn.init.uniform_(block.attn.c_k.weight, -s, s)
torch.nn.init.uniform_(block.attn.c_v.weight, -s, s)
torch.nn.init.zeros_(block.attn.c_proj.weight) # projections are zero
torch.nn.init.uniform_(block.attn.c_proj.weight, -s, s)
torch.nn.init.uniform_(block.mlp.c_fc.weight, -s, s)
torch.nn.init.zeros_(block.mlp.c_proj.weight)
torch.nn.init.uniform_(block.mlp.c_proj.weight, -s, s)
# Per-layer scalars
self.resid_lambdas.fill_(1.0) # 1.0 => typical residual connections at init
self.x0_lambdas.fill_(0.1) # 0.1 => small initial weight for skip connection to input embedding
# Per-block projection scalars (zero-init for stable training start)
for block in self.transformer.h:
block.attn.c_proj_scalar.fill_(0.0)
block.mlp.c_proj_scalar.fill_(0.0)
if self.transformer.wte.weight.device.type == "cuda":
block.attn.c_proj_scalar.data = block.attn.c_proj_scalar.data.to(torch.bfloat16)
block.mlp.c_proj_scalar.data = block.mlp.c_proj_scalar.data.to(torch.bfloat16)
# Block RMSNorm weights (cast to bf16 for fused kernel)
for block in self.transformer.h:
block.attn_norm.weight.fill_(1.0)
block.mlp_norm.weight.fill_(1.0)
if self.transformer.wte.weight.device.type == "cuda":
block.attn_norm.to(dtype=torch.bfloat16)
block.mlp_norm.to(dtype=torch.bfloat16)
# Value embeddings (init like c_v: uniform with same std)
for ve in self.value_embeds.values():
torch.nn.init.uniform_(ve.weight, -s, s)
# Gate weights init to zero so gates start at sigmoid(0) = 0.5, scaled by 2 -> 1.0 (neutral)
# Gate weights init to uniform (avoid zero-norm params under Hyperball)
for block in self.transformer.h:
if block.attn.ve_gate is not None:
torch.nn.init.zeros_(block.attn.ve_gate.weight)
torch.nn.init.uniform_(block.attn.ve_gate.weight, -s, s)
# Rotary embeddings
head_dim = self.config.n_embd // self.config.n_head
@ -302,10 +332,11 @@ class GPT(nn.Module):
- Chinchilla counts exp/sum/divide in attention softmax as flops (a little sus and very tiny => we ignore)
"""
nparams = sum(p.numel() for p in self.parameters())
# Exclude non-matmul params: embeddings and per-layer scalars
# Exclude non-matmul params: embeddings, per-layer scalars, and 1D params in blocks
value_embeds_numel = sum(ve.weight.numel() for ve in self.value_embeds.values())
block_1d_params = sum(p.numel() for p in self.transformer.h.parameters() if p.ndim == 1)
nparams_exclude = (self.transformer.wte.weight.numel() + value_embeds_numel +
self.resid_lambdas.numel() + self.x0_lambdas.numel())
self.resid_lambdas.numel() + self.x0_lambdas.numel() + block_1d_params)
h, q, t = self.config.n_head, self.config.n_embd // self.config.n_head, self.config.sequence_len
# Sum attention FLOPs per layer, accounting for sliding window
attn_flops = 0
@ -332,31 +363,36 @@ class GPT(nn.Module):
wte = sum(p.numel() for p in self.transformer.wte.parameters())
value_embeds = sum(p.numel() for p in self.value_embeds.parameters())
lm_head = sum(p.numel() for p in self.lm_head.parameters())
transformer_matrices = sum(p.numel() for p in self.transformer.h.parameters())
block_matrix_params = sum(p.numel() for p in self.transformer.h.parameters() if p.ndim == 2)
block_1d_params = sum(p.numel() for p in self.transformer.h.parameters() if p.ndim == 1)
scalars = self.resid_lambdas.numel() + self.x0_lambdas.numel()
total = wte + value_embeds + lm_head + transformer_matrices + scalars
total = wte + value_embeds + lm_head + block_matrix_params + block_1d_params + scalars
assert total == sum(p.numel() for p in self.parameters()), "Parameter count mismatch"
return {
'wte': wte,
'value_embeds': value_embeds,
'lm_head': lm_head,
'transformer_matrices': transformer_matrices,
'transformer_matrices': block_matrix_params,
'norm_and_proj_scalars': block_1d_params,
'scalars': scalars,
'total': total,
}
def setup_optimizer(self, unembedding_lr=0.004, embedding_lr=0.2, matrix_lr=0.02, weight_decay=0.0, adam_betas=(0.8, 0.95), scalar_lr=0.5):
def setup_optimizer(self, unembedding_lr=0.004, embedding_lr=0.2, matrix_lr=0.02, weight_decay=0.0, adam_betas=(0.8, 0.95), scalar_lr=0.5, matrix_optimizer="muon"):
model_dim = self.config.n_embd
ddp, rank, local_rank, world_size = get_dist_info()
# Separate out all parameters into groups
matrix_params = list(self.transformer.h.parameters())
block_matrix_params = [p for p in self.transformer.h.parameters() if p.ndim == 2]
block_1d_params = [p for p in self.transformer.h.parameters() if p.ndim == 1]
value_embeds_params = list(self.value_embeds.parameters())
embedding_params = list(self.transformer.wte.parameters())
lm_head_params = list(self.lm_head.parameters())
resid_params = [self.resid_lambdas]
x0_params = [self.x0_lambdas]
assert len(list(self.parameters())) == len(matrix_params) + len(embedding_params) + len(lm_head_params) + len(value_embeds_params) + len(resid_params) + len(x0_params)
all_params_count = (len(block_matrix_params) + len(block_1d_params) + len(embedding_params) +
len(lm_head_params) + len(value_embeds_params) + len(resid_params) + len(x0_params))
assert len(list(self.parameters())) == all_params_count
# Scale the LR for the AdamW parameters by ∝1/√dmodel (tuned for 768 dim model)
dmodel_lr_scale = (model_dim / 768) ** -0.5
@ -370,14 +406,23 @@ class GPT(nn.Module):
dict(kind='adamw', params=value_embeds_params, lr=embedding_lr * dmodel_lr_scale, betas=adam_betas, eps=1e-10, weight_decay=0.0),
dict(kind='adamw', params=resid_params, lr=scalar_lr * 0.01, betas=adam_betas, eps=1e-10, weight_decay=0.0),
dict(kind='adamw', params=x0_params, lr=scalar_lr, betas=(0.96, 0.95), eps=1e-10, weight_decay=0.0), # higher beta1 for x0
dict(kind='adamw', params=block_1d_params, lr=scalar_lr, betas=adam_betas, eps=1e-10, weight_decay=0.0),
]
# Muon groups (matrix params, grouped by shape for stacking)
for shape in sorted({p.shape for p in matrix_params}):
group_params = [p for p in matrix_params if p.shape == shape]
param_groups.append(dict(
kind='muon', params=group_params, lr=matrix_lr,
momentum=0.95, ns_steps=5, beta2=0.95, weight_decay=weight_decay,
))
# Matrix params (Muon or Hyperball), grouped by shape for stacking
if matrix_optimizer not in {"muon", "hyperball"}:
raise ValueError(f"Unknown matrix_optimizer: {matrix_optimizer}")
for shape in sorted({p.shape for p in block_matrix_params}):
group_params = [p for p in block_matrix_params if p.shape == shape]
if matrix_optimizer == "hyperball":
param_groups.append(dict(
kind='hyperball', params=group_params, lr=matrix_lr,
momentum=0.95, ns_steps=5, beta2=0.95,
))
else:
param_groups.append(dict(
kind='muon', params=group_params, lr=matrix_lr,
momentum=0.95, ns_steps=5, beta2=0.95, weight_decay=weight_decay,
))
Factory = DistMuonAdamW if ddp else MuonAdamW
optimizer = Factory(param_groups)

View File

@ -1,6 +1,6 @@
"""
A nice and efficient mixed AdamW/Muon Combined Optimizer.
Usually the embeddings and scalars go into AdamW, and the matrix parameters go into Muon.
Usually the embeddings and scalars go into AdamW, and the matrix parameters go into Muon/Hyperball.
Two versions are provided (MuonAdamW, DistMuonAdamW), for single GPU and distributed.
Addapted from: https://github.com/KellerJordan/modded-nanogpt
@ -141,6 +141,80 @@ def muon_step_fused(
mask = (g * stacked_params) >= 0
stacked_params.sub_(lr * g + lr * wd * stacked_params * mask)
# -----------------------------------------------------------------------------
"""
Hyperball optimizer (MuonH): Muon with scale-invariant updates.
https://github.com/marin-community/marin/blob/main/lib/levanter/src/levanter/optim/muonh.py
The key difference from Muon is that weights maintain constant Frobenius norm
throughout training via the following update rule:
p_new_intermediate = p - lr * u * ||p|| / ||u||
p_new = p_new_intermediate / ||p_new_intermediate|| * ||p||
This projects the update onto the hypersphere of constant norm, hence "Hyperball".
Uses variance reduction like Muon, but no cautious weight decay.
"""
@torch.compile(dynamic=False, fullgraph=True)
def hyperball_step_fused(
stacked_grads: Tensor, # (K, M, N) - stacked gradients
stacked_params: Tensor, # (K, M, N) - stacked parameters
momentum_buffer: Tensor, # (K, M, N) - momentum buffer
second_momentum_buffer: Tensor, # (K, M, 1) or (K, 1, N) - factored second moment
p_norm: Tensor, # (K, 1, 1) - pre-computed Frobenius norm of params (constant)
momentum_t: Tensor, # () - 0-D CPU tensor, momentum coefficient
lr_t: Tensor, # () - 0-D CPU tensor, learning rate
beta2_t: Tensor, # () - 0-D CPU tensor, beta2 for second moment
ns_steps: int, # 5 - number of Newton-Schulz/Polar Express iterations
red_dim: int, # -1 or -2 - reduction dimension for variance
) -> None:
"""
Fused Hyperball step: momentum -> polar_express -> variance_reduction -> scale_invariant_update
All in one compiled graph. Weights maintain constant Frobenius norm.
"""
# Nesterov momentum
momentum = momentum_t.to(stacked_grads.dtype)
momentum_buffer.lerp_(stacked_grads, 1 - momentum)
g = stacked_grads.lerp_(momentum_buffer, momentum)
# Polar express orthogonalization
X = g.bfloat16()
X = X / (X.norm(dim=(-2, -1), keepdim=True) * 1.02 + 1e-6)
if g.size(-2) > g.size(-1): # Tall matrix
for a, b, c in polar_express_coeffs[:ns_steps]:
A = X.mT @ X
B = b * A + c * (A @ A)
X = a * X + X @ B
else: # Wide matrix
for a, b, c in polar_express_coeffs[:ns_steps]:
A = X @ X.mT
B = b * A + c * (A @ A)
X = a * X + B @ X
g = X
# Variance reduction (note: this preserves ||g||_F, so ||u||_F == ||g||_F == v_norm)
beta2 = beta2_t.to(g.dtype)
v_mean = g.float().square().mean(dim=red_dim, keepdim=True)
red_dim_size = g.size(red_dim)
second_momentum_buffer.lerp_(v_mean.to(dtype=second_momentum_buffer.dtype), 1 - beta2)
step_size = second_momentum_buffer.clamp_min(1e-10).rsqrt()
scaled_sq_sum = (v_mean * red_dim_size) * step_size.float().square()
v_norm_new = scaled_sq_sum.sum(dim=(-2, -1), keepdim=True).sqrt()
p_norm = p_norm.to(v_norm_new.dtype)
final_scale = step_size * p_norm / v_norm_new.clamp_min(1e-10)
g = g * final_scale.to(g.dtype)
u = g.to(stacked_params.dtype)
# Scale-invariant update: keeps ||p|| constant
lr = lr_t.to(stacked_params.dtype)
stacked_params.sub_(lr * u)
# Project back to hypersphere: p = p * (||p_orig|| / ||p_new||)
p_new_norm = stacked_params.norm(dim=(-2, -1), keepdim=True).clamp_min(1e-10)
stacked_params.mul_(p_norm / p_new_norm)
# -----------------------------------------------------------------------------
# Single GPU version of the MuonAdamW optimizer.
# Used mostly for reference, debugging and testing.
@ -167,9 +241,10 @@ class MuonAdamW(torch.optim.Optimizer):
Arguments:
param_groups: List of dicts, each containing:
- 'params': List of parameters
- 'kind': 'adamw' or 'muon'
- 'kind': 'adamw', 'muon', or 'hyperball'
- For AdamW groups: 'lr', 'betas', 'eps', 'weight_decay'
- For Muon groups: 'lr', 'momentum', 'ns_steps', 'beta2', 'weight_decay'
- For Hyperball groups: 'lr', 'momentum', 'ns_steps', 'beta2'
"""
def __init__(self, param_groups: list[dict]):
super().__init__(param_groups, defaults={})
@ -186,6 +261,10 @@ class MuonAdamW(torch.optim.Optimizer):
self._muon_lr_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
self._muon_wd_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
self._muon_beta2_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
# Hyperball tensors
self._hyperball_momentum_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
self._hyperball_lr_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
self._hyperball_beta2_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
def _step_adamw(self, group: dict) -> None:
"""
@ -276,6 +355,64 @@ class MuonAdamW(torch.optim.Optimizer):
# Copy back to original params
torch._foreach_copy_(params, list(stacked_params.unbind(0)))
def _step_hyperball(self, group: dict) -> None:
"""
Hyperball update for all params in the group (stacked for efficiency).
Like Muon, but uses scale-invariant updates that keep weight norms constant.
"""
params: list[Tensor] = group['params']
if not params:
return
# Get or create group-level buffers (stored in first param's state for convenience)
p = params[0]
state = self.state[p]
num_params = len(params)
shape, device, dtype = p.shape, p.device, p.dtype
# Stack grads and params (NOTE: this assumes all params have the same shape)
stacked_grads = torch.stack([p.grad for p in params])
stacked_params = torch.stack(params)
# Momentum buffer for every individual parameter
if "momentum_buffer" not in state:
state["momentum_buffer"] = torch.zeros(num_params, *shape, dtype=dtype, device=device)
momentum_buffer = state["momentum_buffer"]
# Second momentum buffer is factored, either per-row or per-column
if "second_momentum_buffer" not in state:
state_shape = (num_params, shape[-2], 1) if shape[-2] >= shape[-1] else (num_params, 1, shape[-1])
state["second_momentum_buffer"] = torch.zeros(state_shape, dtype=dtype, device=device)
second_momentum_buffer = state["second_momentum_buffer"]
red_dim = -1 if shape[-2] >= shape[-1] else -2
# Pre-compute and cache p_norm (Frobenius norm of each param, constant throughout training)
if "p_norm" not in state:
state["p_norm"] = stacked_params.norm(dim=(-2, -1), keepdim=True).clone()
p_norm = state["p_norm"]
# Fill all the 0-D tensors with current values
self._hyperball_momentum_t.fill_(group["momentum"])
self._hyperball_lr_t.fill_(group["lr"])
self._hyperball_beta2_t.fill_(group["beta2"] if group["beta2"] is not None else 0.0)
# Single fused kernel: momentum -> polar_express -> variance_reduction -> scale_invariant_update
hyperball_step_fused(
stacked_grads,
stacked_params,
momentum_buffer,
second_momentum_buffer,
p_norm,
self._hyperball_momentum_t,
self._hyperball_lr_t,
self._hyperball_beta2_t,
group["ns_steps"],
red_dim,
)
# Copy back to original params
torch._foreach_copy_(params, list(stacked_params.unbind(0)))
@torch.no_grad()
def step(self):
for group in self.param_groups:
@ -283,6 +420,8 @@ class MuonAdamW(torch.optim.Optimizer):
self._step_adamw(group)
elif group['kind'] == 'muon':
self._step_muon(group)
elif group['kind'] == 'hyperball':
self._step_hyperball(group)
else:
raise ValueError(f"Unknown optimizer kind: {group['kind']}")
@ -344,9 +483,10 @@ class DistMuonAdamW(torch.optim.Optimizer):
Arguments:
param_groups: List of dicts, each containing:
- 'params': List of parameters
- 'kind': 'adamw' or 'muon'
- 'kind': 'adamw', 'muon', or 'hyperball'
- For AdamW groups: 'lr', 'betas', 'eps', 'weight_decay'
- For Muon groups: 'lr', 'momentum', 'ns_steps', 'beta2', 'weight_decay'
- For Hyperball groups: 'lr', 'momentum', 'ns_steps', 'beta2'
"""
def __init__(self, param_groups: list[dict]):
super().__init__(param_groups, defaults={})
@ -361,6 +501,10 @@ class DistMuonAdamW(torch.optim.Optimizer):
self._muon_lr_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
self._muon_wd_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
self._muon_beta2_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
# Hyperball tensors
self._hyperball_momentum_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
self._hyperball_lr_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
self._hyperball_beta2_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
def _reduce_adamw(self, group: dict, world_size: int) -> dict:
"""Launch async reduce ops for AdamW group. Returns info dict with per-param infos."""
@ -491,6 +635,85 @@ class DistMuonAdamW(torch.optim.Optimizer):
future = dist.all_gather_into_tensor(stacked_params, updated_params, async_op=True).get_future()
gather_list.append(dict(future=future, stacked_params=stacked_params, params=params))
def _reduce_hyperball(self, group: dict, world_size: int) -> dict:
"""Launch async reduce op for Hyperball group. Returns info dict."""
params = group['params']
chunk_size = (len(params) + world_size - 1) // world_size
padded_num_params = chunk_size * world_size
p = params[0]
shape, device, dtype = p.shape, p.device, p.dtype
# Stack grads and zero-pad to padded_num_params
grad_stack = torch.stack([p.grad for p in params])
stacked_grads = torch.empty(padded_num_params, *shape, dtype=dtype, device=device)
stacked_grads[:len(params)].copy_(grad_stack)
if len(params) < padded_num_params:
stacked_grads[len(params):].zero_()
# Reduce_scatter to get this rank's chunk
grad_chunk = torch.empty(chunk_size, *shape, dtype=dtype, device=device)
future = dist.reduce_scatter_tensor(grad_chunk, stacked_grads, op=dist.ReduceOp.AVG, async_op=True).get_future()
return dict(future=future, grad_chunk=grad_chunk, stacked_grads=stacked_grads, chunk_size=chunk_size)
def _compute_hyperball(self, group: dict, info: dict, gather_list: list, rank: int) -> None:
"""Wait for reduce, compute Hyperball updates, launch gather."""
info['future'].wait()
params = group['params']
chunk_size = info['chunk_size']
grad_chunk = info['grad_chunk']
p = params[0]
shape, device, dtype = p.shape, p.device, p.dtype
# How many params does this rank own?
start_idx = rank * chunk_size
num_owned = min(chunk_size, max(0, len(params) - start_idx))
# Get or create group-level state
state = self.state[p]
if "momentum_buffer" not in state:
state["momentum_buffer"] = torch.zeros(chunk_size, *shape, dtype=dtype, device=device)
if "second_momentum_buffer" not in state:
state_shape = (chunk_size, shape[-2], 1) if shape[-2] >= shape[-1] else (chunk_size, 1, shape[-1])
state["second_momentum_buffer"] = torch.zeros(state_shape, dtype=dtype, device=device)
red_dim = -1 if shape[-2] >= shape[-1] else -2
# Build output buffer for all_gather
updated_params = torch.empty(chunk_size, *shape, dtype=dtype, device=device)
if num_owned > 0:
owned_params = [params[start_idx + i] for i in range(num_owned)]
stacked_owned = torch.stack(owned_params)
# Pre-compute and cache p_norm for owned params (constant throughout training)
if "p_norm" not in state:
state["p_norm"] = stacked_owned.norm(dim=(-2, -1), keepdim=True).clone()
p_norm = state["p_norm"]
# Fill 0-D tensors and run fused kernel
self._hyperball_momentum_t.fill_(group["momentum"])
self._hyperball_lr_t.fill_(group["lr"])
self._hyperball_beta2_t.fill_(group["beta2"] if group["beta2"] is not None else 0.0)
hyperball_step_fused(
grad_chunk[:num_owned], stacked_owned,
state["momentum_buffer"][:num_owned],
state["second_momentum_buffer"][:num_owned],
p_norm,
self._hyperball_momentum_t, self._hyperball_lr_t,
self._hyperball_beta2_t,
group["ns_steps"],
red_dim,
)
updated_params[:num_owned].copy_(stacked_owned)
if num_owned < chunk_size:
updated_params[num_owned:].zero_()
# Reuse stacked_grads buffer for all_gather output
stacked_params = info["stacked_grads"]
future = dist.all_gather_into_tensor(stacked_params, updated_params, async_op=True).get_future()
gather_list.append(dict(future=future, stacked_params=stacked_params, params=params))
def _finish_gathers(self, gather_list: list) -> None:
"""Wait for all gathers and copy Muon params back."""
for info in gather_list:
@ -511,6 +734,8 @@ class DistMuonAdamW(torch.optim.Optimizer):
reduce_infos.append(self._reduce_adamw(group, world_size))
elif group['kind'] == 'muon':
reduce_infos.append(self._reduce_muon(group, world_size))
elif group['kind'] == 'hyperball':
reduce_infos.append(self._reduce_hyperball(group, world_size))
else:
raise ValueError(f"Unknown optimizer kind: {group['kind']}")
@ -521,6 +746,8 @@ class DistMuonAdamW(torch.optim.Optimizer):
self._compute_adamw(group, info, gather_list, rank, world_size)
elif group['kind'] == 'muon':
self._compute_muon(group, info, gather_list, rank)
elif group['kind'] == 'hyperball':
self._compute_hyperball(group, info, gather_list, rank)
else:
raise ValueError(f"Unknown optimizer kind: {group['kind']}")

156
runs/quickrun_gamma_muonh_d24.sh Executable file
View File

@ -0,0 +1,156 @@
#!/bin/bash
# Quickrun: GPT-Gamma + MuonH (Hyperball), depth=24
# - Parameterized RMSNorm (learnable gamma)
# - Per-block projection scalars
# - Hyperball or Muon for matrix params
#
# Examples:
# bash runs/quickrun_gamma_muonh_d24.sh
# WANDB_RUN=exp1 bash runs/quickrun_gamma_muonh_d24.sh
# FP8=1 FP8_RECIPE=tensorwise bash runs/quickrun_gamma_muonh_d24.sh
set -e
# -----------------------------------------------------------------------------
# Config
DEPTH="${DEPTH:-24}"
NUM_SHARDS="${NUM_SHARDS:-370}" # ~10B tokens for d24 @ ratio~11
TARGET_RATIO="${TARGET_RATIO:-11}"
WINDOW_PATTERN="${WINDOW_PATTERN:-SSSL}"
DEVICE_BATCH_SIZE="${DEVICE_BATCH_SIZE:-16}"
TOTAL_BATCH_SIZE="${TOTAL_BATCH_SIZE:-524288}"
NPROC_PER_NODE="${NPROC_PER_NODE:-$(nvidia-smi -L 2>/dev/null | wc -l || echo 1)}"
if [ "$NPROC_PER_NODE" -eq 0 ]; then
NPROC_PER_NODE=1
fi
# Optimizer
MATRIX_OPTIMIZER="${MATRIX_OPTIMIZER:-hyperball}"
SCALAR_LR="${SCALAR_LR:-0.5}"
MATRIX_LR="$SCALAR_LR" # share with scalar LR
WARMDOWN_RATIO="${WARMDOWN_RATIO:-0.3}"
# AdamW
EMBEDDING_LR="${EMBEDDING_LR:-0.3}"
UNEMBEDDING_LR="${UNEMBEDDING_LR:-0.004}"
# Wandb
export WANDB_MODE=offline
WANDB_RUN="${WANDB_RUN:-muonh_d${DEPTH}_ratio${TARGET_RATIO}}"
MODEL_TAG="${MODEL_TAG:-d${DEPTH}_gamma_muonh}"
# FP8
FP8_ARGS=""
if [ "${FP8:-0}" -eq 1 ]; then
FP8_RECIPE="${FP8_RECIPE:-tensorwise}"
FP8_ARGS="--fp8 --fp8-recipe=${FP8_RECIPE}"
fi
# NCCL (single node)
export NCCL_P2P_LEVEL=NVL
export NCCL_NVLS_ENABLE=1
export NCCL_IB_DISABLE=1
# -----------------------------------------------------------------------------
# Paths and cache
export OMP_NUM_THREADS=1
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
PROJECT_ROOT="$(dirname "$SCRIPT_DIR")"
export NANOCHAT_BASE_DIR="$PROJECT_ROOT/cache"
export TORCHINDUCTOR_CACHE_DIR="$NANOCHAT_BASE_DIR/torch_inductor"
export TRITON_CACHE_DIR="$NANOCHAT_BASE_DIR/triton"
export TMPDIR="$NANOCHAT_BASE_DIR/tmp"
mkdir -p "$NANOCHAT_BASE_DIR" "$TORCHINDUCTOR_CACHE_DIR" "$TRITON_CACHE_DIR" "$TMPDIR"
# -----------------------------------------------------------------------------
# Print summary
echo "=============================================="
echo "Quickrun (GPT-Gamma + MuonH D24)"
echo "=============================================="
echo "Project root: $PROJECT_ROOT"
echo "Cache dir: $NANOCHAT_BASE_DIR"
echo "Depth: $DEPTH"
echo "Num shards: $NUM_SHARDS"
echo "Target ratio: $TARGET_RATIO"
echo "Window pattern: $WINDOW_PATTERN"
echo "Num GPUs: $NPROC_PER_NODE"
echo "Device batch size: $DEVICE_BATCH_SIZE"
echo "Total batch size: $TOTAL_BATCH_SIZE"
echo "Matrix optimizer: $MATRIX_OPTIMIZER"
echo "Matrix LR: $MATRIX_LR (shared with scalar)"
echo "Adam LRs: embedding=$EMBEDDING_LR, unembedding=$UNEMBEDDING_LR, scalar=$SCALAR_LR"
echo "Warmdown ratio: $WARMDOWN_RATIO"
echo "Wandb run: $WANDB_RUN"
echo "Model tag: $MODEL_TAG"
if [ "${FP8:-0}" -eq 1 ]; then
echo "FP8: enabled ($FP8_RECIPE)"
fi
echo "=============================================="
cd "$PROJECT_ROOT"
# -----------------------------------------------------------------------------
# Python venv
if [ ! -d ".venv" ]; then
echo "Setting up Python environment..."
command -v uv &> /dev/null || curl -LsSf https://astral.sh/uv/install.sh | sh
uv venv
uv sync --extra gpu
fi
source .venv/bin/activate
# -----------------------------------------------------------------------------
# Data + tokenizer
echo ""
echo "Downloading $NUM_SHARDS data shards..."
python -m nanochat.dataset -n "$NUM_SHARDS"
echo ""
echo "Checking tokenizer..."
python -m scripts.tok_train --max-chars=500000000 --vocab-size=32768
# -----------------------------------------------------------------------------
# Train
echo ""
echo "Starting pretraining (depth=$DEPTH)..."
TRAIN_ARGS=(
--depth=$DEPTH
--run=$WANDB_RUN
--model-tag=$MODEL_TAG
--window-pattern=$WINDOW_PATTERN
--target-param-data-ratio=$TARGET_RATIO
--device-batch-size=$DEVICE_BATCH_SIZE
--total-batch-size=$TOTAL_BATCH_SIZE
--matrix-optimizer=$MATRIX_OPTIMIZER
--matrix-lr=$MATRIX_LR
--warmdown-ratio=$WARMDOWN_RATIO
--embedding-lr=$EMBEDDING_LR
--unembedding-lr=$UNEMBEDDING_LR
--scalar-lr=$SCALAR_LR
--core-metric-every=2000
--sample-every=-1
--save-every=-1
)
if [ "$NPROC_PER_NODE" -gt 1 ]; then
torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.base_train -- \
"${TRAIN_ARGS[@]}" $FP8_ARGS
else
python -m scripts.base_train \
"${TRAIN_ARGS[@]}" $FP8_ARGS
fi
echo ""
echo "=============================================="
echo "Training complete!"
echo "=============================================="
echo "Checkpoint saved to: $NANOCHAT_BASE_DIR/base_checkpoints/${MODEL_TAG}/"

View File

@ -58,7 +58,8 @@ parser.add_argument("--total-batch-size", type=int, default=524288, help="total
parser.add_argument("--embedding-lr", type=float, default=0.3, help="learning rate for embedding parameters (Adam)")
parser.add_argument("--unembedding-lr", type=float, default=0.004, help="learning rate for unembedding parameters (Adam)")
parser.add_argument("--weight-decay", type=float, default=0.2, help="cautious weight decay for the Muon optimizer (for weights)")
parser.add_argument("--matrix-lr", type=float, default=0.02, help="learning rate for matrix parameters (Muon)")
parser.add_argument("--matrix-lr", type=float, default=0.02, help="learning rate for matrix parameters (Muon/Hyperball)")
parser.add_argument("--matrix-optimizer", type=str, default="muon", choices=["muon", "hyperball"], help="optimizer for matrix parameters")
parser.add_argument("--scalar-lr", type=float, default=0.5, help="learning rate for scalars (resid_lambdas, x0_lambdas)")
parser.add_argument("--adam-beta1", type=float, default=0.8, help="Adam beta1 for embedding/unembedding")
parser.add_argument("--adam-beta2", type=float, default=0.95, help="Adam beta2 for embedding/unembedding")
@ -303,6 +304,7 @@ optimizer = model.setup_optimizer(
weight_decay=weight_decay_scaled,
adam_betas=adam_betas,
scalar_lr=args.scalar_lr * batch_lr_scale,
matrix_optimizer=args.matrix_optimizer,
)
if resuming:
@ -331,7 +333,7 @@ def get_lr_multiplier(it):
progress = (num_iterations - it) / warmdown_iters
return progress * 1.0 + (1 - progress) * args.final_lr_frac
# Momentum scheduler for Muon optimizer
# Momentum scheduler for matrix optimizer (Muon/Hyperball)
def get_muon_momentum(it):
frac = min(it / 300, 1)
momentum = (1 - frac) * 0.85 + frac * 0.95
@ -466,8 +468,9 @@ while True:
muon_weight_decay = get_weight_decay(step)
for group in optimizer.param_groups:
group["lr"] = group["initial_lr"] * lrm
if group['kind'] == 'muon':
if group['kind'] in {'muon', 'hyperball'}:
group["momentum"] = muon_momentum
if group['kind'] == 'muon':
group["weight_decay"] = muon_weight_decay
optimizer.step()
model.zero_grad(set_to_none=True)