ready to run

This commit is contained in:
Muheng 2026-01-08 13:34:34 +00:00
parent 8f1378235e
commit 9196ff6fc0
3 changed files with 43 additions and 77 deletions

View File

@ -20,6 +20,14 @@ def log0(message):
if int(os.environ.get('RANK', 0)) == 0:
logger.info(message)
def _optimizer_path(checkpoint_dir, step):
return os.path.join(checkpoint_dir, f"optim_{step:06d}.pt")
def _legacy_optimizer_path(checkpoint_dir, step, rank):
return os.path.join(checkpoint_dir, f"optim_{step:06d}_rank{rank:d}.pt")
def save_checkpoint(checkpoint_dir, step, model_data, optimizer_data, meta_data, rank=0):
if rank == 0:
os.makedirs(checkpoint_dir, exist_ok=True)
@ -32,11 +40,11 @@ def save_checkpoint(checkpoint_dir, step, model_data, optimizer_data, meta_data,
with open(meta_path, "w", encoding="utf-8") as f:
json.dump(meta_data, f, indent=2)
logger.info(f"Saved metadata to: {meta_path}")
# Note that optimizer state is sharded across ranks, so each rank must save its own.
if optimizer_data is not None:
optimizer_path = os.path.join(checkpoint_dir, f"optim_{step:06d}_rank{rank:d}.pt")
torch.save(optimizer_data, optimizer_path)
logger.info(f"Saved optimizer state to: {optimizer_path}")
# Save optimizer state once per step (non-sharded optimizer).
if optimizer_data is not None:
optimizer_path = _optimizer_path(checkpoint_dir, step)
torch.save(optimizer_data, optimizer_path)
logger.info(f"Saved optimizer state to: {optimizer_path}")
def load_checkpoint(checkpoint_dir, step, device, load_optimizer=False, rank=0):
# Load the model state
@ -45,7 +53,11 @@ def load_checkpoint(checkpoint_dir, step, device, load_optimizer=False, rank=0):
# Load the optimizer state if requested
optimizer_data = None
if load_optimizer:
optimizer_path = os.path.join(checkpoint_dir, f"optim_{step:06d}_rank{rank:d}.pt")
optimizer_path = _optimizer_path(checkpoint_dir, step)
if not os.path.exists(optimizer_path):
optimizer_path = _legacy_optimizer_path(checkpoint_dir, step, rank)
if rank != 0 and not os.path.exists(optimizer_path):
optimizer_path = _legacy_optimizer_path(checkpoint_dir, step, 0)
optimizer_data = torch.load(optimizer_path, map_location=device)
# Load the metadata
meta_path = os.path.join(checkpoint_dir, f"meta_{step:06d}.json")

View File

@ -39,6 +39,16 @@ from scripts.base_eval import evaluate_model
print_banner()
# Allow env overrides for common LR knobs used in cluster runs.
def _get_env_float(name, default):
val = os.getenv(name)
if val is None or val == "":
return default
try:
return float(val)
except ValueError as exc:
raise ValueError(f"Invalid {name} env value: {val}") from exc
# -----------------------------------------------------------------------------
# User settings
run = "dummy" # wandb run name default ("dummy" is special - we won't log to wandb)
@ -75,14 +85,14 @@ embedding_lr = 0.0006 # learning rate for the embedding parameters (Adam)
unembedding_lr = 0.0006 # learning rate for the unembedding parameters (Adam)
weight_decay = 0.1 # weight decay (matches nanoMoE weight_decay=1e-1)
matrix_lr = 0.0006 # learning rate for the matrix parameters (Muon)
learning_rate = 6e-4 # learning rate for AdamW optimizer (matches nanoMoE: 6e-4)
learning_rate = _get_env_float("LEARNING_RATE", 6e-4) # learning rate for AdamW optimizer (matches nanoMoE: 6e-4)
betas = (0.9, 0.95) # betas for AdamW optimizer (matches nanoMoE: beta1=0.9, beta2=0.95)
grad_clip = 1.0 # gradient clipping value (0.0 = disabled)
decay_lr = True # whether to decay the learning rate (matches train_nano_moe.py)
# Learning rate decay parameters (matching train.py and train_nano_moe.py)
warmup_iters = 2000 # how many steps to warm up for (matches train.py default)
lr_decay_iters = 50000 # learning rate decay iterations (matches train_nano_moe.py)
min_lr = 6e-5 # minimum learning rate (matches train.py default, which equals 6e-4 * 0.1)
min_lr = _get_env_float("MIN_LR", 6e-5) # minimum learning rate (matches train.py default, which equals 6e-4 * 0.1)
final_lr_frac = 0.1 # final learning rate as fraction of initial learning rate (for compatibility)
resume_from_step = -1 # resume training from this step of the optimization (-1 = disable)
@ -93,11 +103,11 @@ log_interval = 10 # every how many steps to log training metrics (matches nanoMo
core_metric_every = -1 # every how many steps to evaluate the core metric (-1 = disable)
core_metric_max_per_task = -1 # examples per task in estimating the core metric
sample_every = 200000000 # every how many steps to sample from the model
save_every = 1000 # every how many steps to save model checkpoints (-1 = disable, and save only at the end of the run)
save_every = 10000 # every how many steps to save model checkpoints (-1 = disable, and save only at the end of the run)
# System
compile = True # use PyTorch 2.0 to compile the model to be faster (matches nanoMoE)
# Output
model_tag = "" # optionally override the model tag for the output checkpoint directory name
model_tag = f"d6_min_lr{min_lr}_max_lr{learning_rate}" # optionally override the model tag for the output checkpoint directory name
# now allow CLI to override the settings via the configurator lol
config_keys = [k for k,v in globals().items() if not k.startswith('_') and isinstance(v, (int, float, bool, str))]
@ -345,26 +355,6 @@ while True:
})
model.train()
# once in a while: sample from the model (only on master process)
# use the original uncompiled model because the inputs keep changing shape
if master_process and (last_step or (step > 0 and step % sample_every == 0)):
model.eval()
prompts = [
"The capital of France is",
"The chemical symbol of gold is",
"If yesterday was Friday, then tomorrow will be",
"The opposite of hot is",
"The planets of the solar system are:",
"My favorite color is",
"If 5*x + 3 = 13, then x is",
]
engine = Engine(orig_model, tokenizer) # use orig_model to avoid recompilation
for prompt in prompts:
tokens = tokenizer(prompt, prepend="<|bos|>")
with autocast_ctx:
sample, _ = engine.generate_batch(tokens, num_samples=1, max_tokens=16, temperature=0)
print0(tokenizer.decode(sample[0]))
model.train()
# save checkpoint: at the end of the run, or every save_every steps, except at the first step or the resume step
if last_step or (step > 0 and step != resume_from_step and save_every > 0 and step % save_every == 0):

