mirror of
https://github.com/karpathy/nanochat.git
synced 2026-03-19 11:23:19 +00:00
Add muonh model and quickrun
This commit is contained in:
parent
d510b1385b
commit
e28d4ead22
|
|
@ -2,14 +2,16 @@
|
|||
GPT model (rewrite, a lot simpler)
|
||||
Notable features:
|
||||
- rotary embeddings (and no positional embeddings)
|
||||
- QK norm
|
||||
- QK norm (functional RMSNorm, no learnable params)
|
||||
- untied weights for token embedding and lm_head
|
||||
- relu^2 activation in MLP
|
||||
- norm after token embedding
|
||||
- no learnable params in rmsnorm
|
||||
- norm after token embedding (functional RMSNorm)
|
||||
- parameterized RMSNorm in blocks (learnable gamma)
|
||||
- per-block projection scalars for attention/MLP
|
||||
- no bias in linear layers
|
||||
- Group-Query Attention (GQA) support for more efficient inference
|
||||
- Flash Attention 3 integration
|
||||
- optional Hyperball optimizer for matrix params
|
||||
"""
|
||||
|
||||
from functools import partial
|
||||
|
|
@ -40,7 +42,7 @@ class GPTConfig:
|
|||
|
||||
|
||||
def norm(x):
|
||||
# Purely functional rmsnorm with no learnable params
|
||||
# Purely functional RMSNorm with no learnable params
|
||||
return F.rms_norm(x, (x.size(-1),))
|
||||
|
||||
|
||||
|
|
@ -72,6 +74,7 @@ class CausalSelfAttention(nn.Module):
|
|||
self.c_proj = nn.Linear(self.n_embd, self.n_embd, bias=False)
|
||||
self.ve_gate_channels = 32
|
||||
self.ve_gate = nn.Linear(self.ve_gate_channels, self.n_kv_head, bias=False) if has_ve(layer_idx, config.n_layer) else None
|
||||
self.c_proj_scalar = nn.Parameter(torch.zeros(config.n_embd))
|
||||
|
||||
def forward(self, x, ve, cos_sin, window_size, kv_cache):
|
||||
B, T, C = x.size()
|
||||
|
|
@ -115,6 +118,7 @@ class CausalSelfAttention(nn.Module):
|
|||
# Re-assemble the heads and project back to residual stream
|
||||
y = y.contiguous().view(B, T, -1)
|
||||
y = self.c_proj(y)
|
||||
y = y * self.c_proj_scalar
|
||||
return y
|
||||
|
||||
|
||||
|
|
@ -123,23 +127,27 @@ class MLP(nn.Module):
|
|||
super().__init__()
|
||||
self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=False)
|
||||
self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=False)
|
||||
self.c_proj_scalar = nn.Parameter(torch.zeros(config.n_embd))
|
||||
|
||||
def forward(self, x):
|
||||
x = self.c_fc(x)
|
||||
x = F.relu(x).square()
|
||||
x = self.c_proj(x)
|
||||
x = x * self.c_proj_scalar
|
||||
return x
|
||||
|
||||
|
||||
class Block(nn.Module):
|
||||
def __init__(self, config, layer_idx):
|
||||
super().__init__()
|
||||
self.attn_norm = nn.RMSNorm(config.n_embd)
|
||||
self.attn = CausalSelfAttention(config, layer_idx)
|
||||
self.mlp_norm = nn.RMSNorm(config.n_embd)
|
||||
self.mlp = MLP(config)
|
||||
|
||||
def forward(self, x, ve, cos_sin, window_size, kv_cache):
|
||||
x = x + self.attn(norm(x), ve, cos_sin, window_size, kv_cache)
|
||||
x = x + self.mlp(norm(x))
|
||||
x = x + self.attn(self.attn_norm(x), ve, cos_sin, window_size, kv_cache)
|
||||
x = x + self.mlp(self.mlp_norm(x))
|
||||
return x
|
||||
|
||||
|
||||
|
|
@ -196,15 +204,21 @@ class GPT(nn.Module):
|
|||
attn.c_q: uniform, std=1/sqrt(n_embd)
|
||||
attn.c_k: uniform, std=1/sqrt(n_embd)
|
||||
attn.c_v: uniform, std=1/sqrt(n_embd)
|
||||
attn.c_proj: zeros
|
||||
attn.c_proj: uniform, std=1/sqrt(n_embd)
|
||||
mlp.c_fc: uniform, std=1/sqrt(n_embd)
|
||||
mlp.c_proj: zeros
|
||||
mlp.c_proj: uniform, std=1/sqrt(n_embd)
|
||||
nn.RMSNorm weight: ones (via explicit init below)
|
||||
"""
|
||||
|
||||
# Embedding and unembedding
|
||||
torch.nn.init.normal_(self.transformer.wte.weight, mean=0.0, std=1.0)
|
||||
torch.nn.init.normal_(self.lm_head.weight, mean=0.0, std=0.001)
|
||||
|
||||
# nn.RMSNorm weight parameters: init to ones (must be explicit due to meta device)
|
||||
for module in self.modules():
|
||||
if isinstance(module, nn.RMSNorm):
|
||||
module.weight.fill_(1.0)
|
||||
|
||||
# Transformer blocks: uniform init with bound = sqrt(3) * std (same standard deviation as normal)
|
||||
n_embd = self.config.n_embd
|
||||
s = 3**0.5 * n_embd**-0.5 # sqrt(3) multiplier makes sure Uniform achieves the same std as Normal
|
||||
|
|
@ -212,22 +226,38 @@ class GPT(nn.Module):
|
|||
torch.nn.init.uniform_(block.attn.c_q.weight, -s, s) # weights use Uniform to avoid outliers
|
||||
torch.nn.init.uniform_(block.attn.c_k.weight, -s, s)
|
||||
torch.nn.init.uniform_(block.attn.c_v.weight, -s, s)
|
||||
torch.nn.init.zeros_(block.attn.c_proj.weight) # projections are zero
|
||||
torch.nn.init.uniform_(block.attn.c_proj.weight, -s, s)
|
||||
torch.nn.init.uniform_(block.mlp.c_fc.weight, -s, s)
|
||||
torch.nn.init.zeros_(block.mlp.c_proj.weight)
|
||||
torch.nn.init.uniform_(block.mlp.c_proj.weight, -s, s)
|
||||
|
||||
# Per-layer scalars
|
||||
self.resid_lambdas.fill_(1.0) # 1.0 => typical residual connections at init
|
||||
self.x0_lambdas.fill_(0.1) # 0.1 => small initial weight for skip connection to input embedding
|
||||
|
||||
# Per-block projection scalars (zero-init for stable training start)
|
||||
for block in self.transformer.h:
|
||||
block.attn.c_proj_scalar.fill_(0.0)
|
||||
block.mlp.c_proj_scalar.fill_(0.0)
|
||||
if self.transformer.wte.weight.device.type == "cuda":
|
||||
block.attn.c_proj_scalar.data = block.attn.c_proj_scalar.data.to(torch.bfloat16)
|
||||
block.mlp.c_proj_scalar.data = block.mlp.c_proj_scalar.data.to(torch.bfloat16)
|
||||
|
||||
# Block RMSNorm weights (cast to bf16 for fused kernel)
|
||||
for block in self.transformer.h:
|
||||
block.attn_norm.weight.fill_(1.0)
|
||||
block.mlp_norm.weight.fill_(1.0)
|
||||
if self.transformer.wte.weight.device.type == "cuda":
|
||||
block.attn_norm.to(dtype=torch.bfloat16)
|
||||
block.mlp_norm.to(dtype=torch.bfloat16)
|
||||
|
||||
# Value embeddings (init like c_v: uniform with same std)
|
||||
for ve in self.value_embeds.values():
|
||||
torch.nn.init.uniform_(ve.weight, -s, s)
|
||||
|
||||
# Gate weights init to zero so gates start at sigmoid(0) = 0.5, scaled by 2 -> 1.0 (neutral)
|
||||
# Gate weights init to uniform (avoid zero-norm params under Hyperball)
|
||||
for block in self.transformer.h:
|
||||
if block.attn.ve_gate is not None:
|
||||
torch.nn.init.zeros_(block.attn.ve_gate.weight)
|
||||
torch.nn.init.uniform_(block.attn.ve_gate.weight, -s, s)
|
||||
|
||||
# Rotary embeddings
|
||||
head_dim = self.config.n_embd // self.config.n_head
|
||||
|
|
@ -302,10 +332,11 @@ class GPT(nn.Module):
|
|||
- Chinchilla counts exp/sum/divide in attention softmax as flops (a little sus and very tiny => we ignore)
|
||||
"""
|
||||
nparams = sum(p.numel() for p in self.parameters())
|
||||
# Exclude non-matmul params: embeddings and per-layer scalars
|
||||
# Exclude non-matmul params: embeddings, per-layer scalars, and 1D params in blocks
|
||||
value_embeds_numel = sum(ve.weight.numel() for ve in self.value_embeds.values())
|
||||
block_1d_params = sum(p.numel() for p in self.transformer.h.parameters() if p.ndim == 1)
|
||||
nparams_exclude = (self.transformer.wte.weight.numel() + value_embeds_numel +
|
||||
self.resid_lambdas.numel() + self.x0_lambdas.numel())
|
||||
self.resid_lambdas.numel() + self.x0_lambdas.numel() + block_1d_params)
|
||||
h, q, t = self.config.n_head, self.config.n_embd // self.config.n_head, self.config.sequence_len
|
||||
# Sum attention FLOPs per layer, accounting for sliding window
|
||||
attn_flops = 0
|
||||
|
|
@ -332,31 +363,36 @@ class GPT(nn.Module):
|
|||
wte = sum(p.numel() for p in self.transformer.wte.parameters())
|
||||
value_embeds = sum(p.numel() for p in self.value_embeds.parameters())
|
||||
lm_head = sum(p.numel() for p in self.lm_head.parameters())
|
||||
transformer_matrices = sum(p.numel() for p in self.transformer.h.parameters())
|
||||
block_matrix_params = sum(p.numel() for p in self.transformer.h.parameters() if p.ndim == 2)
|
||||
block_1d_params = sum(p.numel() for p in self.transformer.h.parameters() if p.ndim == 1)
|
||||
scalars = self.resid_lambdas.numel() + self.x0_lambdas.numel()
|
||||
total = wte + value_embeds + lm_head + transformer_matrices + scalars
|
||||
total = wte + value_embeds + lm_head + block_matrix_params + block_1d_params + scalars
|
||||
assert total == sum(p.numel() for p in self.parameters()), "Parameter count mismatch"
|
||||
return {
|
||||
'wte': wte,
|
||||
'value_embeds': value_embeds,
|
||||
'lm_head': lm_head,
|
||||
'transformer_matrices': transformer_matrices,
|
||||
'transformer_matrices': block_matrix_params,
|
||||
'norm_and_proj_scalars': block_1d_params,
|
||||
'scalars': scalars,
|
||||
'total': total,
|
||||
}
|
||||
|
||||
def setup_optimizer(self, unembedding_lr=0.004, embedding_lr=0.2, matrix_lr=0.02, weight_decay=0.0, adam_betas=(0.8, 0.95), scalar_lr=0.5):
|
||||
def setup_optimizer(self, unembedding_lr=0.004, embedding_lr=0.2, matrix_lr=0.02, weight_decay=0.0, adam_betas=(0.8, 0.95), scalar_lr=0.5, matrix_optimizer="muon"):
|
||||
model_dim = self.config.n_embd
|
||||
ddp, rank, local_rank, world_size = get_dist_info()
|
||||
|
||||
# Separate out all parameters into groups
|
||||
matrix_params = list(self.transformer.h.parameters())
|
||||
block_matrix_params = [p for p in self.transformer.h.parameters() if p.ndim == 2]
|
||||
block_1d_params = [p for p in self.transformer.h.parameters() if p.ndim == 1]
|
||||
value_embeds_params = list(self.value_embeds.parameters())
|
||||
embedding_params = list(self.transformer.wte.parameters())
|
||||
lm_head_params = list(self.lm_head.parameters())
|
||||
resid_params = [self.resid_lambdas]
|
||||
x0_params = [self.x0_lambdas]
|
||||
assert len(list(self.parameters())) == len(matrix_params) + len(embedding_params) + len(lm_head_params) + len(value_embeds_params) + len(resid_params) + len(x0_params)
|
||||
all_params_count = (len(block_matrix_params) + len(block_1d_params) + len(embedding_params) +
|
||||
len(lm_head_params) + len(value_embeds_params) + len(resid_params) + len(x0_params))
|
||||
assert len(list(self.parameters())) == all_params_count
|
||||
|
||||
# Scale the LR for the AdamW parameters by ∝1/√dmodel (tuned for 768 dim model)
|
||||
dmodel_lr_scale = (model_dim / 768) ** -0.5
|
||||
|
|
@ -370,14 +406,23 @@ class GPT(nn.Module):
|
|||
dict(kind='adamw', params=value_embeds_params, lr=embedding_lr * dmodel_lr_scale, betas=adam_betas, eps=1e-10, weight_decay=0.0),
|
||||
dict(kind='adamw', params=resid_params, lr=scalar_lr * 0.01, betas=adam_betas, eps=1e-10, weight_decay=0.0),
|
||||
dict(kind='adamw', params=x0_params, lr=scalar_lr, betas=(0.96, 0.95), eps=1e-10, weight_decay=0.0), # higher beta1 for x0
|
||||
dict(kind='adamw', params=block_1d_params, lr=scalar_lr, betas=adam_betas, eps=1e-10, weight_decay=0.0),
|
||||
]
|
||||
# Muon groups (matrix params, grouped by shape for stacking)
|
||||
for shape in sorted({p.shape for p in matrix_params}):
|
||||
group_params = [p for p in matrix_params if p.shape == shape]
|
||||
param_groups.append(dict(
|
||||
kind='muon', params=group_params, lr=matrix_lr,
|
||||
momentum=0.95, ns_steps=5, beta2=0.95, weight_decay=weight_decay,
|
||||
))
|
||||
# Matrix params (Muon or Hyperball), grouped by shape for stacking
|
||||
if matrix_optimizer not in {"muon", "hyperball"}:
|
||||
raise ValueError(f"Unknown matrix_optimizer: {matrix_optimizer}")
|
||||
for shape in sorted({p.shape for p in block_matrix_params}):
|
||||
group_params = [p for p in block_matrix_params if p.shape == shape]
|
||||
if matrix_optimizer == "hyperball":
|
||||
param_groups.append(dict(
|
||||
kind='hyperball', params=group_params, lr=matrix_lr,
|
||||
momentum=0.95, ns_steps=5, beta2=0.95,
|
||||
))
|
||||
else:
|
||||
param_groups.append(dict(
|
||||
kind='muon', params=group_params, lr=matrix_lr,
|
||||
momentum=0.95, ns_steps=5, beta2=0.95, weight_decay=weight_decay,
|
||||
))
|
||||
|
||||
Factory = DistMuonAdamW if ddp else MuonAdamW
|
||||
optimizer = Factory(param_groups)
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
"""
|
||||
A nice and efficient mixed AdamW/Muon Combined Optimizer.
|
||||
Usually the embeddings and scalars go into AdamW, and the matrix parameters go into Muon.
|
||||
Usually the embeddings and scalars go into AdamW, and the matrix parameters go into Muon/Hyperball.
|
||||
Two versions are provided (MuonAdamW, DistMuonAdamW), for single GPU and distributed.
|
||||
|
||||
Addapted from: https://github.com/KellerJordan/modded-nanogpt
|
||||
|
|
@ -141,6 +141,80 @@ def muon_step_fused(
|
|||
mask = (g * stacked_params) >= 0
|
||||
stacked_params.sub_(lr * g + lr * wd * stacked_params * mask)
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
"""
|
||||
Hyperball optimizer (MuonH): Muon with scale-invariant updates.
|
||||
https://github.com/marin-community/marin/blob/main/lib/levanter/src/levanter/optim/muonh.py
|
||||
|
||||
The key difference from Muon is that weights maintain constant Frobenius norm
|
||||
throughout training via the following update rule:
|
||||
|
||||
p_new_intermediate = p - lr * u * ||p|| / ||u||
|
||||
p_new = p_new_intermediate / ||p_new_intermediate|| * ||p||
|
||||
|
||||
This projects the update onto the hypersphere of constant norm, hence "Hyperball".
|
||||
Uses variance reduction like Muon, but no cautious weight decay.
|
||||
"""
|
||||
|
||||
@torch.compile(dynamic=False, fullgraph=True)
|
||||
def hyperball_step_fused(
|
||||
stacked_grads: Tensor, # (K, M, N) - stacked gradients
|
||||
stacked_params: Tensor, # (K, M, N) - stacked parameters
|
||||
momentum_buffer: Tensor, # (K, M, N) - momentum buffer
|
||||
second_momentum_buffer: Tensor, # (K, M, 1) or (K, 1, N) - factored second moment
|
||||
p_norm: Tensor, # (K, 1, 1) - pre-computed Frobenius norm of params (constant)
|
||||
momentum_t: Tensor, # () - 0-D CPU tensor, momentum coefficient
|
||||
lr_t: Tensor, # () - 0-D CPU tensor, learning rate
|
||||
beta2_t: Tensor, # () - 0-D CPU tensor, beta2 for second moment
|
||||
ns_steps: int, # 5 - number of Newton-Schulz/Polar Express iterations
|
||||
red_dim: int, # -1 or -2 - reduction dimension for variance
|
||||
) -> None:
|
||||
"""
|
||||
Fused Hyperball step: momentum -> polar_express -> variance_reduction -> scale_invariant_update
|
||||
All in one compiled graph. Weights maintain constant Frobenius norm.
|
||||
"""
|
||||
|
||||
# Nesterov momentum
|
||||
momentum = momentum_t.to(stacked_grads.dtype)
|
||||
momentum_buffer.lerp_(stacked_grads, 1 - momentum)
|
||||
g = stacked_grads.lerp_(momentum_buffer, momentum)
|
||||
|
||||
# Polar express orthogonalization
|
||||
X = g.bfloat16()
|
||||
X = X / (X.norm(dim=(-2, -1), keepdim=True) * 1.02 + 1e-6)
|
||||
if g.size(-2) > g.size(-1): # Tall matrix
|
||||
for a, b, c in polar_express_coeffs[:ns_steps]:
|
||||
A = X.mT @ X
|
||||
B = b * A + c * (A @ A)
|
||||
X = a * X + X @ B
|
||||
else: # Wide matrix
|
||||
for a, b, c in polar_express_coeffs[:ns_steps]:
|
||||
A = X @ X.mT
|
||||
B = b * A + c * (A @ A)
|
||||
X = a * X + B @ X
|
||||
g = X
|
||||
|
||||
# Variance reduction (note: this preserves ||g||_F, so ||u||_F == ||g||_F == v_norm)
|
||||
beta2 = beta2_t.to(g.dtype)
|
||||
v_mean = g.float().square().mean(dim=red_dim, keepdim=True)
|
||||
red_dim_size = g.size(red_dim)
|
||||
second_momentum_buffer.lerp_(v_mean.to(dtype=second_momentum_buffer.dtype), 1 - beta2)
|
||||
step_size = second_momentum_buffer.clamp_min(1e-10).rsqrt()
|
||||
scaled_sq_sum = (v_mean * red_dim_size) * step_size.float().square()
|
||||
v_norm_new = scaled_sq_sum.sum(dim=(-2, -1), keepdim=True).sqrt()
|
||||
p_norm = p_norm.to(v_norm_new.dtype)
|
||||
final_scale = step_size * p_norm / v_norm_new.clamp_min(1e-10)
|
||||
g = g * final_scale.to(g.dtype)
|
||||
u = g.to(stacked_params.dtype)
|
||||
|
||||
# Scale-invariant update: keeps ||p|| constant
|
||||
lr = lr_t.to(stacked_params.dtype)
|
||||
stacked_params.sub_(lr * u)
|
||||
|
||||
# Project back to hypersphere: p = p * (||p_orig|| / ||p_new||)
|
||||
p_new_norm = stacked_params.norm(dim=(-2, -1), keepdim=True).clamp_min(1e-10)
|
||||
stacked_params.mul_(p_norm / p_new_norm)
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Single GPU version of the MuonAdamW optimizer.
|
||||
# Used mostly for reference, debugging and testing.
|
||||
|
|
@ -167,9 +241,10 @@ class MuonAdamW(torch.optim.Optimizer):
|
|||
Arguments:
|
||||
param_groups: List of dicts, each containing:
|
||||
- 'params': List of parameters
|
||||
- 'kind': 'adamw' or 'muon'
|
||||
- 'kind': 'adamw', 'muon', or 'hyperball'
|
||||
- For AdamW groups: 'lr', 'betas', 'eps', 'weight_decay'
|
||||
- For Muon groups: 'lr', 'momentum', 'ns_steps', 'beta2', 'weight_decay'
|
||||
- For Hyperball groups: 'lr', 'momentum', 'ns_steps', 'beta2'
|
||||
"""
|
||||
def __init__(self, param_groups: list[dict]):
|
||||
super().__init__(param_groups, defaults={})
|
||||
|
|
@ -186,6 +261,10 @@ class MuonAdamW(torch.optim.Optimizer):
|
|||
self._muon_lr_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
||||
self._muon_wd_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
||||
self._muon_beta2_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
||||
# Hyperball tensors
|
||||
self._hyperball_momentum_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
||||
self._hyperball_lr_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
||||
self._hyperball_beta2_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
||||
|
||||
def _step_adamw(self, group: dict) -> None:
|
||||
"""
|
||||
|
|
@ -276,6 +355,64 @@ class MuonAdamW(torch.optim.Optimizer):
|
|||
# Copy back to original params
|
||||
torch._foreach_copy_(params, list(stacked_params.unbind(0)))
|
||||
|
||||
def _step_hyperball(self, group: dict) -> None:
|
||||
"""
|
||||
Hyperball update for all params in the group (stacked for efficiency).
|
||||
Like Muon, but uses scale-invariant updates that keep weight norms constant.
|
||||
"""
|
||||
params: list[Tensor] = group['params']
|
||||
if not params:
|
||||
return
|
||||
|
||||
# Get or create group-level buffers (stored in first param's state for convenience)
|
||||
p = params[0]
|
||||
state = self.state[p]
|
||||
num_params = len(params)
|
||||
shape, device, dtype = p.shape, p.device, p.dtype
|
||||
|
||||
# Stack grads and params (NOTE: this assumes all params have the same shape)
|
||||
stacked_grads = torch.stack([p.grad for p in params])
|
||||
stacked_params = torch.stack(params)
|
||||
|
||||
# Momentum buffer for every individual parameter
|
||||
if "momentum_buffer" not in state:
|
||||
state["momentum_buffer"] = torch.zeros(num_params, *shape, dtype=dtype, device=device)
|
||||
momentum_buffer = state["momentum_buffer"]
|
||||
|
||||
# Second momentum buffer is factored, either per-row or per-column
|
||||
if "second_momentum_buffer" not in state:
|
||||
state_shape = (num_params, shape[-2], 1) if shape[-2] >= shape[-1] else (num_params, 1, shape[-1])
|
||||
state["second_momentum_buffer"] = torch.zeros(state_shape, dtype=dtype, device=device)
|
||||
second_momentum_buffer = state["second_momentum_buffer"]
|
||||
red_dim = -1 if shape[-2] >= shape[-1] else -2
|
||||
|
||||
# Pre-compute and cache p_norm (Frobenius norm of each param, constant throughout training)
|
||||
if "p_norm" not in state:
|
||||
state["p_norm"] = stacked_params.norm(dim=(-2, -1), keepdim=True).clone()
|
||||
p_norm = state["p_norm"]
|
||||
|
||||
# Fill all the 0-D tensors with current values
|
||||
self._hyperball_momentum_t.fill_(group["momentum"])
|
||||
self._hyperball_lr_t.fill_(group["lr"])
|
||||
self._hyperball_beta2_t.fill_(group["beta2"] if group["beta2"] is not None else 0.0)
|
||||
|
||||
# Single fused kernel: momentum -> polar_express -> variance_reduction -> scale_invariant_update
|
||||
hyperball_step_fused(
|
||||
stacked_grads,
|
||||
stacked_params,
|
||||
momentum_buffer,
|
||||
second_momentum_buffer,
|
||||
p_norm,
|
||||
self._hyperball_momentum_t,
|
||||
self._hyperball_lr_t,
|
||||
self._hyperball_beta2_t,
|
||||
group["ns_steps"],
|
||||
red_dim,
|
||||
)
|
||||
|
||||
# Copy back to original params
|
||||
torch._foreach_copy_(params, list(stacked_params.unbind(0)))
|
||||
|
||||
@torch.no_grad()
|
||||
def step(self):
|
||||
for group in self.param_groups:
|
||||
|
|
@ -283,6 +420,8 @@ class MuonAdamW(torch.optim.Optimizer):
|
|||
self._step_adamw(group)
|
||||
elif group['kind'] == 'muon':
|
||||
self._step_muon(group)
|
||||
elif group['kind'] == 'hyperball':
|
||||
self._step_hyperball(group)
|
||||
else:
|
||||
raise ValueError(f"Unknown optimizer kind: {group['kind']}")
|
||||
|
||||
|
|
@ -344,9 +483,10 @@ class DistMuonAdamW(torch.optim.Optimizer):
|
|||
Arguments:
|
||||
param_groups: List of dicts, each containing:
|
||||
- 'params': List of parameters
|
||||
- 'kind': 'adamw' or 'muon'
|
||||
- 'kind': 'adamw', 'muon', or 'hyperball'
|
||||
- For AdamW groups: 'lr', 'betas', 'eps', 'weight_decay'
|
||||
- For Muon groups: 'lr', 'momentum', 'ns_steps', 'beta2', 'weight_decay'
|
||||
- For Hyperball groups: 'lr', 'momentum', 'ns_steps', 'beta2'
|
||||
"""
|
||||
def __init__(self, param_groups: list[dict]):
|
||||
super().__init__(param_groups, defaults={})
|
||||
|
|
@ -361,6 +501,10 @@ class DistMuonAdamW(torch.optim.Optimizer):
|
|||
self._muon_lr_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
||||
self._muon_wd_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
||||
self._muon_beta2_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
||||
# Hyperball tensors
|
||||
self._hyperball_momentum_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
||||
self._hyperball_lr_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
||||
self._hyperball_beta2_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
||||
|
||||
def _reduce_adamw(self, group: dict, world_size: int) -> dict:
|
||||
"""Launch async reduce ops for AdamW group. Returns info dict with per-param infos."""
|
||||
|
|
@ -491,6 +635,85 @@ class DistMuonAdamW(torch.optim.Optimizer):
|
|||
future = dist.all_gather_into_tensor(stacked_params, updated_params, async_op=True).get_future()
|
||||
gather_list.append(dict(future=future, stacked_params=stacked_params, params=params))
|
||||
|
||||
def _reduce_hyperball(self, group: dict, world_size: int) -> dict:
|
||||
"""Launch async reduce op for Hyperball group. Returns info dict."""
|
||||
params = group['params']
|
||||
chunk_size = (len(params) + world_size - 1) // world_size
|
||||
padded_num_params = chunk_size * world_size
|
||||
p = params[0]
|
||||
shape, device, dtype = p.shape, p.device, p.dtype
|
||||
|
||||
# Stack grads and zero-pad to padded_num_params
|
||||
grad_stack = torch.stack([p.grad for p in params])
|
||||
stacked_grads = torch.empty(padded_num_params, *shape, dtype=dtype, device=device)
|
||||
stacked_grads[:len(params)].copy_(grad_stack)
|
||||
if len(params) < padded_num_params:
|
||||
stacked_grads[len(params):].zero_()
|
||||
|
||||
# Reduce_scatter to get this rank's chunk
|
||||
grad_chunk = torch.empty(chunk_size, *shape, dtype=dtype, device=device)
|
||||
future = dist.reduce_scatter_tensor(grad_chunk, stacked_grads, op=dist.ReduceOp.AVG, async_op=True).get_future()
|
||||
|
||||
return dict(future=future, grad_chunk=grad_chunk, stacked_grads=stacked_grads, chunk_size=chunk_size)
|
||||
|
||||
def _compute_hyperball(self, group: dict, info: dict, gather_list: list, rank: int) -> None:
|
||||
"""Wait for reduce, compute Hyperball updates, launch gather."""
|
||||
info['future'].wait()
|
||||
params = group['params']
|
||||
chunk_size = info['chunk_size']
|
||||
grad_chunk = info['grad_chunk']
|
||||
p = params[0]
|
||||
shape, device, dtype = p.shape, p.device, p.dtype
|
||||
|
||||
# How many params does this rank own?
|
||||
start_idx = rank * chunk_size
|
||||
num_owned = min(chunk_size, max(0, len(params) - start_idx))
|
||||
|
||||
# Get or create group-level state
|
||||
state = self.state[p]
|
||||
if "momentum_buffer" not in state:
|
||||
state["momentum_buffer"] = torch.zeros(chunk_size, *shape, dtype=dtype, device=device)
|
||||
if "second_momentum_buffer" not in state:
|
||||
state_shape = (chunk_size, shape[-2], 1) if shape[-2] >= shape[-1] else (chunk_size, 1, shape[-1])
|
||||
state["second_momentum_buffer"] = torch.zeros(state_shape, dtype=dtype, device=device)
|
||||
red_dim = -1 if shape[-2] >= shape[-1] else -2
|
||||
|
||||
# Build output buffer for all_gather
|
||||
updated_params = torch.empty(chunk_size, *shape, dtype=dtype, device=device)
|
||||
|
||||
if num_owned > 0:
|
||||
owned_params = [params[start_idx + i] for i in range(num_owned)]
|
||||
stacked_owned = torch.stack(owned_params)
|
||||
|
||||
# Pre-compute and cache p_norm for owned params (constant throughout training)
|
||||
if "p_norm" not in state:
|
||||
state["p_norm"] = stacked_owned.norm(dim=(-2, -1), keepdim=True).clone()
|
||||
p_norm = state["p_norm"]
|
||||
|
||||
# Fill 0-D tensors and run fused kernel
|
||||
self._hyperball_momentum_t.fill_(group["momentum"])
|
||||
self._hyperball_lr_t.fill_(group["lr"])
|
||||
self._hyperball_beta2_t.fill_(group["beta2"] if group["beta2"] is not None else 0.0)
|
||||
hyperball_step_fused(
|
||||
grad_chunk[:num_owned], stacked_owned,
|
||||
state["momentum_buffer"][:num_owned],
|
||||
state["second_momentum_buffer"][:num_owned],
|
||||
p_norm,
|
||||
self._hyperball_momentum_t, self._hyperball_lr_t,
|
||||
self._hyperball_beta2_t,
|
||||
group["ns_steps"],
|
||||
red_dim,
|
||||
)
|
||||
updated_params[:num_owned].copy_(stacked_owned)
|
||||
|
||||
if num_owned < chunk_size:
|
||||
updated_params[num_owned:].zero_()
|
||||
|
||||
# Reuse stacked_grads buffer for all_gather output
|
||||
stacked_params = info["stacked_grads"]
|
||||
future = dist.all_gather_into_tensor(stacked_params, updated_params, async_op=True).get_future()
|
||||
gather_list.append(dict(future=future, stacked_params=stacked_params, params=params))
|
||||
|
||||
def _finish_gathers(self, gather_list: list) -> None:
|
||||
"""Wait for all gathers and copy Muon params back."""
|
||||
for info in gather_list:
|
||||
|
|
@ -511,6 +734,8 @@ class DistMuonAdamW(torch.optim.Optimizer):
|
|||
reduce_infos.append(self._reduce_adamw(group, world_size))
|
||||
elif group['kind'] == 'muon':
|
||||
reduce_infos.append(self._reduce_muon(group, world_size))
|
||||
elif group['kind'] == 'hyperball':
|
||||
reduce_infos.append(self._reduce_hyperball(group, world_size))
|
||||
else:
|
||||
raise ValueError(f"Unknown optimizer kind: {group['kind']}")
|
||||
|
||||
|
|
@ -521,6 +746,8 @@ class DistMuonAdamW(torch.optim.Optimizer):
|
|||
self._compute_adamw(group, info, gather_list, rank, world_size)
|
||||
elif group['kind'] == 'muon':
|
||||
self._compute_muon(group, info, gather_list, rank)
|
||||
elif group['kind'] == 'hyperball':
|
||||
self._compute_hyperball(group, info, gather_list, rank)
|
||||
else:
|
||||
raise ValueError(f"Unknown optimizer kind: {group['kind']}")
|
||||
|
||||
|
|
|
|||
156
runs/quickrun_gamma_muonh_d24.sh
Executable file
156
runs/quickrun_gamma_muonh_d24.sh
Executable file
|
|
@ -0,0 +1,156 @@
|
|||
#!/bin/bash
|
||||
|
||||
# Quickrun: GPT-Gamma + MuonH (Hyperball), depth=24
|
||||
# - Parameterized RMSNorm (learnable gamma)
|
||||
# - Per-block projection scalars
|
||||
# - Hyperball or Muon for matrix params
|
||||
#
|
||||
# Examples:
|
||||
# bash runs/quickrun_gamma_muonh_d24.sh
|
||||
# WANDB_RUN=exp1 bash runs/quickrun_gamma_muonh_d24.sh
|
||||
# FP8=1 FP8_RECIPE=tensorwise bash runs/quickrun_gamma_muonh_d24.sh
|
||||
|
||||
set -e
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Config
|
||||
|
||||
DEPTH="${DEPTH:-24}"
|
||||
NUM_SHARDS="${NUM_SHARDS:-370}" # ~10B tokens for d24 @ ratio~11
|
||||
TARGET_RATIO="${TARGET_RATIO:-11}"
|
||||
WINDOW_PATTERN="${WINDOW_PATTERN:-SSSL}"
|
||||
DEVICE_BATCH_SIZE="${DEVICE_BATCH_SIZE:-16}"
|
||||
TOTAL_BATCH_SIZE="${TOTAL_BATCH_SIZE:-524288}"
|
||||
|
||||
NPROC_PER_NODE="${NPROC_PER_NODE:-$(nvidia-smi -L 2>/dev/null | wc -l || echo 1)}"
|
||||
if [ "$NPROC_PER_NODE" -eq 0 ]; then
|
||||
NPROC_PER_NODE=1
|
||||
fi
|
||||
|
||||
# Optimizer
|
||||
MATRIX_OPTIMIZER="${MATRIX_OPTIMIZER:-hyperball}"
|
||||
SCALAR_LR="${SCALAR_LR:-0.5}"
|
||||
MATRIX_LR="$SCALAR_LR" # share with scalar LR
|
||||
WARMDOWN_RATIO="${WARMDOWN_RATIO:-0.3}"
|
||||
|
||||
# AdamW
|
||||
EMBEDDING_LR="${EMBEDDING_LR:-0.3}"
|
||||
UNEMBEDDING_LR="${UNEMBEDDING_LR:-0.004}"
|
||||
|
||||
# Wandb
|
||||
export WANDB_MODE=offline
|
||||
WANDB_RUN="${WANDB_RUN:-muonh_d${DEPTH}_ratio${TARGET_RATIO}}"
|
||||
MODEL_TAG="${MODEL_TAG:-d${DEPTH}_gamma_muonh}"
|
||||
|
||||
# FP8
|
||||
FP8_ARGS=""
|
||||
if [ "${FP8:-0}" -eq 1 ]; then
|
||||
FP8_RECIPE="${FP8_RECIPE:-tensorwise}"
|
||||
FP8_ARGS="--fp8 --fp8-recipe=${FP8_RECIPE}"
|
||||
fi
|
||||
|
||||
# NCCL (single node)
|
||||
export NCCL_P2P_LEVEL=NVL
|
||||
export NCCL_NVLS_ENABLE=1
|
||||
export NCCL_IB_DISABLE=1
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Paths and cache
|
||||
|
||||
export OMP_NUM_THREADS=1
|
||||
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
||||
PROJECT_ROOT="$(dirname "$SCRIPT_DIR")"
|
||||
export NANOCHAT_BASE_DIR="$PROJECT_ROOT/cache"
|
||||
export TORCHINDUCTOR_CACHE_DIR="$NANOCHAT_BASE_DIR/torch_inductor"
|
||||
export TRITON_CACHE_DIR="$NANOCHAT_BASE_DIR/triton"
|
||||
export TMPDIR="$NANOCHAT_BASE_DIR/tmp"
|
||||
mkdir -p "$NANOCHAT_BASE_DIR" "$TORCHINDUCTOR_CACHE_DIR" "$TRITON_CACHE_DIR" "$TMPDIR"
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Print summary
|
||||
|
||||
echo "=============================================="
|
||||
echo "Quickrun (GPT-Gamma + MuonH D24)"
|
||||
echo "=============================================="
|
||||
echo "Project root: $PROJECT_ROOT"
|
||||
echo "Cache dir: $NANOCHAT_BASE_DIR"
|
||||
echo "Depth: $DEPTH"
|
||||
echo "Num shards: $NUM_SHARDS"
|
||||
echo "Target ratio: $TARGET_RATIO"
|
||||
echo "Window pattern: $WINDOW_PATTERN"
|
||||
echo "Num GPUs: $NPROC_PER_NODE"
|
||||
echo "Device batch size: $DEVICE_BATCH_SIZE"
|
||||
echo "Total batch size: $TOTAL_BATCH_SIZE"
|
||||
echo "Matrix optimizer: $MATRIX_OPTIMIZER"
|
||||
echo "Matrix LR: $MATRIX_LR (shared with scalar)"
|
||||
echo "Adam LRs: embedding=$EMBEDDING_LR, unembedding=$UNEMBEDDING_LR, scalar=$SCALAR_LR"
|
||||
echo "Warmdown ratio: $WARMDOWN_RATIO"
|
||||
echo "Wandb run: $WANDB_RUN"
|
||||
echo "Model tag: $MODEL_TAG"
|
||||
if [ "${FP8:-0}" -eq 1 ]; then
|
||||
echo "FP8: enabled ($FP8_RECIPE)"
|
||||
fi
|
||||
echo "=============================================="
|
||||
|
||||
cd "$PROJECT_ROOT"
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Python venv
|
||||
|
||||
if [ ! -d ".venv" ]; then
|
||||
echo "Setting up Python environment..."
|
||||
command -v uv &> /dev/null || curl -LsSf https://astral.sh/uv/install.sh | sh
|
||||
uv venv
|
||||
uv sync --extra gpu
|
||||
fi
|
||||
source .venv/bin/activate
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Data + tokenizer
|
||||
|
||||
echo ""
|
||||
echo "Downloading $NUM_SHARDS data shards..."
|
||||
python -m nanochat.dataset -n "$NUM_SHARDS"
|
||||
|
||||
echo ""
|
||||
echo "Checking tokenizer..."
|
||||
python -m scripts.tok_train --max-chars=500000000 --vocab-size=32768
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Train
|
||||
|
||||
echo ""
|
||||
echo "Starting pretraining (depth=$DEPTH)..."
|
||||
|
||||
TRAIN_ARGS=(
|
||||
--depth=$DEPTH
|
||||
--run=$WANDB_RUN
|
||||
--model-tag=$MODEL_TAG
|
||||
--window-pattern=$WINDOW_PATTERN
|
||||
--target-param-data-ratio=$TARGET_RATIO
|
||||
--device-batch-size=$DEVICE_BATCH_SIZE
|
||||
--total-batch-size=$TOTAL_BATCH_SIZE
|
||||
--matrix-optimizer=$MATRIX_OPTIMIZER
|
||||
--matrix-lr=$MATRIX_LR
|
||||
--warmdown-ratio=$WARMDOWN_RATIO
|
||||
--embedding-lr=$EMBEDDING_LR
|
||||
--unembedding-lr=$UNEMBEDDING_LR
|
||||
--scalar-lr=$SCALAR_LR
|
||||
--core-metric-every=2000
|
||||
--sample-every=-1
|
||||
--save-every=-1
|
||||
)
|
||||
|
||||
if [ "$NPROC_PER_NODE" -gt 1 ]; then
|
||||
torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.base_train -- \
|
||||
"${TRAIN_ARGS[@]}" $FP8_ARGS
|
||||
else
|
||||
python -m scripts.base_train \
|
||||
"${TRAIN_ARGS[@]}" $FP8_ARGS
|
||||
fi
|
||||
|
||||
echo ""
|
||||
echo "=============================================="
|
||||
echo "Training complete!"
|
||||
echo "=============================================="
|
||||
echo "Checkpoint saved to: $NANOCHAT_BASE_DIR/base_checkpoints/${MODEL_TAG}/"
|
||||
|
|
@ -58,7 +58,8 @@ parser.add_argument("--total-batch-size", type=int, default=524288, help="total
|
|||
parser.add_argument("--embedding-lr", type=float, default=0.3, help="learning rate for embedding parameters (Adam)")
|
||||
parser.add_argument("--unembedding-lr", type=float, default=0.004, help="learning rate for unembedding parameters (Adam)")
|
||||
parser.add_argument("--weight-decay", type=float, default=0.2, help="cautious weight decay for the Muon optimizer (for weights)")
|
||||
parser.add_argument("--matrix-lr", type=float, default=0.02, help="learning rate for matrix parameters (Muon)")
|
||||
parser.add_argument("--matrix-lr", type=float, default=0.02, help="learning rate for matrix parameters (Muon/Hyperball)")
|
||||
parser.add_argument("--matrix-optimizer", type=str, default="muon", choices=["muon", "hyperball"], help="optimizer for matrix parameters")
|
||||
parser.add_argument("--scalar-lr", type=float, default=0.5, help="learning rate for scalars (resid_lambdas, x0_lambdas)")
|
||||
parser.add_argument("--adam-beta1", type=float, default=0.8, help="Adam beta1 for embedding/unembedding")
|
||||
parser.add_argument("--adam-beta2", type=float, default=0.95, help="Adam beta2 for embedding/unembedding")
|
||||
|
|
@ -303,6 +304,7 @@ optimizer = model.setup_optimizer(
|
|||
weight_decay=weight_decay_scaled,
|
||||
adam_betas=adam_betas,
|
||||
scalar_lr=args.scalar_lr * batch_lr_scale,
|
||||
matrix_optimizer=args.matrix_optimizer,
|
||||
)
|
||||
|
||||
if resuming:
|
||||
|
|
@ -331,7 +333,7 @@ def get_lr_multiplier(it):
|
|||
progress = (num_iterations - it) / warmdown_iters
|
||||
return progress * 1.0 + (1 - progress) * args.final_lr_frac
|
||||
|
||||
# Momentum scheduler for Muon optimizer
|
||||
# Momentum scheduler for matrix optimizer (Muon/Hyperball)
|
||||
def get_muon_momentum(it):
|
||||
frac = min(it / 300, 1)
|
||||
momentum = (1 - frac) * 0.85 + frac * 0.95
|
||||
|
|
@ -466,8 +468,9 @@ while True:
|
|||
muon_weight_decay = get_weight_decay(step)
|
||||
for group in optimizer.param_groups:
|
||||
group["lr"] = group["initial_lr"] * lrm
|
||||
if group['kind'] == 'muon':
|
||||
if group['kind'] in {'muon', 'hyperball'}:
|
||||
group["momentum"] = muon_momentum
|
||||
if group['kind'] == 'muon':
|
||||
group["weight_decay"] = muon_weight_decay
|
||||
optimizer.step()
|
||||
model.zero_grad(set_to_none=True)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user