mirror of
https://github.com/karpathy/nanochat.git
synced 2025-12-06 04:12:13 +00:00
change to allow 24GB VRAM gpu(3090/4090) to run training/eval
This commit is contained in:
parent
d6d86cbf4c
commit
de6597533f
|
|
@ -16,7 +16,7 @@ from nanochat.loss_eval import evaluate_bpb
|
|||
from nanochat.engine import Engine
|
||||
|
||||
# Configuration
|
||||
device_batch_size = 32
|
||||
device_batch_size = 4
|
||||
split_tokens = 20*524288 # number of tokens to evaluate per split
|
||||
model_tag = None # optional model tag for the output directory name
|
||||
model_step = None # optional model step for the output directory name
|
||||
|
|
|
|||
|
|
@ -35,7 +35,8 @@ num_iterations = -1 # explicit number of steps of the optimization (-1 = disable
|
|||
target_flops = -1.0 # calculate num_iterations to reach target_flops. Useful for scaling laws experiments (-1 = disable)
|
||||
target_param_data_ratio = 20 # calculate num_iterations to maintain fixed data:param ratio (Chinchilla=20) (-1 = disable)
|
||||
# Optimization
|
||||
device_batch_size = 32 # per-device batch size (set to not OOM)
|
||||
device_batch_size = 4 # per-device batch size (set to not OOM)
|
||||
device_batch_size_val = 1 # per-device batch size (set to not OOM)
|
||||
total_batch_size = 524288 # total desired batch size, in #tokens
|
||||
embedding_lr = 0.2 # learning rate for the embedding parameters (Adam)
|
||||
unembedding_lr = 0.004 # learning rate for the unembedding parameters (Adam)
|
||||
|
|
@ -134,7 +135,7 @@ adamw_optimizer, muon_optimizer = optimizers
|
|||
base_dir = get_base_dir()
|
||||
tokens_dir = os.path.join(base_dir, "tokenized_data")
|
||||
train_loader = tokenizing_distributed_data_loader(device_batch_size, max_seq_len, split="train")
|
||||
build_val_loader = lambda: tokenizing_distributed_data_loader(device_batch_size, max_seq_len, split="val")
|
||||
build_val_loader = lambda: tokenizing_distributed_data_loader(device_batch_size_val, max_seq_len, split="val")
|
||||
x, y = next(train_loader) # kick off load of the very first batch of data
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
|
|
@ -208,7 +209,8 @@ for step in range(num_iterations + 1):
|
|||
|
||||
# once in a while: sample from the model (only on master process)
|
||||
# use the original uncompiled model because the inputs keep changing shape
|
||||
if master_process and (last_step or (step > 0 and step % sample_every == 0)):
|
||||
# if master_process and (last_step or (step > 0 and step % sample_every == 0)):
|
||||
if master_process and last_step:
|
||||
model.eval()
|
||||
prompts = [
|
||||
"The capital of France is",
|
||||
|
|
@ -284,7 +286,7 @@ for step in range(num_iterations + 1):
|
|||
smooth_train_loss = ema_beta * smooth_train_loss + (1 - ema_beta) * train_loss.item() # EMA the training loss
|
||||
debiased_smooth_loss = smooth_train_loss / (1 - ema_beta**(step + 1)) # debias the EMA
|
||||
pct_done = 100 * step / num_iterations
|
||||
tok_per_sec = int(world_tokens_per_fwdbwd / dt)
|
||||
tok_per_sec = int(world_tokens_per_fwdbwd * grad_accum_steps / dt)
|
||||
flops_per_sec = num_flops_per_token * total_batch_size / dt
|
||||
promised_flops_per_sec_h100 = 989e12 * ddp_world_size # bfloat16 H100 SXM and without 2:4 sparsity
|
||||
mfu = 100 * flops_per_sec / promised_flops_per_sec_h100 # in %
|
||||
|
|
|
|||
|
|
@ -36,7 +36,7 @@ model_tag = None # model tag to load the model from (base model or midtrained mo
|
|||
step = None # step to load the model from (base model or midtrained model)
|
||||
# compute/precision
|
||||
dtype = "bfloat16"
|
||||
device_batch_size = 4 # max to avoid OOM
|
||||
device_batch_size = 2 # max to avoid OOM
|
||||
# optimization
|
||||
num_epochs = 1
|
||||
max_iterations = -1 # override number of iterations (-1 = use num_epochs * num_iterations)
|
||||
|
|
|
|||
|
|
@ -34,7 +34,7 @@ model_tag = None # model tag to load the model from (base model or midtrained mo
|
|||
step = None # step to load the model from (base model or midtrained model)
|
||||
dtype = "bfloat16"
|
||||
max_seq_len = 2048
|
||||
device_batch_size = 32
|
||||
device_batch_size = 4
|
||||
unembedding_lr = 0.004
|
||||
embedding_lr = 0.2
|
||||
matrix_lr = 0.02
|
||||
|
|
@ -248,7 +248,7 @@ while True:
|
|||
smooth_train_loss = ema_beta * smooth_train_loss + (1 - ema_beta) * train_loss.item() # EMA the training loss
|
||||
debiased_smooth_loss = smooth_train_loss / (1 - ema_beta**(step + 1)) # debias the EMA
|
||||
pct_done = 100 * progress
|
||||
tok_per_sec = int(world_tokens_per_fwdbwd / dt)
|
||||
tok_per_sec = int(world_tokens_per_fwdbwd * grad_accum_steps / dt)
|
||||
flops_per_sec = num_flops_per_token * total_batch_size / dt
|
||||
promised_flops_per_sec_h100 = 989e12 * ddp_world_size # bfloat16 H100 SXM and without 2:4 sparsity
|
||||
mfu = 100 * flops_per_sec / promised_flops_per_sec_h100 # in %
|
||||
|
|
|
|||
133
speedrun_1x3090.sh
Normal file
133
speedrun_1x3090.sh
Normal file
|
|
@ -0,0 +1,133 @@
|
|||
#!/bin/bash
|
||||
|
||||
# This script is the "Best ChatGPT clone that $100 can buy",
|
||||
# It is designed to run in ~4 hours on 8XH100 node at $3/GPU/hour.
|
||||
|
||||
# 1) Example launch (simplest):
|
||||
# bash speedrun.sh
|
||||
# 2) Example launch in a screen session (because the run takes ~4 hours):
|
||||
# screen -L -Logfile speedrun.log -S speedrun bash speedrun.sh
|
||||
# 3) Example launch with wandb logging, but see below for setting up wandb first:
|
||||
# WANDB_RUN=speedrun screen -L -Logfile speedrun.log -S speedrun bash speedrun.sh
|
||||
|
||||
# Default intermediate artifacts directory is in ~/.cache/nanochat
|
||||
export OMP_NUM_THREADS=1
|
||||
export NANOCHAT_BASE_DIR="$HOME/.cache/nanochat"
|
||||
mkdir -p $NANOCHAT_BASE_DIR
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Python venv setup with uv
|
||||
|
||||
# install uv (if not already installed)
|
||||
command -v uv &> /dev/null || curl -LsSf https://astral.sh/uv/install.sh | sh
|
||||
# create a .venv local virtual environment (if it doesn't exist)
|
||||
[ -d ".venv" ] || uv venv
|
||||
# install the repo dependencies
|
||||
uv sync
|
||||
# activate venv so that `python` uses the project's venv instead of system python
|
||||
source .venv/bin/activate
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# wandb setup
|
||||
# If you wish to use wandb for logging (it's nice!, recommended).
|
||||
# 1) Make sure to first log in to wandb, e.g. run:
|
||||
# `wandb login`
|
||||
# 2) Set the WANDB_RUN environment variable when running this script, e.g.:
|
||||
# `WANDB_RUN=d26 bash speedrun.sh`
|
||||
if [ -z "$WANDB_RUN" ]; then
|
||||
# by default use "dummy" : it's handled as a special case, skips logging to wandb
|
||||
WANDB_RUN=dummy
|
||||
fi
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# During the course of the run, we will be writing markdown reports to the report/
|
||||
# directory in the base dir. This command clears it out and writes a header section
|
||||
# with a bunch of system info and a timestamp that marks the start of the run.
|
||||
python -m nanochat.report reset
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Tokenizer
|
||||
|
||||
# Install Rust / Cargo
|
||||
curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y
|
||||
source "$HOME/.cargo/env"
|
||||
|
||||
# Build the rustbpe Tokenizer
|
||||
uv run maturin develop --release --manifest-path rustbpe/Cargo.toml
|
||||
|
||||
# Download the first ~2B characters of pretraining dataset
|
||||
# look at dev/repackage_data_reference.py for details on how this data was prepared
|
||||
# each data shard is ~250M chars
|
||||
# so we download 2e9 / 250e6 = 8 data shards at this point
|
||||
# each shard is ~100MB of text (compressed), so this is about ~800MB of data on disk
|
||||
python -m nanochat.dataset -n 8
|
||||
# Immediately also kick off downloading more shards in the background while tokenizer trains
|
||||
# See comment below for why 240 is the right number here
|
||||
python -m nanochat.dataset -n 240 &
|
||||
DATASET_DOWNLOAD_PID=$!
|
||||
# train the tokenizer with vocab size 2**16 = 65536 on ~2B characters of data
|
||||
python -m scripts.tok_train --max_chars=2000000000
|
||||
# evaluate the tokenizer (report compression ratio etc.)
|
||||
python -m scripts.tok_eval
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Base model (pretraining)
|
||||
|
||||
# Download the eval_bundle from s3 to evaluate CORE metric during training (~162MB)
|
||||
EVAL_BUNDLE_URL=https://karpathy-public.s3.us-west-2.amazonaws.com/eval_bundle.zip
|
||||
if [ ! -d "$NANOCHAT_BASE_DIR/eval_bundle" ]; then
|
||||
curl -L -o eval_bundle.zip $EVAL_BUNDLE_URL
|
||||
unzip -q eval_bundle.zip
|
||||
rm eval_bundle.zip
|
||||
mv eval_bundle $NANOCHAT_BASE_DIR
|
||||
fi
|
||||
|
||||
# The d20 model is 561M parameters.
|
||||
# Chinchilla says #tokens = 20X #params, so we need 561e6 * 20 = 11.2B tokens.
|
||||
# Assume our tokenizer is 4.8 chars/token, this is 11.2B * 4.8 ~= 54B chars.
|
||||
# At 250M chars/shard, this is 54B / 250M ~= 216 shards needed for pretraining.
|
||||
# Round up to 240 for safety. At ~100MB/shard, this downloads ~24GB of data to disk.
|
||||
# (The total number of shards available in the entire dataset is 1822.)
|
||||
echo "Waiting for dataset download to complete..."
|
||||
wait $DATASET_DOWNLOAD_PID
|
||||
|
||||
# pretrain the d20 model
|
||||
python -m scripts.base_train --depth=20 --run=$WANDB_RUN
|
||||
# evaluate the model on a larger chunk of train/val data and draw some samples
|
||||
python -m scripts.base_loss
|
||||
# evaluate the model on CORE tasks
|
||||
python -m scripts.base_eval
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Midtraining (teach the model conversation special tokens, tool use, multiple choice)
|
||||
|
||||
# run midtraining and eval the model
|
||||
python -m scripts.mid_train --run=$WANDB_RUN
|
||||
python -m scripts.chat_eval -i mid
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Supervised Finetuning (domain adaptation to each sequence all by itself per row)
|
||||
|
||||
# train sft and re-eval right away (should see a small bump)
|
||||
python -m scripts.chat_sft --run=$WANDB_RUN
|
||||
python -m scripts.chat_eval -i sft
|
||||
|
||||
# chat with the model over CLI! Leave out the -p to chat interactively
|
||||
# python -m scripts.chat_cli -p "Why is the sky blue?"
|
||||
|
||||
# even better, chat with your model over a pretty WebUI ChatGPT style
|
||||
# python -m scripts.chat_web
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Reinforcement Learning. Optional, and currently only on GSM8K
|
||||
# (optional)
|
||||
|
||||
# run reinforcement learning
|
||||
python -m scripts.chat_rl --run=$WANDB_RUN
|
||||
# eval the RL model only on GSM8K
|
||||
python -m scripts.chat_eval -i rl -a GSM8K
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Generate the full report by putting together all the sections
|
||||
# report.md is the output and will be copied to current directory for convenience
|
||||
python -m nanochat.report generate
|
||||
Loading…
Reference in New Issue
Block a user