mirror of
https://github.com/karpathy/nanochat.git
synced 2026-01-22 19:34:17 +00:00
midtraining, sft, rl scripts and the final version of the nanochat-Mo
This commit is contained in:
parent
9b05e7c625
commit
708385a0d2
4
.gitignore
vendored
4
.gitignore
vendored
|
|
@ -22,4 +22,6 @@ hf-export/**/*
|
|||
lm_eval_pretrain/
|
||||
benchmark.md
|
||||
d*.png
|
||||
loss*.png
|
||||
loss*.png
|
||||
|
||||
wandb/
|
||||
|
|
@ -75,6 +75,11 @@ def build_model(checkpoint_dir, step, device, phase):
|
|||
- meta data saved during base model training
|
||||
"""
|
||||
assert phase in ["train", "eval"], f"Invalid phase: {phase}"
|
||||
|
||||
# Allow callers to pass strings like "cuda" / "cuda:0".
|
||||
if isinstance(device, str):
|
||||
device = torch.device(device)
|
||||
|
||||
model_data, optimizer_data, meta_data = load_checkpoint(checkpoint_dir, step, device, load_optimizer=False)
|
||||
if device.type in {"cpu", "mps"}:
|
||||
# Convert bfloat16 tensors to float for CPU inference
|
||||
|
|
@ -85,6 +90,25 @@ def build_model(checkpoint_dir, step, device, phase):
|
|||
# Hack: fix torch compile issue, which prepends all keys with _orig_mod.
|
||||
model_data = {k.removeprefix("_orig_mod."): v for k, v in model_data.items()}
|
||||
model_config_kwargs = meta_data["model_config"]
|
||||
|
||||
# Backward-compat: older checkpoints may not store MoE fields (e.g. n_exp) in meta.json.
|
||||
# If missing, infer from checkpoint tensor shapes to avoid load_state_dict size mismatches.
|
||||
if "n_exp" not in model_config_kwargs:
|
||||
inferred_n_exp = None
|
||||
for k, v in model_data.items():
|
||||
if not isinstance(v, torch.Tensor):
|
||||
continue
|
||||
# Router gate weight: [n_exp, n_embd]
|
||||
if k.endswith("mlp.router.w_g.weight") and v.ndim == 2:
|
||||
inferred_n_exp = int(v.shape[0])
|
||||
break
|
||||
# Expert FC: [n_exp, n_embd, 4*n_embd]
|
||||
if k.endswith("mlp.experts.c_fc") and v.ndim == 3:
|
||||
inferred_n_exp = int(v.shape[0])
|
||||
break
|
||||
if inferred_n_exp is not None:
|
||||
model_config_kwargs["n_exp"] = inferred_n_exp
|
||||
log0(f"Inferred missing model_config.n_exp={inferred_n_exp} from checkpoint weights")
|
||||
log0(f"Building model with config: {model_config_kwargs}")
|
||||
model_config = GPTConfig(**model_config_kwargs)
|
||||
with torch.device("meta"):
|
||||
|
|
@ -137,6 +161,9 @@ def find_last_step(checkpoint_dir):
|
|||
# convenience functions that take into account nanochat's directory structure
|
||||
|
||||
def load_model_from_dir(checkpoints_dir, device, phase, model_tag=None, step=None):
|
||||
# Normalize for callers that pass strings (e.g. "cuda").
|
||||
if isinstance(device, str):
|
||||
device = torch.device(device)
|
||||
if model_tag is None:
|
||||
# guess the model tag by defaulting to the largest model
|
||||
model_tag = find_largest_model(checkpoints_dir)
|
||||
|
|
|
|||
|
|
@ -190,6 +190,13 @@ class Engine:
|
|||
self.model = model
|
||||
self.tokenizer = tokenizer # needed for tool use
|
||||
|
||||
@staticmethod
|
||||
def _unwrap_logits(model_out):
|
||||
"""Some model variants return (logits, loss). Engine only needs logits."""
|
||||
if isinstance(model_out, (tuple, list)):
|
||||
return model_out[0]
|
||||
return model_out
|
||||
|
||||
@torch.inference_mode()
|
||||
def generate(self, tokens, num_samples=1, max_tokens=None, temperature=1.0, top_k=None, seed=42):
|
||||
"""Same as generate, but does single prefill and then clones the KV cache."""
|
||||
|
|
@ -216,7 +223,7 @@ class Engine:
|
|||
**kv_model_kwargs,
|
||||
)
|
||||
ids = torch.tensor([tokens], dtype=torch.long, device=device)
|
||||
logits = self.model.forward(ids, kv_cache=kv_cache_prefill)
|
||||
logits = self._unwrap_logits(self.model.forward(ids, kv_cache=kv_cache_prefill))
|
||||
logits = logits[:, -1, :]
|
||||
next_ids = sample_next_token(logits, rng, temperature, top_k) # (B, 1)
|
||||
sampled_tokens = next_ids[:, 0].tolist()
|
||||
|
|
@ -253,7 +260,7 @@ class Engine:
|
|||
first_iteration = False
|
||||
else:
|
||||
# Forward the model and get the next token for each row
|
||||
logits = self.model.forward(ids, kv_cache=kv_cache_decode) # (B, T, vocab_size)
|
||||
logits = self._unwrap_logits(self.model.forward(ids, kv_cache=kv_cache_decode)) # (B, T, vocab_size)
|
||||
logits = logits[:, -1, :] # (B, vocab_size) at last time step
|
||||
next_ids = sample_next_token(logits, rng, temperature, top_k) # (B, 1)
|
||||
sampled_tokens = next_ids[:, 0].tolist()
|
||||
|
|
|
|||
103
nanochat/gpt.py
103
nanochat/gpt.py
|
|
@ -93,7 +93,7 @@ class CausalSelfAttention(nn.Module):
|
|||
self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size))
|
||||
.view(1, 1, config.block_size, config.block_size))
|
||||
|
||||
def forward(self, x):
|
||||
def forward(self, x, kv_cache=None, layer_idx: int | None = None):
|
||||
B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
|
||||
|
||||
# calculate query, key, values for all heads in batch and move head forward to be the batch dim
|
||||
|
|
@ -102,6 +102,30 @@ class CausalSelfAttention(nn.Module):
|
|||
q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
|
||||
v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
|
||||
|
||||
# Optional KV-cache path for fast autoregressive decoding in Engine.
|
||||
# We only use the cache for decode-time single-token forward (T==1).
|
||||
# For prefill (T>1), we still populate the cache but use normal causal attention.
|
||||
if kv_cache is not None:
|
||||
assert layer_idx is not None, "layer_idx is required when using kv_cache"
|
||||
# Insert the new keys/values and get a view of the full cache so far.
|
||||
full_k, full_v = kv_cache.insert_kv(layer_idx, k, v) # (B, nh, T_total, hs)
|
||||
if T == 1:
|
||||
# Query is length-1; no causal mask needed (no future keys are present).
|
||||
if self.flash:
|
||||
y = torch.nn.functional.scaled_dot_product_attention(
|
||||
q, full_k, full_v,
|
||||
attn_mask=None,
|
||||
dropout_p=0.0,
|
||||
is_causal=False,
|
||||
)
|
||||
else:
|
||||
att = (q @ full_k.transpose(-2, -1)) * (1.0 / math.sqrt(full_k.size(-1)))
|
||||
att = F.softmax(att, dim=-1)
|
||||
y = att @ full_v
|
||||
y = y.transpose(1, 2).contiguous().view(B, T, C)
|
||||
y = self.resid_dropout(self.c_proj(y))
|
||||
return y
|
||||
|
||||
# causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
|
||||
if self.flash:
|
||||
# efficient attention using Flash Attention CUDA kernels
|
||||
|
|
@ -378,8 +402,8 @@ class Block(nn.Module):
|
|||
else:
|
||||
self.mlp = MLP(config)
|
||||
|
||||
def forward(self, x):
|
||||
x = x + self.attn(self.ln_1(x))
|
||||
def forward(self, x, kv_cache=None, layer_idx: int | None = None):
|
||||
x = x + self.attn(self.ln_1(x), kv_cache=kv_cache, layer_idx=layer_idx)
|
||||
x = x + self.mlp(self.ln_2(x))
|
||||
return x
|
||||
|
||||
|
|
@ -443,6 +467,9 @@ class GPT(nn.Module):
|
|||
n_params -= self.transformer.wpe.weight.numel()
|
||||
return n_params
|
||||
|
||||
def get_device(self):
|
||||
return self.transformer.wte.weight.device
|
||||
|
||||
@torch.no_grad()
|
||||
def _init_weights(self, module):
|
||||
# optionally use switch transformer-style initialization
|
||||
|
|
@ -585,18 +612,32 @@ class GPT(nn.Module):
|
|||
return optimizer
|
||||
|
||||
|
||||
def forward(self, idx, targets=None, return_full_logits: bool = False):
|
||||
def forward(
|
||||
self,
|
||||
idx,
|
||||
targets=None,
|
||||
return_full_logits: bool = False,
|
||||
loss_reduction: str = 'mean',
|
||||
return_moe_losses: bool = False,
|
||||
kv_cache=None,
|
||||
):
|
||||
device = idx.device
|
||||
b, t = idx.size()
|
||||
assert t <= self.config.block_size, f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}"
|
||||
pos = torch.arange(0, t, dtype=torch.long, device=device) # shape (t)
|
||||
|
||||
# When decoding with a KV cache, the position indices must continue from kv_cache.pos.
|
||||
if kv_cache is not None and t == 1:
|
||||
pos0 = int(kv_cache.get_pos())
|
||||
pos = torch.tensor([pos0], dtype=torch.long, device=device)
|
||||
else:
|
||||
pos = torch.arange(0, t, dtype=torch.long, device=device) # shape (t)
|
||||
|
||||
# forward the GPT model itself
|
||||
tok_emb = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd)
|
||||
pos_emb = self.transformer.wpe(pos) # position embeddings of shape (t, n_embd)
|
||||
x = self.transformer.drop(tok_emb + pos_emb)
|
||||
for block in self.transformer.h:
|
||||
x = block(x)
|
||||
for layer_idx, block in enumerate(self.transformer.h):
|
||||
x = block(x, kv_cache=kv_cache, layer_idx=layer_idx)
|
||||
x = self.transformer.ln_f(x)
|
||||
|
||||
if targets is not None or return_full_logits:
|
||||
|
|
@ -604,15 +645,47 @@ class GPT(nn.Module):
|
|||
logits = self.lm_head(x)
|
||||
if targets is not None:
|
||||
# if we are given some desired targets also calculate the loss
|
||||
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
|
||||
if loss_reduction not in {'mean', 'none'}:
|
||||
raise ValueError(f"Unsupported loss_reduction: {loss_reduction}")
|
||||
loss = F.cross_entropy(
|
||||
logits.view(-1, logits.size(-1)),
|
||||
targets.view(-1),
|
||||
ignore_index=-1,
|
||||
reduction=loss_reduction,
|
||||
)
|
||||
if loss_reduction == 'none':
|
||||
# Return token-level NLL (B, T). Do NOT add MoE auxiliary losses here.
|
||||
loss = loss.view(b, t)
|
||||
|
||||
# add the auxiliary load balancing loss and router z loss to the main loss
|
||||
if self.config.n_exp > 1 and self.config.use_aux_loss:
|
||||
loss += self.config.aux_loss_weight * MANAGER.aggregate_aux_loss()
|
||||
MANAGER.reset_aux_loss()
|
||||
if self.config.n_exp > 1 and self.config.use_router_z_loss:
|
||||
loss += self.config.router_z_loss_weight * MANAGER.aggregate_router_z_loss()
|
||||
MANAGER.reset_router_z_loss()
|
||||
moe_aux = None
|
||||
moe_z = None
|
||||
if return_moe_losses and self.config.n_exp > 1:
|
||||
# Return the *weighted* contributions that would normally be added
|
||||
# to the scalar training loss. Caller decides how to scale/accumulate.
|
||||
if self.config.use_aux_loss:
|
||||
moe_aux = self.config.aux_loss_weight * MANAGER.aggregate_aux_loss()
|
||||
else:
|
||||
moe_aux = torch.zeros((), device=loss.device, dtype=loss.dtype)
|
||||
if self.config.use_router_z_loss:
|
||||
moe_z = self.config.router_z_loss_weight * MANAGER.aggregate_router_z_loss()
|
||||
else:
|
||||
moe_z = torch.zeros((), device=loss.device, dtype=loss.dtype)
|
||||
|
||||
# Always reset to avoid cross-step accumulation.
|
||||
if self.config.n_exp > 1:
|
||||
MANAGER.reset_aux_loss()
|
||||
MANAGER.reset_router_z_loss()
|
||||
|
||||
if return_moe_losses:
|
||||
return logits, loss, moe_aux, moe_z
|
||||
else:
|
||||
# add the auxiliary load balancing loss and router z loss to the main loss
|
||||
if self.config.n_exp > 1 and self.config.use_aux_loss:
|
||||
loss += self.config.aux_loss_weight * MANAGER.aggregate_aux_loss()
|
||||
MANAGER.reset_aux_loss()
|
||||
if self.config.n_exp > 1 and self.config.use_router_z_loss:
|
||||
loss += self.config.router_z_loss_weight * MANAGER.aggregate_router_z_loss()
|
||||
MANAGER.reset_router_z_loss()
|
||||
else:
|
||||
loss = None
|
||||
if self.config.n_exp > 1:
|
||||
|
|
|
|||
|
|
@ -5,6 +5,15 @@ import math
|
|||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
|
||||
def _unpack_xy(batch):
|
||||
# Supports loaders that yield (x, y) or (x, y, *extras)
|
||||
if isinstance(batch, (tuple, list)):
|
||||
if len(batch) < 2:
|
||||
raise ValueError(f"Expected batch to have at least 2 items (x, y), got {len(batch)}")
|
||||
return batch[0], batch[1]
|
||||
raise ValueError(f"Expected batch to be a tuple/list (x, y, ...), got {type(batch)}")
|
||||
|
||||
@torch.no_grad()
|
||||
def evaluate_bpb(model, batches, steps, token_bytes=None):
|
||||
"""
|
||||
|
|
@ -37,7 +46,7 @@ def evaluate_bpb(model, batches, steps, token_bytes=None):
|
|||
batch_iter = iter(batches)
|
||||
|
||||
for _ in range(steps):
|
||||
x, y = next(batch_iter)
|
||||
x, y = _unpack_xy(next(batch_iter))
|
||||
# Model returns (logits, loss) tuple
|
||||
logits, loss = model(x, y)
|
||||
# Calculate per-token loss from logits for bpb calculation
|
||||
|
|
|
|||
|
|
@ -7,6 +7,10 @@ class MOEManager:
|
|||
def __init__(self):
|
||||
self.aux_loss = []
|
||||
self.router_z_loss = []
|
||||
# Cache the most recently aggregated sums for logging/debugging.
|
||||
# These values persist across reset_* calls.
|
||||
self.last_aux_loss_sum = 0.0
|
||||
self.last_router_z_loss_sum = 0.0
|
||||
|
||||
def reset_aux_loss(self):
|
||||
self.aux_loss = []
|
||||
|
|
@ -21,10 +25,14 @@ class MOEManager:
|
|||
self.router_z_loss.append(loss)
|
||||
|
||||
def aggregate_aux_loss(self):
|
||||
return sum(self.aux_loss)
|
||||
s = sum(self.aux_loss)
|
||||
self.last_aux_loss_sum = s
|
||||
return s
|
||||
|
||||
def aggregate_router_z_loss(self):
|
||||
return sum(self.router_z_loss)
|
||||
s = sum(self.router_z_loss)
|
||||
self.last_router_z_loss_sum = s
|
||||
return s
|
||||
|
||||
MANAGER = MOEManager()
|
||||
|
||||
|
|
|
|||
|
|
@ -113,8 +113,9 @@ def run_categorical_eval(task_object, tokenizer, model, batch_size, max_problems
|
|||
prompt_ids = torch.tensor(padded_prompt_ids, dtype=torch.long, device=device)
|
||||
|
||||
# Get the logits for the whole batch of conversations in parallel (efficiency win here)
|
||||
# GPT.forward returns (logits, loss); we need full logits across the sequence for categorical eval.
|
||||
with torch.no_grad():
|
||||
logits = model(prompt_ids) # (B, T, V)
|
||||
logits, _ = model(prompt_ids, return_full_logits=True) # (B, T, V)
|
||||
|
||||
# Focus on the available answer on just the letters corresponding to choices
|
||||
# Note that this helps the evaluation a lot because it specifically narrows the focus to only the available letters
|
||||
|
|
|
|||
|
|
@ -23,7 +23,7 @@ import wandb
|
|||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from nanochat.common import compute_init, compute_cleanup, print0, get_base_dir, DummyWandb
|
||||
from nanochat.common import compute_init, compute_cleanup, print0, get_base_dir, DummyWandb, autodetect_device_type
|
||||
from nanochat.checkpoint_manager import save_checkpoint, load_model
|
||||
from nanochat.engine import Engine
|
||||
from tasks.gsm8k import GSM8K
|
||||
|
|
@ -32,6 +32,8 @@ from tasks.gsm8k import GSM8K
|
|||
run = "dummy" # wandb run name
|
||||
source = "sft" # mid|sft
|
||||
dtype = "bfloat16"
|
||||
model_tag = None
|
||||
device_type = "" # cuda|cpu|mps (empty => autodetect)
|
||||
device_batch_size = 8 # no forward pass will go above this to not OOM
|
||||
examples_per_step = 16 # in total and across all ranks (note: examples, not samples/completions!)
|
||||
num_samples = 16 # number of samples per example (/question)
|
||||
|
|
@ -41,12 +43,21 @@ top_k = 50 # TODO: try None?
|
|||
unembedding_lr = 0.004
|
||||
embedding_lr = 0.2
|
||||
matrix_lr = 0.02
|
||||
learning_rate = 9e-5
|
||||
betas = (0.9, 0.95)
|
||||
weight_decay = 0.0
|
||||
init_lr_frac = 0.05
|
||||
num_epochs = 1 # how many epochs of gsm8k to train on
|
||||
save_every = 60 # every how many steps to save the model
|
||||
eval_every = 60 # every how many steps to evaluate the model for val pass@k
|
||||
eval_examples = 400 # number of examples used for evaluating pass@k
|
||||
|
||||
# Debug knobs for MoE loss components (defaults preserve existing behavior)
|
||||
disable_aux_loss = False
|
||||
disable_router_z_loss = False
|
||||
override_aux_loss_weight = -1.0 # <0 means do not override
|
||||
override_router_z_loss_weight = -1.0 # <0 means do not override
|
||||
|
||||
# now allow CLI to override the settings via the configurator lol
|
||||
config_keys = [k for k,v in globals().items() if not k.startswith('_') and isinstance(v, (int, float, bool, str))]
|
||||
exec(open(os.path.join('nanochat', 'configurator.py')).read()) # overrides from command line or config file
|
||||
|
|
@ -54,6 +65,7 @@ user_config = {k: globals()[k] for k in config_keys} # will be useful for loggin
|
|||
# -----------------------------------------------------------------------------
|
||||
|
||||
# Init compute/precision
|
||||
device_type = autodetect_device_type() if device_type == "" else device_type
|
||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init()
|
||||
master_process = ddp_rank == 0 # this process will do logging, checkpointing etc.
|
||||
dtype = torch.float32 if dtype == 'float32' else torch.bfloat16
|
||||
|
|
@ -64,9 +76,26 @@ use_dummy_wandb = run == "dummy" or not master_process
|
|||
wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat-rl", name=run, config=user_config)
|
||||
|
||||
# Init model and tokenizer
|
||||
model, tokenizer, meta = load_model(source, device, phase="eval")
|
||||
model, tokenizer, meta = load_model(source, device, phase="eval", model_tag=model_tag)
|
||||
engine = Engine(model, tokenizer) # for sampling rollouts
|
||||
|
||||
# Optional overrides for MoE auxiliary losses (useful when total loss plateaus)
|
||||
if hasattr(model, "config"):
|
||||
if disable_aux_loss and getattr(model.config, "n_exp", 1) > 1:
|
||||
print0("Disabling MoE aux loss for this midtraining run")
|
||||
model.config.use_aux_loss = False
|
||||
if disable_router_z_loss and getattr(model.config, "n_exp", 1) > 1:
|
||||
print0("Disabling MoE router z loss for this midtraining run")
|
||||
model.config.use_router_z_loss = False
|
||||
if override_aux_loss_weight >= 0 and getattr(model.config, "n_exp", 1) > 1:
|
||||
print0(f"Overriding MoE aux_loss_weight to {override_aux_loss_weight}")
|
||||
model.config.aux_loss_weight = float(override_aux_loss_weight)
|
||||
if override_router_z_loss_weight >= 0 and getattr(model.config, "n_exp", 1) > 1:
|
||||
print0(f"Overriding MoE router_z_loss_weight to {override_router_z_loss_weight}")
|
||||
model.config.router_z_loss_weight = float(override_router_z_loss_weight)
|
||||
|
||||
print0(f"MoE training loss is configured to use aux_loss: {getattr(model.config, 'use_aux_loss', False)} with weight {getattr(model.config, 'aux_loss_weight', 0.0)}, router_z_loss: {getattr(model.config, 'use_router_z_loss', False)} with weight {getattr(model.config, 'router_z_loss_weight', 0.0)}")
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Rollout / sampling generator loop that yields batches of examples for training
|
||||
|
||||
|
|
@ -188,18 +217,29 @@ def run_gsm8k_eval(task, tokenizer, engine,
|
|||
# Training loop
|
||||
|
||||
# Init the optimizer
|
||||
optimizers = model.setup_optimizers(
|
||||
unembedding_lr=unembedding_lr,
|
||||
embedding_lr=embedding_lr,
|
||||
matrix_lr=matrix_lr,
|
||||
adamw_optimizer = model.configure_optimizers(
|
||||
weight_decay=weight_decay,
|
||||
learning_rate=learning_rate,
|
||||
betas=betas,
|
||||
device_type=device_type,
|
||||
)
|
||||
|
||||
# Set the initial learning rate as a fraction of the base learning rate
|
||||
optimizers = [adamw_optimizer]
|
||||
for opt in optimizers:
|
||||
for group in opt.param_groups:
|
||||
group["lr"] = group["lr"] * init_lr_frac
|
||||
group["initial_lr"] = group["lr"] # save the initial learning so we can decay easily later
|
||||
group["initial_lr"] = group["lr"]
|
||||
# optimizers = model.setup_optimizers(
|
||||
# unembedding_lr=unembedding_lr,
|
||||
# embedding_lr=embedding_lr,
|
||||
# matrix_lr=matrix_lr,
|
||||
# weight_decay=weight_decay,
|
||||
# )
|
||||
|
||||
# # Set the initial learning rate as a fraction of the base learning rate
|
||||
# for opt in optimizers:
|
||||
# for group in opt.param_groups:
|
||||
# group["lr"] = group["lr"] * init_lr_frac
|
||||
# group["initial_lr"] = group["lr"] # save the initial learning so we can decay easily later
|
||||
|
||||
# Learning rate scheduler: simple rampdown to zero over num_steps
|
||||
def get_lr_multiplier(it):
|
||||
|
|
@ -241,6 +281,13 @@ for step in range(num_steps):
|
|||
# Forward/Backward on rollouts over multiple examples in the dataset
|
||||
rewards_list = []
|
||||
sequence_lengths = []
|
||||
|
||||
# Track loss components for logging
|
||||
pg_loss_list = []
|
||||
aux_contrib_list = []
|
||||
z_contrib_list = []
|
||||
total_loss_list = []
|
||||
|
||||
for example_step in range(examples_per_rank):
|
||||
# Get one batch corresponding to one example in the training dataset
|
||||
sequences_all, inputs_all, targets_all, rewards_all, advantages_all = next(batch_iterator)
|
||||
|
|
@ -258,7 +305,12 @@ for step in range(num_steps):
|
|||
advantages = advantages_all[b0:b1]
|
||||
# Calculate log probabilities. Note that the loss calculates NLL = -logp, so we negate
|
||||
with autocast_ctx:
|
||||
_, loss2d = model(inputs, targets, loss_reduction='none')
|
||||
_, loss2d, aux_contrib, z_contrib = model(
|
||||
inputs,
|
||||
targets,
|
||||
loss_reduction='none',
|
||||
return_moe_losses=True,
|
||||
)
|
||||
logp = -loss2d.view_as(inputs) # (B, T)
|
||||
# Calculate the PG objective. Note that ignore_index=-1 ensures that invalid tokens have loss 0.
|
||||
pg_obj = (logp * advantages.unsqueeze(-1)).sum()
|
||||
|
|
@ -267,9 +319,28 @@ for step in range(num_steps):
|
|||
pg_obj = pg_obj / (num_valid * num_passes * examples_per_rank)
|
||||
# Note, there is no need to add PPO ratio+clip because we are on policy
|
||||
# Finally, formulate the loss that we want to minimize (instead of objective we wish to maximize)
|
||||
loss = -pg_obj
|
||||
pg_loss = -pg_obj
|
||||
|
||||
# Add MoE routing regularizers as separate scalar terms.
|
||||
# We scale by num_passes/examples_per_rank to match the pg_loss normalization.
|
||||
moe_scale = 1.0 / (num_passes * examples_per_rank)
|
||||
aux_term = (aux_contrib if aux_contrib is not None else 0.0) * moe_scale
|
||||
z_term = (z_contrib if z_contrib is not None else 0.0) * moe_scale
|
||||
loss = pg_loss + aux_term + z_term
|
||||
|
||||
loss.backward()
|
||||
print0(f"Step {step}/{num_steps} | Example step {example_step} | Pass {pass_idx} | loss: {loss.item():.6f} | Average reward: {rewards.mean().item()}")
|
||||
|
||||
# For logging (detach to avoid autograd sync issues)
|
||||
pg_loss_list.append(float(pg_loss.detach().item()))
|
||||
aux_contrib_list.append(float((aux_term.detach().item()) if torch.is_tensor(aux_term) else aux_term))
|
||||
z_contrib_list.append(float((z_term.detach().item()) if torch.is_tensor(z_term) else z_term))
|
||||
total_loss_list.append(float(loss.detach().item()))
|
||||
|
||||
print0(
|
||||
f"Step {step}/{num_steps} | Example step {example_step} | Pass {pass_idx} | "
|
||||
f"loss: {loss.item():.6f} | pg: {pg_loss.item():.6f} | aux: {float(aux_term):.6f} | z: {float(z_term):.6f} | "
|
||||
f"Average reward: {rewards.mean().item()}"
|
||||
)
|
||||
# For logging
|
||||
rewards_list.append(rewards_all.mean().item())
|
||||
sequence_lengths.extend(len(seq) for seq in sequences_all)
|
||||
|
|
@ -277,6 +348,12 @@ for step in range(num_steps):
|
|||
# A bunch of logging for how the rollouts went this step
|
||||
mean_reward = sum(rewards_list) / len(rewards_list)
|
||||
mean_sequence_length = sum(sequence_lengths) / len(sequence_lengths)
|
||||
|
||||
mean_pg_loss = sum(pg_loss_list) / max(len(pg_loss_list), 1)
|
||||
mean_aux = sum(aux_contrib_list) / max(len(aux_contrib_list), 1)
|
||||
mean_z = sum(z_contrib_list) / max(len(z_contrib_list), 1)
|
||||
mean_total_loss = sum(total_loss_list) / max(len(total_loss_list), 1)
|
||||
|
||||
if ddp: # aggregate across ranks
|
||||
mean_reward_tensor = torch.tensor(mean_reward, dtype=torch.float, device=device)
|
||||
mean_sequence_length_tensor = torch.tensor(mean_sequence_length, dtype=torch.float, device=device)
|
||||
|
|
@ -284,11 +361,28 @@ for step in range(num_steps):
|
|||
dist.all_reduce(mean_sequence_length_tensor, op=dist.ReduceOp.AVG)
|
||||
mean_reward = mean_reward_tensor.item()
|
||||
mean_sequence_length = mean_sequence_length_tensor.item()
|
||||
|
||||
mean_pg_loss_tensor = torch.tensor(mean_pg_loss, dtype=torch.float, device=device)
|
||||
mean_aux_tensor = torch.tensor(mean_aux, dtype=torch.float, device=device)
|
||||
mean_z_tensor = torch.tensor(mean_z, dtype=torch.float, device=device)
|
||||
mean_total_loss_tensor = torch.tensor(mean_total_loss, dtype=torch.float, device=device)
|
||||
dist.all_reduce(mean_pg_loss_tensor, op=dist.ReduceOp.AVG)
|
||||
dist.all_reduce(mean_aux_tensor, op=dist.ReduceOp.AVG)
|
||||
dist.all_reduce(mean_z_tensor, op=dist.ReduceOp.AVG)
|
||||
dist.all_reduce(mean_total_loss_tensor, op=dist.ReduceOp.AVG)
|
||||
mean_pg_loss = mean_pg_loss_tensor.item()
|
||||
mean_aux = mean_aux_tensor.item()
|
||||
mean_z = mean_z_tensor.item()
|
||||
mean_total_loss = mean_total_loss_tensor.item()
|
||||
print0(f"Step {step}/{num_steps} | Average reward: {mean_reward} | Average sequence length: {mean_sequence_length:.2f}")
|
||||
wandb_run.log({
|
||||
"step": step,
|
||||
"reward": mean_reward,
|
||||
"sequence_length": mean_sequence_length,
|
||||
"train/pg_loss": mean_pg_loss,
|
||||
"train/aux_loss_contrib": mean_aux,
|
||||
"train/router_z_loss_contrib": mean_z,
|
||||
"train/total_loss": mean_total_loss,
|
||||
})
|
||||
|
||||
# Update the model parameters
|
||||
|
|
@ -308,8 +402,17 @@ for step in range(num_steps):
|
|||
if master_process and ((step > 0 and step % save_every == 0) or step == num_steps - 1):
|
||||
base_dir = get_base_dir()
|
||||
depth = model.config.n_layer
|
||||
model_tag = f"d{depth}" # base the model tag on the depth of the base model
|
||||
checkpoint_dir = os.path.join(base_dir, "chatrl_checkpoints", model_tag)
|
||||
# model_tag = f"d{depth}" # base the model tag on the depth of the base model
|
||||
if disable_aux_loss:
|
||||
aux_tag = "noaux"
|
||||
else:
|
||||
aux_tag = "aux"
|
||||
if disable_router_z_loss:
|
||||
z_tag = "noz"
|
||||
else:
|
||||
z_tag = "z"
|
||||
output_dirname = f"d{depth}_{aux_tag}_{z_tag}_lr{learning_rate}_model{model_tag}"
|
||||
checkpoint_dir = os.path.join(base_dir, "chatrl_checkpoints", output_dirname)
|
||||
model_config_kwargs = model.config.__dict__ # slightly naughty, abusing the simplicity of GPTConfig, TODO nicer
|
||||
save_checkpoint(
|
||||
checkpoint_dir,
|
||||
|
|
|
|||
|
|
@ -50,11 +50,20 @@ embedding_lr = 0.2
|
|||
matrix_lr = 0.02
|
||||
weight_decay = 0.0
|
||||
init_lr_frac = 0.02
|
||||
learning_rate = 9e-5 # 5e-5 for d6 model, 9e-5 for d12 model
|
||||
betas = (0.9, 0.95)
|
||||
# evaluation and logging there of
|
||||
eval_every = 100
|
||||
eval_steps = 100
|
||||
eval_metrics_every = 200
|
||||
eval_metrics_max_problems = 1024
|
||||
|
||||
# Debug knobs for MoE loss components (defaults preserve existing behavior)
|
||||
disable_aux_loss = False
|
||||
disable_router_z_loss = False
|
||||
override_aux_loss_weight = -1.0 # <0 means do not override
|
||||
override_router_z_loss_weight = -1.0 # <0 means do not override
|
||||
|
||||
# now allow CLI to override the settings via the configurator lol
|
||||
config_keys = [k for k,v in globals().items() if not k.startswith('_') and isinstance(v, (int, float, bool, str))]
|
||||
exec(open(os.path.join('nanochat', 'configurator.py')).read()) # overrides from command line or config file
|
||||
|
|
@ -78,6 +87,23 @@ orig_model = model # original, uncompiled model
|
|||
# model = torch.compile(model, dynamic=True) # doesn't work super well because of variable lengths of inputs
|
||||
engine = Engine(model, tokenizer) # will be used for inline model evaluation only
|
||||
|
||||
# Optional overrides for MoE auxiliary losses (useful when total loss plateaus)
|
||||
if hasattr(model, "config"):
|
||||
if disable_aux_loss and getattr(model.config, "n_exp", 1) > 1:
|
||||
print0("Disabling MoE aux loss for this midtraining run")
|
||||
model.config.use_aux_loss = False
|
||||
if disable_router_z_loss and getattr(model.config, "n_exp", 1) > 1:
|
||||
print0("Disabling MoE router z loss for this midtraining run")
|
||||
model.config.use_router_z_loss = False
|
||||
if override_aux_loss_weight >= 0 and getattr(model.config, "n_exp", 1) > 1:
|
||||
print0(f"Overriding MoE aux_loss_weight to {override_aux_loss_weight}")
|
||||
model.config.aux_loss_weight = float(override_aux_loss_weight)
|
||||
if override_router_z_loss_weight >= 0 and getattr(model.config, "n_exp", 1) > 1:
|
||||
print0(f"Overriding MoE router_z_loss_weight to {override_router_z_loss_weight}")
|
||||
model.config.router_z_loss_weight = float(override_router_z_loss_weight)
|
||||
|
||||
print0(f"MoE training loss is configured to use aux_loss: {getattr(model.config, 'use_aux_loss', False)} with weight {getattr(model.config, 'aux_loss_weight', 0.0)}, router_z_loss: {getattr(model.config, 'use_router_z_loss', False)} with weight {getattr(model.config, 'router_z_loss_weight', 0.0)}")
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Task data mixture we'll train on
|
||||
identity_conversations_filepath = os.path.join(get_base_dir(), "identity_conversations.jsonl")
|
||||
|
|
@ -97,6 +123,10 @@ val_ds = SmolTalk(split="test") # general conversations, 24K rows (though we don
|
|||
|
||||
def sft_data_generator(dataset, batch_size):
|
||||
pad_token_id = tokenizer.encode_special("<|assistant_end|>") # use <|assistant_end|> as the pad token is ok, these positions are masked in the loss
|
||||
# Ensure we never feed sequences longer than the model block size.
|
||||
# render_conversation returns a sequence that includes a BOS token, and we then create
|
||||
# inputs/targets by shifting by 1, so cap to block_size+1 tokens.
|
||||
max_tokens = int(getattr(model.config, "block_size", getattr(model.config, "sequence_len", 1024))) + 1
|
||||
# prepares a list of tokenized conversations into a batch and yields
|
||||
def collate_and_yield(batch):
|
||||
nrows = len(batch)
|
||||
|
|
@ -121,7 +151,7 @@ def sft_data_generator(dataset, batch_size):
|
|||
while True:
|
||||
for i in range(ddp_rank, len(dataset), ddp_world_size):
|
||||
doc = dataset[i]
|
||||
ids, mask = tokenizer.render_conversation(doc)
|
||||
ids, mask = tokenizer.render_conversation(doc, max_tokens=max_tokens)
|
||||
batch.append((ids, mask))
|
||||
if len(batch) == batch_size:
|
||||
yield collate_and_yield(batch)
|
||||
|
|
@ -144,18 +174,30 @@ build_val_loader = lambda: sft_data_generator(val_ds, batch_size=device_batch_si
|
|||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Initialize the Optimizer
|
||||
|
||||
optimizers = model.setup_optimizers(
|
||||
unembedding_lr=unembedding_lr,
|
||||
embedding_lr=embedding_lr,
|
||||
matrix_lr=matrix_lr,
|
||||
# Initialize the Optimizer (AdamW for all parameters) - BEFORE DDP wrapping (matching nanoMoE)
|
||||
adamw_optimizer = model.configure_optimizers(
|
||||
weight_decay=weight_decay,
|
||||
learning_rate=learning_rate,
|
||||
betas=betas,
|
||||
device_type=device_type,
|
||||
)
|
||||
# Set the initial learning rate as a fraction of the base learning rate
|
||||
optimizers = [adamw_optimizer]
|
||||
for opt in optimizers:
|
||||
for group in opt.param_groups:
|
||||
group["lr"] = group["lr"] * init_lr_frac
|
||||
group["initial_lr"] = group["lr"] # save the initial learning so we can decay easily later
|
||||
group["initial_lr"] = group["lr"]
|
||||
|
||||
# optimizers = model.setup_optimizers(
|
||||
# unembedding_lr=unembedding_lr,
|
||||
# embedding_lr=embedding_lr,
|
||||
# matrix_lr=matrix_lr,
|
||||
# weight_decay=weight_decay,
|
||||
# )
|
||||
# # Set the initial learning rate as a fraction of the base learning rate
|
||||
# for opt in optimizers:
|
||||
# for group in opt.param_groups:
|
||||
# group["lr"] = group["lr"] * init_lr_frac
|
||||
# group["initial_lr"] = group["lr"] # save the initial learning so we can decay easily later
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Training loop
|
||||
|
|
@ -179,7 +221,7 @@ for step in range(num_iterations):
|
|||
for _ in range(eval_steps):
|
||||
val_inputs, val_targets = next(val_iter)
|
||||
with torch.no_grad(), autocast_ctx:
|
||||
loss = model(val_inputs, val_targets)
|
||||
_, loss = model(val_inputs, val_targets)
|
||||
losses.append(loss)
|
||||
val_loss = torch.stack(losses).mean() # average over eval_steps
|
||||
if ddp:
|
||||
|
|
@ -216,7 +258,7 @@ for step in range(num_iterations):
|
|||
for micro_step in range(grad_accum_steps):
|
||||
train_inputs, train_targets = next(train_iter)
|
||||
with autocast_ctx:
|
||||
loss = model(train_inputs, train_targets)
|
||||
_, loss = model(train_inputs, train_targets)
|
||||
train_loss = loss.detach() # for logging
|
||||
loss = loss / grad_accum_steps # each .backward() is a grad sum => normalize loss here
|
||||
loss.backward() # accumulate the gradient
|
||||
|
|
@ -251,8 +293,18 @@ for step in range(num_iterations):
|
|||
if master_process:
|
||||
base_dir = get_base_dir()
|
||||
depth = model.config.n_layer
|
||||
model_tag = f"d{depth}" # base the model tag on the depth of the base model
|
||||
checkpoint_dir = os.path.join(base_dir, "chatsft_checkpoints", model_tag)
|
||||
# output-dirname = f"d{depth}" # base the model tag on the depth of the base model
|
||||
if disable_aux_loss:
|
||||
aux_tag = "noaux"
|
||||
else:
|
||||
aux_tag = "aux"
|
||||
if disable_router_z_loss:
|
||||
z_tag = "noz"
|
||||
else:
|
||||
z_tag = "z"
|
||||
# output_dirname = f"d{depth}_{aux_tag}_{z_tag}_lr{learning_rate}_model{model_tag}"
|
||||
output_dirname = f"d{depth}_lr{learning_rate}_init{init_lr_frac}_model{model_tag}"
|
||||
checkpoint_dir = os.path.join(base_dir, "chatsft_checkpoints", output_dirname)
|
||||
model_config_kwargs = model.config.__dict__ # slightly naughty, abusing the simplicity of GPTConfig, TODO nicer
|
||||
save_checkpoint(
|
||||
checkpoint_dir,
|
||||
|
|
|
|||
|
|
@ -15,12 +15,14 @@ os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
|
|||
import time
|
||||
import wandb
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from contextlib import nullcontext
|
||||
from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, get_base_dir, autodetect_device_type
|
||||
from nanochat.tokenizer import get_token_bytes
|
||||
from nanochat.checkpoint_manager import save_checkpoint
|
||||
from nanochat.loss_eval import evaluate_bpb
|
||||
from nanochat.checkpoint_manager import load_model
|
||||
from nanochat.manager import MANAGER
|
||||
import torch.distributed as dist
|
||||
|
||||
from tasks.common import TaskMixture
|
||||
|
|
@ -37,13 +39,24 @@ model_tag = None # model tag to load the model from (base model or midtrained mo
|
|||
step = None # step to load the model from (base model or midtrained model)
|
||||
dtype = "bfloat16"
|
||||
num_iterations = -1 # explicit number of steps of the optimization (-1 = disable)
|
||||
num_epochs = 1 # number of full passes over the midtraining dataset (only used if num_iterations < 0)
|
||||
max_seq_len = 2048
|
||||
device_batch_size = 32
|
||||
unembedding_lr = 0.004
|
||||
embedding_lr = 0.2
|
||||
matrix_lr = 0.02
|
||||
init_lr_frac = 1.0 # initial learning rate is this fraction of the base learning rate
|
||||
learning_rate = 3e-4
|
||||
betas = (0.9, 0.95)
|
||||
weight_decay = 0.0
|
||||
warmup_ratio = 0.0 # LR warmup (ratio of total training progress in [0, 1]). 0 disables warmup.
|
||||
|
||||
# Debug knobs for MoE loss components (defaults preserve existing behavior)
|
||||
disable_aux_loss = False
|
||||
disable_router_z_loss = False
|
||||
override_aux_loss_weight = -1.0 # <0 means do not override
|
||||
override_router_z_loss_weight = -1.0 # <0 means do not override
|
||||
|
||||
eval_every = 150 # -1 = disable
|
||||
eval_tokens = 20*524288
|
||||
total_batch_size = 524288
|
||||
|
|
@ -67,13 +80,30 @@ wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat-mi
|
|||
|
||||
# Load the model and tokenizer
|
||||
model, tokenizer, meta = load_model("base", device, phase="train", model_tag=model_tag, step=step)
|
||||
|
||||
# Optional overrides for MoE auxiliary losses (useful when total loss plateaus)
|
||||
if hasattr(model, "config"):
|
||||
if disable_aux_loss and getattr(model.config, "n_exp", 1) > 1:
|
||||
print0("Disabling MoE aux loss for this midtraining run")
|
||||
model.config.use_aux_loss = False
|
||||
if disable_router_z_loss and getattr(model.config, "n_exp", 1) > 1:
|
||||
print0("Disabling MoE router z loss for this midtraining run")
|
||||
model.config.use_router_z_loss = False
|
||||
if override_aux_loss_weight >= 0 and getattr(model.config, "n_exp", 1) > 1:
|
||||
print0(f"Overriding MoE aux_loss_weight to {override_aux_loss_weight}")
|
||||
model.config.aux_loss_weight = float(override_aux_loss_weight)
|
||||
if override_router_z_loss_weight >= 0 and getattr(model.config, "n_exp", 1) > 1:
|
||||
print0(f"Overriding MoE router_z_loss_weight to {override_router_z_loss_weight}")
|
||||
model.config.router_z_loss_weight = float(override_router_z_loss_weight)
|
||||
|
||||
print0(f"MoE training loss is configured to use aux_loss: {getattr(model.config, 'use_aux_loss', False)} with weight {getattr(model.config, 'aux_loss_weight', 0.0)}, router_z_loss: {getattr(model.config, 'use_router_z_loss', False)} with weight {getattr(model.config, 'router_z_loss_weight', 0.0)}")
|
||||
pretrain_batch_size = meta.get("device_batch_size", None)
|
||||
if pretrain_batch_size is not None and device_batch_size > pretrain_batch_size:
|
||||
print0(f"FOOTGUN WARNING: base model training used device_batch_size {pretrain_batch_size}, did you pass in a good --device_batch_size to this script?")
|
||||
orig_model = model
|
||||
model = torch.compile(model, dynamic=False)
|
||||
depth = model.config.n_layer
|
||||
num_flops_per_token = model.estimate_flops()
|
||||
# num_flops_per_token = model.estimate_flops(max_seq_len)
|
||||
tokens_per_fwdbwd = device_batch_size * max_seq_len # tokens per iteration for a single rank
|
||||
world_tokens_per_fwdbwd = tokens_per_fwdbwd * ddp_world_size # total tokens per iteration for all ranks
|
||||
assert total_batch_size % world_tokens_per_fwdbwd == 0
|
||||
|
|
@ -83,14 +113,26 @@ print0(f"Tokens / micro-batch: {world_tokens_per_fwdbwd:,}")
|
|||
print0(f"Total batch size {total_batch_size:,} => gradient accumulation steps: {grad_accum_steps}")
|
||||
token_bytes = get_token_bytes(device=device)
|
||||
|
||||
# Initialize the Optimizer (Muon for Linear layers, AdamW for embedding and lm_head)
|
||||
optimizers = model.setup_optimizers(unembedding_lr=unembedding_lr, embedding_lr=embedding_lr, matrix_lr=matrix_lr, weight_decay=weight_decay)
|
||||
adamw_optimizer, muon_optimizer = optimizers
|
||||
# Override the initial learning rate as a fraction of the base learning rate
|
||||
for opt in optimizers:
|
||||
for group in opt.param_groups:
|
||||
group["lr"] = group["lr"] * init_lr_frac
|
||||
group["initial_lr"] = group["lr"] # save the initial learning so we can decay easily later
|
||||
# Sanity print: tokenizer ids must fit inside model vocab (esp. when vocab_size=50304 padded GPT-2)
|
||||
print0(f"Model vocab_size: {model.config.vocab_size}")
|
||||
print0(f"Tokenizer vocab_size: {tokenizer.get_vocab_size()}")
|
||||
|
||||
# Initialize the Optimizer (AdamW for all parameters) - BEFORE DDP wrapping (matching nanoMoE)
|
||||
adamw_optimizer = model.configure_optimizers(
|
||||
weight_decay=weight_decay,
|
||||
learning_rate=learning_rate,
|
||||
betas=betas,
|
||||
device_type=device_type,
|
||||
)
|
||||
optimizers = [adamw_optimizer]
|
||||
# # Initialize the Optimizer (Muon for Linear layers, AdamW for embedding and lm_head)
|
||||
# optimizers = model.setup_optimizers(unembedding_lr=unembedding_lr, embedding_lr=embedding_lr, matrix_lr=matrix_lr, weight_decay=weight_decay)
|
||||
# adamw_optimizer, muon_optimizer = optimizers
|
||||
# # Override the initial learning rate as a fraction of the base learning rate
|
||||
# for opt in optimizers:
|
||||
# for group in opt.param_groups:
|
||||
# group["lr"] = group["lr"] * init_lr_frac
|
||||
# group["initial_lr"] = group["lr"] # save the initial learning so we can decay easily later
|
||||
|
||||
# Midtraining data mixture and DataLoader
|
||||
base_dir = get_base_dir()
|
||||
|
|
@ -114,14 +156,17 @@ val_dataset = TaskMixture([
|
|||
# these two global variables and update them from within the data generator.
|
||||
last_step = False # we will toggle this to True when we reach the end of the dataset
|
||||
approx_progress = 0.0 # will go from 0 to 1 over the course of the epoch
|
||||
current_epoch = 1 # will go from 1 to num_epochs
|
||||
def mid_data_generator(split):
|
||||
global last_step, approx_progress
|
||||
global last_step, approx_progress, current_epoch
|
||||
assert split in {"train", "val"}, "split must be 'train' or 'val'"
|
||||
dataset = train_dataset if split == "train" else val_dataset
|
||||
dataset_size = len(dataset)
|
||||
assert dataset_size > 0
|
||||
needed_tokens = device_batch_size * max_seq_len + 1 # to form one training batch of inputs,targets
|
||||
token_buffer = deque()
|
||||
# A lightweight resumable state dict (similar spirit to base_train.py)
|
||||
dataloader_state_dict = {"split": split}
|
||||
# CUDA supports memory pinning for faster transfers between CPU and GPU:
|
||||
scratch = torch.empty(needed_tokens, dtype=torch.int64, pin_memory=(device_type == "cuda"))
|
||||
cursor = ddp_rank # increments by ddp_world_size each time, so each rank processes unique documents
|
||||
|
|
@ -136,7 +181,13 @@ def mid_data_generator(split):
|
|||
if cursor >= dataset_size:
|
||||
cursor -= dataset_size # wrap around for another epoch
|
||||
if split == "train":
|
||||
last_step = True # toggle last_step to True, which will terminate the training loop
|
||||
# Track epochs (unless num_iterations explicitly caps steps)
|
||||
if num_iterations < 0:
|
||||
current_epoch += 1
|
||||
if current_epoch > num_epochs:
|
||||
last_step = True # terminate after requested epochs
|
||||
else:
|
||||
last_step = True # legacy behavior when num_iterations is set elsewhere
|
||||
# Stopping condition to respect num_iterations, if given
|
||||
it += 1
|
||||
if num_iterations > 0 and it >= num_iterations:
|
||||
|
|
@ -146,14 +197,39 @@ def mid_data_generator(split):
|
|||
scratch[i] = token_buffer.popleft()
|
||||
inputs_cpu = scratch[:-1].to(dtype=torch.int32)
|
||||
targets_cpu = scratch[1:]
|
||||
|
||||
# Early token-id range check on CPU to avoid opaque torch.compile CUDA OOB asserts.
|
||||
# Only do this for a few batches to keep overhead minimal.
|
||||
if it <= 5:
|
||||
min_id = int(inputs_cpu.min().item())
|
||||
max_id = int(inputs_cpu.max().item())
|
||||
vocab_limit = int(model.config.vocab_size)
|
||||
if not (0 <= min_id and max_id < vocab_limit):
|
||||
raise ValueError(
|
||||
f"Token id out of range: min={min_id}, max={max_id}, expected within [0, {vocab_limit}). "
|
||||
f"Tokenizer vocab_size={int(tokenizer.get_vocab_size())}. "
|
||||
"This usually means the tokenizer used for midtraining doesn't match the model vocab."
|
||||
)
|
||||
|
||||
inputs = inputs_cpu.view(device_batch_size, max_seq_len).to(device=device, dtype=torch.int32, non_blocking=True)
|
||||
targets = targets_cpu.view(device_batch_size, max_seq_len).to(device=device, dtype=torch.int64, non_blocking=True)
|
||||
if split == "train":
|
||||
if num_iterations > 0:
|
||||
approx_progress = it / num_iterations # calculate progress from the max number of iterations
|
||||
else:
|
||||
approx_progress = cursor / dataset_size # approximate progress as a fraction of the dataset
|
||||
yield inputs, targets
|
||||
# progress across epochs, in [0, 1]
|
||||
denom = max(float(num_epochs), 1.0)
|
||||
approx_progress = min(((current_epoch - 1) + (cursor / dataset_size)) / denom, 1.0)
|
||||
dataloader_state_dict.update({
|
||||
"cursor": int(cursor),
|
||||
"it": int(it),
|
||||
"current_epoch": int(current_epoch),
|
||||
"last_step": bool(last_step),
|
||||
"approx_progress": float(approx_progress),
|
||||
# Keep the remaining buffered tokens for exact resume semantics.
|
||||
"token_buffer": list(token_buffer),
|
||||
})
|
||||
yield inputs, targets, dataloader_state_dict
|
||||
|
||||
train_loader = mid_data_generator("train")
|
||||
build_val_loader = lambda: mid_data_generator("val")
|
||||
|
|
@ -161,8 +237,15 @@ progress = 0 # will go from 0 to 1 over the course of the epoch
|
|||
|
||||
# Learning rate scheduler
|
||||
def get_lr_multiplier(progress):
|
||||
# first 80% of training: no decay, then linearly ramp down to 0.
|
||||
return 1 if progress < 0.8 else 1 - (progress - 0.8) / 0.2
|
||||
# Warmup: linearly ramp from 0 -> 1 over the first `warmup_ratio` portion of training.
|
||||
if warmup_ratio and warmup_ratio > 0:
|
||||
warmup_mult = min(max(progress / warmup_ratio, 0.0), 1.0)
|
||||
else:
|
||||
warmup_mult = 1.0
|
||||
|
||||
# Decay: first 80% of training no decay, then linearly ramp down to 0.
|
||||
decay_mult = 1 if progress < 0.8 else 1 - (progress - 0.8) / 0.2
|
||||
return warmup_mult * decay_mult
|
||||
|
||||
# Momentum scheduler for Muon optimizer
|
||||
def get_muon_momentum(it):
|
||||
|
|
@ -172,14 +255,16 @@ def get_muon_momentum(it):
|
|||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Training loop
|
||||
x, y = next(train_loader) # prefetch the very first batch of data
|
||||
x, y, dataloader_state_dict = next(train_loader) # prefetch the very first batch of data
|
||||
min_val_bpb = float("inf")
|
||||
smooth_train_loss = 0 # EMA of training loss
|
||||
smooth_train_ce_loss = 0 # EMA of CE loss
|
||||
ema_beta = 0.9 # EMA decay factor
|
||||
total_training_time = 0 # total wall-clock time of training
|
||||
val_bpb = None # populated during evaluation (keep defined for checkpoint metadata)
|
||||
step = 0
|
||||
while True:
|
||||
flops_so_far = num_flops_per_token * total_batch_size * step
|
||||
# flops_so_far = num_flops_per_token * total_batch_size * step
|
||||
|
||||
# Synchronize last_step across all ranks to avoid hangs in the distributed setting
|
||||
if ddp:
|
||||
|
|
@ -199,7 +284,7 @@ while True:
|
|||
min_val_bpb = val_bpb
|
||||
wandb_run.log({
|
||||
"step": step,
|
||||
"total_training_flops": flops_so_far,
|
||||
# "total_training_flops": flops_so_far,
|
||||
"total_training_time": total_training_time,
|
||||
"val/bpb": val_bpb,
|
||||
})
|
||||
|
|
@ -207,25 +292,72 @@ while True:
|
|||
|
||||
# save checkpoint at the end of the run (only on master process)
|
||||
if master_process and last_step and not dry_run:
|
||||
output_dirname = f"d{depth}" # e.g. d12
|
||||
# output_dirname = f"d{depth}" # e.g. d12
|
||||
if disable_aux_loss:
|
||||
aux_tag = "noaux"
|
||||
else:
|
||||
aux_tag = "aux"
|
||||
if disable_router_z_loss:
|
||||
z_tag = "noz"
|
||||
else:
|
||||
z_tag = "z"
|
||||
# output_dirname = f"d{depth}_{aux_tag}_{z_tag}_lr{learning_rate}_model{model_tag}"
|
||||
output_dirname = f"d{depth}_lr{learning_rate}_model{model_tag}"
|
||||
checkpoint_dir = os.path.join(base_dir, "mid_checkpoints", output_dirname)
|
||||
|
||||
# Save metadata in the same shape as base_train.py for consistency.
|
||||
model_config_for_save = {}
|
||||
for k in [
|
||||
# Core GPT config
|
||||
"block_size",
|
||||
"vocab_size",
|
||||
"n_layer",
|
||||
"n_head",
|
||||
"n_kv_head",
|
||||
"n_embd",
|
||||
"dropout",
|
||||
"bias",
|
||||
# MoE config (if present)
|
||||
"n_exp",
|
||||
"top_k",
|
||||
"use_aux_loss",
|
||||
"use_router_z_loss",
|
||||
"use_noisy_top_k",
|
||||
"aux_loss_weight",
|
||||
"router_z_loss_weight",
|
||||
"train_capacity",
|
||||
"eval_capacity",
|
||||
"min_capacity",
|
||||
"stride",
|
||||
"use_switch_tfm_init",
|
||||
"switch_tfm_init_scale",
|
||||
"router_use_full_prec",
|
||||
]:
|
||||
if hasattr(orig_model.config, k):
|
||||
v = getattr(orig_model.config, k)
|
||||
if isinstance(v, (int, float, bool, str)):
|
||||
model_config_for_save[k] = v
|
||||
|
||||
save_checkpoint(
|
||||
checkpoint_dir,
|
||||
step,
|
||||
orig_model.state_dict(),
|
||||
[opt.state_dict() for opt in optimizers], # TODO: make sure saving across ranks is done correctly
|
||||
adamw_optimizer.state_dict(), # TODO: make sure saving across ranks is done correctly
|
||||
{
|
||||
"step": step,
|
||||
"val_bpb": val_bpb, # loss at last step
|
||||
"model_config": {
|
||||
"sequence_len": max_seq_len,
|
||||
"vocab_size": tokenizer.get_vocab_size(),
|
||||
"n_layer": depth,
|
||||
"n_head": model.config.n_head,
|
||||
"n_kv_head": model.config.n_kv_head,
|
||||
"n_embd": model.config.n_embd,
|
||||
},
|
||||
"model_config": model_config_for_save,
|
||||
"user_config": user_config, # inputs to the training script
|
||||
"device_batch_size": device_batch_size,
|
||||
"max_seq_len": max_seq_len,
|
||||
"loop_state": {
|
||||
"min_val_bpb": min_val_bpb,
|
||||
"smooth_train_loss": smooth_train_loss,
|
||||
"smooth_train_ce_loss": smooth_train_ce_loss,
|
||||
"total_training_time": total_training_time,
|
||||
"progress": progress,
|
||||
"current_epoch": int(current_epoch),
|
||||
},
|
||||
"dataloader_state_dict": dataloader_state_dict,
|
||||
}
|
||||
)
|
||||
|
||||
|
|
@ -237,24 +369,52 @@ while True:
|
|||
# evaluate the gradient
|
||||
synchronize()
|
||||
t0 = time.time()
|
||||
total_loss_accum = 0.0
|
||||
ce_loss_accum = 0.0
|
||||
aux_loss_contrib_accum = 0.0
|
||||
router_z_loss_contrib_accum = 0.0
|
||||
for micro_step in range(grad_accum_steps):
|
||||
with autocast_ctx:
|
||||
loss = model(x, y)
|
||||
train_loss = loss.detach() # for logging
|
||||
loss = loss / grad_accum_steps # each .backward() is a grad sum => normalize loss here
|
||||
logits, total_loss = model(x, y) # returns (logits, loss)
|
||||
# Cross-entropy only (language modeling objective)
|
||||
ce_loss = F.cross_entropy(
|
||||
logits.view(-1, logits.size(-1)),
|
||||
y.view(-1),
|
||||
ignore_index=-1,
|
||||
)
|
||||
# Cache logging values (average across micro-steps)
|
||||
total_loss_accum += float(total_loss.detach().item())
|
||||
ce_loss_accum += float(ce_loss.detach().item())
|
||||
aux_sum = getattr(MANAGER, "last_aux_loss_sum", 0.0)
|
||||
z_sum = getattr(MANAGER, "last_router_z_loss_sum", 0.0)
|
||||
# Convert sums into the *weighted* contribution that is actually added to total_loss
|
||||
if getattr(model.config, "n_exp", 1) > 1 and getattr(model.config, "use_aux_loss", False):
|
||||
if torch.is_tensor(aux_sum):
|
||||
aux_loss_contrib_accum += float(getattr(model.config, "aux_loss_weight", 0.0)) * aux_sum.detach().item()
|
||||
else:
|
||||
aux_loss_contrib_accum += float(getattr(model.config, "aux_loss_weight", 0.0)) * float(aux_sum)
|
||||
if getattr(model.config, "n_exp", 1) > 1 and getattr(model.config, "use_router_z_loss", False):
|
||||
if torch.is_tensor(z_sum):
|
||||
router_z_loss_contrib_accum += float(getattr(model.config, "router_z_loss_weight", 0.0)) * z_sum.detach().item()
|
||||
else:
|
||||
router_z_loss_contrib_accum += float(getattr(model.config, "router_z_loss_weight", 0.0)) * float(z_sum)
|
||||
|
||||
loss = total_loss / grad_accum_steps # each .backward() is a grad sum => normalize loss here
|
||||
loss.backward()
|
||||
x, y = next(train_loader) # prefetch the next batch while the GPU is busy with forward/backward
|
||||
x, y, dataloader_state_dict = next(train_loader) # prefetch the next batch while the GPU is busy with forward/backward
|
||||
progress = max(progress, approx_progress) # only increase progress monotonically
|
||||
# step the optimizers
|
||||
|
||||
# micro-step averages for logging
|
||||
train_total_loss = total_loss_accum / grad_accum_steps
|
||||
train_ce_loss = ce_loss_accum / grad_accum_steps
|
||||
train_aux_loss_contrib = aux_loss_contrib_accum / grad_accum_steps
|
||||
train_router_z_loss_contrib = router_z_loss_contrib_accum / grad_accum_steps
|
||||
# step the optimizer(s)
|
||||
lrm = get_lr_multiplier(progress)
|
||||
for opt in optimizers:
|
||||
for group in opt.param_groups:
|
||||
group["lr"] = group["initial_lr"] * lrm
|
||||
muon_momentum = get_muon_momentum(step)
|
||||
for group in muon_optimizer.param_groups:
|
||||
group["momentum"] = muon_momentum
|
||||
for opt in optimizers:
|
||||
opt.step()
|
||||
current_lr = learning_rate * init_lr_frac * lrm
|
||||
for group in adamw_optimizer.param_groups:
|
||||
group["lr"] = current_lr
|
||||
adamw_optimizer.step()
|
||||
model.zero_grad(set_to_none=True)
|
||||
synchronize()
|
||||
t1 = time.time()
|
||||
|
|
@ -265,26 +425,37 @@ while True:
|
|||
step += 1
|
||||
|
||||
# logging
|
||||
smooth_train_loss = ema_beta * smooth_train_loss + (1 - ema_beta) * train_loss.item() # EMA the training loss
|
||||
smooth_train_loss = ema_beta * smooth_train_loss + (1 - ema_beta) * train_total_loss # EMA the total loss
|
||||
smooth_train_ce_loss = ema_beta * smooth_train_ce_loss + (1 - ema_beta) * train_ce_loss # EMA the CE loss
|
||||
debiased_smooth_loss = smooth_train_loss / (1 - ema_beta**(step + 1)) # debias the EMA
|
||||
debiased_smooth_ce_loss = smooth_train_ce_loss / (1 - ema_beta**(step + 1)) # debias the EMA
|
||||
pct_done = 100 * progress
|
||||
tok_per_sec = int(total_batch_size / dt)
|
||||
flops_per_sec = num_flops_per_token * total_batch_size / dt
|
||||
# flops_per_sec = num_flops_per_token * total_batch_size / dt
|
||||
promised_flops_per_sec_h100 = 989e12 * ddp_world_size # bfloat16 H100 SXM and without 2:4 sparsity
|
||||
mfu = 100 * flops_per_sec / promised_flops_per_sec_h100 # in %
|
||||
# mfu = 100 * flops_per_sec / promised_flops_per_sec_h100 # in %
|
||||
if step > 10:
|
||||
total_training_time += dt # only count the time after the first 10 steps
|
||||
print0(f"step {step:05d} ({pct_done:.2f}%) | loss: {debiased_smooth_loss:.6f} | lrm: {lrm:.2f} | dt: {dt * 1000:.2f}ms | tok/sec: {tok_per_sec:,} | mfu: {mfu:.2f} | total time: {total_training_time/60:.2f}m")
|
||||
print0(
|
||||
f"step {step:05d} ({pct_done:.2f}%) | "
|
||||
f"loss: {debiased_smooth_loss:.6f} | ce: {debiased_smooth_ce_loss:.6f} | "
|
||||
f"aux: {train_aux_loss_contrib:.6f} | z: {train_router_z_loss_contrib:.6f} | "
|
||||
f"lr: {current_lr:.6g} | lrm: {lrm:.2f} | dt: {dt * 1000:.2f}ms | tok/sec: {tok_per_sec:,} | total time: {total_training_time/60:.2f}m"
|
||||
)
|
||||
if step % 10 == 0:
|
||||
wandb_run.log({
|
||||
"step": step,
|
||||
"total_training_flops": flops_so_far,
|
||||
# "total_training_flops": flops_so_far,
|
||||
"total_training_time": total_training_time,
|
||||
"train/loss": debiased_smooth_loss,
|
||||
"train/ce_loss": debiased_smooth_ce_loss,
|
||||
"train/aux_loss_contrib": train_aux_loss_contrib,
|
||||
"train/router_z_loss_contrib": train_router_z_loss_contrib,
|
||||
"train/lr": current_lr,
|
||||
"train/lrm": lrm,
|
||||
"train/dt": dt,
|
||||
"train/tok_per_sec": tok_per_sec,
|
||||
"train/mfu": mfu,
|
||||
# "train/mfu": mfu,
|
||||
})
|
||||
|
||||
# print a few more stats
|
||||
|
|
|
|||
174
speedrun_moe.sh
174
speedrun_moe.sh
|
|
@ -11,10 +11,10 @@
|
|||
# WANDB_RUN=speedrun screen -L -Logfile speedrun.log -S speedrun bash speedrun.sh
|
||||
|
||||
# Default intermediate artifacts directory is in ~/.cache/nanochat-moe
|
||||
export USER="limh23"
|
||||
export USER=""
|
||||
export OMP_NUM_THREADS=1
|
||||
export NANOCHAT_BASE_DIR="/thullms/$USER/.cache/nanochat-moe"
|
||||
export NANOCHAT_DATA_DIR="/thullms/$USER/.cache/nanochat-moe-data"
|
||||
export NANOCHAT_BASE_DIR="$USER/.cache/nanochat-moe"
|
||||
export NANOCHAT_DATA_DIR="$USER/.cache/nanochat-moe/base_data"
|
||||
mkdir -p $NANOCHAT_BASE_DIR
|
||||
mkdir -p $NANOCHAT_DATA_DIR
|
||||
|
||||
|
|
@ -22,7 +22,7 @@ mkdir -p $NANOCHAT_DATA_DIR
|
|||
|
||||
# Use tokenizer from nanochat (not nanochat-moe)
|
||||
# Create a symlink to nanochat's tokenizer directory if it doesn't exist
|
||||
NANOCHAT_TOKENIZER_DIR="$HOME/.cache/nanochat/tokenizer"
|
||||
NANOCHAT_TOKENIZER_DIR="$USER/.cache/nanochat-moe/tokenizer"
|
||||
MOE_TOKENIZER_DIR="$NANOCHAT_BASE_DIR/tokenizer"
|
||||
if [ -d "$NANOCHAT_TOKENIZER_DIR" ] && [ ! -e "$MOE_TOKENIZER_DIR" ]; then
|
||||
echo "Creating symlink to nanochat tokenizer: $MOE_TOKENIZER_DIR -> $NANOCHAT_TOKENIZER_DIR"
|
||||
|
|
@ -82,21 +82,20 @@ fi
|
|||
# fi
|
||||
# export HF_ENDPOINT=https://hf-mirror.com
|
||||
|
||||
# # -----------------------------------------------------------------------------
|
||||
# # Python venv setup with uv
|
||||
# -----------------------------------------------------------------------------
|
||||
# Python venv setup with uv
|
||||
|
||||
# # install uv (if not already installed)
|
||||
# if ! command -v uv &> /dev/null; then
|
||||
# pip3 install uv -i https://pypi.tuna.tsinghua.edu.cn/simple
|
||||
# fi
|
||||
# # create a .venv local virtual environment (if it doesn't exist)
|
||||
# [ -d ".venv" ] || uv venv
|
||||
# # install the repo dependencies with China mirror
|
||||
# export UV_INDEX_URL=https://pypi.tuna.tsinghua.edu.cn/simple
|
||||
# uv sync --extra gpu
|
||||
# # activate venv so that `python` uses the project's venv instead of system python
|
||||
cd $HOME/nanochat-MoE
|
||||
source .venv/bin/activate
|
||||
# install uv (if not already installed)
|
||||
if ! command -v uv &> /dev/null; then
|
||||
pip3 install uv -i https://pypi.tuna.tsinghua.edu.cn/simple
|
||||
fi
|
||||
# create a .venv local virtual environment (if it doesn't exist)
|
||||
[ -d ".venv" ] || uv venv
|
||||
# install the repo dependencies with China mirror
|
||||
export UV_INDEX_URL=https://pypi.tuna.tsinghua.edu.cn/simple
|
||||
uv sync --extra gpu
|
||||
# activate venv so that `python` uses the project's venv instead of system python
|
||||
source "${USER}/nanochat/.venv/bin/activate"
|
||||
|
||||
# # -----------------------------------------------------------------------------
|
||||
# wandb setup
|
||||
|
|
@ -117,32 +116,32 @@ fi
|
|||
# # with a bunch of system info and a timestamp that marks the start of the run.
|
||||
# python -m nanochat.report reset
|
||||
|
||||
# # -----------------------------------------------------------------------------
|
||||
# # Tokenizer
|
||||
# -----------------------------------------------------------------------------
|
||||
# Tokenizer
|
||||
|
||||
# # Install Rust / Cargo (if not already installed)
|
||||
# if ! command -v rustc &> /dev/null; then
|
||||
# curl --proto '=https' --tlsv1.2 -sSf https://rsproxy.cn/rustup-init.sh | sh -s -- -y
|
||||
# source "$HOME/.cargo/env"
|
||||
# fi
|
||||
# Install Rust / Cargo (if not already installed)
|
||||
if ! command -v rustc &> /dev/null; then
|
||||
curl --proto '=https' --tlsv1.2 -sSf https://rsproxy.cn/rustup-init.sh | sh -s -- -y
|
||||
source "$HOME/.cargo/env"
|
||||
fi
|
||||
|
||||
# # Build the rustbpe Tokenizer
|
||||
# uv run maturin develop --release --manifest-path rustbpe/Cargo.toml
|
||||
# Build the rustbpe Tokenizer
|
||||
uv run maturin develop --release --manifest-path rustbpe/Cargo.toml
|
||||
|
||||
# # Download the first ~2B characters of pretraining dataset
|
||||
# # look at dev/repackage_data_reference.py for details on how this data was prepared
|
||||
# # each data shard is ~250M chars
|
||||
# # so we download 2e9 / 250e6 = 8 data shards at this point
|
||||
# # each shard is ~100MB of text (compressed), so this is about ~800MB of data on disk
|
||||
# python -m nanochat.dataset -n 8
|
||||
# # Immediately also kick off downloading more shards in the background while tokenizer trains
|
||||
# # See comment below for why 240 is the right number here
|
||||
# python -m nanochat.dataset -n 240 &
|
||||
# DATASET_DOWNLOAD_PID=$!
|
||||
# # train the tokenizer with vocab size 2**16 = 65536 on ~2B characters of data
|
||||
# python -m scripts.tok_train --max_chars=2000000000
|
||||
# # evaluate the tokenizer (report compression ratio etc.)
|
||||
# python -m scripts.tok_eval
|
||||
# Download the first ~2B characters of pretraining dataset
|
||||
# look at dev/repackage_data_reference.py for details on how this data was prepared
|
||||
# each data shard is ~250M chars
|
||||
# so we download 2e9 / 250e6 = 8 data shards at this point
|
||||
# each shard is ~100MB of text (compressed), so this is about ~800MB of data on disk
|
||||
python -m nanochat.dataset -n 8
|
||||
# Immediately also kick off downloading more shards in the background while tokenizer trains
|
||||
# See comment below for why 240 is the right number here
|
||||
python -m nanochat.dataset -n 240 &
|
||||
DATASET_DOWNLOAD_PID=$!
|
||||
# train the tokenizer with vocab size 2**16 = 65536 on ~2B characters of data
|
||||
python -m scripts.tok_train --max_chars=2000000000
|
||||
# evaluate the tokenizer (report compression ratio etc.)
|
||||
python -m scripts.tok_eval
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Base model (pretraining)
|
||||
|
|
@ -156,19 +155,108 @@ fi
|
|||
# echo "Waiting for dataset download to complete..."
|
||||
# wait $DATASET_DOWNLOAD_PID
|
||||
|
||||
|
||||
MODEL_DIM=${MODEL_DIM:-384}
|
||||
GLOBAL_BS=${GLOBAL_BS:-480}
|
||||
MIN_LR=${MIN_LR:-6e-5}
|
||||
LEARNING_RATE=${LEARNING_RATE:-6e-4}
|
||||
DEPTH=${DEPTH:-${N_LAYER:-6}}
|
||||
MODEL_TAG=${MODEL_TAG:-d${DEPTH}_min_lr${MIN_LR}_max_lr${LEARNING_RATE}}
|
||||
# # Number of processes/GPUs to use
|
||||
# NPROC_PER_NODE=${NPROC_PER_NODE:-8}
|
||||
# Number of processes/GPUs to use
|
||||
NPROC_PER_NODE=${NPROC_PER_NODE:-8}
|
||||
# Auto-detect number of GPUs: prefer CUDA_VISIBLE_DEVICES, then nvidia-smi, then python torch fallback
|
||||
if [ -n "${CUDA_VISIBLE_DEVICES:-}" ]; then
|
||||
# Count entries in CUDA_VISIBLE_DEVICES (comma-separated list)
|
||||
NPROC_PER_NODE=$(echo "$CUDA_VISIBLE_DEVICES" | awk -F',' '{print NF}')
|
||||
else
|
||||
if command -v nvidia-smi &>/dev/null; then
|
||||
NPROC_PER_NODE=$(nvidia-smi --query-gpu=name --format=csv,noheader | wc -l)
|
||||
else
|
||||
if command -v python3 &>/dev/null; then
|
||||
NPROC_PER_NODE=$(python3 - <<'PY'
|
||||
import sys
|
||||
try:
|
||||
import torch
|
||||
print(torch.cuda.device_count())
|
||||
except Exception:
|
||||
print(0)
|
||||
PY
|
||||
)
|
||||
else
|
||||
NPROC_PER_NODE=1
|
||||
fi
|
||||
fi
|
||||
fi
|
||||
# Ensure at least 1
|
||||
NPROC_PER_NODE=${NPROC_PER_NODE:-1}
|
||||
if [ "$NPROC_PER_NODE" -lt 1 ]; then
|
||||
NPROC_PER_NODE=1
|
||||
fi
|
||||
# Master port for distributed training (default: 29500)
|
||||
# Set this to avoid port conflicts when running multiple torchrun tasks simultaneously
|
||||
# Example: MASTER_PORT=29501 bash speedrun.sh
|
||||
MASTER_PORT=${MASTER_PORT:-29501}
|
||||
LOG_TAG=${LOG_TAG:-$(date +%Y%m%d_%H%M%S)}
|
||||
LOG_FILE=${LOG_FILE:-$NANOCHAT_BASE_DIR/${MODEL_TAG}_${LOG_TAG}.log}
|
||||
LOG_FILE=${LOG_FILE:-$NANOCHAT_BASE_DIR/log/${MODEL_TAG}_${LOG_TAG}.log}
|
||||
# # # pretrain the d20 model
|
||||
MASTER_PORT=$MASTER_PORT torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.base_train >> "$LOG_FILE" 2>&1
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Midtraining (teach the model conversation special tokens, tool use, multiple choice)
|
||||
|
||||
# download 2.3MB of synthetic identity conversations to impart a personality to nanochat
|
||||
# see dev/gen_sft_data.py for details on how this data was prepared and to get a sense of how you can easily tune it
|
||||
# curl -L -o $NANOCHAT_BASE_DIR/identity_conversations.jsonl https://karpathy-public.s3.us-west-2.amazonaws.com/identity_conversations.jsonl
|
||||
|
||||
# MID_LR=${MID_LR:-3e-4}
|
||||
# LOAD_FROM_MODEL_TAG=${LOAD_FROM_MODEL_TAG:-d6_min_lr0.0002_max_lr0.002}
|
||||
# WANDB_RUN=moe_mid_${MID_LR}_${LOAD_FROM_MODEL_TAG}
|
||||
|
||||
# # run midtraining and eval the model
|
||||
# MASTER_PORT=$MASTER_PORT torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.mid_train -- --run=$WANDB_RUN \
|
||||
# --model_tag=$LOAD_FROM_MODEL_TAG --learning_rate=$MID_LR \
|
||||
# --device_batch_size=8 --max_seq_len=1024 --total_batch_size=524288
|
||||
# MASTER_PORT=$MASTER_PORT torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.mid_train -- --run=$WANDB_RUN \
|
||||
# --model_tag=d12_e16_lr2e-4_double_dmodel \
|
||||
# --learning_rate=$MID_LR --num_epochs=1 \
|
||||
# --device_batch_size=8 --max_seq_len=1024 --total_batch_size=524288
|
||||
# --disable_aux_loss=True --disable_router_z_loss=True \
|
||||
# MASTER_PORT=$MASTER_PORT torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts_moe.chat_eval -- -i mid
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Supervised Finetuning (domain adaptation to each sequence all by itself per row)
|
||||
SFT_LR=${SFT_LR:-9e-6}
|
||||
LOAD_FROM_MID_MODEL_TAG=${LOAD_FROM_MID_MODEL_TAG:-d6_lr0.0003_modeld6_min_lr0.0002_max_lr0.002}
|
||||
WANDB_RUN=moe_sft_${SFT_LR}_${LOAD_FROM_MID_MODEL_TAG}
|
||||
# train sft and re-eval right away (should see a small bump)
|
||||
|
||||
torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.chat_sft -- --run=$WANDB_RUN --learning_rate=$SFT_LR --init_lr_frac=1.0 \
|
||||
--model_tag=$LOAD_FROM_MID_MODEL_TAG
|
||||
|
||||
# torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.chat_sft -- --run=$WANDB_RUN --learning_rate=3e-4 --init_lr_frac=1.0 --num_epochs=4
|
||||
# torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.chat_eval -- -i sft
|
||||
|
||||
# chat with the model over CLI! Leave out the -p to chat interactively
|
||||
# python -m scripts.chat_cli -p "Why is the sky blue?"
|
||||
|
||||
# even better, chat with your model over a pretty WebUI ChatGPT style
|
||||
# python -m scripts.chat_web
|
||||
# -----------------------------------------------------------------------------
|
||||
# Reinforcement Learning. Optional, and currently only on GSM8K
|
||||
# (optional)
|
||||
RL_LR=${RL_LR:-9e-6}
|
||||
LOAD_FROM_SFT_MODEL_TAG=${LOAD_FROM_SFT_MODEL_TAG:-d6_lr5e-05_init1.0_modeld6_lr0.0003_modeld6_min_lr0.0002_max_lr0.002}
|
||||
WANDB_RUN=moe_rl_${RL_LR}_${LOAD_FROM_SFT_MODEL_TAG}
|
||||
|
||||
# run reinforcement learning
|
||||
MASTER_PORT=$MASTER_PORT torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.chat_rl -- --run=$WANDB_RUN \
|
||||
--model_tag=$LOAD_FROM_SFT_MODEL_TAG
|
||||
|
||||
# eval the RL model only on GSM8K
|
||||
# MASTER_PORT=$MASTER_PORT torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.chat_eval -- -i rl -a GSM8K
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Generate the full report by putting together all the sections
|
||||
# report.md is the output and will be copied to current directory for convenience
|
||||
python -m nanochat.report generate
|
||||
Loading…
Reference in New Issue
Block a user