From 25d2573f47d7b247f5017f90603b70e044b0f29e Mon Sep 17 00:00:00 2001 From: William Thurston Date: Tue, 11 Nov 2025 19:58:38 -0800 Subject: [PATCH] Add MoE configuration and implementation in training scripts and model architecture - Introduced parameters for Mixture of Experts (MoE) in `runmps.sh`, `base_train.py`, and `gpt.py`, allowing for dynamic configuration of experts during training. - Enhanced `gpt.py` with new classes `MoEFeedForward` and `ExpertFFN` to implement MoE functionality in the model architecture. - Updated `configurator.py` to handle type conversions for new MoE parameters. - Improved logging in `base_train.py` to include MoE-related metrics and configurations during training. - Added assertions and derived defaults for MoE parameters to ensure valid configurations. - Implemented methods to estimate and log FLOPs for both dense and MoE active configurations during training. - Enhanced gradient handling in `muon.py` to accommodate potential absence of gradients for unused experts. --- dev/runmps.sh | 34 ++++++- nanochat/configurator.py | 5 +- nanochat/gpt.py | 208 ++++++++++++++++++++++++++++++++++++++- nanochat/muon.py | 10 +- scripts/base_train.py | 111 +++++++++++++++++++-- 5 files changed, 354 insertions(+), 14 deletions(-) diff --git a/dev/runmps.sh b/dev/runmps.sh index 5a3529d..4027949 100755 --- a/dev/runmps.sh +++ b/dev/runmps.sh @@ -130,6 +130,13 @@ SEQ_LEN=${SEQ_LEN:-1024} DEVICE_BATCH=${DEVICE_BATCH:-16} TOTAL_BATCH=${TOTAL_BATCH:-$((DEVICE_BATCH * SEQ_LEN))} # tokens per optimizer step KV_HEAD_MULT=${KV_HEAD_MULT:-1} +MOE_NUM_EXPERTS=${MOE_NUM_EXPERTS:-0} +MOE_NUM_SHARED=${MOE_NUM_SHARED:--1} +MOE_EXPERTS_PER_TOKEN=${MOE_EXPERTS_PER_TOKEN:--1} +MOE_EXPERT_FFN_MULT=${MOE_EXPERT_FFN_MULT:--1} +MOE_DENSE_LAYERS=${MOE_DENSE_LAYERS:--1} +MOE_GRANULARITY_TARGET=${MOE_GRANULARITY_TARGET:-12} +MOE_ACTIVATION_DEN=${MOE_ACTIVATION_DEN:-32} EVAL_SEQUENCES=10000 EVAL_STEPS=$(((EVAL_SEQUENCES + DEVICE_BATCH - 1) / DEVICE_BATCH)) EVAL_BATCH_MULT=4 # evaluate on 4 full batches @@ -246,6 +253,29 @@ python -m scripts.tok_eval BASE_MODEL_TAG_FLAG_HYPHEN="" fi + MOE_FLAGS=() + if [ "$MOE_NUM_EXPERTS" -gt 0 ]; then + MOE_FLAGS+=("--moe_num_experts=$MOE_NUM_EXPERTS") + if [ "$MOE_NUM_SHARED" -ge 0 ]; then + MOE_FLAGS+=("--moe_num_shared_experts=$MOE_NUM_SHARED") + fi + if [ "$MOE_EXPERTS_PER_TOKEN" -ge 0 ]; then + MOE_FLAGS+=("--moe_experts_per_token=$MOE_EXPERTS_PER_TOKEN") + fi + if [ "$MOE_EXPERT_FFN_MULT" != "-1" ]; then + MOE_FLAGS+=("--moe_expert_ffn_mult=$MOE_EXPERT_FFN_MULT") + fi + if [ "$MOE_DENSE_LAYERS" -ge 0 ]; then + MOE_FLAGS+=("--dense_layers_before_moe=$MOE_DENSE_LAYERS") + fi + if [ "$MOE_GRANULARITY_TARGET" != "" ]; then + MOE_FLAGS+=("--moe_granularity_target=$MOE_GRANULARITY_TARGET") + fi + if [ "$MOE_ACTIVATION_DEN" -gt 0 ]; then + MOE_FLAGS+=("--moe_activation_denominator=$MOE_ACTIVATION_DEN") + fi + fi + python -m scripts.base_train \ --depth=$BASE_DEPTH \ --max_seq_len=$SEQ_LEN \ @@ -254,12 +284,12 @@ python -m scripts.tok_eval --kv_head_mult=$KV_HEAD_MULT \ --target_param_data_ratio=$TARGET_PARAM_DATA_RATIO \ --run="$WANDB_RUN" \ - --eval_every=$EVAL_STEPS \ --eval_tokens=$EVAL_TOKENS \ --core_metric_every=-1 \ --sample_every=-1 \ --checkpoint_every_steps=$BASE_CHECKPOINT_STEPS \ - $BASE_MODEL_TAG_FLAG + $BASE_MODEL_TAG_FLAG \ + ${MOE_FLAGS[@]} if [ "$WANDB_RUN" != "dummy" ]; then unset WANDB_RUN_ID diff --git a/nanochat/configurator.py b/nanochat/configurator.py index ec1b76d..63d197f 100644 --- a/nanochat/configurator.py +++ b/nanochat/configurator.py @@ -48,7 +48,10 @@ for arg in sys.argv[1:]: if globals()[key] is not None: attempt_type = type(attempt) default_type = type(globals()[key]) - assert attempt_type == default_type, f"Type mismatch: {attempt_type} != {default_type}" + if default_type is float and attempt_type in (int, float): + attempt = float(attempt) + else: + assert attempt_type == default_type, f"Type mismatch: {attempt_type} != {default_type}" # cross fingers print0(f"Overriding: {key} = {attempt}") globals()[key] = attempt diff --git a/nanochat/gpt.py b/nanochat/gpt.py index b640f1e..58c1596 100644 --- a/nanochat/gpt.py +++ b/nanochat/gpt.py @@ -31,6 +31,13 @@ class GPTConfig: n_head: int = 6 # number of query heads n_kv_head: int = 6 # number of key/value heads (MQA) n_embd: int = 768 + moe_num_experts: int = 0 + moe_num_shared_experts: int = 0 + moe_experts_per_token: int = 2 + moe_expert_ffn_mult: float = 4.0 + dense_layers_before_moe: int = 0 + moe_granularity_target: float = 12.0 + moe_activation_denominator: int = 32 def norm(x): @@ -123,11 +130,155 @@ class MLP(nn.Module): return x +class ExpertFFN(nn.Module): + def __init__(self, input_dim, ffn_mult): + super().__init__() + hidden_dim = max(1, int(round(ffn_mult * input_dim))) + self.c_fc = nn.Linear(input_dim, hidden_dim, bias=False) + self.c_proj = nn.Linear(hidden_dim, input_dim, bias=False) + + def forward(self, x): + x = self.c_fc(x) + x = F.relu(x).square() + x = self.c_proj(x) + return x + + +class MoEFeedForward(nn.Module): + def __init__(self, config): + super().__init__() + self.model_dim = config.n_embd + self.num_routed_experts = config.moe_num_experts + self.num_shared_experts = config.moe_num_shared_experts + self.top_k = max(1, min(config.moe_experts_per_token, self.num_routed_experts)) + self.ffn_mult = config.moe_expert_ffn_mult + assert self.num_routed_experts > 0, "MoE requires at least one routed expert" + assert self.top_k <= self.num_routed_experts, "experts_per_token must be <= num_experts" + self.router = nn.Linear(self.model_dim, self.num_routed_experts, bias=False) + self.hidden_dim = max(1, int(round(self.ffn_mult * self.model_dim))) + self.routed_w1 = nn.Parameter(torch.empty(self.num_routed_experts, self.hidden_dim, self.model_dim)) + self.routed_w2 = nn.Parameter(torch.empty(self.num_routed_experts, self.model_dim, self.hidden_dim)) + self.shared_experts = nn.ModuleList( + [ExpertFFN(self.model_dim, self.ffn_mult) for _ in range(self.num_shared_experts)] + ) + self.register_buffer("router_bias", torch.zeros(self.num_routed_experts, dtype=torch.float32)) + self.register_buffer("ema_load", torch.zeros(self.num_routed_experts, dtype=torch.float32)) + self.balance_strength = 0.001 + self.balance_decay = 0.99 + self.bias_clamp = 0.2 + self._last_stats = None + self.last_entropy = None + self.last_load = None + self._init_experts() + + def _balance_router(self, assignments, probs): + if not self.training: + self.last_load = None + self.last_entropy = None + self._last_stats = None + return + with torch.no_grad(): + load = torch.bincount(assignments, minlength=self.num_routed_experts).float() + tokens = max(1, assignments.numel()) + load = load / tokens + target = load.new_full((self.num_routed_experts,), 1.0 / self.num_routed_experts) + self.ema_load.mul_(self.balance_decay).add_((1 - self.balance_decay) * load) + balance = target - self.ema_load + self.router_bias.add_(self.balance_strength * balance) + self.router_bias.clamp_(-self.bias_clamp, self.bias_clamp) + entropy = (-probs * torch.log(probs + 1e-9)).sum(dim=-1).mean() + load_cpu = load.detach().cpu() + stats = { + "tokens": probs.size(0), + "load_std": load_cpu.std(unbiased=False).item(), + "load_max": load_cpu.max().item(), + "load_min": load_cpu.min().item(), + "load_active_frac": (load_cpu > 0).float().mean().item(), + "load_imbalance": torch.abs(load_cpu - (1.0 / self.num_routed_experts)).mean().item(), + } + self._last_stats = stats + self.last_load = load_cpu + self.last_entropy = entropy.item() + + def forward(self, x): + B, T, C = x.shape + x_flat = x.view(B * T, C) + router_logits = self.router(x_flat).float() + router_logits = router_logits + self.router_bias.to(router_logits.device) + probs = torch.softmax(router_logits, dim=-1) + topk_scores, topk_idx = torch.topk(probs, k=self.top_k, dim=-1) + topk_scores = topk_scores / torch.clamp(topk_scores.sum(dim=-1, keepdim=True), min=1e-9) + flat_assignments = topk_idx.reshape(-1) + self._balance_router(flat_assignments, probs) + + mixture = self._dispatch_batched(flat_assignments, x_flat, topk_scores) + + if self._last_stats is not None: + self._last_stats["gate_mean"] = topk_scores.mean().item() + self._last_stats["gate_std"] = topk_scores.std(unbiased=False).item() + confident = probs.max(dim=-1).values + self._last_stats["router_confidence"] = confident.mean().item() + self._last_stats["tokens"] = x_flat.size(0) + self._last_stats["entropy"] = self.last_entropy if self.last_entropy is not None else 0.0 + + if self.num_shared_experts > 0: + shared_sum = None + for shared in self.shared_experts: + out = shared(x_flat) + shared_sum = out if shared_sum is None else shared_sum + out + mixture = mixture + shared_sum / self.num_shared_experts + + return mixture.view(B, T, C) + + def _dispatch_batched(self, flat_assignments, x_flat, topk_scores): + num_tokens, model_dim = x_flat.shape + total = num_tokens * self.top_k + if total == 0: + return torch.zeros_like(x_flat) + + token_idx = torch.arange(num_tokens, device=x_flat.device).repeat_interleave(self.top_k) + selected_inputs = x_flat[token_idx] + expert_ids = flat_assignments + + w1 = torch.index_select(self.routed_w1, 0, expert_ids) + hidden = torch.bmm(w1, selected_inputs.unsqueeze(-1)).squeeze(-1) + hidden = F.relu(hidden).square() + + w2 = torch.index_select(self.routed_w2, 0, expert_ids) + routed = torch.bmm(w2, hidden.unsqueeze(-1)).squeeze(-1) + routed = routed.view(num_tokens, self.top_k, model_dim) + weights = topk_scores.unsqueeze(-1).to(routed.dtype) + return torch.sum(routed * weights, dim=1) + + def _init_experts(self): + self._init_linear(self.routed_w1, fan_in=self.model_dim, fan_out=self.hidden_dim) + self._init_linear(self.routed_w2, fan_in=self.hidden_dim, fan_out=self.model_dim) + + @staticmethod + def _init_linear(weight, fan_in, fan_out): + std = 1.0 / math.sqrt(fan_in) * min(1.0, math.sqrt(fan_out / fan_in)) + torch.nn.init.normal_(weight, mean=0.0, std=std) + + def get_stats(self): + if self._last_stats is None: + return None + return dict(self._last_stats) + + def count_parameters(self): + total = self.routed_w1.numel() + self.routed_w2.numel() + total += sum(p.numel() for expert in self.shared_experts for p in expert.parameters()) + return total + + class Block(nn.Module): def __init__(self, config, layer_idx): super().__init__() self.attn = CausalSelfAttention(config, layer_idx) - self.mlp = MLP(config) + use_moe = ( + config.moe_num_experts > 0 + and layer_idx >= config.dense_layers_before_moe + ) + self.mlp = MoEFeedForward(config) if use_moe else MLP(config) def forward(self, x, cos_sin, kv_cache): x = x + self.attn(norm(x), cos_sin, kv_cache) @@ -160,7 +311,13 @@ class GPT(nn.Module): torch.nn.init.zeros_(self.lm_head.weight) # zero out c_proj weights in all blocks for block in self.transformer.h: - torch.nn.init.zeros_(block.mlp.c_proj.weight) + if hasattr(block.mlp, "c_proj"): + torch.nn.init.zeros_(block.mlp.c_proj.weight) + elif isinstance(block.mlp, MoEFeedForward): + with torch.no_grad(): + block.mlp.routed_w2.zero_() + for expert in block.mlp.shared_experts: + torch.nn.init.zeros_(expert.c_proj.weight) torch.nn.init.zeros_(block.attn.c_proj.weight) # init the rotary embeddings head_dim = self.config.n_embd // self.config.n_head @@ -210,6 +367,53 @@ class GPT(nn.Module): num_flops_per_token = 6 * (nparams - nparams_embedding) + 12 * l * h * q * t return num_flops_per_token + def estimate_moe_active_flops(self): + if self.config.moe_num_experts <= 0: + dense_like = self.estimate_flops() + return dense_like, dense_like + total_params = sum(p.numel() for p in self.parameters()) + embed_params = self.transformer.wte.weight.numel() + linear_params = total_params - embed_params + moe_params = 0 + moe_layers = 0 + for block in self.transformer.h: + if isinstance(block.mlp, MoEFeedForward): + moe_layers += 1 + moe_params += block.mlp.count_parameters() + residual_linear = linear_params - moe_params + n = self.config.n_embd + dense_hidden = 4 * n + dense_mlp_params = 2 * n * dense_hidden + moe_hidden = max(1, int(round(self.config.moe_expert_ffn_mult * n))) + active_experts = max(1, self.config.moe_experts_per_token) + self.config.moe_num_shared_experts + active_mlp_params = 2 * n * moe_hidden * active_experts + dense_linear_equiv = residual_linear + moe_layers * dense_mlp_params + active_linear_equiv = residual_linear + moe_layers * active_mlp_params + l, h, q, t = self.config.n_layer, self.config.n_head, self.config.n_embd // self.config.n_head, self.config.sequence_len + attn_term = 12 * l * h * q * t + dense_ref = 6 * dense_linear_equiv + attn_term + active = 6 * active_linear_equiv + attn_term + return active, dense_ref + + def get_moe_stats(self): + stats = {} + aggregates = {} + counts = {} + for layer_idx, block in enumerate(self.transformer.h): + mlp = block.mlp + if isinstance(mlp, MoEFeedForward): + layer_stats = mlp.get_stats() + if not layer_stats: + continue + for key, value in layer_stats.items(): + stats[f"moe/layer{layer_idx}/{key}"] = value + aggregates.setdefault(key, 0.0) + counts[key] = counts.get(key, 0) + 1 + aggregates[key] += value + for key, total in aggregates.items(): + stats[f"moe/{key}_mean"] = total / counts[key] + return stats + def setup_optimizers(self, unembedding_lr=0.004, embedding_lr=0.2, matrix_lr=0.02, weight_decay=0.0): model_dim = self.config.n_embd ddp, rank, local_rank, world_size = get_dist_info() diff --git a/nanochat/muon.py b/nanochat/muon.py index d916103..bd4459f 100644 --- a/nanochat/muon.py +++ b/nanochat/muon.py @@ -72,7 +72,8 @@ class Muon(torch.optim.Optimizer): params: list[Tensor] = group["params"] for p in params: g = p.grad - assert g is not None + if g is None: + continue state = self.state[p] if "momentum_buffer" not in state: state["momentum_buffer"] = torch.zeros_like(g) @@ -128,8 +129,11 @@ class DistMuon(torch.optim.Optimizer): rank = dist.get_rank() world_size = dist.get_world_size() - # Ensure all grads exist - assert all(p.grad is not None for group in self.param_groups for p in group["params"]), "All params must have grads" + # Ensure grads exist (experts can be unused, so synthesize zeros when needed) + for group in self.param_groups: + for p in group["params"]: + if p.grad is None: + p.grad = torch.zeros_like(p) # Kick off all the reduce scatter operations to average up the gradients across all ranks all_reduce_futures = [] diff --git a/scripts/base_train.py b/scripts/base_train.py index 914df17..2f7ccef 100644 --- a/scripts/base_train.py +++ b/scripts/base_train.py @@ -38,6 +38,13 @@ device_type = "" # cuda|cpu|mps (empty => autodetect good device type default, i depth = 20 # the depth of the Transformer model to train, rest of the kwargs are derived max_seq_len = 2048 # max context length kv_head_mult = 1 # number of query heads that share a single key/value head (1 disables GQA) +moe_num_experts = 0 # routed experts per MoE layer (0 disables MoE) +moe_num_shared_experts = -1 # -1 => derive (defaults to 1 shared expert) +moe_experts_per_token = -1 # -1 => derive using Ling-style sparsity (≈1/32 active) +moe_expert_ffn_mult = -1.0 # -1 => derive from granularity target (defaults to 12) +dense_layers_before_moe = -1 # -1 => derive (≈10% of layers, min 1) before switching to MoE +moe_granularity_target = 12.0 # Ling guidance: target granularity per layer (2*d_model/d_expert) +moe_activation_denominator = 32 # derive top-k as num_experts / denominator (~3% activation) # Training horizon. Only one of these 3 will be used, in this order of precedence. num_iterations = -1 # explicit number of steps of the optimization (-1 = disable) target_flops = -1.0 # calculate num_iterations to reach target_flops. Useful for scaling laws experiments (-1 = disable) @@ -105,6 +112,33 @@ num_heads = max(1, (model_dim + 127) // 128) # head dim 128 (the division here i assert kv_head_mult >= 1, "kv_head_mult must be >= 1" assert num_heads % kv_head_mult == 0, f"num_heads ({num_heads}) must be divisible by kv_head_mult ({kv_head_mult})" num_kv_heads = max(1, num_heads // kv_head_mult) +activation_denom = max(1, int(round(moe_activation_denominator))) +granularity_target = moe_granularity_target if moe_granularity_target > 0 else 12.0 +auto_dense_layers = max(1, num_layers // 10) if moe_num_experts > 0 else num_layers +if dense_layers_before_moe < 0: + dense_layers_before_moe = auto_dense_layers +dense_layers_before_moe = max(0, min(dense_layers_before_moe, num_layers)) +if moe_num_experts <= 0: + dense_layers_before_moe = num_layers + +if moe_num_experts > 0: + derived_top_k = max(1, round(moe_num_experts / activation_denom)) + moe_experts_per_token = moe_experts_per_token if moe_experts_per_token > 0 else derived_top_k + moe_num_shared_experts = moe_num_shared_experts if moe_num_shared_experts >= 0 else 1 + if moe_expert_ffn_mult <= 0: + moe_expert_ffn_mult = max(1e-6, 2.0 / granularity_target) + assert moe_experts_per_token > 0, "moe_experts_per_token must be > 0 when MoE is enabled" + assert moe_num_experts >= moe_experts_per_token, "moe_num_experts must be >= moe_experts_per_token" + assert moe_num_shared_experts >= 0, "moe_num_shared_experts must be >= 0" + assert moe_expert_ffn_mult > 0, "moe_expert_ffn_mult must be > 0" + moe_activation_ratio = moe_experts_per_token / (moe_num_experts + moe_num_shared_experts) + moe_granularity_actual = 2.0 / moe_expert_ffn_mult +else: + moe_num_shared_experts = 0 + moe_experts_per_token = 0 + moe_expert_ffn_mult = 4.0 if moe_expert_ffn_mult <= 0 else moe_expert_ffn_mult + moe_activation_ratio = 0.0 + moe_granularity_actual = 0.0 def _resolve_checkpoint_tag(tag, run_name, depth_value): if tag: return tag @@ -119,6 +153,18 @@ print0(f"model_dim: {model_dim}") print0(f"kv_head_mult: {kv_head_mult}") print0(f"num_heads: {num_heads}") print0(f"num_kv_heads: {num_kv_heads}") +if moe_num_experts > 0: + print0( + "MoE config: experts=%d shared=%d topk=%d granularity=%.1f (mult=%.3f) sparsity=%.2f%% dense_preface=%d" % ( + moe_num_experts, + moe_num_shared_experts, + moe_experts_per_token, + moe_granularity_actual, + moe_expert_ffn_mult, + moe_activation_ratio * 100, + dense_layers_before_moe, + ) + ) print0(f"Checkpoint tag: {model_tag}") # Optimizer / data / training length related hyperparameters @@ -132,18 +178,63 @@ print0(f"Tokens / micro-batch: {world_tokens_per_fwdbwd:,}") print0(f"Total batch size {total_batch_size:,} => gradient accumulation steps: {grad_accum_steps}") # ----------------------------------------------------------------------------- # Initialize the Model -model_config_kwargs = dict(sequence_len=max_seq_len, vocab_size=vocab_size, n_layer=num_layers, n_head=num_heads, n_kv_head=num_kv_heads, n_embd=model_dim) +model_config_kwargs = dict( + sequence_len=max_seq_len, + vocab_size=vocab_size, + n_layer=num_layers, + n_head=num_heads, + n_kv_head=num_kv_heads, + n_embd=model_dim, + moe_num_experts=moe_num_experts, + moe_num_shared_experts=moe_num_shared_experts, + moe_experts_per_token=moe_experts_per_token, + moe_expert_ffn_mult=moe_expert_ffn_mult, + dense_layers_before_moe=dense_layers_before_moe, + moe_granularity_target=granularity_target, + moe_activation_denominator=activation_denom, +) with torch.device("meta"): model_config = GPTConfig(**model_config_kwargs) model = GPT(model_config) model.to_empty(device=device) model.init_weights() orig_model = model # original, uncompiled model, for saving raw model state_dict +dense_like_flops = model.estimate_flops() +active_flops_per_token, dense_ref_flops = model.estimate_moe_active_flops() +num_flops_per_token = active_flops_per_token model = torch.compile(model, dynamic=False) # TODO: dynamic True/False think through num_params = sum(p.numel() for p in model.parameters()) print0(f"Number of parameters: {num_params:,}") -num_flops_per_token = model.estimate_flops() -print0(f"Estimated FLOPs per token: {num_flops_per_token:e}") +print0(f"Estimated FLOPs per token (dense-like): {dense_like_flops:e}") +if active_flops_per_token != dense_like_flops: + print0(f"Estimated FLOPs per token (MoE active): {active_flops_per_token:e}") + print0(f"Estimated FLOPs per token (dense reference): {dense_ref_flops:e}") + +user_config.update({ + "moe_num_experts": moe_num_experts, + "moe_num_shared_experts": moe_num_shared_experts, + "moe_experts_per_token": moe_experts_per_token, + "moe_expert_ffn_mult": moe_expert_ffn_mult, + "moe_dense_layers": dense_layers_before_moe, + "moe_activation_ratio": moe_activation_ratio, + "moe_granularity_actual": moe_granularity_actual, + "flops_per_token_dense_like": dense_like_flops, + "flops_per_token_moe_active": active_flops_per_token, + "flops_per_token_dense_reference": dense_ref_flops, +}) +if not use_dummy_wandb: + wandb_run.config.update({ + "moe_num_experts": moe_num_experts, + "moe_num_shared_experts": moe_num_shared_experts, + "moe_experts_per_token": moe_experts_per_token, + "moe_expert_ffn_mult": moe_expert_ffn_mult, + "moe_dense_layers": dense_layers_before_moe, + "moe_activation_ratio": moe_activation_ratio, + "moe_granularity_actual": moe_granularity_actual, + "flops_per_token_dense_like": dense_like_flops, + "flops_per_token_moe_active": active_flops_per_token, + "flops_per_token_dense_reference": dense_ref_flops, + }, allow_val_change=True) # Calculate number of iterations. Either it is given, or from target flops, or from target data:param ratio (in that order) assert num_iterations > 0 or target_param_data_ratio > 0 or target_flops > 0 @@ -164,6 +255,9 @@ total_tokens = total_batch_size * num_iterations print0(f"Total number of training tokens: {total_tokens:,}") print0(f"Tokens : Params ratio: {total_batch_size * num_iterations / num_params:.2f}") # Chinchilla is ~20 print0(f"Total training FLOPs estimate: {num_flops_per_token * total_tokens:e}") +if eval_every <= 0: + eval_every = max(1, num_iterations // 100) + print0(f"Auto-setting eval_every to {eval_every} (~1% of training)") sequences_per_step = max(1, total_batch_size // max_seq_len) checkpoint_every_steps = int(checkpoint_every_steps) @@ -357,7 +451,7 @@ for step in range(num_iterations + 1): total_training_time += dt # only count the time after the first 10 steps print0(f"step {step:05d}/{num_iterations:05d} ({pct_done:.2f}%) | loss: {debiased_smooth_loss:.6f} | lrm: {lrm:.2f} | dt: {dt * 1000:.2f}ms | tok/sec (micro): {tok_per_sec:,} | tok/sec (global): {global_tok_per_sec:,} | mfu: {mfu:.2f} | total time: {total_training_time/60:.2f}m") if step % 100 == 0: - wandb_run.log({ + log_payload = { "step": step, "total_training_flops": flops_so_far, "total_training_time": total_training_time, @@ -369,7 +463,10 @@ for step in range(num_iterations + 1): "train/mfu": mfu, "train/total_tokens": total_tokens_seen, "train/total_sequences": total_sequences_seen, - }) + } + if hasattr(orig_model, "get_moe_stats"): + log_payload.update(orig_model.get_moe_stats()) + wandb_run.log(log_payload) # print a few more stats print0(f"Peak memory usage: {get_max_memory() / 1024 / 1024:.2f}MiB") @@ -382,7 +479,9 @@ get_report().log(section="Base model training", data=[ user_config, # CLI args { # stats about the training setup "Number of parameters": num_params, - "Number of FLOPs per token": f"{num_flops_per_token:e}", + "FLOPs per token (MoE active)": f"{num_flops_per_token:e}", + "FLOPs per token (dense-like)": f"{dense_like_flops:e}", + "FLOPs per token (dense reference)": f"{dense_ref_flops:e}", "Calculated number of iterations": num_iterations, "Number of training tokens": total_tokens, "Tokens : Params ratio": total_batch_size * num_iterations / num_params,