View File

@ -11,16 +11,18 @@
# WANDB_RUN=speedrun screen -L -Logfile speedrun.log -S speedrun bash speedrun.sh
# Default intermediate artifacts directory is in ~/.cache/nanochat-moe
USER = "dpq23"
export USER="limh23"
export OMP_NUM_THREADS=1
export NANOCHAT_BASE_DIR="/thullms/$USER/.cache/nanochat-moe"
export NANOCHAT_DATA_DIR="/thullms/$USER"
export NANOCHAT_DATA_DIR="/thullms/$USER/.cache/nanochat-moe-data"
mkdir -p $NANOCHAT_BASE_DIR
mkdir -p $NANOCHAT_DATA_DIR
# Use tokenizer from nanochat (not nanochat-moe)
# Create a symlink to nanochat's tokenizer directory if it doesn't exist
NANOCHAT_TOKENIZER_DIR="/thullms/$USER/.cache/nanochat/tokenizer"
NANOCHAT_TOKENIZER_DIR="$HOME/.cache/nanochat/tokenizer"
MOE_TOKENIZER_DIR="$NANOCHAT_BASE_DIR/tokenizer"
if [ -d "$NANOCHAT_TOKENIZER_DIR" ] && [ ! -e "$MOE_TOKENIZER_DIR" ]; then
echo "Creating symlink to nanochat tokenizer: $MOE_TOKENIZER_DIR -> $NANOCHAT_TOKENIZER_DIR"
@ -93,6 +95,7 @@ fi
# export UV_INDEX_URL=https://pypi.tuna.tsinghua.edu.cn/simple
# uv sync --extra gpu
# # activate venv so that `python` uses the project's venv instead of system python
cd $HOME/nanochat-MoE
source .venv/bin/activate
# # -----------------------------------------------------------------------------
@ -153,53 +156,14 @@ fi
# echo "Waiting for dataset download to complete..."
# wait $DATASET_DOWNLOAD_PID
MIN_LR=${MIN_LR:-6e-5}
LEARNING_RATE=${LEARNING_RATE:-6e-4}
# Number of processes/GPUs to use
NPROC_PER_NODE=2
NPROC_PER_NODE=8
# Master port for distributed training (default: 29500)
# Set this to avoid port conflicts when running multiple torchrun tasks simultaneously
# Example: MASTER_PORT=29501 bash speedrun.sh
MASTER_PORT=${MASTER_PORT:-29501}
# # # pretrain the d20 model
MASTER_PORT=$MASTER_PORT torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts_moe.base_train -- --depth=20 --run=$WANDB_RUN
# evaluate the model on a larger chunk of train/val data and draw some samples
MASTER_PORT=$MASTER_PORT torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts_moe.base_loss
# evaluate the model on CORE tasks
MASTER_PORT=$MASTER_PORT torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts_moe.base_eval
# # -----------------------------------------------------------------------------
# # Midtraining (teach the model conversation special tokens, tool use, multiple choice)
# # download 2.3MB of synthetic identity conversations to impart a personality to nanochat
# # see dev/gen_sft_data.py for details on how this data was prepared and to get a sense of how you can easily tune it
# curl -L -o $NANOCHAT_BASE_DIR/identity_conversations.jsonl https://karpathy-public.s3.us-west-2.amazonaws.com/identity_conversations.jsonl
# # run midtraining and eval the model
# torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.mid_train -- --run=$WANDB_RUN
# torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.chat_eval -- -i mid
# # -----------------------------------------------------------------------------
# # Supervised Finetuning (domain adaptation to each sequence all by itself per row)
# # train sft and re-eval right away (should see a small bump)
# torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.chat_sft -- --run=$WANDB_RUN
# torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.chat_eval -- -i sft
# # chat with the model over CLI! Leave out the -p to chat interactively
# python -m scripts.chat_cli -p "Why is the sky blue?"
# even better, chat with your model over a pretty WebUI ChatGPT style
# python -m scripts.chat_web
# # -----------------------------------------------------------------------------
# # Reinforcement Learning. Optional, and currently only on GSM8K
# # (optional)
# # run reinforcement learning
# MASTER_PORT=$MASTER_PORT torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.chat_rl -- --run=$WANDB_RUN
# # eval the RL model only on GSM8K
# MASTER_PORT=$MASTER_PORT torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.chat_eval -- -i rl -a GSM8K
# # -----------------------------------------------------------------------------
# # Generate the full report by putting together all the sections
# report.md is the output and will be copied to current directory for convenience
python -m nanochat_moe.report generate
MASTER_PORT=$MASTER_PORT torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.base_train >> $NANOCHAT_BASE_DIR/d6_min_lr${MIN_LR}_max_lr${LEARNING_RATE}.log 2>&1