Add MoE configuration and implementation in training scripts and model architecture

- Introduced parameters for Mixture of Experts (MoE) in `runmps.sh`, `base_train.py`, and `gpt.py`, allowing for dynamic configuration of experts during training.
- Enhanced `gpt.py` with new classes `MoEFeedForward` and `ExpertFFN` to implement MoE functionality in the model architecture.
- Updated `configurator.py` to handle type conversions for new MoE parameters.
- Improved logging in `base_train.py` to include MoE-related metrics and configurations during training.
- Added assertions and derived defaults for MoE parameters to ensure valid configurations.
- Implemented methods to estimate and log FLOPs for both dense and MoE active configurations during training.
- Enhanced gradient handling in `muon.py` to accommodate potential absence of gradients for unused experts.
This commit is contained in:
William Thurston 2025-11-11 19:58:38 -08:00
parent 9550053cc1
commit 25d2573f47
5 changed files with 354 additions and 14 deletions

View File

@ -130,6 +130,13 @@ SEQ_LEN=${SEQ_LEN:-1024}
DEVICE_BATCH=${DEVICE_BATCH:-16}
TOTAL_BATCH=${TOTAL_BATCH:-$((DEVICE_BATCH * SEQ_LEN))} # tokens per optimizer step
KV_HEAD_MULT=${KV_HEAD_MULT:-1}
MOE_NUM_EXPERTS=${MOE_NUM_EXPERTS:-0}
MOE_NUM_SHARED=${MOE_NUM_SHARED:--1}
MOE_EXPERTS_PER_TOKEN=${MOE_EXPERTS_PER_TOKEN:--1}
MOE_EXPERT_FFN_MULT=${MOE_EXPERT_FFN_MULT:--1}
MOE_DENSE_LAYERS=${MOE_DENSE_LAYERS:--1}
MOE_GRANULARITY_TARGET=${MOE_GRANULARITY_TARGET:-12}
MOE_ACTIVATION_DEN=${MOE_ACTIVATION_DEN:-32}
EVAL_SEQUENCES=10000
EVAL_STEPS=$(((EVAL_SEQUENCES + DEVICE_BATCH - 1) / DEVICE_BATCH))
EVAL_BATCH_MULT=4 # evaluate on 4 full batches
@ -246,6 +253,29 @@ python -m scripts.tok_eval
BASE_MODEL_TAG_FLAG_HYPHEN=""
fi
MOE_FLAGS=()
if [ "$MOE_NUM_EXPERTS" -gt 0 ]; then
MOE_FLAGS+=("--moe_num_experts=$MOE_NUM_EXPERTS")
if [ "$MOE_NUM_SHARED" -ge 0 ]; then
MOE_FLAGS+=("--moe_num_shared_experts=$MOE_NUM_SHARED")
fi
if [ "$MOE_EXPERTS_PER_TOKEN" -ge 0 ]; then
MOE_FLAGS+=("--moe_experts_per_token=$MOE_EXPERTS_PER_TOKEN")
fi
if [ "$MOE_EXPERT_FFN_MULT" != "-1" ]; then
MOE_FLAGS+=("--moe_expert_ffn_mult=$MOE_EXPERT_FFN_MULT")
fi
if [ "$MOE_DENSE_LAYERS" -ge 0 ]; then
MOE_FLAGS+=("--dense_layers_before_moe=$MOE_DENSE_LAYERS")
fi
if [ "$MOE_GRANULARITY_TARGET" != "" ]; then
MOE_FLAGS+=("--moe_granularity_target=$MOE_GRANULARITY_TARGET")
fi
if [ "$MOE_ACTIVATION_DEN" -gt 0 ]; then
MOE_FLAGS+=("--moe_activation_denominator=$MOE_ACTIVATION_DEN")
fi
fi
python -m scripts.base_train \
--depth=$BASE_DEPTH \
--max_seq_len=$SEQ_LEN \
@ -254,12 +284,12 @@ python -m scripts.tok_eval
--kv_head_mult=$KV_HEAD_MULT \
--target_param_data_ratio=$TARGET_PARAM_DATA_RATIO \
--run="$WANDB_RUN" \
--eval_every=$EVAL_STEPS \
--eval_tokens=$EVAL_TOKENS \
--core_metric_every=-1 \
--sample_every=-1 \
--checkpoint_every_steps=$BASE_CHECKPOINT_STEPS \
$BASE_MODEL_TAG_FLAG
$BASE_MODEL_TAG_FLAG \
${MOE_FLAGS[@]}
if [ "$WANDB_RUN" != "dummy" ]; then
unset WANDB_RUN_ID

