mirror of
https://github.com/karpathy/nanochat.git
synced 2026-01-23 03:44:19 +00:00
483 lines
23 KiB
Python
483 lines
23 KiB
Python
"""
|
|
Midtrain the model. Same as pretraining but simpler.
|
|
Run as:
|
|
|
|
python -m scripts.mid_train
|
|
|
|
Or torchrun for training:
|
|
|
|
torchrun --standalone --nproc_per_node=8 -m scripts.mid_train -- --device_batch_size=16
|
|
"""
|
|
|
|
from collections import deque
|
|
import os
|
|
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
|
|
import time
|
|
import wandb
|
|
import torch
|
|
import torch.nn.functional as F
|
|
from contextlib import nullcontext
|
|
from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, get_base_dir, autodetect_device_type
|
|
from nanochat.tokenizer import get_token_bytes
|
|
from nanochat.checkpoint_manager import save_checkpoint
|
|
from nanochat.loss_eval import evaluate_bpb
|
|
from nanochat.checkpoint_manager import load_model
|
|
from nanochat.manager import MANAGER
|
|
import torch.distributed as dist
|
|
|
|
from tasks.common import TaskMixture
|
|
from tasks.gsm8k import GSM8K
|
|
from tasks.mmlu import MMLU
|
|
from tasks.smoltalk import SmolTalk
|
|
from tasks.customjson import CustomJSON
|
|
from tasks.spellingbee import SimpleSpelling, SpellingBee
|
|
|
|
# -----------------------------------------------------------------------------
|
|
run = "dummy" # wandb run name default ("dummy" is special - we won't log to wandb)
|
|
device_type = "" # cuda|cpu|mps (empty => autodetect)
|
|
model_tag = None # model tag to load the model from (base model or midtrained model)
|
|
step = None # step to load the model from (base model or midtrained model)
|
|
dtype = "bfloat16"
|
|
num_iterations = -1 # explicit number of steps of the optimization (-1 = disable)
|
|
num_epochs = 1 # number of full passes over the midtraining dataset (only used if num_iterations < 0)
|
|
max_seq_len = 2048
|
|
device_batch_size = 32
|
|
unembedding_lr = 0.004
|
|
embedding_lr = 0.2
|
|
matrix_lr = 0.02
|
|
init_lr_frac = 1.0 # initial learning rate is this fraction of the base learning rate
|
|
learning_rate = 3e-4
|
|
betas = (0.9, 0.95)
|
|
weight_decay = 0.0
|
|
warmup_ratio = 0.0 # LR warmup (ratio of total training progress in [0, 1]). 0 disables warmup.
|
|
|
|
# Debug knobs for MoE loss components (defaults preserve existing behavior)
|
|
disable_aux_loss = False
|
|
disable_router_z_loss = False
|
|
override_aux_loss_weight = -1.0 # <0 means do not override
|
|
override_router_z_loss_weight = -1.0 # <0 means do not override
|
|
|
|
eval_every = 150 # -1 = disable
|
|
eval_tokens = 20*524288
|
|
total_batch_size = 524288
|
|
dry_run = 0 # dry_run=1 is for experiments: we will log to wandb but we won't write checkpoints or report
|
|
config_keys = [k for k,v in globals().items() if not k.startswith('_') and isinstance(v, (int, float, bool, str))]
|
|
exec(open(os.path.join('nanochat', 'configurator.py')).read()) # overrides from command line or config file
|
|
user_config = {k: globals()[k] for k in config_keys} # possibly useful for logging
|
|
# -----------------------------------------------------------------------------
|
|
|
|
# Compute init
|
|
device_type = autodetect_device_type() if device_type == "" else device_type
|
|
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
|
|
master_process = ddp_rank == 0
|
|
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type == "cuda" else nullcontext()
|
|
synchronize = torch.cuda.synchronize if device_type == "cuda" else lambda: None
|
|
get_max_memory = torch.cuda.max_memory_allocated if device_type == "cuda" else lambda: 0
|
|
|
|
# wandb logging init
|
|
use_dummy_wandb = run == "dummy" or not master_process
|
|
wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat-mid", name=run, config=user_config)
|
|
|
|
# Load the model and tokenizer
|
|
model, tokenizer, meta = load_model("base", device, phase="train", model_tag=model_tag, step=step)
|
|
|
|
# Optional overrides for MoE auxiliary losses (useful when total loss plateaus)
|
|
if hasattr(model, "config"):
|
|
if disable_aux_loss and getattr(model.config, "n_exp", 1) > 1:
|
|
print0("Disabling MoE aux loss for this midtraining run")
|
|
model.config.use_aux_loss = False
|
|
if disable_router_z_loss and getattr(model.config, "n_exp", 1) > 1:
|
|
print0("Disabling MoE router z loss for this midtraining run")
|
|
model.config.use_router_z_loss = False
|
|
if override_aux_loss_weight >= 0 and getattr(model.config, "n_exp", 1) > 1:
|
|
print0(f"Overriding MoE aux_loss_weight to {override_aux_loss_weight}")
|
|
model.config.aux_loss_weight = float(override_aux_loss_weight)
|
|
if override_router_z_loss_weight >= 0 and getattr(model.config, "n_exp", 1) > 1:
|
|
print0(f"Overriding MoE router_z_loss_weight to {override_router_z_loss_weight}")
|
|
model.config.router_z_loss_weight = float(override_router_z_loss_weight)
|
|
|
|
print0(f"MoE training loss is configured to use aux_loss: {getattr(model.config, 'use_aux_loss', False)} with weight {getattr(model.config, 'aux_loss_weight', 0.0)}, router_z_loss: {getattr(model.config, 'use_router_z_loss', False)} with weight {getattr(model.config, 'router_z_loss_weight', 0.0)}")
|
|
pretrain_batch_size = meta.get("device_batch_size", None)
|
|
if pretrain_batch_size is not None and device_batch_size > pretrain_batch_size:
|
|
print0(f"FOOTGUN WARNING: base model training used device_batch_size {pretrain_batch_size}, did you pass in a good --device_batch_size to this script?")
|
|
orig_model = model
|
|
model = torch.compile(model, dynamic=False)
|
|
depth = model.config.n_layer
|
|
# num_flops_per_token = model.estimate_flops(max_seq_len)
|
|
tokens_per_fwdbwd = device_batch_size * max_seq_len # tokens per iteration for a single rank
|
|
world_tokens_per_fwdbwd = tokens_per_fwdbwd * ddp_world_size # total tokens per iteration for all ranks
|
|
assert total_batch_size % world_tokens_per_fwdbwd == 0
|
|
grad_accum_steps = total_batch_size // world_tokens_per_fwdbwd
|
|
print0(f"Tokens / micro-batch / rank: {device_batch_size} x {max_seq_len} = {tokens_per_fwdbwd:,}")
|
|
print0(f"Tokens / micro-batch: {world_tokens_per_fwdbwd:,}")
|
|
print0(f"Total batch size {total_batch_size:,} => gradient accumulation steps: {grad_accum_steps}")
|
|
token_bytes = get_token_bytes(device=device)
|
|
|
|
# Sanity print: tokenizer ids must fit inside model vocab (esp. when vocab_size=50304 padded GPT-2)
|
|
print0(f"Model vocab_size: {model.config.vocab_size}")
|
|
print0(f"Tokenizer vocab_size: {tokenizer.get_vocab_size()}")
|
|
|
|
# Initialize the Optimizer (AdamW for all parameters) - BEFORE DDP wrapping (matching nanoMoE)
|
|
adamw_optimizer = model.configure_optimizers(
|
|
weight_decay=weight_decay,
|
|
learning_rate=learning_rate,
|
|
betas=betas,
|
|
device_type=device_type,
|
|
)
|
|
optimizers = [adamw_optimizer]
|
|
# # Initialize the Optimizer (Muon for Linear layers, AdamW for embedding and lm_head)
|
|
# optimizers = model.setup_optimizers(unembedding_lr=unembedding_lr, embedding_lr=embedding_lr, matrix_lr=matrix_lr, weight_decay=weight_decay)
|
|
# adamw_optimizer, muon_optimizer = optimizers
|
|
# # Override the initial learning rate as a fraction of the base learning rate
|
|
# for opt in optimizers:
|
|
# for group in opt.param_groups:
|
|
# group["lr"] = group["lr"] * init_lr_frac
|
|
# group["initial_lr"] = group["lr"] # save the initial learning so we can decay easily later
|
|
|
|
# Midtraining data mixture and DataLoader
|
|
base_dir = get_base_dir()
|
|
identity_conversations_filepath = os.path.join(base_dir, "identity_conversations.jsonl")
|
|
train_dataset = TaskMixture([
|
|
SmolTalk(split="train"), # 460K rows of general conversations
|
|
MMLU(subset="auxiliary_train", split="train"), # 100K rows of multiple choice problems drawn from ARC, MC_TEST, OBQA, RACE
|
|
GSM8K(subset="main", split="train"), # 8K rows teaching simple math and (calculator) tool use
|
|
CustomJSON(filepath=identity_conversations_filepath), # 1000 rows of synthetic identity conversations
|
|
CustomJSON(filepath=identity_conversations_filepath), # let's do 2 epochs of these
|
|
SimpleSpelling(size=200000, split="train"), # 200K rows of Simple Spelling (e.g. spell the word 'apple')
|
|
SpellingBee(size=80000, split="train"), # 80K rows of Spelling Bee (e.g. how many 'r' are in 'strawberry'?)
|
|
]) # total: 460K + 100K + 8K + 200K + 80K = 848K rows
|
|
val_dataset = TaskMixture([
|
|
SmolTalk(split="test"), # 24K rows in test set
|
|
MMLU(subset="all", split="test", stop=5200), # 14K rows in test set, use only 5.2K to match the train ratios
|
|
GSM8K(subset="main", split="test", stop=420), # 1.32K rows in test set, use only 420 to match the train ratios
|
|
]) # total: 24K + 14K + 1.32K ~= 39K rows
|
|
# DataLoader is defined here, it emits inputs, targets : 2D tensors of shape (device_batch_size, max_seq_len)
|
|
# A big problem is that we don't know the final num_iterations in advance. So we create
|
|
# these two global variables and update them from within the data generator.
|
|
last_step = False # we will toggle this to True when we reach the end of the dataset
|
|
approx_progress = 0.0 # will go from 0 to 1 over the course of the epoch
|
|
current_epoch = 1 # will go from 1 to num_epochs
|
|
def mid_data_generator(split):
|
|
global last_step, approx_progress, current_epoch
|
|
assert split in {"train", "val"}, "split must be 'train' or 'val'"
|
|
dataset = train_dataset if split == "train" else val_dataset
|
|
dataset_size = len(dataset)
|
|
assert dataset_size > 0
|
|
needed_tokens = device_batch_size * max_seq_len + 1 # to form one training batch of inputs,targets
|
|
token_buffer = deque()
|
|
# A lightweight resumable state dict (similar spirit to base_train.py)
|
|
dataloader_state_dict = {"split": split}
|
|
# CUDA supports memory pinning for faster transfers between CPU and GPU:
|
|
scratch = torch.empty(needed_tokens, dtype=torch.int64, pin_memory=(device_type == "cuda"))
|
|
cursor = ddp_rank # increments by ddp_world_size each time, so each rank processes unique documents
|
|
it = 0 # iteration counter
|
|
while True:
|
|
# Accumulate enough tokens for one iteration before yielding
|
|
while len(token_buffer) < needed_tokens:
|
|
conversation = dataset[cursor]
|
|
ids, _ = tokenizer.render_conversation(conversation)
|
|
token_buffer.extend(ids)
|
|
cursor += ddp_world_size
|
|
if cursor >= dataset_size:
|
|
cursor -= dataset_size # wrap around for another epoch
|
|
if split == "train":
|
|
# Track epochs (unless num_iterations explicitly caps steps)
|
|
if num_iterations < 0:
|
|
current_epoch += 1
|
|
if current_epoch > num_epochs:
|
|
last_step = True # terminate after requested epochs
|
|
else:
|
|
last_step = True # legacy behavior when num_iterations is set elsewhere
|
|
# Stopping condition to respect num_iterations, if given
|
|
it += 1
|
|
if num_iterations > 0 and it >= num_iterations:
|
|
last_step = True # toggle last_step to True, which will terminate the training loop
|
|
# Build up inputs/targets and yield
|
|
for i in range(needed_tokens):
|
|
scratch[i] = token_buffer.popleft()
|
|
inputs_cpu = scratch[:-1].to(dtype=torch.int32)
|
|
targets_cpu = scratch[1:]
|
|
|
|
# Early token-id range check on CPU to avoid opaque torch.compile CUDA OOB asserts.
|
|
# Only do this for a few batches to keep overhead minimal.
|
|
if it <= 5:
|
|
min_id = int(inputs_cpu.min().item())
|
|
max_id = int(inputs_cpu.max().item())
|
|
vocab_limit = int(model.config.vocab_size)
|
|
if not (0 <= min_id and max_id < vocab_limit):
|
|
raise ValueError(
|
|
f"Token id out of range: min={min_id}, max={max_id}, expected within [0, {vocab_limit}). "
|
|
f"Tokenizer vocab_size={int(tokenizer.get_vocab_size())}. "
|
|
"This usually means the tokenizer used for midtraining doesn't match the model vocab."
|
|
)
|
|
|
|
inputs = inputs_cpu.view(device_batch_size, max_seq_len).to(device=device, dtype=torch.int32, non_blocking=True)
|
|
targets = targets_cpu.view(device_batch_size, max_seq_len).to(device=device, dtype=torch.int64, non_blocking=True)
|
|
if split == "train":
|
|
if num_iterations > 0:
|
|
approx_progress = it / num_iterations # calculate progress from the max number of iterations
|
|
else:
|
|
# progress across epochs, in [0, 1]
|
|
denom = max(float(num_epochs), 1.0)
|
|
approx_progress = min(((current_epoch - 1) + (cursor / dataset_size)) / denom, 1.0)
|
|
dataloader_state_dict.update({
|
|
"cursor": int(cursor),
|
|
"it": int(it),
|
|
"current_epoch": int(current_epoch),
|
|
"last_step": bool(last_step),
|
|
"approx_progress": float(approx_progress),
|
|
# Keep the remaining buffered tokens for exact resume semantics.
|
|
"token_buffer": list(token_buffer),
|
|
})
|
|
yield inputs, targets, dataloader_state_dict
|
|
|
|
train_loader = mid_data_generator("train")
|
|
build_val_loader = lambda: mid_data_generator("val")
|
|
progress = 0 # will go from 0 to 1 over the course of the epoch
|
|
|
|
# Learning rate scheduler
|
|
def get_lr_multiplier(progress):
|
|
# Warmup: linearly ramp from 0 -> 1 over the first `warmup_ratio` portion of training.
|
|
if warmup_ratio and warmup_ratio > 0:
|
|
warmup_mult = min(max(progress / warmup_ratio, 0.0), 1.0)
|
|
else:
|
|
warmup_mult = 1.0
|
|
|
|
# Decay: first 80% of training no decay, then linearly ramp down to 0.
|
|
decay_mult = 1 if progress < 0.8 else 1 - (progress - 0.8) / 0.2
|
|
return warmup_mult * decay_mult
|
|
|
|
# Momentum scheduler for Muon optimizer
|
|
def get_muon_momentum(it):
|
|
frac = min(it / 300, 1)
|
|
momentum = (1 - frac) * 0.85 + frac * 0.95
|
|
return momentum
|
|
|
|
# -----------------------------------------------------------------------------
|
|
# Training loop
|
|
x, y, dataloader_state_dict = next(train_loader) # prefetch the very first batch of data
|
|
min_val_bpb = float("inf")
|
|
smooth_train_loss = 0 # EMA of training loss
|
|
smooth_train_ce_loss = 0 # EMA of CE loss
|
|
ema_beta = 0.9 # EMA decay factor
|
|
total_training_time = 0 # total wall-clock time of training
|
|
val_bpb = None # populated during evaluation (keep defined for checkpoint metadata)
|
|
step = 0
|
|
while True:
|
|
# flops_so_far = num_flops_per_token * total_batch_size * step
|
|
|
|
# Synchronize last_step across all ranks to avoid hangs in the distributed setting
|
|
if ddp:
|
|
last_step_tensor = torch.tensor(last_step, dtype=torch.int32, device=device)
|
|
dist.all_reduce(last_step_tensor, op=dist.ReduceOp.MAX)
|
|
last_step = bool(last_step_tensor.item())
|
|
|
|
# once in a while: evaluate the val bpb (all ranks participate)
|
|
if eval_every > 0 and (last_step or step % eval_every == 0):
|
|
model.eval()
|
|
val_loader = build_val_loader()
|
|
eval_steps = eval_tokens // (device_batch_size * max_seq_len * ddp_world_size)
|
|
with autocast_ctx:
|
|
val_bpb = evaluate_bpb(model, val_loader, eval_steps, token_bytes)
|
|
print0(f"Step {step:05d} | Validation bpb: {val_bpb:.4f}")
|
|
if val_bpb < min_val_bpb:
|
|
min_val_bpb = val_bpb
|
|
wandb_run.log({
|
|
"step": step,
|
|
# "total_training_flops": flops_so_far,
|
|
"total_training_time": total_training_time,
|
|
"val/bpb": val_bpb,
|
|
})
|
|
model.train()
|
|
|
|
# save checkpoint at the end of the run (only on master process)
|
|
if master_process and last_step and not dry_run:
|
|
# output_dirname = f"d{depth}" # e.g. d12
|
|
if disable_aux_loss:
|
|
aux_tag = "noaux"
|
|
else:
|
|
aux_tag = "aux"
|
|
if disable_router_z_loss:
|
|
z_tag = "noz"
|
|
else:
|
|
z_tag = "z"
|
|
# output_dirname = f"d{depth}_{aux_tag}_{z_tag}_lr{learning_rate}_model{model_tag}"
|
|
output_dirname = f"d{depth}_lr{learning_rate}_model{model_tag}"
|
|
checkpoint_dir = os.path.join(base_dir, "mid_checkpoints", output_dirname)
|
|
|
|
# Save metadata in the same shape as base_train.py for consistency.
|
|
model_config_for_save = {}
|
|
for k in [
|
|
# Core GPT config
|
|
"block_size",
|
|
"vocab_size",
|
|
"n_layer",
|
|
"n_head",
|
|
"n_kv_head",
|
|
"n_embd",
|
|
"dropout",
|
|
"bias",
|
|
# MoE config (if present)
|
|
"n_exp",
|
|
"top_k",
|
|
"use_aux_loss",
|
|
"use_router_z_loss",
|
|
"use_noisy_top_k",
|
|
"aux_loss_weight",
|
|
"router_z_loss_weight",
|
|
"train_capacity",
|
|
"eval_capacity",
|
|
"min_capacity",
|
|
"stride",
|
|
"use_switch_tfm_init",
|
|
"switch_tfm_init_scale",
|
|
"router_use_full_prec",
|
|
]:
|
|
if hasattr(orig_model.config, k):
|
|
v = getattr(orig_model.config, k)
|
|
if isinstance(v, (int, float, bool, str)):
|
|
model_config_for_save[k] = v
|
|
|
|
save_checkpoint(
|
|
checkpoint_dir,
|
|
step,
|
|
orig_model.state_dict(),
|
|
adamw_optimizer.state_dict(), # TODO: make sure saving across ranks is done correctly
|
|
{
|
|
"step": step,
|
|
"model_config": model_config_for_save,
|
|
"user_config": user_config, # inputs to the training script
|
|
"device_batch_size": device_batch_size,
|
|
"max_seq_len": max_seq_len,
|
|
"loop_state": {
|
|
"min_val_bpb": min_val_bpb,
|
|
"smooth_train_loss": smooth_train_loss,
|
|
"smooth_train_ce_loss": smooth_train_ce_loss,
|
|
"total_training_time": total_training_time,
|
|
"progress": progress,
|
|
"current_epoch": int(current_epoch),
|
|
},
|
|
"dataloader_state_dict": dataloader_state_dict,
|
|
}
|
|
)
|
|
|
|
if last_step:
|
|
break
|
|
|
|
# -------------------------------------------------------------------------
|
|
# single training step
|
|
# evaluate the gradient
|
|
synchronize()
|
|
t0 = time.time()
|
|
total_loss_accum = 0.0
|
|
ce_loss_accum = 0.0
|
|
aux_loss_contrib_accum = 0.0
|
|
router_z_loss_contrib_accum = 0.0
|
|
for micro_step in range(grad_accum_steps):
|
|
with autocast_ctx:
|
|
logits, total_loss = model(x, y) # returns (logits, loss)
|
|
# Cross-entropy only (language modeling objective)
|
|
ce_loss = F.cross_entropy(
|
|
logits.view(-1, logits.size(-1)),
|
|
y.view(-1),
|
|
ignore_index=-1,
|
|
)
|
|
# Cache logging values (average across micro-steps)
|
|
total_loss_accum += float(total_loss.detach().item())
|
|
ce_loss_accum += float(ce_loss.detach().item())
|
|
aux_sum = getattr(MANAGER, "last_aux_loss_sum", 0.0)
|
|
z_sum = getattr(MANAGER, "last_router_z_loss_sum", 0.0)
|
|
# Convert sums into the *weighted* contribution that is actually added to total_loss
|
|
if getattr(model.config, "n_exp", 1) > 1 and getattr(model.config, "use_aux_loss", False):
|
|
if torch.is_tensor(aux_sum):
|
|
aux_loss_contrib_accum += float(getattr(model.config, "aux_loss_weight", 0.0)) * aux_sum.detach().item()
|
|
else:
|
|
aux_loss_contrib_accum += float(getattr(model.config, "aux_loss_weight", 0.0)) * float(aux_sum)
|
|
if getattr(model.config, "n_exp", 1) > 1 and getattr(model.config, "use_router_z_loss", False):
|
|
if torch.is_tensor(z_sum):
|
|
router_z_loss_contrib_accum += float(getattr(model.config, "router_z_loss_weight", 0.0)) * z_sum.detach().item()
|
|
else:
|
|
router_z_loss_contrib_accum += float(getattr(model.config, "router_z_loss_weight", 0.0)) * float(z_sum)
|
|
|
|
loss = total_loss / grad_accum_steps # each .backward() is a grad sum => normalize loss here
|
|
loss.backward()
|
|
x, y, dataloader_state_dict = next(train_loader) # prefetch the next batch while the GPU is busy with forward/backward
|
|
progress = max(progress, approx_progress) # only increase progress monotonically
|
|
|
|
# micro-step averages for logging
|
|
train_total_loss = total_loss_accum / grad_accum_steps
|
|
train_ce_loss = ce_loss_accum / grad_accum_steps
|
|
train_aux_loss_contrib = aux_loss_contrib_accum / grad_accum_steps
|
|
train_router_z_loss_contrib = router_z_loss_contrib_accum / grad_accum_steps
|
|
# step the optimizer(s)
|
|
lrm = get_lr_multiplier(progress)
|
|
current_lr = learning_rate * init_lr_frac * lrm
|
|
for group in adamw_optimizer.param_groups:
|
|
group["lr"] = current_lr
|
|
adamw_optimizer.step()
|
|
model.zero_grad(set_to_none=True)
|
|
synchronize()
|
|
t1 = time.time()
|
|
dt = t1 - t0
|
|
# -------------------------------------------------------------------------
|
|
|
|
# State
|
|
step += 1
|
|
|
|
# logging
|
|
smooth_train_loss = ema_beta * smooth_train_loss + (1 - ema_beta) * train_total_loss # EMA the total loss
|
|
smooth_train_ce_loss = ema_beta * smooth_train_ce_loss + (1 - ema_beta) * train_ce_loss # EMA the CE loss
|
|
debiased_smooth_loss = smooth_train_loss / (1 - ema_beta**(step + 1)) # debias the EMA
|
|
debiased_smooth_ce_loss = smooth_train_ce_loss / (1 - ema_beta**(step + 1)) # debias the EMA
|
|
pct_done = 100 * progress
|
|
tok_per_sec = int(total_batch_size / dt)
|
|
# flops_per_sec = num_flops_per_token * total_batch_size / dt
|
|
promised_flops_per_sec_h100 = 989e12 * ddp_world_size # bfloat16 H100 SXM and without 2:4 sparsity
|
|
# mfu = 100 * flops_per_sec / promised_flops_per_sec_h100 # in %
|
|
if step > 10:
|
|
total_training_time += dt # only count the time after the first 10 steps
|
|
print0(
|
|
f"step {step:05d} ({pct_done:.2f}%) | "
|
|
f"loss: {debiased_smooth_loss:.6f} | ce: {debiased_smooth_ce_loss:.6f} | "
|
|
f"aux: {train_aux_loss_contrib:.6f} | z: {train_router_z_loss_contrib:.6f} | "
|
|
f"lr: {current_lr:.6g} | lrm: {lrm:.2f} | dt: {dt * 1000:.2f}ms | tok/sec: {tok_per_sec:,} | total time: {total_training_time/60:.2f}m"
|
|
)
|
|
if step % 10 == 0:
|
|
wandb_run.log({
|
|
"step": step,
|
|
# "total_training_flops": flops_so_far,
|
|
"total_training_time": total_training_time,
|
|
"train/loss": debiased_smooth_loss,
|
|
"train/ce_loss": debiased_smooth_ce_loss,
|
|
"train/aux_loss_contrib": train_aux_loss_contrib,
|
|
"train/router_z_loss_contrib": train_router_z_loss_contrib,
|
|
"train/lr": current_lr,
|
|
"train/lrm": lrm,
|
|
"train/dt": dt,
|
|
"train/tok_per_sec": tok_per_sec,
|
|
# "train/mfu": mfu,
|
|
})
|
|
|
|
# print a few more stats
|
|
print0(f"Peak memory usage: {get_max_memory() / 1024 / 1024:.2f}MiB")
|
|
print0(f"Total training time: {total_training_time/60:.2f}m")
|
|
print0(f"Minimum validation bpb: {min_val_bpb:.4f}")
|
|
|
|
# Log to report
|
|
if not dry_run:
|
|
from nanochat.report import get_report
|
|
get_report().log(section="Midtraining", data=[
|
|
user_config, # CLI args
|
|
{ # stats about the training setup
|
|
"Number of iterations": step,
|
|
"DDP world size": ddp_world_size,
|
|
},
|
|
{ # stats about training outcomes
|
|
"Minimum validation bpb": min_val_bpb,
|
|
}
|
|
])
|
|
|
|
# cleanup
|
|
wandb_run.finish() # wandb run finish
|
|
compute_cleanup()
|