support custom training data, train tokenizer

This commit is contained in:
Shizhe Diao 2025-10-19 07:55:41 -07:00
parent 15e7a22a41
commit 2085e6637a
5 changed files with 46 additions and 28 deletions

View File

@ -28,7 +28,7 @@ def evaluate_bpb(model, batches, steps, token_bytes):
total_nats = torch.tensor(0.0, dtype=torch.float32, device=model.get_device())
total_bytes = torch.tensor(0, dtype=torch.int64, device=model.get_device())
batch_iter = iter(batches)
for _ in range(steps):
for step in range(steps):
x, y = next(batch_iter)
loss2d = model(x, y, loss_reduction='none') # (B, T)
loss2d = loss2d.view(-1) # flatten

View File

@ -34,6 +34,8 @@ print_banner()
run = "dummy" # wandb run name default ("dummy" is special - we won't log to wandb)
# Runtime
device_type = "" # cuda|cpu|mps (empty => autodetect good device type default, in order: CUDA > MPS > CPU)
# Data
data_dir = "" # path to directory containing parquet files with 'text' column (empty string = use default: ~/.cache/nanochat/base_data)
# Model architecture
depth = 20 # the depth of the Transformer model to train, rest of the kwargs are derived
max_seq_len = 2048 # max context length
@ -43,7 +45,7 @@ target_flops = -1.0 # calculate num_iterations to reach target_flops. Useful for
target_param_data_ratio = 20 # calculate num_iterations to maintain fixed data:param ratio (Chinchilla=20) (-1 = disable)
# Optimization
device_batch_size = 32 # per-device batch size (set to not OOM)
total_batch_size = 524288 # 524288 # total desired batch size, in #tokens
total_batch_size = 524288 # 2097152 #1048576 # 524288 # total desired batch size, in #tokens
embedding_lr = 0.2 # learning rate for the embedding parameters (Adam)
unembedding_lr = 0.004 # learning rate for the unembedding parameters (Adam)
weight_decay = 0.0 # weight decay for the embedding/unembedding parameters (Adam)
@ -56,7 +58,7 @@ core_metric_every = 2000 # every how many steps to evaluate the core metric (-1
core_metric_max_per_task = 500 # examples per task in estimating the core metric
sample_every = 2000 # every how many steps to sample from the model
# Output
model_tag = "" # optionally override the model tag for the output checkpoint directory name
model_tag = run # optionally override the model tag for the output checkpoint directory name
# now allow CLI to override the settings via the configurator lol
config_keys = [k for k,v in globals().items() if not k.startswith('_') and isinstance(v, (int, float, bool, str))]
exec(open(os.path.join('nanochat', 'configurator.py')).read()) # overrides from command line or config file
@ -139,12 +141,13 @@ print0(f"Total training FLOPs estimate: {num_flops_per_token * total_tokens:e}")
# Initialize the Optimizer (Muon for Linear layers, AdamW for embedding and lm_head)
optimizers = model.setup_optimizers(unembedding_lr=unembedding_lr, embedding_lr=embedding_lr, matrix_lr=matrix_lr, weight_decay=weight_decay)
adamw_optimizer, muon_optimizer = optimizers
# Initialize the DataLoaders for train/val
base_dir = get_base_dir()
tokens_dir = os.path.join(base_dir, "tokenized_data")
train_loader = tokenizing_distributed_data_loader(device_batch_size, max_seq_len, split="train", device=device)
build_val_loader = lambda: tokenizing_distributed_data_loader(device_batch_size, max_seq_len, split="val", device=device)
# Use custom data_dir if provided, otherwise use default
custom_data_dir = data_dir if data_dir else None
train_loader = tokenizing_distributed_data_loader(device_batch_size, max_seq_len, split="train", device=device, data_dir=custom_data_dir)
build_val_loader = lambda: tokenizing_distributed_data_loader(device_batch_size, max_seq_len, split="val", device=device, data_dir="/lustre/fsw/portfolios/nvr/users/sdiao/nanochat/.cache/base_data") # SHIZHE: always use the default val data dir from FineWeb by Andrej Karpathy
x, y = next(train_loader) # kick off load of the very first batch of data
# -----------------------------------------------------------------------------

View File