View File

@ -48,7 +48,10 @@ for arg in sys.argv[1:]:
if globals()[key] is not None:
attempt_type = type(attempt)
default_type = type(globals()[key])
assert attempt_type == default_type, f"Type mismatch: {attempt_type} != {default_type}"
if default_type is float and attempt_type in (int, float):
attempt = float(attempt)
else:
assert attempt_type == default_type, f"Type mismatch: {attempt_type} != {default_type}"
# cross fingers
print0(f"Overriding: {key} = {attempt}")
globals()[key] = attempt

View File

@ -31,6 +31,13 @@ class GPTConfig:
n_head: int = 6 # number of query heads
n_kv_head: int = 6 # number of key/value heads (MQA)
n_embd: int = 768
moe_num_experts: int = 0
moe_num_shared_experts: int = 0
moe_experts_per_token: int = 2
moe_expert_ffn_mult: float = 4.0
dense_layers_before_moe: int = 0
moe_granularity_target: float = 12.0
moe_activation_denominator: int = 32
def norm(x):
@ -123,11 +130,155 @@ class MLP(nn.Module):
return x
class ExpertFFN(nn.Module):
def __init__(self, input_dim, ffn_mult):
super().__init__()
hidden_dim = max(1, int(round(ffn_mult * input_dim)))
self.c_fc = nn.Linear(input_dim, hidden_dim, bias=False)
self.c_proj = nn.Linear(hidden_dim, input_dim, bias=False)
def forward(self, x):
x = self.c_fc(x)
x = F.relu(x).square()
x = self.c_proj(x)
return x
class MoEFeedForward(nn.Module):
def __init__(self, config):
super().__init__()
self.model_dim = config.n_embd
self.num_routed_experts = config.moe_num_experts
self.num_shared_experts = config.moe_num_shared_experts
self.top_k = max(1, min(config.moe_experts_per_token, self.num_routed_experts))
self.ffn_mult = config.moe_expert_ffn_mult
assert self.num_routed_experts > 0, "MoE requires at least one routed expert"
assert self.top_k <= self.num_routed_experts, "experts_per_token must be <= num_experts"
self.router = nn.Linear(self.model_dim, self.num_routed_experts, bias=False)
self.hidden_dim = max(1, int(round(self.ffn_mult * self.model_dim)))
self.routed_w1 = nn.Parameter(torch.empty(self.num_routed_experts, self.hidden_dim, self.model_dim))
self.routed_w2 = nn.Parameter(torch.empty(self.num_routed_experts, self.model_dim, self.hidden_dim))
self.shared_experts = nn.ModuleList(
[ExpertFFN(self.model_dim, self.ffn_mult) for _ in range(self.num_shared_experts)]
)
self.register_buffer("router_bias", torch.zeros(self.num_routed_experts, dtype=torch.float32))
self.register_buffer("ema_load", torch.zeros(self.num_routed_experts, dtype=torch.float32))
self.balance_strength = 0.001
self.balance_decay = 0.99
self.bias_clamp = 0.2
self._last_stats = None
self.last_entropy = None
self.last_load = None
self._init_experts()
def _balance_router(self, assignments, probs):
if not self.training:
self.last_load = None
self.last_entropy = None
self._last_stats = None
return
with torch.no_grad():
load = torch.bincount(assignments, minlength=self.num_routed_experts).float()
tokens = max(1, assignments.numel())
load = load / tokens
target = load.new_full((self.num_routed_experts,), 1.0 / self.num_routed_experts)
self.ema_load.mul_(self.balance_decay).add_((1 - self.balance_decay) * load)
balance = target - self.ema_load
self.router_bias.add_(self.balance_strength * balance)
self.router_bias.clamp_(-self.bias_clamp, self.bias_clamp)
entropy = (-probs * torch.log(probs + 1e-9)).sum(dim=-1).mean()
load_cpu = load.detach().cpu()
stats = {
"tokens": probs.size(0),
"load_std": load_cpu.std(unbiased=False).item(),
"load_max": load_cpu.max().item(),
"load_min": load_cpu.min().item(),
"load_active_frac": (load_cpu > 0).float().mean().item(),
"load_imbalance": torch.abs(load_cpu - (1.0 / self.num_routed_experts)).mean().item(),
}
self._last_stats = stats
self.last_load = load_cpu
self.last_entropy = entropy.item()
def forward(self, x):
B, T, C = x.shape
x_flat = x.view(B * T, C)
router_logits = self.router(x_flat).float()
router_logits = router_logits + self.router_bias.to(router_logits.device)
probs = torch.softmax(router_logits, dim=-1)
topk_scores, topk_idx = torch.topk(probs, k=self.top_k, dim=-1)
topk_scores = topk_scores / torch.clamp(topk_scores.sum(dim=-1, keepdim=True), min=1e-9)
flat_assignments = topk_idx.reshape(-1)
self._balance_router(flat_assignments, probs)
mixture = self._dispatch_batched(flat_assignments, x_flat, topk_scores)
if self._last_stats is not None:
self._last_stats["gate_mean"] = topk_scores.mean().item()
self._last_stats["gate_std"] = topk_scores.std(unbiased=False).item()
confident = probs.max(dim=-1).values
self._last_stats["router_confidence"] = confident.mean().item()
self._last_stats["tokens"] = x_flat.size(0)
self._last_stats["entropy"] = self.last_entropy if self.last_entropy is not None else 0.0
if self.num_shared_experts > 0:
shared_sum = None
for shared in self.shared_experts:
out = shared(x_flat)
shared_sum = out if shared_sum is None else shared_sum + out
mixture = mixture + shared_sum / self.num_shared_experts
return mixture.view(B, T, C)
def _dispatch_batched(self, flat_assignments, x_flat, topk_scores):
num_tokens, model_dim = x_flat.shape
total = num_tokens * self.top_k
if total == 0:
return torch.zeros_like(x_flat)
token_idx = torch.arange(num_tokens, device=x_flat.device).repeat_interleave(self.top_k)
selected_inputs = x_flat[token_idx]
expert_ids = flat_assignments
w1 = torch.index_select(self.routed_w1, 0, expert_ids)
hidden = torch.bmm(w1, selected_inputs.unsqueeze(-1)).squeeze(-1)
hidden = F.relu(hidden).square()
w2 = torch.index_select(self.routed_w2, 0, expert_ids)
routed = torch.bmm(w2, hidden.unsqueeze(-1)).squeeze(-1)
routed = routed.view(num_tokens, self.top_k, model_dim)
weights = topk_scores.unsqueeze(-1).to(routed.dtype)
return torch.sum(routed * weights, dim=1)
def _init_experts(self):
self._init_linear(self.routed_w1, fan_in=self.model_dim, fan_out=self.hidden_dim)
self._init_linear(self.routed_w2, fan_in=self.hidden_dim, fan_out=self.model_dim)
@staticmethod
def _init_linear(weight, fan_in, fan_out):
std = 1.0 / math.sqrt(fan_in) * min(1.0, math.sqrt(fan_out / fan_in))
torch.nn.init.normal_(weight, mean=0.0, std=std)
def get_stats(self):
if self._last_stats is None:
return None
return dict(self._last_stats)
def count_parameters(self):
total = self.routed_w1.numel() + self.routed_w2.numel()
total += sum(p.numel() for expert in self.shared_experts for p in expert.parameters())
return total
class Block(nn.Module):
def __init__(self, config, layer_idx):
super().__init__()
self.attn = CausalSelfAttention(config, layer_idx)
self.mlp = MLP(config)
use_moe = (
config.moe_num_experts > 0
and layer_idx >= config.dense_layers_before_moe
)
self.mlp = MoEFeedForward(config) if use_moe else MLP(config)
def forward(self, x, cos_sin, kv_cache):
x = x + self.attn(norm(x), cos_sin, kv_cache)
@ -160,7 +311,13 @@ class GPT(nn.Module):
torch.nn.init.zeros_(self.lm_head.weight)
# zero out c_proj weights in all blocks
for block in self.transformer.h:
torch.nn.init.zeros_(block.mlp.c_proj.weight)
if hasattr(block.mlp, "c_proj"):
torch.nn.init.zeros_(block.mlp.c_proj.weight)
elif isinstance(block.mlp, MoEFeedForward):
with torch.no_grad():
block.mlp.routed_w2.zero_()
for expert in block.mlp.shared_experts:
torch.nn.init.zeros_(expert.c_proj.weight)
torch.nn.init.zeros_(block.attn.c_proj.weight)
# init the rotary embeddings
head_dim = self.config.n_embd // self.config.n_head
@ -210,6 +367,53 @@ class GPT(nn.Module):
num_flops_per_token = 6 * (nparams - nparams_embedding) + 12 * l * h * q * t
return num_flops_per_token
def estimate_moe_active_flops(self):
if self.config.moe_num_experts <= 0:
dense_like = self.estimate_flops()
return dense_like, dense_like
total_params = sum(p.numel() for p in self.parameters())
embed_params = self.transformer.wte.weight.numel()
linear_params = total_params - embed_params
moe_params = 0
moe_layers = 0
for block in self.transformer.h:
if isinstance(block.mlp, MoEFeedForward):
moe_layers += 1
moe_params += block.mlp.count_parameters()
residual_linear = linear_params - moe_params
n = self.config.n_embd
dense_hidden = 4 * n
dense_mlp_params = 2 * n * dense_hidden
moe_hidden = max(1, int(round(self.config.moe_expert_ffn_mult * n)))
active_experts = max(1, self.config.moe_experts_per_token) + self.config.moe_num_shared_experts
active_mlp_params = 2 * n * moe_hidden * active_experts
dense_linear_equiv = residual_linear + moe_layers * dense_mlp_params
active_linear_equiv = residual_linear + moe_layers * active_mlp_params
l, h, q, t = self.config.n_layer, self.config.n_head, self.config.n_embd // self.config.n_head, self.config.sequence_len
attn_term = 12 * l * h * q * t
dense_ref = 6 * dense_linear_equiv + attn_term
active = 6 * active_linear_equiv + attn_term
return active, dense_ref
def get_moe_stats(self):
stats = {}
aggregates = {}
counts = {}
for layer_idx, block in enumerate(self.transformer.h):
mlp = block.mlp
if isinstance(mlp, MoEFeedForward):
layer_stats = mlp.get_stats()
if not layer_stats:
continue
for key, value in layer_stats.items():
stats[f"moe/layer{layer_idx}/{key}"] = value
aggregates.setdefault(key, 0.0)
counts[key] = counts.get(key, 0) + 1
aggregates[key] += value
for key, total in aggregates.items():
stats[f"moe/{key}_mean"] = total / counts[key]
return stats
def setup_optimizers(self, unembedding_lr=0.004, embedding_lr=0.2, matrix_lr=0.02, weight_decay=0.0):
model_dim = self.config.n_embd
ddp, rank, local_rank, world_size = get_dist_info()

