diff --git a/.gitignore b/.gitignore index 49f0686..e217989 100644 --- a/.gitignore +++ b/.gitignore @@ -22,4 +22,6 @@ hf-export/**/* lm_eval_pretrain/ benchmark.md d*.png -loss*.png \ No newline at end of file +loss*.png + +wandb/ \ No newline at end of file diff --git a/nanochat/checkpoint_manager.py b/nanochat/checkpoint_manager.py index 780f212..3cf000c 100644 --- a/nanochat/checkpoint_manager.py +++ b/nanochat/checkpoint_manager.py @@ -75,6 +75,11 @@ def build_model(checkpoint_dir, step, device, phase): - meta data saved during base model training """ assert phase in ["train", "eval"], f"Invalid phase: {phase}" + + # Allow callers to pass strings like "cuda" / "cuda:0". + if isinstance(device, str): + device = torch.device(device) + model_data, optimizer_data, meta_data = load_checkpoint(checkpoint_dir, step, device, load_optimizer=False) if device.type in {"cpu", "mps"}: # Convert bfloat16 tensors to float for CPU inference @@ -85,6 +90,25 @@ def build_model(checkpoint_dir, step, device, phase): # Hack: fix torch compile issue, which prepends all keys with _orig_mod. model_data = {k.removeprefix("_orig_mod."): v for k, v in model_data.items()} model_config_kwargs = meta_data["model_config"] + + # Backward-compat: older checkpoints may not store MoE fields (e.g. n_exp) in meta.json. + # If missing, infer from checkpoint tensor shapes to avoid load_state_dict size mismatches. + if "n_exp" not in model_config_kwargs: + inferred_n_exp = None + for k, v in model_data.items(): + if not isinstance(v, torch.Tensor): + continue + # Router gate weight: [n_exp, n_embd] + if k.endswith("mlp.router.w_g.weight") and v.ndim == 2: + inferred_n_exp = int(v.shape[0]) + break + # Expert FC: [n_exp, n_embd, 4*n_embd] + if k.endswith("mlp.experts.c_fc") and v.ndim == 3: + inferred_n_exp = int(v.shape[0]) + break + if inferred_n_exp is not None: + model_config_kwargs["n_exp"] = inferred_n_exp + log0(f"Inferred missing model_config.n_exp={inferred_n_exp} from checkpoint weights") log0(f"Building model with config: {model_config_kwargs}") model_config = GPTConfig(**model_config_kwargs) with torch.device("meta"): @@ -137,6 +161,9 @@ def find_last_step(checkpoint_dir): # convenience functions that take into account nanochat's directory structure def load_model_from_dir(checkpoints_dir, device, phase, model_tag=None, step=None): + # Normalize for callers that pass strings (e.g. "cuda"). + if isinstance(device, str): + device = torch.device(device) if model_tag is None: # guess the model tag by defaulting to the largest model model_tag = find_largest_model(checkpoints_dir) diff --git a/nanochat/engine.py b/nanochat/engine.py index d749d94..39a6fdf 100644 --- a/nanochat/engine.py +++ b/nanochat/engine.py @@ -190,6 +190,13 @@ class Engine: self.model = model self.tokenizer = tokenizer # needed for tool use + @staticmethod + def _unwrap_logits(model_out): + """Some model variants return (logits, loss). Engine only needs logits.""" + if isinstance(model_out, (tuple, list)): + return model_out[0] + return model_out + @torch.inference_mode() def generate(self, tokens, num_samples=1, max_tokens=None, temperature=1.0, top_k=None, seed=42): """Same as generate, but does single prefill and then clones the KV cache.""" @@ -216,7 +223,7 @@ class Engine: **kv_model_kwargs, ) ids = torch.tensor([tokens], dtype=torch.long, device=device) - logits = self.model.forward(ids, kv_cache=kv_cache_prefill) + logits = self._unwrap_logits(self.model.forward(ids, kv_cache=kv_cache_prefill)) logits = logits[:, -1, :] next_ids = sample_next_token(logits, rng, temperature, top_k) # (B, 1) sampled_tokens = next_ids[:, 0].tolist() @@ -253,7 +260,7 @@ class Engine: first_iteration = False else: # Forward the model and get the next token for each row - logits = self.model.forward(ids, kv_cache=kv_cache_decode) # (B, T, vocab_size) + logits = self._unwrap_logits(self.model.forward(ids, kv_cache=kv_cache_decode)) # (B, T, vocab_size) logits = logits[:, -1, :] # (B, vocab_size) at last time step next_ids = sample_next_token(logits, rng, temperature, top_k) # (B, 1) sampled_tokens = next_ids[:, 0].tolist() diff --git a/nanochat/gpt.py b/nanochat/gpt.py index ee9e77e..97b5803 100644 --- a/nanochat/gpt.py +++ b/nanochat/gpt.py @@ -93,7 +93,7 @@ class CausalSelfAttention(nn.Module): self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size)) .view(1, 1, config.block_size, config.block_size)) - def forward(self, x): + def forward(self, x, kv_cache=None, layer_idx: int | None = None): B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd) # calculate query, key, values for all heads in batch and move head forward to be the batch dim @@ -102,6 +102,30 @@ class CausalSelfAttention(nn.Module): q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) + # Optional KV-cache path for fast autoregressive decoding in Engine. + # We only use the cache for decode-time single-token forward (T==1). + # For prefill (T>1), we still populate the cache but use normal causal attention. + if kv_cache is not None: + assert layer_idx is not None, "layer_idx is required when using kv_cache" + # Insert the new keys/values and get a view of the full cache so far. + full_k, full_v = kv_cache.insert_kv(layer_idx, k, v) # (B, nh, T_total, hs) + if T == 1: + # Query is length-1; no causal mask needed (no future keys are present). + if self.flash: + y = torch.nn.functional.scaled_dot_product_attention( + q, full_k, full_v, + attn_mask=None, + dropout_p=0.0, + is_causal=False, + ) + else: + att = (q @ full_k.transpose(-2, -1)) * (1.0 / math.sqrt(full_k.size(-1))) + att = F.softmax(att, dim=-1) + y = att @ full_v + y = y.transpose(1, 2).contiguous().view(B, T, C) + y = self.resid_dropout(self.c_proj(y)) + return y + # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T) if self.flash: # efficient attention using Flash Attention CUDA kernels @@ -378,8 +402,8 @@ class Block(nn.Module): else: self.mlp = MLP(config) - def forward(self, x): - x = x + self.attn(self.ln_1(x)) + def forward(self, x, kv_cache=None, layer_idx: int | None = None): + x = x + self.attn(self.ln_1(x), kv_cache=kv_cache, layer_idx=layer_idx) x = x + self.mlp(self.ln_2(x)) return x @@ -443,6 +467,9 @@ class GPT(nn.Module): n_params -= self.transformer.wpe.weight.numel() return n_params + def get_device(self): + return self.transformer.wte.weight.device + @torch.no_grad() def _init_weights(self, module): # optionally use switch transformer-style initialization @@ -585,18 +612,32 @@ class GPT(nn.Module): return optimizer - def forward(self, idx, targets=None, return_full_logits: bool = False): + def forward( + self, + idx, + targets=None, + return_full_logits: bool = False, + loss_reduction: str = 'mean', + return_moe_losses: bool = False, + kv_cache=None, + ): device = idx.device b, t = idx.size() assert t <= self.config.block_size, f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}" - pos = torch.arange(0, t, dtype=torch.long, device=device) # shape (t) + + # When decoding with a KV cache, the position indices must continue from kv_cache.pos. + if kv_cache is not None and t == 1: + pos0 = int(kv_cache.get_pos()) + pos = torch.tensor([pos0], dtype=torch.long, device=device) + else: + pos = torch.arange(0, t, dtype=torch.long, device=device) # shape (t) # forward the GPT model itself tok_emb = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd) pos_emb = self.transformer.wpe(pos) # position embeddings of shape (t, n_embd) x = self.transformer.drop(tok_emb + pos_emb) - for block in self.transformer.h: - x = block(x) + for layer_idx, block in enumerate(self.transformer.h): + x = block(x, kv_cache=kv_cache, layer_idx=layer_idx) x = self.transformer.ln_f(x) if targets is not None or return_full_logits: @@ -604,15 +645,47 @@ class GPT(nn.Module): logits = self.lm_head(x) if targets is not None: # if we are given some desired targets also calculate the loss - loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1) + if loss_reduction not in {'mean', 'none'}: + raise ValueError(f"Unsupported loss_reduction: {loss_reduction}") + loss = F.cross_entropy( + logits.view(-1, logits.size(-1)), + targets.view(-1), + ignore_index=-1, + reduction=loss_reduction, + ) + if loss_reduction == 'none': + # Return token-level NLL (B, T). Do NOT add MoE auxiliary losses here. + loss = loss.view(b, t) - # add the auxiliary load balancing loss and router z loss to the main loss - if self.config.n_exp > 1 and self.config.use_aux_loss: - loss += self.config.aux_loss_weight * MANAGER.aggregate_aux_loss() - MANAGER.reset_aux_loss() - if self.config.n_exp > 1 and self.config.use_router_z_loss: - loss += self.config.router_z_loss_weight * MANAGER.aggregate_router_z_loss() - MANAGER.reset_router_z_loss() + moe_aux = None + moe_z = None + if return_moe_losses and self.config.n_exp > 1: + # Return the *weighted* contributions that would normally be added + # to the scalar training loss. Caller decides how to scale/accumulate. + if self.config.use_aux_loss: + moe_aux = self.config.aux_loss_weight * MANAGER.aggregate_aux_loss() + else: + moe_aux = torch.zeros((), device=loss.device, dtype=loss.dtype) + if self.config.use_router_z_loss: + moe_z = self.config.router_z_loss_weight * MANAGER.aggregate_router_z_loss() + else: + moe_z = torch.zeros((), device=loss.device, dtype=loss.dtype) + + # Always reset to avoid cross-step accumulation. + if self.config.n_exp > 1: + MANAGER.reset_aux_loss() + MANAGER.reset_router_z_loss() + + if return_moe_losses: + return logits, loss, moe_aux, moe_z + else: + # add the auxiliary load balancing loss and router z loss to the main loss + if self.config.n_exp > 1 and self.config.use_aux_loss: + loss += self.config.aux_loss_weight * MANAGER.aggregate_aux_loss() + MANAGER.reset_aux_loss() + if self.config.n_exp > 1 and self.config.use_router_z_loss: + loss += self.config.router_z_loss_weight * MANAGER.aggregate_router_z_loss() + MANAGER.reset_router_z_loss() else: loss = None if self.config.n_exp > 1: diff --git a/nanochat/loss_eval.py b/nanochat/loss_eval.py index b0fec1e..ca4c42d 100644 --- a/nanochat/loss_eval.py +++ b/nanochat/loss_eval.py @@ -5,6 +5,15 @@ import math import torch import torch.distributed as dist + +def _unpack_xy(batch): + # Supports loaders that yield (x, y) or (x, y, *extras) + if isinstance(batch, (tuple, list)): + if len(batch) < 2: + raise ValueError(f"Expected batch to have at least 2 items (x, y), got {len(batch)}") + return batch[0], batch[1] + raise ValueError(f"Expected batch to be a tuple/list (x, y, ...), got {type(batch)}") + @torch.no_grad() def evaluate_bpb(model, batches, steps, token_bytes=None): """ @@ -37,7 +46,7 @@ def evaluate_bpb(model, batches, steps, token_bytes=None): batch_iter = iter(batches) for _ in range(steps): - x, y = next(batch_iter) + x, y = _unpack_xy(next(batch_iter)) # Model returns (logits, loss) tuple logits, loss = model(x, y) # Calculate per-token loss from logits for bpb calculation diff --git a/nanochat/manager.py b/nanochat/manager.py index 9dc7896..e28109b 100644 --- a/nanochat/manager.py +++ b/nanochat/manager.py @@ -7,6 +7,10 @@ class MOEManager: def __init__(self): self.aux_loss = [] self.router_z_loss = [] + # Cache the most recently aggregated sums for logging/debugging. + # These values persist across reset_* calls. + self.last_aux_loss_sum = 0.0 + self.last_router_z_loss_sum = 0.0 def reset_aux_loss(self): self.aux_loss = [] @@ -21,10 +25,14 @@ class MOEManager: self.router_z_loss.append(loss) def aggregate_aux_loss(self): - return sum(self.aux_loss) + s = sum(self.aux_loss) + self.last_aux_loss_sum = s + return s def aggregate_router_z_loss(self): - return sum(self.router_z_loss) + s = sum(self.router_z_loss) + self.last_router_z_loss_sum = s + return s MANAGER = MOEManager() diff --git a/scripts/chat_eval.py b/scripts/chat_eval.py index cae2f0f..c0100a7 100644 --- a/scripts/chat_eval.py +++ b/scripts/chat_eval.py @@ -113,8 +113,9 @@ def run_categorical_eval(task_object, tokenizer, model, batch_size, max_problems prompt_ids = torch.tensor(padded_prompt_ids, dtype=torch.long, device=device) # Get the logits for the whole batch of conversations in parallel (efficiency win here) + # GPT.forward returns (logits, loss); we need full logits across the sequence for categorical eval. with torch.no_grad(): - logits = model(prompt_ids) # (B, T, V) + logits, _ = model(prompt_ids, return_full_logits=True) # (B, T, V) # Focus on the available answer on just the letters corresponding to choices # Note that this helps the evaluation a lot because it specifically narrows the focus to only the available letters diff --git a/scripts/chat_rl.py b/scripts/chat_rl.py index 8ebbe82..b6f90f9 100644 --- a/scripts/chat_rl.py +++ b/scripts/chat_rl.py @@ -23,7 +23,7 @@ import wandb import torch import torch.distributed as dist -from nanochat.common import compute_init, compute_cleanup, print0, get_base_dir, DummyWandb +from nanochat.common import compute_init, compute_cleanup, print0, get_base_dir, DummyWandb, autodetect_device_type from nanochat.checkpoint_manager import save_checkpoint, load_model from nanochat.engine import Engine from tasks.gsm8k import GSM8K @@ -32,6 +32,8 @@ from tasks.gsm8k import GSM8K run = "dummy" # wandb run name source = "sft" # mid|sft dtype = "bfloat16" +model_tag = None +device_type = "" # cuda|cpu|mps (empty => autodetect) device_batch_size = 8 # no forward pass will go above this to not OOM examples_per_step = 16 # in total and across all ranks (note: examples, not samples/completions!) num_samples = 16 # number of samples per example (/question) @@ -41,12 +43,21 @@ top_k = 50 # TODO: try None? unembedding_lr = 0.004 embedding_lr = 0.2 matrix_lr = 0.02 +learning_rate = 9e-5 +betas = (0.9, 0.95) weight_decay = 0.0 init_lr_frac = 0.05 num_epochs = 1 # how many epochs of gsm8k to train on save_every = 60 # every how many steps to save the model eval_every = 60 # every how many steps to evaluate the model for val pass@k eval_examples = 400 # number of examples used for evaluating pass@k + +# Debug knobs for MoE loss components (defaults preserve existing behavior) +disable_aux_loss = False +disable_router_z_loss = False +override_aux_loss_weight = -1.0 # <0 means do not override +override_router_z_loss_weight = -1.0 # <0 means do not override + # now allow CLI to override the settings via the configurator lol config_keys = [k for k,v in globals().items() if not k.startswith('_') and isinstance(v, (int, float, bool, str))] exec(open(os.path.join('nanochat', 'configurator.py')).read()) # overrides from command line or config file @@ -54,6 +65,7 @@ user_config = {k: globals()[k] for k in config_keys} # will be useful for loggin # ----------------------------------------------------------------------------- # Init compute/precision +device_type = autodetect_device_type() if device_type == "" else device_type ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init() master_process = ddp_rank == 0 # this process will do logging, checkpointing etc. dtype = torch.float32 if dtype == 'float32' else torch.bfloat16 @@ -64,9 +76,26 @@ use_dummy_wandb = run == "dummy" or not master_process wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat-rl", name=run, config=user_config) # Init model and tokenizer -model, tokenizer, meta = load_model(source, device, phase="eval") +model, tokenizer, meta = load_model(source, device, phase="eval", model_tag=model_tag) engine = Engine(model, tokenizer) # for sampling rollouts +# Optional overrides for MoE auxiliary losses (useful when total loss plateaus) +if hasattr(model, "config"): + if disable_aux_loss and getattr(model.config, "n_exp", 1) > 1: + print0("Disabling MoE aux loss for this midtraining run") + model.config.use_aux_loss = False + if disable_router_z_loss and getattr(model.config, "n_exp", 1) > 1: + print0("Disabling MoE router z loss for this midtraining run") + model.config.use_router_z_loss = False + if override_aux_loss_weight >= 0 and getattr(model.config, "n_exp", 1) > 1: + print0(f"Overriding MoE aux_loss_weight to {override_aux_loss_weight}") + model.config.aux_loss_weight = float(override_aux_loss_weight) + if override_router_z_loss_weight >= 0 and getattr(model.config, "n_exp", 1) > 1: + print0(f"Overriding MoE router_z_loss_weight to {override_router_z_loss_weight}") + model.config.router_z_loss_weight = float(override_router_z_loss_weight) + +print0(f"MoE training loss is configured to use aux_loss: {getattr(model.config, 'use_aux_loss', False)} with weight {getattr(model.config, 'aux_loss_weight', 0.0)}, router_z_loss: {getattr(model.config, 'use_router_z_loss', False)} with weight {getattr(model.config, 'router_z_loss_weight', 0.0)}") + # ----------------------------------------------------------------------------- # Rollout / sampling generator loop that yields batches of examples for training @@ -188,18 +217,29 @@ def run_gsm8k_eval(task, tokenizer, engine, # Training loop # Init the optimizer -optimizers = model.setup_optimizers( - unembedding_lr=unembedding_lr, - embedding_lr=embedding_lr, - matrix_lr=matrix_lr, +adamw_optimizer = model.configure_optimizers( weight_decay=weight_decay, + learning_rate=learning_rate, + betas=betas, + device_type=device_type, ) - -# Set the initial learning rate as a fraction of the base learning rate +optimizers = [adamw_optimizer] for opt in optimizers: for group in opt.param_groups: group["lr"] = group["lr"] * init_lr_frac - group["initial_lr"] = group["lr"] # save the initial learning so we can decay easily later + group["initial_lr"] = group["lr"] +# optimizers = model.setup_optimizers( +# unembedding_lr=unembedding_lr, +# embedding_lr=embedding_lr, +# matrix_lr=matrix_lr, +# weight_decay=weight_decay, +# ) + +# # Set the initial learning rate as a fraction of the base learning rate +# for opt in optimizers: +# for group in opt.param_groups: +# group["lr"] = group["lr"] * init_lr_frac +# group["initial_lr"] = group["lr"] # save the initial learning so we can decay easily later # Learning rate scheduler: simple rampdown to zero over num_steps def get_lr_multiplier(it): @@ -241,6 +281,13 @@ for step in range(num_steps): # Forward/Backward on rollouts over multiple examples in the dataset rewards_list = [] sequence_lengths = [] + + # Track loss components for logging + pg_loss_list = [] + aux_contrib_list = [] + z_contrib_list = [] + total_loss_list = [] + for example_step in range(examples_per_rank): # Get one batch corresponding to one example in the training dataset sequences_all, inputs_all, targets_all, rewards_all, advantages_all = next(batch_iterator) @@ -258,7 +305,12 @@ for step in range(num_steps): advantages = advantages_all[b0:b1] # Calculate log probabilities. Note that the loss calculates NLL = -logp, so we negate with autocast_ctx: - _, loss2d = model(inputs, targets, loss_reduction='none') + _, loss2d, aux_contrib, z_contrib = model( + inputs, + targets, + loss_reduction='none', + return_moe_losses=True, + ) logp = -loss2d.view_as(inputs) # (B, T) # Calculate the PG objective. Note that ignore_index=-1 ensures that invalid tokens have loss 0. pg_obj = (logp * advantages.unsqueeze(-1)).sum() @@ -267,9 +319,28 @@ for step in range(num_steps): pg_obj = pg_obj / (num_valid * num_passes * examples_per_rank) # Note, there is no need to add PPO ratio+clip because we are on policy # Finally, formulate the loss that we want to minimize (instead of objective we wish to maximize) - loss = -pg_obj + pg_loss = -pg_obj + + # Add MoE routing regularizers as separate scalar terms. + # We scale by num_passes/examples_per_rank to match the pg_loss normalization. + moe_scale = 1.0 / (num_passes * examples_per_rank) + aux_term = (aux_contrib if aux_contrib is not None else 0.0) * moe_scale + z_term = (z_contrib if z_contrib is not None else 0.0) * moe_scale + loss = pg_loss + aux_term + z_term + loss.backward() - print0(f"Step {step}/{num_steps} | Example step {example_step} | Pass {pass_idx} | loss: {loss.item():.6f} | Average reward: {rewards.mean().item()}") + + # For logging (detach to avoid autograd sync issues) + pg_loss_list.append(float(pg_loss.detach().item())) + aux_contrib_list.append(float((aux_term.detach().item()) if torch.is_tensor(aux_term) else aux_term)) + z_contrib_list.append(float((z_term.detach().item()) if torch.is_tensor(z_term) else z_term)) + total_loss_list.append(float(loss.detach().item())) + + print0( + f"Step {step}/{num_steps} | Example step {example_step} | Pass {pass_idx} | " + f"loss: {loss.item():.6f} | pg: {pg_loss.item():.6f} | aux: {float(aux_term):.6f} | z: {float(z_term):.6f} | " + f"Average reward: {rewards.mean().item()}" + ) # For logging rewards_list.append(rewards_all.mean().item()) sequence_lengths.extend(len(seq) for seq in sequences_all) @@ -277,6 +348,12 @@ for step in range(num_steps): # A bunch of logging for how the rollouts went this step mean_reward = sum(rewards_list) / len(rewards_list) mean_sequence_length = sum(sequence_lengths) / len(sequence_lengths) + + mean_pg_loss = sum(pg_loss_list) / max(len(pg_loss_list), 1) + mean_aux = sum(aux_contrib_list) / max(len(aux_contrib_list), 1) + mean_z = sum(z_contrib_list) / max(len(z_contrib_list), 1) + mean_total_loss = sum(total_loss_list) / max(len(total_loss_list), 1) + if ddp: # aggregate across ranks mean_reward_tensor = torch.tensor(mean_reward, dtype=torch.float, device=device) mean_sequence_length_tensor = torch.tensor(mean_sequence_length, dtype=torch.float, device=device) @@ -284,11 +361,28 @@ for step in range(num_steps): dist.all_reduce(mean_sequence_length_tensor, op=dist.ReduceOp.AVG) mean_reward = mean_reward_tensor.item() mean_sequence_length = mean_sequence_length_tensor.item() + + mean_pg_loss_tensor = torch.tensor(mean_pg_loss, dtype=torch.float, device=device) + mean_aux_tensor = torch.tensor(mean_aux, dtype=torch.float, device=device) + mean_z_tensor = torch.tensor(mean_z, dtype=torch.float, device=device) + mean_total_loss_tensor = torch.tensor(mean_total_loss, dtype=torch.float, device=device) + dist.all_reduce(mean_pg_loss_tensor, op=dist.ReduceOp.AVG) + dist.all_reduce(mean_aux_tensor, op=dist.ReduceOp.AVG) + dist.all_reduce(mean_z_tensor, op=dist.ReduceOp.AVG) + dist.all_reduce(mean_total_loss_tensor, op=dist.ReduceOp.AVG) + mean_pg_loss = mean_pg_loss_tensor.item() + mean_aux = mean_aux_tensor.item() + mean_z = mean_z_tensor.item() + mean_total_loss = mean_total_loss_tensor.item() print0(f"Step {step}/{num_steps} | Average reward: {mean_reward} | Average sequence length: {mean_sequence_length:.2f}") wandb_run.log({ "step": step, "reward": mean_reward, "sequence_length": mean_sequence_length, + "train/pg_loss": mean_pg_loss, + "train/aux_loss_contrib": mean_aux, + "train/router_z_loss_contrib": mean_z, + "train/total_loss": mean_total_loss, }) # Update the model parameters @@ -308,8 +402,17 @@ for step in range(num_steps): if master_process and ((step > 0 and step % save_every == 0) or step == num_steps - 1): base_dir = get_base_dir() depth = model.config.n_layer - model_tag = f"d{depth}" # base the model tag on the depth of the base model - checkpoint_dir = os.path.join(base_dir, "chatrl_checkpoints", model_tag) + # model_tag = f"d{depth}" # base the model tag on the depth of the base model + if disable_aux_loss: + aux_tag = "noaux" + else: + aux_tag = "aux" + if disable_router_z_loss: + z_tag = "noz" + else: + z_tag = "z" + output_dirname = f"d{depth}_{aux_tag}_{z_tag}_lr{learning_rate}_model{model_tag}" + checkpoint_dir = os.path.join(base_dir, "chatrl_checkpoints", output_dirname) model_config_kwargs = model.config.__dict__ # slightly naughty, abusing the simplicity of GPTConfig, TODO nicer save_checkpoint( checkpoint_dir, diff --git a/scripts/chat_sft.py b/scripts/chat_sft.py index bbeb1f9..dc3a354 100644 --- a/scripts/chat_sft.py +++ b/scripts/chat_sft.py @@ -50,11 +50,20 @@ embedding_lr = 0.2 matrix_lr = 0.02 weight_decay = 0.0 init_lr_frac = 0.02 +learning_rate = 9e-5 # 5e-5 for d6 model, 9e-5 for d12 model +betas = (0.9, 0.95) # evaluation and logging there of eval_every = 100 eval_steps = 100 eval_metrics_every = 200 eval_metrics_max_problems = 1024 + +# Debug knobs for MoE loss components (defaults preserve existing behavior) +disable_aux_loss = False +disable_router_z_loss = False +override_aux_loss_weight = -1.0 # <0 means do not override +override_router_z_loss_weight = -1.0 # <0 means do not override + # now allow CLI to override the settings via the configurator lol config_keys = [k for k,v in globals().items() if not k.startswith('_') and isinstance(v, (int, float, bool, str))] exec(open(os.path.join('nanochat', 'configurator.py')).read()) # overrides from command line or config file @@ -78,6 +87,23 @@ orig_model = model # original, uncompiled model # model = torch.compile(model, dynamic=True) # doesn't work super well because of variable lengths of inputs engine = Engine(model, tokenizer) # will be used for inline model evaluation only +# Optional overrides for MoE auxiliary losses (useful when total loss plateaus) +if hasattr(model, "config"): + if disable_aux_loss and getattr(model.config, "n_exp", 1) > 1: + print0("Disabling MoE aux loss for this midtraining run") + model.config.use_aux_loss = False + if disable_router_z_loss and getattr(model.config, "n_exp", 1) > 1: + print0("Disabling MoE router z loss for this midtraining run") + model.config.use_router_z_loss = False + if override_aux_loss_weight >= 0 and getattr(model.config, "n_exp", 1) > 1: + print0(f"Overriding MoE aux_loss_weight to {override_aux_loss_weight}") + model.config.aux_loss_weight = float(override_aux_loss_weight) + if override_router_z_loss_weight >= 0 and getattr(model.config, "n_exp", 1) > 1: + print0(f"Overriding MoE router_z_loss_weight to {override_router_z_loss_weight}") + model.config.router_z_loss_weight = float(override_router_z_loss_weight) + +print0(f"MoE training loss is configured to use aux_loss: {getattr(model.config, 'use_aux_loss', False)} with weight {getattr(model.config, 'aux_loss_weight', 0.0)}, router_z_loss: {getattr(model.config, 'use_router_z_loss', False)} with weight {getattr(model.config, 'router_z_loss_weight', 0.0)}") + # ----------------------------------------------------------------------------- # Task data mixture we'll train on identity_conversations_filepath = os.path.join(get_base_dir(), "identity_conversations.jsonl") @@ -97,6 +123,10 @@ val_ds = SmolTalk(split="test") # general conversations, 24K rows (though we don def sft_data_generator(dataset, batch_size): pad_token_id = tokenizer.encode_special("<|assistant_end|>") # use <|assistant_end|> as the pad token is ok, these positions are masked in the loss + # Ensure we never feed sequences longer than the model block size. + # render_conversation returns a sequence that includes a BOS token, and we then create + # inputs/targets by shifting by 1, so cap to block_size+1 tokens. + max_tokens = int(getattr(model.config, "block_size", getattr(model.config, "sequence_len", 1024))) + 1 # prepares a list of tokenized conversations into a batch and yields def collate_and_yield(batch): nrows = len(batch) @@ -121,7 +151,7 @@ def sft_data_generator(dataset, batch_size): while True: for i in range(ddp_rank, len(dataset), ddp_world_size): doc = dataset[i] - ids, mask = tokenizer.render_conversation(doc) + ids, mask = tokenizer.render_conversation(doc, max_tokens=max_tokens) batch.append((ids, mask)) if len(batch) == batch_size: yield collate_and_yield(batch) @@ -144,18 +174,30 @@ build_val_loader = lambda: sft_data_generator(val_ds, batch_size=device_batch_si # ----------------------------------------------------------------------------- # Initialize the Optimizer - -optimizers = model.setup_optimizers( - unembedding_lr=unembedding_lr, - embedding_lr=embedding_lr, - matrix_lr=matrix_lr, +# Initialize the Optimizer (AdamW for all parameters) - BEFORE DDP wrapping (matching nanoMoE) +adamw_optimizer = model.configure_optimizers( weight_decay=weight_decay, + learning_rate=learning_rate, + betas=betas, + device_type=device_type, ) -# Set the initial learning rate as a fraction of the base learning rate +optimizers = [adamw_optimizer] for opt in optimizers: for group in opt.param_groups: group["lr"] = group["lr"] * init_lr_frac - group["initial_lr"] = group["lr"] # save the initial learning so we can decay easily later + group["initial_lr"] = group["lr"] + +# optimizers = model.setup_optimizers( +# unembedding_lr=unembedding_lr, +# embedding_lr=embedding_lr, +# matrix_lr=matrix_lr, +# weight_decay=weight_decay, +# ) +# # Set the initial learning rate as a fraction of the base learning rate +# for opt in optimizers: +# for group in opt.param_groups: +# group["lr"] = group["lr"] * init_lr_frac +# group["initial_lr"] = group["lr"] # save the initial learning so we can decay easily later # ----------------------------------------------------------------------------- # Training loop @@ -179,7 +221,7 @@ for step in range(num_iterations): for _ in range(eval_steps): val_inputs, val_targets = next(val_iter) with torch.no_grad(), autocast_ctx: - loss = model(val_inputs, val_targets) + _, loss = model(val_inputs, val_targets) losses.append(loss) val_loss = torch.stack(losses).mean() # average over eval_steps if ddp: @@ -216,7 +258,7 @@ for step in range(num_iterations): for micro_step in range(grad_accum_steps): train_inputs, train_targets = next(train_iter) with autocast_ctx: - loss = model(train_inputs, train_targets) + _, loss = model(train_inputs, train_targets) train_loss = loss.detach() # for logging loss = loss / grad_accum_steps # each .backward() is a grad sum => normalize loss here loss.backward() # accumulate the gradient @@ -251,8 +293,18 @@ for step in range(num_iterations): if master_process: base_dir = get_base_dir() depth = model.config.n_layer - model_tag = f"d{depth}" # base the model tag on the depth of the base model - checkpoint_dir = os.path.join(base_dir, "chatsft_checkpoints", model_tag) + # output-dirname = f"d{depth}" # base the model tag on the depth of the base model + if disable_aux_loss: + aux_tag = "noaux" + else: + aux_tag = "aux" + if disable_router_z_loss: + z_tag = "noz" + else: + z_tag = "z" + # output_dirname = f"d{depth}_{aux_tag}_{z_tag}_lr{learning_rate}_model{model_tag}" + output_dirname = f"d{depth}_lr{learning_rate}_init{init_lr_frac}_model{model_tag}" + checkpoint_dir = os.path.join(base_dir, "chatsft_checkpoints", output_dirname) model_config_kwargs = model.config.__dict__ # slightly naughty, abusing the simplicity of GPTConfig, TODO nicer save_checkpoint( checkpoint_dir, diff --git a/scripts/mid_train.py b/scripts/mid_train.py index 6c2b82f..3107970 100644 --- a/scripts/mid_train.py +++ b/scripts/mid_train.py @@ -15,12 +15,14 @@ os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" import time import wandb import torch +import torch.nn.functional as F from contextlib import nullcontext from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, get_base_dir, autodetect_device_type from nanochat.tokenizer import get_token_bytes from nanochat.checkpoint_manager import save_checkpoint from nanochat.loss_eval import evaluate_bpb from nanochat.checkpoint_manager import load_model +from nanochat.manager import MANAGER import torch.distributed as dist from tasks.common import TaskMixture @@ -37,13 +39,24 @@ model_tag = None # model tag to load the model from (base model or midtrained mo step = None # step to load the model from (base model or midtrained model) dtype = "bfloat16" num_iterations = -1 # explicit number of steps of the optimization (-1 = disable) +num_epochs = 1 # number of full passes over the midtraining dataset (only used if num_iterations < 0) max_seq_len = 2048 device_batch_size = 32 unembedding_lr = 0.004 embedding_lr = 0.2 matrix_lr = 0.02 init_lr_frac = 1.0 # initial learning rate is this fraction of the base learning rate +learning_rate = 3e-4 +betas = (0.9, 0.95) weight_decay = 0.0 +warmup_ratio = 0.0 # LR warmup (ratio of total training progress in [0, 1]). 0 disables warmup. + +# Debug knobs for MoE loss components (defaults preserve existing behavior) +disable_aux_loss = False +disable_router_z_loss = False +override_aux_loss_weight = -1.0 # <0 means do not override +override_router_z_loss_weight = -1.0 # <0 means do not override + eval_every = 150 # -1 = disable eval_tokens = 20*524288 total_batch_size = 524288 @@ -67,13 +80,30 @@ wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat-mi # Load the model and tokenizer model, tokenizer, meta = load_model("base", device, phase="train", model_tag=model_tag, step=step) + +# Optional overrides for MoE auxiliary losses (useful when total loss plateaus) +if hasattr(model, "config"): + if disable_aux_loss and getattr(model.config, "n_exp", 1) > 1: + print0("Disabling MoE aux loss for this midtraining run") + model.config.use_aux_loss = False + if disable_router_z_loss and getattr(model.config, "n_exp", 1) > 1: + print0("Disabling MoE router z loss for this midtraining run") + model.config.use_router_z_loss = False + if override_aux_loss_weight >= 0 and getattr(model.config, "n_exp", 1) > 1: + print0(f"Overriding MoE aux_loss_weight to {override_aux_loss_weight}") + model.config.aux_loss_weight = float(override_aux_loss_weight) + if override_router_z_loss_weight >= 0 and getattr(model.config, "n_exp", 1) > 1: + print0(f"Overriding MoE router_z_loss_weight to {override_router_z_loss_weight}") + model.config.router_z_loss_weight = float(override_router_z_loss_weight) + +print0(f"MoE training loss is configured to use aux_loss: {getattr(model.config, 'use_aux_loss', False)} with weight {getattr(model.config, 'aux_loss_weight', 0.0)}, router_z_loss: {getattr(model.config, 'use_router_z_loss', False)} with weight {getattr(model.config, 'router_z_loss_weight', 0.0)}") pretrain_batch_size = meta.get("device_batch_size", None) if pretrain_batch_size is not None and device_batch_size > pretrain_batch_size: print0(f"FOOTGUN WARNING: base model training used device_batch_size {pretrain_batch_size}, did you pass in a good --device_batch_size to this script?") orig_model = model model = torch.compile(model, dynamic=False) depth = model.config.n_layer -num_flops_per_token = model.estimate_flops() +# num_flops_per_token = model.estimate_flops(max_seq_len) tokens_per_fwdbwd = device_batch_size * max_seq_len # tokens per iteration for a single rank world_tokens_per_fwdbwd = tokens_per_fwdbwd * ddp_world_size # total tokens per iteration for all ranks assert total_batch_size % world_tokens_per_fwdbwd == 0 @@ -83,14 +113,26 @@ print0(f"Tokens / micro-batch: {world_tokens_per_fwdbwd:,}") print0(f"Total batch size {total_batch_size:,} => gradient accumulation steps: {grad_accum_steps}") token_bytes = get_token_bytes(device=device) -# Initialize the Optimizer (Muon for Linear layers, AdamW for embedding and lm_head) -optimizers = model.setup_optimizers(unembedding_lr=unembedding_lr, embedding_lr=embedding_lr, matrix_lr=matrix_lr, weight_decay=weight_decay) -adamw_optimizer, muon_optimizer = optimizers -# Override the initial learning rate as a fraction of the base learning rate -for opt in optimizers: - for group in opt.param_groups: - group["lr"] = group["lr"] * init_lr_frac - group["initial_lr"] = group["lr"] # save the initial learning so we can decay easily later +# Sanity print: tokenizer ids must fit inside model vocab (esp. when vocab_size=50304 padded GPT-2) +print0(f"Model vocab_size: {model.config.vocab_size}") +print0(f"Tokenizer vocab_size: {tokenizer.get_vocab_size()}") + +# Initialize the Optimizer (AdamW for all parameters) - BEFORE DDP wrapping (matching nanoMoE) +adamw_optimizer = model.configure_optimizers( + weight_decay=weight_decay, + learning_rate=learning_rate, + betas=betas, + device_type=device_type, +) +optimizers = [adamw_optimizer] +# # Initialize the Optimizer (Muon for Linear layers, AdamW for embedding and lm_head) +# optimizers = model.setup_optimizers(unembedding_lr=unembedding_lr, embedding_lr=embedding_lr, matrix_lr=matrix_lr, weight_decay=weight_decay) +# adamw_optimizer, muon_optimizer = optimizers +# # Override the initial learning rate as a fraction of the base learning rate +# for opt in optimizers: +# for group in opt.param_groups: +# group["lr"] = group["lr"] * init_lr_frac +# group["initial_lr"] = group["lr"] # save the initial learning so we can decay easily later # Midtraining data mixture and DataLoader base_dir = get_base_dir() @@ -114,14 +156,17 @@ val_dataset = TaskMixture([ # these two global variables and update them from within the data generator. last_step = False # we will toggle this to True when we reach the end of the dataset approx_progress = 0.0 # will go from 0 to 1 over the course of the epoch +current_epoch = 1 # will go from 1 to num_epochs def mid_data_generator(split): - global last_step, approx_progress + global last_step, approx_progress, current_epoch assert split in {"train", "val"}, "split must be 'train' or 'val'" dataset = train_dataset if split == "train" else val_dataset dataset_size = len(dataset) assert dataset_size > 0 needed_tokens = device_batch_size * max_seq_len + 1 # to form one training batch of inputs,targets token_buffer = deque() + # A lightweight resumable state dict (similar spirit to base_train.py) + dataloader_state_dict = {"split": split} # CUDA supports memory pinning for faster transfers between CPU and GPU: scratch = torch.empty(needed_tokens, dtype=torch.int64, pin_memory=(device_type == "cuda")) cursor = ddp_rank # increments by ddp_world_size each time, so each rank processes unique documents @@ -136,7 +181,13 @@ def mid_data_generator(split): if cursor >= dataset_size: cursor -= dataset_size # wrap around for another epoch if split == "train": - last_step = True # toggle last_step to True, which will terminate the training loop + # Track epochs (unless num_iterations explicitly caps steps) + if num_iterations < 0: + current_epoch += 1 + if current_epoch > num_epochs: + last_step = True # terminate after requested epochs + else: + last_step = True # legacy behavior when num_iterations is set elsewhere # Stopping condition to respect num_iterations, if given it += 1 if num_iterations > 0 and it >= num_iterations: @@ -146,14 +197,39 @@ def mid_data_generator(split): scratch[i] = token_buffer.popleft() inputs_cpu = scratch[:-1].to(dtype=torch.int32) targets_cpu = scratch[1:] + + # Early token-id range check on CPU to avoid opaque torch.compile CUDA OOB asserts. + # Only do this for a few batches to keep overhead minimal. + if it <= 5: + min_id = int(inputs_cpu.min().item()) + max_id = int(inputs_cpu.max().item()) + vocab_limit = int(model.config.vocab_size) + if not (0 <= min_id and max_id < vocab_limit): + raise ValueError( + f"Token id out of range: min={min_id}, max={max_id}, expected within [0, {vocab_limit}). " + f"Tokenizer vocab_size={int(tokenizer.get_vocab_size())}. " + "This usually means the tokenizer used for midtraining doesn't match the model vocab." + ) + inputs = inputs_cpu.view(device_batch_size, max_seq_len).to(device=device, dtype=torch.int32, non_blocking=True) targets = targets_cpu.view(device_batch_size, max_seq_len).to(device=device, dtype=torch.int64, non_blocking=True) if split == "train": if num_iterations > 0: approx_progress = it / num_iterations # calculate progress from the max number of iterations else: - approx_progress = cursor / dataset_size # approximate progress as a fraction of the dataset - yield inputs, targets + # progress across epochs, in [0, 1] + denom = max(float(num_epochs), 1.0) + approx_progress = min(((current_epoch - 1) + (cursor / dataset_size)) / denom, 1.0) + dataloader_state_dict.update({ + "cursor": int(cursor), + "it": int(it), + "current_epoch": int(current_epoch), + "last_step": bool(last_step), + "approx_progress": float(approx_progress), + # Keep the remaining buffered tokens for exact resume semantics. + "token_buffer": list(token_buffer), + }) + yield inputs, targets, dataloader_state_dict train_loader = mid_data_generator("train") build_val_loader = lambda: mid_data_generator("val") @@ -161,8 +237,15 @@ progress = 0 # will go from 0 to 1 over the course of the epoch # Learning rate scheduler def get_lr_multiplier(progress): - # first 80% of training: no decay, then linearly ramp down to 0. - return 1 if progress < 0.8 else 1 - (progress - 0.8) / 0.2 + # Warmup: linearly ramp from 0 -> 1 over the first `warmup_ratio` portion of training. + if warmup_ratio and warmup_ratio > 0: + warmup_mult = min(max(progress / warmup_ratio, 0.0), 1.0) + else: + warmup_mult = 1.0 + + # Decay: first 80% of training no decay, then linearly ramp down to 0. + decay_mult = 1 if progress < 0.8 else 1 - (progress - 0.8) / 0.2 + return warmup_mult * decay_mult # Momentum scheduler for Muon optimizer def get_muon_momentum(it): @@ -172,14 +255,16 @@ def get_muon_momentum(it): # ----------------------------------------------------------------------------- # Training loop -x, y = next(train_loader) # prefetch the very first batch of data +x, y, dataloader_state_dict = next(train_loader) # prefetch the very first batch of data min_val_bpb = float("inf") smooth_train_loss = 0 # EMA of training loss +smooth_train_ce_loss = 0 # EMA of CE loss ema_beta = 0.9 # EMA decay factor total_training_time = 0 # total wall-clock time of training +val_bpb = None # populated during evaluation (keep defined for checkpoint metadata) step = 0 while True: - flops_so_far = num_flops_per_token * total_batch_size * step + # flops_so_far = num_flops_per_token * total_batch_size * step # Synchronize last_step across all ranks to avoid hangs in the distributed setting if ddp: @@ -199,7 +284,7 @@ while True: min_val_bpb = val_bpb wandb_run.log({ "step": step, - "total_training_flops": flops_so_far, + # "total_training_flops": flops_so_far, "total_training_time": total_training_time, "val/bpb": val_bpb, }) @@ -207,25 +292,72 @@ while True: # save checkpoint at the end of the run (only on master process) if master_process and last_step and not dry_run: - output_dirname = f"d{depth}" # e.g. d12 + # output_dirname = f"d{depth}" # e.g. d12 + if disable_aux_loss: + aux_tag = "noaux" + else: + aux_tag = "aux" + if disable_router_z_loss: + z_tag = "noz" + else: + z_tag = "z" + # output_dirname = f"d{depth}_{aux_tag}_{z_tag}_lr{learning_rate}_model{model_tag}" + output_dirname = f"d{depth}_lr{learning_rate}_model{model_tag}" checkpoint_dir = os.path.join(base_dir, "mid_checkpoints", output_dirname) + + # Save metadata in the same shape as base_train.py for consistency. + model_config_for_save = {} + for k in [ + # Core GPT config + "block_size", + "vocab_size", + "n_layer", + "n_head", + "n_kv_head", + "n_embd", + "dropout", + "bias", + # MoE config (if present) + "n_exp", + "top_k", + "use_aux_loss", + "use_router_z_loss", + "use_noisy_top_k", + "aux_loss_weight", + "router_z_loss_weight", + "train_capacity", + "eval_capacity", + "min_capacity", + "stride", + "use_switch_tfm_init", + "switch_tfm_init_scale", + "router_use_full_prec", + ]: + if hasattr(orig_model.config, k): + v = getattr(orig_model.config, k) + if isinstance(v, (int, float, bool, str)): + model_config_for_save[k] = v + save_checkpoint( checkpoint_dir, step, orig_model.state_dict(), - [opt.state_dict() for opt in optimizers], # TODO: make sure saving across ranks is done correctly + adamw_optimizer.state_dict(), # TODO: make sure saving across ranks is done correctly { "step": step, - "val_bpb": val_bpb, # loss at last step - "model_config": { - "sequence_len": max_seq_len, - "vocab_size": tokenizer.get_vocab_size(), - "n_layer": depth, - "n_head": model.config.n_head, - "n_kv_head": model.config.n_kv_head, - "n_embd": model.config.n_embd, - }, + "model_config": model_config_for_save, "user_config": user_config, # inputs to the training script + "device_batch_size": device_batch_size, + "max_seq_len": max_seq_len, + "loop_state": { + "min_val_bpb": min_val_bpb, + "smooth_train_loss": smooth_train_loss, + "smooth_train_ce_loss": smooth_train_ce_loss, + "total_training_time": total_training_time, + "progress": progress, + "current_epoch": int(current_epoch), + }, + "dataloader_state_dict": dataloader_state_dict, } ) @@ -237,24 +369,52 @@ while True: # evaluate the gradient synchronize() t0 = time.time() + total_loss_accum = 0.0 + ce_loss_accum = 0.0 + aux_loss_contrib_accum = 0.0 + router_z_loss_contrib_accum = 0.0 for micro_step in range(grad_accum_steps): with autocast_ctx: - loss = model(x, y) - train_loss = loss.detach() # for logging - loss = loss / grad_accum_steps # each .backward() is a grad sum => normalize loss here + logits, total_loss = model(x, y) # returns (logits, loss) + # Cross-entropy only (language modeling objective) + ce_loss = F.cross_entropy( + logits.view(-1, logits.size(-1)), + y.view(-1), + ignore_index=-1, + ) + # Cache logging values (average across micro-steps) + total_loss_accum += float(total_loss.detach().item()) + ce_loss_accum += float(ce_loss.detach().item()) + aux_sum = getattr(MANAGER, "last_aux_loss_sum", 0.0) + z_sum = getattr(MANAGER, "last_router_z_loss_sum", 0.0) + # Convert sums into the *weighted* contribution that is actually added to total_loss + if getattr(model.config, "n_exp", 1) > 1 and getattr(model.config, "use_aux_loss", False): + if torch.is_tensor(aux_sum): + aux_loss_contrib_accum += float(getattr(model.config, "aux_loss_weight", 0.0)) * aux_sum.detach().item() + else: + aux_loss_contrib_accum += float(getattr(model.config, "aux_loss_weight", 0.0)) * float(aux_sum) + if getattr(model.config, "n_exp", 1) > 1 and getattr(model.config, "use_router_z_loss", False): + if torch.is_tensor(z_sum): + router_z_loss_contrib_accum += float(getattr(model.config, "router_z_loss_weight", 0.0)) * z_sum.detach().item() + else: + router_z_loss_contrib_accum += float(getattr(model.config, "router_z_loss_weight", 0.0)) * float(z_sum) + + loss = total_loss / grad_accum_steps # each .backward() is a grad sum => normalize loss here loss.backward() - x, y = next(train_loader) # prefetch the next batch while the GPU is busy with forward/backward + x, y, dataloader_state_dict = next(train_loader) # prefetch the next batch while the GPU is busy with forward/backward progress = max(progress, approx_progress) # only increase progress monotonically - # step the optimizers + + # micro-step averages for logging + train_total_loss = total_loss_accum / grad_accum_steps + train_ce_loss = ce_loss_accum / grad_accum_steps + train_aux_loss_contrib = aux_loss_contrib_accum / grad_accum_steps + train_router_z_loss_contrib = router_z_loss_contrib_accum / grad_accum_steps + # step the optimizer(s) lrm = get_lr_multiplier(progress) - for opt in optimizers: - for group in opt.param_groups: - group["lr"] = group["initial_lr"] * lrm - muon_momentum = get_muon_momentum(step) - for group in muon_optimizer.param_groups: - group["momentum"] = muon_momentum - for opt in optimizers: - opt.step() + current_lr = learning_rate * init_lr_frac * lrm + for group in adamw_optimizer.param_groups: + group["lr"] = current_lr + adamw_optimizer.step() model.zero_grad(set_to_none=True) synchronize() t1 = time.time() @@ -265,26 +425,37 @@ while True: step += 1 # logging - smooth_train_loss = ema_beta * smooth_train_loss + (1 - ema_beta) * train_loss.item() # EMA the training loss + smooth_train_loss = ema_beta * smooth_train_loss + (1 - ema_beta) * train_total_loss # EMA the total loss + smooth_train_ce_loss = ema_beta * smooth_train_ce_loss + (1 - ema_beta) * train_ce_loss # EMA the CE loss debiased_smooth_loss = smooth_train_loss / (1 - ema_beta**(step + 1)) # debias the EMA + debiased_smooth_ce_loss = smooth_train_ce_loss / (1 - ema_beta**(step + 1)) # debias the EMA pct_done = 100 * progress tok_per_sec = int(total_batch_size / dt) - flops_per_sec = num_flops_per_token * total_batch_size / dt + # flops_per_sec = num_flops_per_token * total_batch_size / dt promised_flops_per_sec_h100 = 989e12 * ddp_world_size # bfloat16 H100 SXM and without 2:4 sparsity - mfu = 100 * flops_per_sec / promised_flops_per_sec_h100 # in % + # mfu = 100 * flops_per_sec / promised_flops_per_sec_h100 # in % if step > 10: total_training_time += dt # only count the time after the first 10 steps - print0(f"step {step:05d} ({pct_done:.2f}%) | loss: {debiased_smooth_loss:.6f} | lrm: {lrm:.2f} | dt: {dt * 1000:.2f}ms | tok/sec: {tok_per_sec:,} | mfu: {mfu:.2f} | total time: {total_training_time/60:.2f}m") + print0( + f"step {step:05d} ({pct_done:.2f}%) | " + f"loss: {debiased_smooth_loss:.6f} | ce: {debiased_smooth_ce_loss:.6f} | " + f"aux: {train_aux_loss_contrib:.6f} | z: {train_router_z_loss_contrib:.6f} | " + f"lr: {current_lr:.6g} | lrm: {lrm:.2f} | dt: {dt * 1000:.2f}ms | tok/sec: {tok_per_sec:,} | total time: {total_training_time/60:.2f}m" + ) if step % 10 == 0: wandb_run.log({ "step": step, - "total_training_flops": flops_so_far, + # "total_training_flops": flops_so_far, "total_training_time": total_training_time, "train/loss": debiased_smooth_loss, + "train/ce_loss": debiased_smooth_ce_loss, + "train/aux_loss_contrib": train_aux_loss_contrib, + "train/router_z_loss_contrib": train_router_z_loss_contrib, + "train/lr": current_lr, "train/lrm": lrm, "train/dt": dt, "train/tok_per_sec": tok_per_sec, - "train/mfu": mfu, + # "train/mfu": mfu, }) # print a few more stats diff --git a/speedrun_moe.sh b/speedrun_moe.sh index baa4b72..4de7227 100644 --- a/speedrun_moe.sh +++ b/speedrun_moe.sh @@ -11,10 +11,10 @@ # WANDB_RUN=speedrun screen -L -Logfile speedrun.log -S speedrun bash speedrun.sh # Default intermediate artifacts directory is in ~/.cache/nanochat-moe -export USER="limh23" +export USER="" export OMP_NUM_THREADS=1 -export NANOCHAT_BASE_DIR="/thullms/$USER/.cache/nanochat-moe" -export NANOCHAT_DATA_DIR="/thullms/$USER/.cache/nanochat-moe-data" +export NANOCHAT_BASE_DIR="$USER/.cache/nanochat-moe" +export NANOCHAT_DATA_DIR="$USER/.cache/nanochat-moe/base_data" mkdir -p $NANOCHAT_BASE_DIR mkdir -p $NANOCHAT_DATA_DIR @@ -22,7 +22,7 @@ mkdir -p $NANOCHAT_DATA_DIR # Use tokenizer from nanochat (not nanochat-moe) # Create a symlink to nanochat's tokenizer directory if it doesn't exist -NANOCHAT_TOKENIZER_DIR="$HOME/.cache/nanochat/tokenizer" +NANOCHAT_TOKENIZER_DIR="$USER/.cache/nanochat-moe/tokenizer" MOE_TOKENIZER_DIR="$NANOCHAT_BASE_DIR/tokenizer" if [ -d "$NANOCHAT_TOKENIZER_DIR" ] && [ ! -e "$MOE_TOKENIZER_DIR" ]; then echo "Creating symlink to nanochat tokenizer: $MOE_TOKENIZER_DIR -> $NANOCHAT_TOKENIZER_DIR" @@ -82,21 +82,20 @@ fi # fi # export HF_ENDPOINT=https://hf-mirror.com -# # ----------------------------------------------------------------------------- -# # Python venv setup with uv +# ----------------------------------------------------------------------------- +# Python venv setup with uv -# # install uv (if not already installed) -# if ! command -v uv &> /dev/null; then -# pip3 install uv -i https://pypi.tuna.tsinghua.edu.cn/simple -# fi -# # create a .venv local virtual environment (if it doesn't exist) -# [ -d ".venv" ] || uv venv -# # install the repo dependencies with China mirror -# export UV_INDEX_URL=https://pypi.tuna.tsinghua.edu.cn/simple -# uv sync --extra gpu -# # activate venv so that `python` uses the project's venv instead of system python -cd $HOME/nanochat-MoE -source .venv/bin/activate +# install uv (if not already installed) +if ! command -v uv &> /dev/null; then + pip3 install uv -i https://pypi.tuna.tsinghua.edu.cn/simple +fi +# create a .venv local virtual environment (if it doesn't exist) +[ -d ".venv" ] || uv venv +# install the repo dependencies with China mirror +export UV_INDEX_URL=https://pypi.tuna.tsinghua.edu.cn/simple +uv sync --extra gpu +# activate venv so that `python` uses the project's venv instead of system python +source "${USER}/nanochat/.venv/bin/activate" # # ----------------------------------------------------------------------------- # wandb setup @@ -117,32 +116,32 @@ fi # # with a bunch of system info and a timestamp that marks the start of the run. # python -m nanochat.report reset -# # ----------------------------------------------------------------------------- -# # Tokenizer +# ----------------------------------------------------------------------------- +# Tokenizer -# # Install Rust / Cargo (if not already installed) -# if ! command -v rustc &> /dev/null; then -# curl --proto '=https' --tlsv1.2 -sSf https://rsproxy.cn/rustup-init.sh | sh -s -- -y -# source "$HOME/.cargo/env" -# fi +# Install Rust / Cargo (if not already installed) +if ! command -v rustc &> /dev/null; then + curl --proto '=https' --tlsv1.2 -sSf https://rsproxy.cn/rustup-init.sh | sh -s -- -y + source "$HOME/.cargo/env" +fi -# # Build the rustbpe Tokenizer -# uv run maturin develop --release --manifest-path rustbpe/Cargo.toml +# Build the rustbpe Tokenizer +uv run maturin develop --release --manifest-path rustbpe/Cargo.toml -# # Download the first ~2B characters of pretraining dataset -# # look at dev/repackage_data_reference.py for details on how this data was prepared -# # each data shard is ~250M chars -# # so we download 2e9 / 250e6 = 8 data shards at this point -# # each shard is ~100MB of text (compressed), so this is about ~800MB of data on disk -# python -m nanochat.dataset -n 8 -# # Immediately also kick off downloading more shards in the background while tokenizer trains -# # See comment below for why 240 is the right number here -# python -m nanochat.dataset -n 240 & -# DATASET_DOWNLOAD_PID=$! -# # train the tokenizer with vocab size 2**16 = 65536 on ~2B characters of data -# python -m scripts.tok_train --max_chars=2000000000 -# # evaluate the tokenizer (report compression ratio etc.) -# python -m scripts.tok_eval +# Download the first ~2B characters of pretraining dataset +# look at dev/repackage_data_reference.py for details on how this data was prepared +# each data shard is ~250M chars +# so we download 2e9 / 250e6 = 8 data shards at this point +# each shard is ~100MB of text (compressed), so this is about ~800MB of data on disk +python -m nanochat.dataset -n 8 +# Immediately also kick off downloading more shards in the background while tokenizer trains +# See comment below for why 240 is the right number here +python -m nanochat.dataset -n 240 & +DATASET_DOWNLOAD_PID=$! +# train the tokenizer with vocab size 2**16 = 65536 on ~2B characters of data +python -m scripts.tok_train --max_chars=2000000000 +# evaluate the tokenizer (report compression ratio etc.) +python -m scripts.tok_eval # ----------------------------------------------------------------------------- # Base model (pretraining) @@ -156,19 +155,108 @@ fi # echo "Waiting for dataset download to complete..." # wait $DATASET_DOWNLOAD_PID + MODEL_DIM=${MODEL_DIM:-384} GLOBAL_BS=${GLOBAL_BS:-480} MIN_LR=${MIN_LR:-6e-5} LEARNING_RATE=${LEARNING_RATE:-6e-4} DEPTH=${DEPTH:-${N_LAYER:-6}} MODEL_TAG=${MODEL_TAG:-d${DEPTH}_min_lr${MIN_LR}_max_lr${LEARNING_RATE}} +# # Number of processes/GPUs to use +# NPROC_PER_NODE=${NPROC_PER_NODE:-8} # Number of processes/GPUs to use -NPROC_PER_NODE=${NPROC_PER_NODE:-8} +# Auto-detect number of GPUs: prefer CUDA_VISIBLE_DEVICES, then nvidia-smi, then python torch fallback +if [ -n "${CUDA_VISIBLE_DEVICES:-}" ]; then + # Count entries in CUDA_VISIBLE_DEVICES (comma-separated list) + NPROC_PER_NODE=$(echo "$CUDA_VISIBLE_DEVICES" | awk -F',' '{print NF}') +else + if command -v nvidia-smi &>/dev/null; then + NPROC_PER_NODE=$(nvidia-smi --query-gpu=name --format=csv,noheader | wc -l) + else + if command -v python3 &>/dev/null; then + NPROC_PER_NODE=$(python3 - <<'PY' +import sys +try: + import torch + print(torch.cuda.device_count()) +except Exception: + print(0) +PY +) + else + NPROC_PER_NODE=1 + fi + fi +fi +# Ensure at least 1 +NPROC_PER_NODE=${NPROC_PER_NODE:-1} +if [ "$NPROC_PER_NODE" -lt 1 ]; then + NPROC_PER_NODE=1 +fi # Master port for distributed training (default: 29500) # Set this to avoid port conflicts when running multiple torchrun tasks simultaneously # Example: MASTER_PORT=29501 bash speedrun.sh MASTER_PORT=${MASTER_PORT:-29501} LOG_TAG=${LOG_TAG:-$(date +%Y%m%d_%H%M%S)} -LOG_FILE=${LOG_FILE:-$NANOCHAT_BASE_DIR/${MODEL_TAG}_${LOG_TAG}.log} +LOG_FILE=${LOG_FILE:-$NANOCHAT_BASE_DIR/log/${MODEL_TAG}_${LOG_TAG}.log} # # # pretrain the d20 model MASTER_PORT=$MASTER_PORT torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.base_train >> "$LOG_FILE" 2>&1 + +# ----------------------------------------------------------------------------- +# Midtraining (teach the model conversation special tokens, tool use, multiple choice) + +# download 2.3MB of synthetic identity conversations to impart a personality to nanochat +# see dev/gen_sft_data.py for details on how this data was prepared and to get a sense of how you can easily tune it +# curl -L -o $NANOCHAT_BASE_DIR/identity_conversations.jsonl https://karpathy-public.s3.us-west-2.amazonaws.com/identity_conversations.jsonl + +# MID_LR=${MID_LR:-3e-4} +# LOAD_FROM_MODEL_TAG=${LOAD_FROM_MODEL_TAG:-d6_min_lr0.0002_max_lr0.002} +# WANDB_RUN=moe_mid_${MID_LR}_${LOAD_FROM_MODEL_TAG} + +# # run midtraining and eval the model +# MASTER_PORT=$MASTER_PORT torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.mid_train -- --run=$WANDB_RUN \ +# --model_tag=$LOAD_FROM_MODEL_TAG --learning_rate=$MID_LR \ +# --device_batch_size=8 --max_seq_len=1024 --total_batch_size=524288 +# MASTER_PORT=$MASTER_PORT torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.mid_train -- --run=$WANDB_RUN \ +# --model_tag=d12_e16_lr2e-4_double_dmodel \ +# --learning_rate=$MID_LR --num_epochs=1 \ +# --device_batch_size=8 --max_seq_len=1024 --total_batch_size=524288 + # --disable_aux_loss=True --disable_router_z_loss=True \ +# MASTER_PORT=$MASTER_PORT torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts_moe.chat_eval -- -i mid + +# ----------------------------------------------------------------------------- +# Supervised Finetuning (domain adaptation to each sequence all by itself per row) +SFT_LR=${SFT_LR:-9e-6} +LOAD_FROM_MID_MODEL_TAG=${LOAD_FROM_MID_MODEL_TAG:-d6_lr0.0003_modeld6_min_lr0.0002_max_lr0.002} +WANDB_RUN=moe_sft_${SFT_LR}_${LOAD_FROM_MID_MODEL_TAG} +# train sft and re-eval right away (should see a small bump) + +torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.chat_sft -- --run=$WANDB_RUN --learning_rate=$SFT_LR --init_lr_frac=1.0 \ + --model_tag=$LOAD_FROM_MID_MODEL_TAG + +# torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.chat_sft -- --run=$WANDB_RUN --learning_rate=3e-4 --init_lr_frac=1.0 --num_epochs=4 +# torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.chat_eval -- -i sft + +# chat with the model over CLI! Leave out the -p to chat interactively +# python -m scripts.chat_cli -p "Why is the sky blue?" + +# even better, chat with your model over a pretty WebUI ChatGPT style +# python -m scripts.chat_web +# ----------------------------------------------------------------------------- +# Reinforcement Learning. Optional, and currently only on GSM8K +# (optional) +RL_LR=${RL_LR:-9e-6} +LOAD_FROM_SFT_MODEL_TAG=${LOAD_FROM_SFT_MODEL_TAG:-d6_lr5e-05_init1.0_modeld6_lr0.0003_modeld6_min_lr0.0002_max_lr0.002} +WANDB_RUN=moe_rl_${RL_LR}_${LOAD_FROM_SFT_MODEL_TAG} + +# run reinforcement learning +MASTER_PORT=$MASTER_PORT torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.chat_rl -- --run=$WANDB_RUN \ + --model_tag=$LOAD_FROM_SFT_MODEL_TAG + +# eval the RL model only on GSM8K +# MASTER_PORT=$MASTER_PORT torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.chat_eval -- -i rl -a GSM8K + +# ----------------------------------------------------------------------------- +# Generate the full report by putting together all the sections +# report.md is the output and will be copied to current directory for convenience +python -m nanochat.report generate \ No newline at end of file