mirror of
https://github.com/karpathy/nanochat.git
synced 2026-03-22 21:03:22 +00:00
support custom training data, train tokenizer
This commit is contained in:
parent
15e7a22a41
commit
2085e6637a
|
|
@ -28,7 +28,7 @@ def evaluate_bpb(model, batches, steps, token_bytes):
|
|||
total_nats = torch.tensor(0.0, dtype=torch.float32, device=model.get_device())
|
||||
total_bytes = torch.tensor(0, dtype=torch.int64, device=model.get_device())
|
||||
batch_iter = iter(batches)
|
||||
for _ in range(steps):
|
||||
for step in range(steps):
|
||||
x, y = next(batch_iter)
|
||||
loss2d = model(x, y, loss_reduction='none') # (B, T)
|
||||
loss2d = loss2d.view(-1) # flatten
|
||||
|
|
|
|||
|
|
@ -34,6 +34,8 @@ print_banner()
|
|||
run = "dummy" # wandb run name default ("dummy" is special - we won't log to wandb)
|
||||
# Runtime
|
||||
device_type = "" # cuda|cpu|mps (empty => autodetect good device type default, in order: CUDA > MPS > CPU)
|
||||
# Data
|
||||
data_dir = "" # path to directory containing parquet files with 'text' column (empty string = use default: ~/.cache/nanochat/base_data)
|
||||
# Model architecture
|
||||
depth = 20 # the depth of the Transformer model to train, rest of the kwargs are derived
|
||||
max_seq_len = 2048 # max context length
|
||||
|
|
@ -43,7 +45,7 @@ target_flops = -1.0 # calculate num_iterations to reach target_flops. Useful for
|
|||
target_param_data_ratio = 20 # calculate num_iterations to maintain fixed data:param ratio (Chinchilla=20) (-1 = disable)
|
||||
# Optimization
|
||||
device_batch_size = 32 # per-device batch size (set to not OOM)
|
||||
total_batch_size = 524288 # 524288 # total desired batch size, in #tokens
|
||||
total_batch_size = 524288 # 2097152 #1048576 # 524288 # total desired batch size, in #tokens
|
||||
embedding_lr = 0.2 # learning rate for the embedding parameters (Adam)
|
||||
unembedding_lr = 0.004 # learning rate for the unembedding parameters (Adam)
|
||||
weight_decay = 0.0 # weight decay for the embedding/unembedding parameters (Adam)
|
||||
|
|
@ -56,7 +58,7 @@ core_metric_every = 2000 # every how many steps to evaluate the core metric (-1
|
|||
core_metric_max_per_task = 500 # examples per task in estimating the core metric
|
||||
sample_every = 2000 # every how many steps to sample from the model
|
||||
# Output
|
||||
model_tag = "" # optionally override the model tag for the output checkpoint directory name
|
||||
model_tag = run # optionally override the model tag for the output checkpoint directory name
|
||||
# now allow CLI to override the settings via the configurator lol
|
||||
config_keys = [k for k,v in globals().items() if not k.startswith('_') and isinstance(v, (int, float, bool, str))]
|
||||
exec(open(os.path.join('nanochat', 'configurator.py')).read()) # overrides from command line or config file
|
||||
|
|
@ -139,12 +141,13 @@ print0(f"Total training FLOPs estimate: {num_flops_per_token * total_tokens:e}")
|
|||
# Initialize the Optimizer (Muon for Linear layers, AdamW for embedding and lm_head)
|
||||
optimizers = model.setup_optimizers(unembedding_lr=unembedding_lr, embedding_lr=embedding_lr, matrix_lr=matrix_lr, weight_decay=weight_decay)
|
||||
adamw_optimizer, muon_optimizer = optimizers
|
||||
|
||||
# Initialize the DataLoaders for train/val
|
||||
base_dir = get_base_dir()
|
||||
tokens_dir = os.path.join(base_dir, "tokenized_data")
|
||||
train_loader = tokenizing_distributed_data_loader(device_batch_size, max_seq_len, split="train", device=device)
|
||||
build_val_loader = lambda: tokenizing_distributed_data_loader(device_batch_size, max_seq_len, split="val", device=device)
|
||||
# Use custom data_dir if provided, otherwise use default
|
||||
custom_data_dir = data_dir if data_dir else None
|
||||
train_loader = tokenizing_distributed_data_loader(device_batch_size, max_seq_len, split="train", device=device, data_dir=custom_data_dir)
|
||||
build_val_loader = lambda: tokenizing_distributed_data_loader(device_batch_size, max_seq_len, split="val", device=device, data_dir="/lustre/fsw/portfolios/nvr/users/sdiao/nanochat/.cache/base_data") # SHIZHE: always use the default val data dir from FineWeb by Andrej Karpathy
|
||||
x, y = next(train_loader) # kick off load of the very first batch of data
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
|
|
|
|||
|
|
@ -17,10 +17,14 @@ parser = argparse.ArgumentParser(description='Train a BPE tokenizer')
|
|||
parser.add_argument('--max_chars', type=int, default=10_000_000_000, help='Maximum characters to train on (default: 10B)')
|
||||
parser.add_argument('--doc_cap', type=int, default=10_000, help='Maximum characters per document (default: 10,000)')
|
||||
parser.add_argument('--vocab_size', type=int, default=65536, help='Vocabulary size (default: 65536 = 2^16)')
|
||||
parser.add_argument('--data_dir', type=str, default=None, help='Custom dataset directory (default: None, uses default dataset)')
|
||||
parser.add_argument('--tokenizer_name', type=str, default='tokenizer', help='Name for the tokenizer subdirectory (default: tokenizer)')
|
||||
args = parser.parse_args()
|
||||
print(f"max_chars: {args.max_chars:,}")
|
||||
print(f"doc_cap: {args.doc_cap:,}")
|
||||
print(f"vocab_size: {args.vocab_size:,}")
|
||||
print(f"data_dir: {args.data_dir}")
|
||||
print(f"tokenizer_name: {args.tokenizer_name}")
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Text iterator
|
||||
|
|
@ -32,7 +36,7 @@ def text_iterator():
|
|||
3) Break when we've seen args.max_chars characters
|
||||
"""
|
||||
nchars = 0
|
||||
for batch in parquets_iter_batched(split="train"):
|
||||
for batch in parquets_iter_batched(split="train", data_dir=args.data_dir):
|
||||
for doc in batch:
|
||||
doc_text = doc
|
||||
if len(doc_text) > args.doc_cap:
|
||||
|
|
@ -54,8 +58,9 @@ print(f"Training time: {train_time:.2f}s")
|
|||
# -----------------------------------------------------------------------------
|
||||
# Save the tokenizer to disk
|
||||
base_dir = get_base_dir()
|
||||
tokenizer_dir = os.path.join(base_dir, "tokenizer")
|
||||
tokenizer_dir = os.path.join(base_dir, args.tokenizer_name)
|
||||
tokenizer.save(tokenizer_dir)
|
||||
print(f"Saved tokenizer to {tokenizer_dir}")
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Quick inline sanity check
|
||||
|
|
|
|||
10
speedrun.sh
10
speedrun.sh
|
|
@ -9,6 +9,10 @@
|
|||
# screen -L -Logfile speedrun.log -S speedrun bash speedrun.sh
|
||||
# 3) Example launch with wandb logging, but see below for setting up wandb first:
|
||||
# WANDB_RUN=speedrun screen -L -Logfile speedrun.log -S speedrun bash speedrun.sh
|
||||
set -x
|
||||
|
||||
DATA_NAME=smollm
|
||||
DATA_DIR=/lustre/fsw/portfolios/nvr/users/sdiao/nanochat/data/$DATA_NAME
|
||||
|
||||
# Default intermediate artifacts directory is in ~/.cache/nanochat
|
||||
export OMP_NUM_THREADS=1
|
||||
|
|
@ -39,7 +43,7 @@ source .venv/bin/activate
|
|||
# WANDB_RUN=dummy
|
||||
# fi
|
||||
export WANDB_API_KEY="ec7a9c0701d404122e4fc5c7c7518ed17f5b03ca"
|
||||
export WANDB_RUN=fineweb_d20
|
||||
export WANDB_RUN=fineweb_d20_test
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# During the course of the run, we will be writing markdown reports to the report/
|
||||
|
|
@ -94,9 +98,9 @@ echo "Waiting for dataset download to complete..."
|
|||
wait $DATASET_DOWNLOAD_PID
|
||||
|
||||
# pretrain the d20 model
|
||||
torchrun --standalone --nproc_per_node=8 -m scripts.base_train -- --depth=20 --run=$WANDB_RUN
|
||||
torchrun --standalone --nproc_per_node=8 -m scripts.base_train -- --depth=20 --run=$WANDB_RUN --data_dir=$DATA_DIR
|
||||
# evaluate the model on a larger chunk of train/val data and draw some samples
|
||||
torchrun --standalone --nproc_per_node=8 -m scripts.base_loss
|
||||
torchrun --standalone --nproc_per_node=8 -m scripts.base_loss --data_dir=$DATA_DIR
|
||||
# evaluate the model on CORE tasks
|
||||
torchrun --standalone --nproc_per_node=8 -m scripts.base_eval
|
||||
|
||||
|
|
|
|||
|
|
@ -1,12 +1,12 @@
|
|||
#!/bin/bash
|
||||
#SBATCH --account nvr_lpr_llm
|
||||
#SBATCH --partition batch_short,batch_block1,backfill
|
||||
#SBATCH --job-name=nanochat_1node_fineweb_d20
|
||||
#SBATCH --nodes=2
|
||||
#SBATCH --partition interactive,batch_short,batch_block1,backfill
|
||||
#SBATCH --job-name=nanochat_multinode_d20
|
||||
#SBATCH --nodes=1
|
||||
#SBATCH --ntasks-per-node=1
|
||||
#SBATCH --gpus-per-node=8
|
||||
#SBATCH --time=02:00:00
|
||||
#SBATCH --output=nanochat_1node_fineweb_d20-%j.out
|
||||
#SBATCH --time=04:00:00
|
||||
#SBATCH --output=logs/nanochat_1node_d20-%j.out
|
||||
#SBATCH --mem=0
|
||||
#SBATCH --exclusive
|
||||
|
||||
|
|
@ -22,6 +22,10 @@
|
|||
|
||||
set -x # Enable debug output
|
||||
|
||||
DATA_NAME=climbmix
|
||||
export DATA_DIR=/lustre/fsw/portfolios/nvr/users/sdiao/nanochat/data/$DATA_NAME
|
||||
export MATRIX_LR=0.02
|
||||
|
||||
# Default intermediate artifacts directory is in ~/.cache/nanochat
|
||||
export OMP_NUM_THREADS=1
|
||||
export NANOCHAT_BASE_DIR="$HOME/nanochat_cache"
|
||||
|
|
@ -72,7 +76,7 @@ uv sync --active
|
|||
# WANDB_RUN=dummy
|
||||
# fi
|
||||
export WANDB_API_KEY="ec7a9c0701d404122e4fc5c7c7518ed17f5b03ca"
|
||||
export WANDB_RUN=fineweb_d20_1node_$SLURM_JOB_ID
|
||||
export WANDB_RUN=${DATA_NAME}_d20_1node_matrixlr${MATRIX_LR}_${SLURM_JOB_ID}
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# During the course of the run, we will be writing markdown reports to the report/
|
||||
|
|
@ -100,15 +104,17 @@ maturin develop --release --manifest-path rustbpe/Cargo.toml
|
|||
# each data shard is ~250M chars
|
||||
# so we download 2e9 / 250e6 = 8 data shards at this point
|
||||
# each shard is ~100MB of text (compressed), so this is about ~800MB of data on disk
|
||||
python -m nanochat.dataset -n 8
|
||||
python -m nanochat.dataset -n 8 --data_dir=$DATA_DIR
|
||||
# Immediately also kick off downloading more shards in the background while tokenizer trains
|
||||
# See comment below for why 240 is the right number here
|
||||
python -m nanochat.dataset -n 240 &
|
||||
python -m nanochat.dataset -n 240 --data_dir=$DATA_DIR &
|
||||
DATASET_DOWNLOAD_PID=$!
|
||||
# train the tokenizer with vocab size 2**16 = 65536 on ~2B characters of data
|
||||
python -m scripts.tok_train --max_chars=2000000000
|
||||
# Use unique tokenizer name based on dataset
|
||||
TOKENIZER_NAME="tokenizer_${DATA_NAME}"
|
||||
python -m scripts.tok_train --max_chars=2000000000 --data_dir=$DATA_DIR --tokenizer_name=$TOKENIZER_NAME
|
||||
# evaluate the tokenizer (report compression ratio etc.)
|
||||
python -m scripts.tok_eval
|
||||
python -m scripts.tok_eval --tokenizer_name=$TOKENIZER_NAME
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Base model (pretraining)
|
||||
|
|
@ -135,25 +141,25 @@ wait $DATASET_DOWNLOAD_PID
|
|||
srun --ntasks=$NNODES --ntasks-per-node=1 bash --noprofile --norc -lc 'source $HOME/nanochat_cache/.venv/bin/activate; python -c "import torch; print(torch.cuda.device_count())"'
|
||||
|
||||
# pretrain the d20 model (multi-node)
|
||||
srun --ntasks=$NNODES --ntasks-per-node=1 bash --noprofile --norc -lc 'source $HOME/nanochat_cache/.venv/bin/activate; torchrun --nnodes=$NNODES --nproc_per_node=$GPUS_PER_NODE --node_rank=$SLURM_NODEID -m scripts.base_train -- --depth=20 --run=$WANDB_RUN'
|
||||
srun --ntasks=$NNODES --ntasks-per-node=1 bash --noprofile --norc -lc 'source $HOME/nanochat_cache/.venv/bin/activate; torchrun --nnodes=$NNODES --nproc_per_node=$GPUS_PER_NODE --rdzv_endpoint=$RDZV_ENDPOINT --rdzv_id=$SLURM_JOB_ID --node_rank=$SLURM_NODEID -m scripts.base_train -- --depth=20 --run=$WANDB_RUN --data_dir=$DATA_DIR --matrix_lr=$MATRIX_LR'
|
||||
# evaluate the model on a larger chunk of train/val data and draw some samples (multi-node)
|
||||
srun --ntasks=$NNODES --ntasks-per-node=1 bash --noprofile --norc -lc 'source $HOME/nanochat_cache/.venv/bin/activate; torchrun --nnodes=$NNODES --nproc_per_node=$GPUS_PER_NODE --node_rank=$SLURM_NODEID -m scripts.base_loss'
|
||||
srun --ntasks=$NNODES --ntasks-per-node=1 bash --noprofile --norc -lc 'source $HOME/nanochat_cache/.venv/bin/activate; torchrun --nnodes=$NNODES --nproc_per_node=$GPUS_PER_NODE --rdzv_endpoint=$RDZV_ENDPOINT --rdzv_id=$SLURM_JOB_ID --node_rank=$SLURM_NODEID -m scripts.base_loss --data_dir=$DATA_DIR'
|
||||
# evaluate the model on CORE tasks (multi-node)
|
||||
srun --ntasks=$NNODES --ntasks-per-node=1 bash --noprofile --norc -lc 'source $HOME/nanochat_cache/.venv/bin/activate; torchrun --nnodes=$NNODES --nproc_per_node=$GPUS_PER_NODE --node_rank=$SLURM_NODEID -m scripts.base_eval'
|
||||
srun --ntasks=$NNODES --ntasks-per-node=1 bash --noprofile --norc -lc 'source $HOME/nanochat_cache/.venv/bin/activate; torchrun --nnodes=$NNODES --nproc_per_node=$GPUS_PER_NODE --rdzv_endpoint=$RDZV_ENDPOINT --rdzv_id=$SLURM_JOB_ID --node_rank=$SLURM_NODEID -m scripts.base_eval'
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Midtraining (teach the model conversation special tokens, tool use, multiple choice)
|
||||
|
||||
# run midtraining and eval the model (multi-node)
|
||||
srun --ntasks=$NNODES --ntasks-per-node=1 bash --noprofile --norc -lc 'source $HOME/nanochat_cache/.venv/bin/activate; torchrun --nnodes=$NNODES --nproc_per_node=$GPUS_PER_NODE --node_rank=$SLURM_NODEID -m scripts.mid_train -- --run=$WANDB_RUN'
|
||||
srun --ntasks=$NNODES --ntasks-per-node=1 bash --noprofile --norc -lc 'source $HOME/nanochat_cache/.venv/bin/activate; torchrun --nnodes=$NNODES --nproc_per_node=$GPUS_PER_NODE --node_rank=$SLURM_NODEID -m scripts.chat_eval -- -i mid'
|
||||
srun --ntasks=$NNODES --ntasks-per-node=1 bash --noprofile --norc -lc 'source $HOME/nanochat_cache/.venv/bin/activate; torchrun --nnodes=$NNODES --nproc_per_node=$GPUS_PER_NODE --rdzv_endpoint=$RDZV_ENDPOINT --rdzv_id=$SLURM_JOB_ID --node_rank=$SLURM_NODEID -m scripts.mid_train -- --run=$WANDB_RUN'
|
||||
srun --ntasks=$NNODES --ntasks-per-node=1 bash --noprofile --norc -lc 'source $HOME/nanochat_cache/.venv/bin/activate; torchrun --nnodes=$NNODES --nproc_per_node=$GPUS_PER_NODE --rdzv_endpoint=$RDZV_ENDPOINT --rdzv_id=$SLURM_JOB_ID --node_rank=$SLURM_NODEID -m scripts.chat_eval -- -i mid'
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Supervised Finetuning (domain adaptation to each sequence all by itself per row)
|
||||
|
||||
# train sft and re-eval right away (should see a small bump) (multi-node)
|
||||
srun --ntasks=$NNODES --ntasks-per-node=1 bash --noprofile --norc -lc 'source $HOME/nanochat_cache/.venv/bin/activate; torchrun --nnodes=$NNODES --nproc_per_node=$GPUS_PER_NODE --node_rank=$SLURM_NODEID -m scripts.chat_sft -- --run=$WANDB_RUN'
|
||||
srun --ntasks=$NNODES --ntasks-per-node=1 bash --noprofile --norc -lc 'source $HOME/nanochat_cache/.venv/bin/activate; torchrun --nnodes=$NNODES --nproc_per_node=$GPUS_PER_NODE --node_rank=$SLURM_NODEID -m scripts.chat_eval -- -i sft'
|
||||
srun --ntasks=$NNODES --ntasks-per-node=1 bash --noprofile --norc -lc 'source $HOME/nanochat_cache/.venv/bin/activate; torchrun --nnodes=$NNODES --nproc_per_node=$GPUS_PER_NODE --rdzv_endpoint=$RDZV_ENDPOINT --rdzv_id=$SLURM_JOB_ID --node_rank=$SLURM_NODEID -m scripts.chat_sft -- --run=$WANDB_RUN'
|
||||
srun --ntasks=$NNODES --ntasks-per-node=1 bash --noprofile --norc -lc 'source $HOME/nanochat_cache/.venv/bin/activate; torchrun --nnodes=$NNODES --nproc_per_node=$GPUS_PER_NODE --rdzv_endpoint=$RDZV_ENDPOINT --rdzv_id=$SLURM_JOB_ID --node_rank=$SLURM_NODEID -m scripts.chat_eval -- -i sft'
|
||||
|
||||
# chat with the model over CLI! Leave out the -p to chat interactively
|
||||
# python -m scripts.chat_cli -p "Why is the sky blue?"
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user