@ -17,10 +17,14 @@ parser = argparse.ArgumentParser(description='Train a BPE tokenizer')
parser.add_argument('--max_chars', type=int, default=10_000_000_000, help='Maximum characters to train on (default: 10B)')
parser.add_argument('--doc_cap', type=int, default=10_000, help='Maximum characters per document (default: 10,000)')
parser.add_argument('--vocab_size', type=int, default=65536, help='Vocabulary size (default: 65536 = 2^16)')
parser.add_argument('--data_dir', type=str, default=None, help='Custom dataset directory (default: None, uses default dataset)')
parser.add_argument('--tokenizer_name', type=str, default='tokenizer', help='Name for the tokenizer subdirectory (default: tokenizer)')
args = parser.parse_args()
print(f"max_chars: {args.max_chars:,}")
print(f"doc_cap: {args.doc_cap:,}")
print(f"vocab_size: {args.vocab_size:,}")
print(f"data_dir: {args.data_dir}")
print(f"tokenizer_name: {args.tokenizer_name}")
# -----------------------------------------------------------------------------
# Text iterator
@ -32,7 +36,7 @@ def text_iterator():
3) Break when we've seen args.max_chars characters
"""
nchars = 0
for batch in parquets_iter_batched(split="train"):
for batch in parquets_iter_batched(split="train", data_dir=args.data_dir):
for doc in batch:
doc_text = doc
if len(doc_text) > args.doc_cap:
@ -54,8 +58,9 @@ print(f"Training time: {train_time:.2f}s")
# -----------------------------------------------------------------------------
# Save the tokenizer to disk
base_dir = get_base_dir()
tokenizer_dir = os.path.join(base_dir, "tokenizer")
tokenizer_dir = os.path.join(base_dir, args.tokenizer_name)
tokenizer.save(tokenizer_dir)
print(f"Saved tokenizer to {tokenizer_dir}")
# -----------------------------------------------------------------------------
# Quick inline sanity check

View File

@ -9,6 +9,10 @@
# screen -L -Logfile speedrun.log -S speedrun bash speedrun.sh
# 3) Example launch with wandb logging, but see below for setting up wandb first:
# WANDB_RUN=speedrun screen -L -Logfile speedrun.log -S speedrun bash speedrun.sh
set -x
DATA_NAME=smollm
DATA_DIR=/lustre/fsw/portfolios/nvr/users/sdiao/nanochat/data/$DATA_NAME
# Default intermediate artifacts directory is in ~/.cache/nanochat
export OMP_NUM_THREADS=1
@ -39,7 +43,7 @@ source .venv/bin/activate
# WANDB_RUN=dummy
# fi
export WANDB_API_KEY="ec7a9c0701d404122e4fc5c7c7518ed17f5b03ca"
export WANDB_RUN=fineweb_d20
export WANDB_RUN=fineweb_d20_test
# -----------------------------------------------------------------------------
# During the course of the run, we will be writing markdown reports to the report/
@ -94,9 +98,9 @@ echo "Waiting for dataset download to complete..."
wait $DATASET_DOWNLOAD_PID
# pretrain the d20 model
torchrun --standalone --nproc_per_node=8 -m scripts.base_train -- --depth=20 --run=$WANDB_RUN
torchrun --standalone --nproc_per_node=8 -m scripts.base_train -- --depth=20 --run=$WANDB_RUN --data_dir=$DATA_DIR
# evaluate the model on a larger chunk of train/val data and draw some samples
torchrun --standalone --nproc_per_node=8 -m scripts.base_loss
torchrun --standalone --nproc_per_node=8 -m scripts.base_loss --data_dir=$DATA_DIR
# evaluate the model on CORE tasks
torchrun --standalone --nproc_per_node=8 -m scripts.base_eval

View File

@ -1,12 +1,12 @@
#!/bin/bash
#SBATCH --account nvr_lpr_llm
#SBATCH --partition batch_short,batch_block1,backfill
#SBATCH --job-name=nanochat_1node_fineweb_d20
#SBATCH --nodes=2
#SBATCH --partition interactive,batch_short,batch_block1,backfill
#SBATCH --job-name=nanochat_multinode_d20
#SBATCH --nodes=1
#SBATCH --ntasks-per-node=1
#SBATCH --gpus-per-node=8
#SBATCH --time=02:00:00
#SBATCH --output=nanochat_1node_fineweb_d20-%j.out
#SBATCH --time=04:00:00
#SBATCH --output=logs/nanochat_1node_d20-%j.out
#SBATCH --mem=0
#SBATCH --exclusive
@ -22,6 +22,10 @@
set -x # Enable debug output
DATA_NAME=climbmix
export DATA_DIR=/lustre/fsw/portfolios/nvr/users/sdiao/nanochat/data/$DATA_NAME
export MATRIX_LR=0.02
# Default intermediate artifacts directory is in ~/.cache/nanochat
export OMP_NUM_THREADS=1
export NANOCHAT_BASE_DIR="$HOME/nanochat_cache"
@ -72,7 +76,7 @@ uv sync --active
# WANDB_RUN=dummy
# fi
export WANDB_API_KEY="ec7a9c0701d404122e4fc5c7c7518ed17f5b03ca"
export WANDB_RUN=fineweb_d20_1node_$SLURM_JOB_ID
export WANDB_RUN=${DATA_NAME}_d20_1node_matrixlr${MATRIX_LR}_${SLURM_JOB_ID}
# -----------------------------------------------------------------------------
# During the course of the run, we will be writing markdown reports to the report/
@ -100,15 +104,17 @@ maturin develop --release --manifest-path rustbpe/Cargo.toml
# each data shard is ~250M chars
# so we download 2e9 / 250e6 = 8 data shards at this point
# each shard is ~100MB of text (compressed), so this is about ~800MB of data on disk
python -m nanochat.dataset -n 8
python -m nanochat.dataset -n 8 --data_dir=$DATA_DIR
# Immediately also kick off downloading more shards in the background while tokenizer trains
# See comment below for why 240 is the right number here
python -m nanochat.dataset -n 240 &
python -m nanochat.dataset -n 240 --data_dir=$DATA_DIR &
DATASET_DOWNLOAD_PID=$!
# train the tokenizer with vocab size 2**16 = 65536 on ~2B characters of data
python -m scripts.tok_train --max_chars=2000000000
# Use unique tokenizer name based on dataset
TOKENIZER_NAME="tokenizer_${DATA_NAME}"
python -m scripts.tok_train --max_chars=2000000000 --data_dir=$DATA_DIR --tokenizer_name=$TOKENIZER_NAME
# evaluate the tokenizer (report compression ratio etc.)
python -m scripts.tok_eval
python -m scripts.tok_eval --tokenizer_name=$TOKENIZER_NAME
# -----------------------------------------------------------------------------
# Base model (pretraining)
@ -135,25 +141,25 @@ wait $DATASET_DOWNLOAD_PID
srun --ntasks=$NNODES --ntasks-per-node=1 bash --noprofile --norc -lc 'source $HOME/nanochat_cache/.venv/bin/activate; python -c "import torch; print(torch.cuda.device_count())"'
# pretrain the d20 model (multi-node)
srun --ntasks=$NNODES --ntasks-per-node=1 bash --noprofile --norc -lc 'source $HOME/nanochat_cache/.venv/bin/activate; torchrun --nnodes=$NNODES --nproc_per_node=$GPUS_PER_NODE --node_rank=$SLURM_NODEID -m scripts.base_train -- --depth=20 --run=$WANDB_RUN'
srun --ntasks=$NNODES --ntasks-per-node=1 bash --noprofile --norc -lc 'source $HOME/nanochat_cache/.venv/bin/activate; torchrun --nnodes=$NNODES --nproc_per_node=$GPUS_PER_NODE --rdzv_endpoint=$RDZV_ENDPOINT --rdzv_id=$SLURM_JOB_ID --node_rank=$SLURM_NODEID -m scripts.base_train -- --depth=20 --run=$WANDB_RUN --data_dir=$DATA_DIR --matrix_lr=$MATRIX_LR'
# evaluate the model on a larger chunk of train/val data and draw some samples (multi-node)
srun --ntasks=$NNODES --ntasks-per-node=1 bash --noprofile --norc -lc 'source $HOME/nanochat_cache/.venv/bin/activate; torchrun --nnodes=$NNODES --nproc_per_node=$GPUS_PER_NODE --node_rank=$SLURM_NODEID -m scripts.base_loss'
srun --ntasks=$NNODES --ntasks-per-node=1 bash --noprofile --norc -lc 'source $HOME/nanochat_cache/.venv/bin/activate; torchrun --nnodes=$NNODES --nproc_per_node=$GPUS_PER_NODE --rdzv_endpoint=$RDZV_ENDPOINT --rdzv_id=$SLURM_JOB_ID --node_rank=$SLURM_NODEID -m scripts.base_loss --data_dir=$DATA_DIR'
# evaluate the model on CORE tasks (multi-node)
srun --ntasks=$NNODES --ntasks-per-node=1 bash --noprofile --norc -lc 'source $HOME/nanochat_cache/.venv/bin/activate; torchrun --nnodes=$NNODES --nproc_per_node=$GPUS_PER_NODE --node_rank=$SLURM_NODEID -m scripts.base_eval'
srun --ntasks=$NNODES --ntasks-per-node=1 bash --noprofile --norc -lc 'source $HOME/nanochat_cache/.venv/bin/activate; torchrun --nnodes=$NNODES --nproc_per_node=$GPUS_PER_NODE --rdzv_endpoint=$RDZV_ENDPOINT --rdzv_id=$SLURM_JOB_ID --node_rank=$SLURM_NODEID -m scripts.base_eval'
# -----------------------------------------------------------------------------
# Midtraining (teach the model conversation special tokens, tool use, multiple choice)
# run midtraining and eval the model (multi-node)
srun --ntasks=$NNODES --ntasks-per-node=1 bash --noprofile --norc -lc 'source $HOME/nanochat_cache/.venv/bin/activate; torchrun --nnodes=$NNODES --nproc_per_node=$GPUS_PER_NODE --node_rank=$SLURM_NODEID -m scripts.mid_train -- --run=$WANDB_RUN'
srun --ntasks=$NNODES --ntasks-per-node=1 bash --noprofile --norc -lc 'source $HOME/nanochat_cache/.venv/bin/activate; torchrun --nnodes=$NNODES --nproc_per_node=$GPUS_PER_NODE --node_rank=$SLURM_NODEID -m scripts.chat_eval -- -i mid'
srun --ntasks=$NNODES --ntasks-per-node=1 bash --noprofile --norc -lc 'source $HOME/nanochat_cache/.venv/bin/activate; torchrun --nnodes=$NNODES --nproc_per_node=$GPUS_PER_NODE --rdzv_endpoint=$RDZV_ENDPOINT --rdzv_id=$SLURM_JOB_ID --node_rank=$SLURM_NODEID -m scripts.mid_train -- --run=$WANDB_RUN'
srun --ntasks=$NNODES --ntasks-per-node=1 bash --noprofile --norc -lc 'source $HOME/nanochat_cache/.venv/bin/activate; torchrun --nnodes=$NNODES --nproc_per_node=$GPUS_PER_NODE --rdzv_endpoint=$RDZV_ENDPOINT --rdzv_id=$SLURM_JOB_ID --node_rank=$SLURM_NODEID -m scripts.chat_eval -- -i mid'
# -----------------------------------------------------------------------------
# Supervised Finetuning (domain adaptation to each sequence all by itself per row)
# train sft and re-eval right away (should see a small bump) (multi-node)
srun --ntasks=$NNODES --ntasks-per-node=1 bash --noprofile --norc -lc 'source $HOME/nanochat_cache/.venv/bin/activate; torchrun --nnodes=$NNODES --nproc_per_node=$GPUS_PER_NODE --node_rank=$SLURM_NODEID -m scripts.chat_sft -- --run=$WANDB_RUN'
srun --ntasks=$NNODES --ntasks-per-node=1 bash --noprofile --norc -lc 'source $HOME/nanochat_cache/.venv/bin/activate; torchrun --nnodes=$NNODES --nproc_per_node=$GPUS_PER_NODE --node_rank=$SLURM_NODEID -m scripts.chat_eval -- -i sft'
srun --ntasks=$NNODES --ntasks-per-node=1 bash --noprofile --norc -lc 'source $HOME/nanochat_cache/.venv/bin/activate; torchrun --nnodes=$NNODES --nproc_per_node=$GPUS_PER_NODE --rdzv_endpoint=$RDZV_ENDPOINT --rdzv_id=$SLURM_JOB_ID --node_rank=$SLURM_NODEID -m scripts.chat_sft -- --run=$WANDB_RUN'
srun --ntasks=$NNODES --ntasks-per-node=1 bash --noprofile --norc -lc 'source $HOME/nanochat_cache/.venv/bin/activate; torchrun --nnodes=$NNODES --nproc_per_node=$GPUS_PER_NODE --rdzv_endpoint=$RDZV_ENDPOINT --rdzv_id=$SLURM_JOB_ID --node_rank=$SLURM_NODEID -m scripts.chat_eval -- -i sft'
# chat with the model over CLI! Leave out the -p to chat interactively
# python -m scripts.chat_cli -p "Why is the sky blue?"