mirror of
https://github.com/karpathy/nanochat.git
synced 2025-12-06 04:12:13 +00:00
标注:speedrun.sh - 核心逻辑 / 设计思路 / 代码规范
This commit is contained in:
parent
4a87a0d19f
commit
0d4f0250df
49
speedrun.sh
49
speedrun.sh
|
|
@ -1,58 +1,75 @@
|
|||
#!/bin/bash
|
||||
# 【核心逻辑】nanochat核心启动脚本:一站式完成环境配置、数据下载、模型训练/微调全流程
|
||||
# 【代码规范】遵循Shell脚本规范:首行指定bash解释器,注释统一用#开头,关键步骤用分隔线区分
|
||||
# 【设计定位】低成本复刻ChatGPT:适配8XH100节点,4小时内完成训练,总成本约$100
|
||||
|
||||
# This script is the "Best ChatGPT clone that $100 can buy",
|
||||
# It is designed to run in ~4 hours on 8XH100 node at $3/GPU/hour.
|
||||
|
||||
# 【使用场景说明】提供3种启动方式,覆盖不同使用需求
|
||||
# 1) Example launch (simplest):
|
||||
# 【场景1】基础启动:适合短时间测试,无后台运行/日志记录,操作最简
|
||||
# bash speedrun.sh
|
||||
# 2) Example launch in a screen session (because the run takes ~4 hours):
|
||||
# 【场景2】Screen会话启动:解决训练耗时久的问题,断开连接后仍能后台运行,同时记录日志
|
||||
# screen -L -Logfile speedrun.log -S speedrun bash speedrun.sh
|
||||
# 3) Example launch with wandb logging, but see below for setting up wandb first:
|
||||
# 【场景3】Wandb日志启动:集成训练可视化工具,需提前配置wandb账号,便于监控训练过程
|
||||
# WANDB_RUN=speedrun screen -L -Logfile speedrun.log -S speedrun bash speedrun.sh
|
||||
|
||||
# Default intermediate artifacts directory is in ~/.cache/nanochat
|
||||
# 【环境配置】设置OMP线程数为1,避免多线程冲突;指定中间产物缓存目录,避免重复下载
|
||||
export OMP_NUM_THREADS=1
|
||||
export NANOCHAT_BASE_DIR="$HOME/.cache/nanochat"
|
||||
# 【容错处理】创建缓存目录:若目录不存在则自动创建,防止后续步骤因目录缺失失败
|
||||
mkdir -p $NANOCHAT_BASE_DIR
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# 【模块划分】Python虚拟环境配置模块:使用uv工具管理依赖,保证环境隔离与版本一致性
|
||||
# Python venv setup with uv
|
||||
|
||||
# install uv (if not already installed)
|
||||
# 【自动化依赖安装】检查uv是否安装:未安装则自动下载安装,降低环境配置门槛
|
||||
command -v uv &> /dev/null || curl -LsSf https://astral.sh/uv/install.sh | sh
|
||||
# create a .venv local virtual environment (if it doesn't exist)
|
||||
# 【环境隔离】创建虚拟环境:仅当.venv目录不存在时创建,避免重复操作
|
||||
[ -d ".venv" ] || uv venv
|
||||
# install the repo dependencies
|
||||
# 【依赖管理】安装GPU版本依赖:通过--extra gpu指定GPU相关依赖,适配训练硬件
|
||||
uv sync --extra gpu
|
||||
# activate venv so that `python` uses the project's venv instead of system python
|
||||
# 【环境激活】激活虚拟环境:确保后续Python命令使用项目专属环境,避免系统环境污染
|
||||
source .venv/bin/activate
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# 【模块划分】Wandb日志配置模块:可选功能,支持训练过程可视化监控
|
||||
# wandb setup
|
||||
# If you wish to use wandb for logging (it's nice!, recommended).
|
||||
# 1) Make sure to first log in to wandb, e.g. run:
|
||||
# `wandb login`
|
||||
# 2) Set the WANDB_RUN environment variable when running this script, e.g.:
|
||||
# `WANDB_RUN=d26 bash speedrun.sh`
|
||||
# 【容错处理】默认日志模式:未指定WANDB_RUN时,设置为dummy模式跳过日志记录,避免启动失败
|
||||
if [ -z "$WANDB_RUN" ]; then
|
||||
# by default use "dummy" : it's handled as a special case, skips logging to wandb
|
||||
WANDB_RUN=dummy
|
||||
fi
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# 【模块划分】训练报告初始化模块:清空历史报告,记录启动信息与系统参数
|
||||
# During the course of the run, we will be writing markdown reports to the report/
|
||||
# directory in the base dir. This command clears it out and writes a header section
|
||||
# with a bunch of system info and a timestamp that marks the start of the run.
|
||||
# 【报告管理】重置报告:清空报告目录,写入启动时间戳和系统信息,为后续报告生成做准备
|
||||
python -m nanochat.report reset
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# 【模块划分】Tokenizer(分词器)配置模块:安装依赖、训练分词器、下载预处理数据
|
||||
# Tokenizer
|
||||
|
||||
# Install Rust / Cargo
|
||||
# 【依赖安装】安装Rust环境:分词器基于Rust开发,需先配置编译环境
|
||||
curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y
|
||||
source "$HOME/.cargo/env"
|
||||
|
||||
# Build the rustbpe Tokenizer
|
||||
# 【编译优化】编译Rust版分词器:使用maturin编译release版本,提升分词性能
|
||||
uv run maturin develop --release --manifest-path rustbpe/Cargo.toml
|
||||
|
||||
# Download the first ~2B characters of pretraining dataset
|
||||
|
|
@ -60,17 +77,22 @@ uv run maturin develop --release --manifest-path rustbpe/Cargo.toml
|
|||
# each data shard is ~250M chars
|
||||
# so we download 2e9 / 250e6 = 8 data shards at this point
|
||||
# each shard is ~100MB of text (compressed), so this is about ~800MB of data on disk
|
||||
# 【数据准备】下载基础预训练数据:先下载8个分片(约20亿字符),满足分词器训练需求
|
||||
python -m nanochat.dataset -n 8
|
||||
# Immediately also kick off downloading more shards in the background while tokenizer trains
|
||||
# See comment below for why 240 is the right number here
|
||||
# 【性能优化】后台下载更多数据:分词器训练时异步下载240个分片,提升整体流程效率
|
||||
python -m nanochat.dataset -n 240 &
|
||||
DATASET_DOWNLOAD_PID=$!
|
||||
# train the tokenizer with vocab size 2**16 = 65536 on ~2B characters of data
|
||||
# 【核心逻辑】训练分词器:基于20亿字符数据,训练词汇量为65536的分词器
|
||||
python -m scripts.tok_train --max_chars=2000000000
|
||||
# evaluate the tokenizer (report compression ratio etc.)
|
||||
# 【效果验证】评估分词器:输出压缩率等指标,验证分词器效果
|
||||
python -m scripts.tok_eval
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# 【模块划分】基础模型预训练模块:基于下载的数据集训练561M参数的基础模型
|
||||
# Base model (pretraining)
|
||||
|
||||
# The d20 model is 561M parameters.
|
||||
|
|
@ -79,53 +101,72 @@ python -m scripts.tok_eval
|
|||
# At 250M chars/shard, this is 54B / 250M ~= 216 shards needed for pretraining.
|
||||
# Round up to 240 for safety. At ~100MB/shard, this downloads ~24GB of data to disk.
|
||||
# (The total number of shards available in the entire dataset is 1822.)
|
||||
# 【数据校验】等待数据集下载完成:确保240个分片全部下载,满足预训练数据需求
|
||||
echo "Waiting for dataset download to complete..."
|
||||
wait $DATASET_DOWNLOAD_PID
|
||||
|
||||
# Number of processes/GPUs to use
|
||||
# 【硬件适配】设置GPU数量:指定8个GPU并行训练,适配8XH100节点
|
||||
NPROC_PER_NODE=8
|
||||
|
||||
# pretrain the d20 model
|
||||
# 【核心逻辑】预训练d20模型:基于8卡并行,训练561M参数的基础模型
|
||||
torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.base_train -- --depth=20 --run=$WANDB_RUN
|
||||
# evaluate the model on a larger chunk of train/val data and draw some samples
|
||||
# 【效果验证】评估模型损失:在更多训练/验证数据上评估,输出样本结果
|
||||
torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.base_loss
|
||||
# evaluate the model on CORE tasks
|
||||
# 【效果验证】CORE任务评估:在标准CORE任务上验证模型性能
|
||||
torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.base_eval
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# 【模块划分】中期训练模块:教会模型对话特殊令牌、工具使用、多选任务
|
||||
# Midtraining (teach the model conversation special tokens, tool use, multiple choice)
|
||||
|
||||
# download 2.3MB of synthetic identity conversations to impart a personality to nanochat
|
||||
# see dev/gen_sft_data.py for details on how this data was prepared and to get a sense of how you can easily tune it
|
||||
# 【数据准备】下载对话数据:获取2.3MB合成身份对话数据,赋予模型对话人格
|
||||
curl -L -o $NANOCHAT_BASE_DIR/identity_conversations.jsonl https://karpathy-public.s3.us-west-2.amazonaws.com/identity_conversations.jsonl
|
||||
|
||||
# run midtraining and eval the model
|
||||
# 【核心逻辑】执行中期训练:训练模型掌握对话特殊令牌、工具使用等能力
|
||||
torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.mid_train -- --run=$WANDB_RUN
|
||||
# 【效果验证】评估中期训练效果:验证对话能力是否达标
|
||||
torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.chat_eval -- -i mid
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# 【模块划分】监督微调模块:针对单条序列做领域适配,提升模型对话效果
|
||||
# Supervised Finetuning (domain adaptation to each sequence all by itself per row)
|
||||
|
||||
# train sft and re-eval right away (should see a small bump)
|
||||
# 【核心逻辑】执行监督微调:进一步适配对话场景,提升模型效果
|
||||
torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.chat_sft -- --run=$WANDB_RUN
|
||||
# 【效果验证】评估微调效果:验证微调后模型性能是否提升
|
||||
torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.chat_eval -- -i sft
|
||||
|
||||
# chat with the model over CLI! Leave out the -p to chat interactively
|
||||
# 【功能体验】CLI交互对话:支持命令行与模型交互,-p参数指定预设问题
|
||||
# python -m scripts.chat_cli -p "Why is the sky blue?"
|
||||
|
||||
# even better, chat with your model over a pretty WebUI ChatGPT style
|
||||
# 【功能体验】WebUI交互对话:提供ChatGPT风格的可视化界面,更友好的交互体验
|
||||
# python -m scripts.chat_web
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# 【模块划分】强化学习模块(可选):针对GSM8K任务优化模型
|
||||
# Reinforcement Learning. Optional, and currently only on GSM8K
|
||||
# (optional)
|
||||
|
||||
# run reinforcement learning
|
||||
# 【可选逻辑】执行强化学习:针对GSM8K数学任务优化模型(可选步骤)
|
||||
# torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.chat_rl -- --run=$WANDB_RUN
|
||||
# eval the RL model only on GSM8K
|
||||
# 【可选验证】评估强化学习效果:仅验证GSM8K任务上的性能提升
|
||||
# torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.chat_eval -- -i rl -a GSM8K
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# 【模块划分】报告生成模块:整合所有训练阶段结果,生成完整报告
|
||||
# Generate the full report by putting together all the sections
|
||||
# report.md is the output and will be copied to current directory for convenience
|
||||
# 【结果汇总】生成完整报告:整合各阶段日志/评估结果,输出report.md到当前目录
|
||||
python -m nanochat.report generate
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user