View File

@ -72,7 +72,8 @@ class Muon(torch.optim.Optimizer):
params: list[Tensor] = group["params"]
for p in params:
g = p.grad
assert g is not None
if g is None:
continue
state = self.state[p]
if "momentum_buffer" not in state:
state["momentum_buffer"] = torch.zeros_like(g)
@ -128,8 +129,11 @@ class DistMuon(torch.optim.Optimizer):
rank = dist.get_rank()
world_size = dist.get_world_size()
# Ensure all grads exist
assert all(p.grad is not None for group in self.param_groups for p in group["params"]), "All params must have grads"
# Ensure grads exist (experts can be unused, so synthesize zeros when needed)
for group in self.param_groups:
for p in group["params"]:
if p.grad is None:
p.grad = torch.zeros_like(p)
# Kick off all the reduce scatter operations to average up the gradients across all ranks
all_reduce_futures = []

View File

@ -38,6 +38,13 @@ device_type = "" # cuda|cpu|mps (empty => autodetect good device type default, i
depth = 20 # the depth of the Transformer model to train, rest of the kwargs are derived
max_seq_len = 2048 # max context length
kv_head_mult = 1 # number of query heads that share a single key/value head (1 disables GQA)
moe_num_experts = 0 # routed experts per MoE layer (0 disables MoE)
moe_num_shared_experts = -1 # -1 => derive (defaults to 1 shared expert)
moe_experts_per_token = -1 # -1 => derive using Ling-style sparsity (≈1/32 active)
moe_expert_ffn_mult = -1.0 # -1 => derive from granularity target (defaults to 12)
dense_layers_before_moe = -1 # -1 => derive (≈10% of layers, min 1) before switching to MoE
moe_granularity_target = 12.0 # Ling guidance: target granularity per layer (2*d_model/d_expert)
moe_activation_denominator = 32 # derive top-k as num_experts / denominator (~3% activation)
# Training horizon. Only one of these 3 will be used, in this order of precedence.
num_iterations = -1 # explicit number of steps of the optimization (-1 = disable)
target_flops = -1.0 # calculate num_iterations to reach target_flops. Useful for scaling laws experiments (-1 = disable)
@ -105,6 +112,33 @@ num_heads = max(1, (model_dim + 127) // 128) # head dim 128 (the division here i
assert kv_head_mult >= 1, "kv_head_mult must be >= 1"
assert num_heads % kv_head_mult == 0, f"num_heads ({num_heads}) must be divisible by kv_head_mult ({kv_head_mult})"
num_kv_heads = max(1, num_heads // kv_head_mult)
activation_denom = max(1, int(round(moe_activation_denominator)))
granularity_target = moe_granularity_target if moe_granularity_target > 0 else 12.0
auto_dense_layers = max(1, num_layers // 10) if moe_num_experts > 0 else num_layers
if dense_layers_before_moe < 0:
dense_layers_before_moe = auto_dense_layers
dense_layers_before_moe = max(0, min(dense_layers_before_moe, num_layers))
if moe_num_experts <= 0:
dense_layers_before_moe = num_layers
if moe_num_experts > 0:
derived_top_k = max(1, round(moe_num_experts / activation_denom))
moe_experts_per_token = moe_experts_per_token if moe_experts_per_token > 0 else derived_top_k
moe_num_shared_experts = moe_num_shared_experts if moe_num_shared_experts >= 0 else 1
if moe_expert_ffn_mult <= 0:
moe_expert_ffn_mult = max(1e-6, 2.0 / granularity_target)
assert moe_experts_per_token > 0, "moe_experts_per_token must be > 0 when MoE is enabled"
assert moe_num_experts >= moe_experts_per_token, "moe_num_experts must be >= moe_experts_per_token"
assert moe_num_shared_experts >= 0, "moe_num_shared_experts must be >= 0"
assert moe_expert_ffn_mult > 0, "moe_expert_ffn_mult must be > 0"
moe_activation_ratio = moe_experts_per_token / (moe_num_experts + moe_num_shared_experts)
moe_granularity_actual = 2.0 / moe_expert_ffn_mult
else:
moe_num_shared_experts = 0
moe_experts_per_token = 0
moe_expert_ffn_mult = 4.0 if moe_expert_ffn_mult <= 0 else moe_expert_ffn_mult
moe_activation_ratio = 0.0
moe_granularity_actual = 0.0
def _resolve_checkpoint_tag(tag, run_name, depth_value):
if tag:
return tag
@ -119,6 +153,18 @@ print0(f"model_dim: {model_dim}")
print0(f"kv_head_mult: {kv_head_mult}")
print0(f"num_heads: {num_heads}")
print0(f"num_kv_heads: {num_kv_heads}")
if moe_num_experts > 0:
print0(
"MoE config: experts=%d shared=%d topk=%d granularity=%.1f (mult=%.3f) sparsity=%.2f%% dense_preface=%d" % (
moe_num_experts,
moe_num_shared_experts,
moe_experts_per_token,
moe_granularity_actual,
moe_expert_ffn_mult,
moe_activation_ratio * 100,
dense_layers_before_moe,
)
)
print0(f"Checkpoint tag: {model_tag}")
# Optimizer / data / training length related hyperparameters
@ -132,18 +178,63 @@ print0(f"Tokens / micro-batch: {world_tokens_per_fwdbwd:,}")
print0(f"Total batch size {total_batch_size:,} => gradient accumulation steps: {grad_accum_steps}")
# -----------------------------------------------------------------------------
# Initialize the Model
model_config_kwargs = dict(sequence_len=max_seq_len, vocab_size=vocab_size, n_layer=num_layers, n_head=num_heads, n_kv_head=num_kv_heads, n_embd=model_dim)
model_config_kwargs = dict(
sequence_len=max_seq_len,
vocab_size=vocab_size,
n_layer=num_layers,
n_head=num_heads,
n_kv_head=num_kv_heads,
n_embd=model_dim,
moe_num_experts=moe_num_experts,
moe_num_shared_experts=moe_num_shared_experts,
moe_experts_per_token=moe_experts_per_token,
moe_expert_ffn_mult=moe_expert_ffn_mult,
dense_layers_before_moe=dense_layers_before_moe,
moe_granularity_target=granularity_target,
moe_activation_denominator=activation_denom,
)
with torch.device("meta"):
model_config = GPTConfig(**model_config_kwargs)
model = GPT(model_config)
model.to_empty(device=device)
model.init_weights()
orig_model = model # original, uncompiled model, for saving raw model state_dict
dense_like_flops = model.estimate_flops()
active_flops_per_token, dense_ref_flops = model.estimate_moe_active_flops()
num_flops_per_token = active_flops_per_token
model = torch.compile(model, dynamic=False) # TODO: dynamic True/False think through
num_params = sum(p.numel() for p in model.parameters())
print0(f"Number of parameters: {num_params:,}")
num_flops_per_token = model.estimate_flops()
print0(f"Estimated FLOPs per token: {num_flops_per_token:e}")
print0(f"Estimated FLOPs per token (dense-like): {dense_like_flops:e}")
if active_flops_per_token != dense_like_flops:
print0(f"Estimated FLOPs per token (MoE active): {active_flops_per_token:e}")
print0(f"Estimated FLOPs per token (dense reference): {dense_ref_flops:e}")
user_config.update({
"moe_num_experts": moe_num_experts,
"moe_num_shared_experts": moe_num_shared_experts,
"moe_experts_per_token": moe_experts_per_token,
"moe_expert_ffn_mult": moe_expert_ffn_mult,
"moe_dense_layers": dense_layers_before_moe,
"moe_activation_ratio": moe_activation_ratio,
"moe_granularity_actual": moe_granularity_actual,
"flops_per_token_dense_like": dense_like_flops,
"flops_per_token_moe_active": active_flops_per_token,
"flops_per_token_dense_reference": dense_ref_flops,
})
if not use_dummy_wandb:
wandb_run.config.update({
"moe_num_experts": moe_num_experts,
"moe_num_shared_experts": moe_num_shared_experts,
"moe_experts_per_token": moe_experts_per_token,
"moe_expert_ffn_mult": moe_expert_ffn_mult,
"moe_dense_layers": dense_layers_before_moe,
"moe_activation_ratio": moe_activation_ratio,
"moe_granularity_actual": moe_granularity_actual,
"flops_per_token_dense_like": dense_like_flops,
"flops_per_token_moe_active": active_flops_per_token,
"flops_per_token_dense_reference": dense_ref_flops,
}, allow_val_change=True)
# Calculate number of iterations. Either it is given, or from target flops, or from target data:param ratio (in that order)
assert num_iterations > 0 or target_param_data_ratio > 0 or target_flops > 0
@ -164,6 +255,9 @@ total_tokens = total_batch_size * num_iterations
print0(f"Total number of training tokens: {total_tokens:,}")
print0(f"Tokens : Params ratio: {total_batch_size * num_iterations / num_params:.2f}") # Chinchilla is ~20
print0(f"Total training FLOPs estimate: {num_flops_per_token * total_tokens:e}")
if eval_every <= 0:
eval_every = max(1, num_iterations // 100)
print0(f"Auto-setting eval_every to {eval_every} (~1% of training)")
sequences_per_step = max(1, total_batch_size // max_seq_len)
checkpoint_every_steps = int(checkpoint_every_steps)
@ -357,7 +451,7 @@ for step in range(num_iterations + 1):
total_training_time += dt # only count the time after the first 10 steps
print0(f"step {step:05d}/{num_iterations:05d} ({pct_done:.2f}%) | loss: {debiased_smooth_loss:.6f} | lrm: {lrm:.2f} | dt: {dt * 1000:.2f}ms | tok/sec (micro): {tok_per_sec:,} | tok/sec (global): {global_tok_per_sec:,} | mfu: {mfu:.2f} | total time: {total_training_time/60:.2f}m")
if step % 100 == 0:
wandb_run.log({
log_payload = {
"step": step,
"total_training_flops": flops_so_far,
"total_training_time": total_training_time,
@ -369,7 +463,10 @@ for step in range(num_iterations + 1):
"train/mfu": mfu,
"train/total_tokens": total_tokens_seen,
"train/total_sequences": total_sequences_seen,
})
}
if hasattr(orig_model, "get_moe_stats"):
log_payload.update(orig_model.get_moe_stats())
wandb_run.log(log_payload)
# print a few more stats
print0(f"Peak memory usage: {get_max_memory() / 1024 / 1024:.2f}MiB")
@ -382,7 +479,9 @@ get_report().log(section="Base model training", data=[
user_config, # CLI args
{ # stats about the training setup
"Number of parameters": num_params,
"Number of FLOPs per token": f"{num_flops_per_token:e}",
"FLOPs per token (MoE active)": f"{num_flops_per_token:e}",
"FLOPs per token (dense-like)": f"{dense_like_flops:e}",
"FLOPs per token (dense reference)": f"{dense_ref_flops:e}",
"Calculated number of iterations": num_iterations,
"Number of training tokens": total_tokens,
"Tokens : Params ratio": total_batch_size * num_iterations / num_params,