nanochat/speedrun.sh

173 lines
10 KiB
Bash
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/bin/bash
# 【核心逻辑】nanochat核心启动脚本一站式完成环境配置、数据下载、模型训练/微调全流程
# 【代码规范】遵循Shell脚本规范首行指定bash解释器注释统一用#开头,关键步骤用分隔线区分
# 【设计定位】低成本复刻ChatGPT适配8XH100节点4小时内完成训练总成本约$100
# 【使用场景说明】提供3种启动方式覆盖不同使用需求
# 1) Example launch (simplest):
# 【场景1】基础启动适合短时间测试无后台运行/日志记录,操作最简
# bash speedrun.sh
# 2) Example launch in a screen session (because the run takes ~4 hours):
# 【场景2】Screen会话启动解决训练耗时久的问题断开连接后仍能后台运行同时记录日志
# screen -L -Logfile speedrun.log -S speedrun bash speedrun.sh
# 3) Example launch with wandb logging, but see below for setting up wandb first:
# 【场景3】Wandb日志启动集成训练可视化工具需提前配置wandb账号便于监控训练过程
# WANDB_RUN=speedrun screen -L -Logfile speedrun.log -S speedrun bash speedrun.sh
# 【环境配置】设置OMP线程数为1避免多线程冲突指定中间产物缓存目录避免重复下载
export OMP_NUM_THREADS=1
export NANOCHAT_BASE_DIR="$HOME/.cache/nanochat"
# 【容错处理】创建缓存目录:若目录不存在则自动创建,防止后续步骤因目录缺失失败
mkdir -p $NANOCHAT_BASE_DIR
# -----------------------------------------------------------------------------
# 【模块划分】Python虚拟环境配置模块使用uv工具管理依赖保证环境隔离与版本一致性
# Python venv setup with uv
# install uv (if not already installed)
# 【自动化依赖安装】检查uv是否安装未安装则自动下载安装降低环境配置门槛
command -v uv &> /dev/null || curl -LsSf https://astral.sh/uv/install.sh | sh
# create a .venv local virtual environment (if it doesn't exist)
# 【环境隔离】创建虚拟环境:仅当.venv目录不存在时创建避免重复操作
[ -d ".venv" ] || uv venv
# install the repo dependencies
# 【依赖管理】安装GPU版本依赖通过--extra gpu指定GPU相关依赖适配训练硬件
uv sync --extra gpu
# activate venv so that `python` uses the project's venv instead of system python
# 【环境激活】激活虚拟环境确保后续Python命令使用项目专属环境避免系统环境污染
source .venv/bin/activate
# -----------------------------------------------------------------------------
# 【模块划分】Wandb日志配置模块可选功能支持训练过程可视化监控
# wandb setup
# If you wish to use wandb for logging (it's nice!, recommended).
# 1) Make sure to first log in to wandb, e.g. run:
# `wandb login`
# 2) Set the WANDB_RUN environment variable when running this script, e.g.:
# `WANDB_RUN=d26 bash speedrun.sh`
# 【容错处理】默认日志模式未指定WANDB_RUN时设置为dummy模式跳过日志记录避免启动失败
if [ -z "$WANDB_RUN" ]; then
# by default use "dummy" : it's handled as a special case, skips logging to wandb
WANDB_RUN=dummy
fi
# -----------------------------------------------------------------------------
# 【模块划分】训练报告初始化模块:清空历史报告,记录启动信息与系统参数
# During the course of the run, we will be writing markdown reports to the report/
# directory in the base dir. This command clears it out and writes a header section
# with a bunch of system info and a timestamp that marks the start of the run.
# 【报告管理】重置报告:清空报告目录,写入启动时间戳和系统信息,为后续报告生成做准备
python -m nanochat.report reset
# -----------------------------------------------------------------------------
# 【模块划分】Tokenizer分词器配置模块安装依赖、训练分词器、下载预处理数据
# Tokenizer
# Install Rust / Cargo
# 【依赖安装】安装Rust环境分词器基于Rust开发需先配置编译环境
curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y
source "$HOME/.cargo/env"
# Build the rustbpe Tokenizer
# 【编译优化】编译Rust版分词器使用maturin编译release版本提升分词性能
uv run maturin develop --release --manifest-path rustbpe/Cargo.toml
# Download the first ~2B characters of pretraining dataset
# look at dev/repackage_data_reference.py for details on how this data was prepared
# each data shard is ~250M chars
# so we download 2e9 / 250e6 = 8 data shards at this point
# each shard is ~100MB of text (compressed), so this is about ~800MB of data on disk
# 【数据准备】下载基础预训练数据先下载8个分片约20亿字符满足分词器训练需求
python -m nanochat.dataset -n 8
# Immediately also kick off downloading more shards in the background while tokenizer trains
# See comment below for why 240 is the right number here
# 【性能优化】后台下载更多数据分词器训练时异步下载240个分片提升整体流程效率
python -m nanochat.dataset -n 240 &
DATASET_DOWNLOAD_PID=$!
# train the tokenizer with vocab size 2**16 = 65536 on ~2B characters of data
# 【核心逻辑】训练分词器基于20亿字符数据训练词汇量为65536的分词器
python -m scripts.tok_train --max_chars=2000000000
# evaluate the tokenizer (report compression ratio etc.)
# 【效果验证】评估分词器:输出压缩率等指标,验证分词器效果
python -m scripts.tok_eval
# -----------------------------------------------------------------------------
# 【模块划分】基础模型预训练模块基于下载的数据集训练561M参数的基础模型
# Base model (pretraining)
# The d20 model is 561M parameters.
# Chinchilla says #tokens = 20X #params, so we need 561e6 * 20 = 11.2B tokens.
# Assume our tokenizer is 4.8 chars/token, this is 11.2B * 4.8 ~= 54B chars.
# At 250M chars/shard, this is 54B / 250M ~= 216 shards needed for pretraining.
# Round up to 240 for safety. At ~100MB/shard, this downloads ~24GB of data to disk.
# (The total number of shards available in the entire dataset is 1822.)
# 【数据校验】等待数据集下载完成确保240个分片全部下载满足预训练数据需求
echo "Waiting for dataset download to complete..."
wait $DATASET_DOWNLOAD_PID
# Number of processes/GPUs to use
# 【硬件适配】设置GPU数量指定8个GPU并行训练适配8XH100节点
NPROC_PER_NODE=8
# pretrain the d20 model
# 【核心逻辑】预训练d20模型基于8卡并行训练561M参数的基础模型
torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.base_train -- --depth=20 --run=$WANDB_RUN
# evaluate the model on a larger chunk of train/val data and draw some samples
# 【效果验证】评估模型损失:在更多训练/验证数据上评估,输出样本结果
torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.base_loss
# evaluate the model on CORE tasks
# 【效果验证】CORE任务评估在标准CORE任务上验证模型性能
torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.base_eval
# -----------------------------------------------------------------------------
# 【模块划分】中期训练模块:教会模型对话特殊令牌、工具使用、多选任务
# Midtraining (teach the model conversation special tokens, tool use, multiple choice)
# download 2.3MB of synthetic identity conversations to impart a personality to nanochat
# see dev/gen_sft_data.py for details on how this data was prepared and to get a sense of how you can easily tune it
# 【数据准备】下载对话数据获取2.3MB合成身份对话数据,赋予模型对话人格
curl -L -o $NANOCHAT_BASE_DIR/identity_conversations.jsonl https://karpathy-public.s3.us-west-2.amazonaws.com/identity_conversations.jsonl
# run midtraining and eval the model
# 【核心逻辑】执行中期训练:训练模型掌握对话特殊令牌、工具使用等能力
torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.mid_train -- --run=$WANDB_RUN
# 【效果验证】评估中期训练效果:验证对话能力是否达标
torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.chat_eval -- -i mid
# -----------------------------------------------------------------------------
# 【模块划分】监督微调模块:针对单条序列做领域适配,提升模型对话效果
# Supervised Finetuning (domain adaptation to each sequence all by itself per row)
# train sft and re-eval right away (should see a small bump)
# 【核心逻辑】执行监督微调:进一步适配对话场景,提升模型效果
torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.chat_sft -- --run=$WANDB_RUN
# 【效果验证】评估微调效果:验证微调后模型性能是否提升
torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.chat_eval -- -i sft
# chat with the model over CLI! Leave out the -p to chat interactively
# 【功能体验】CLI交互对话支持命令行与模型交互-p参数指定预设问题
# python -m scripts.chat_cli -p "Why is the sky blue?"
# even better, chat with your model over a pretty WebUI ChatGPT style
# 【功能体验】WebUI交互对话提供ChatGPT风格的可视化界面更友好的交互体验
# python -m scripts.chat_web
# -----------------------------------------------------------------------------
# 【模块划分】强化学习模块可选针对GSM8K任务优化模型
# Reinforcement Learning. Optional, and currently only on GSM8K
# (optional)
# run reinforcement learning
# 【可选逻辑】执行强化学习针对GSM8K数学任务优化模型可选步骤
# torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.chat_rl -- --run=$WANDB_RUN
# eval the RL model only on GSM8K
# 【可选验证】评估强化学习效果仅验证GSM8K任务上的性能提升
# torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.chat_eval -- -i rl -a GSM8K
# -----------------------------------------------------------------------------
# 【模块划分】报告生成模块:整合所有训练阶段结果,生成完整报告
# Generate the full report by putting together all the sections
# report.md is the output and will be copied to current directory for convenience
# 【结果汇总】生成完整报告:整合各阶段日志/评估结果输出report.md到当前目录
python -m nanochat.report generate