mirror of
https://github.com/karpathy/nanochat.git
synced 2026-03-31 00:55:18 +00:00
update BASE_URL in dataset.py to new modelscope link for data access
This commit is contained in:
parent
c75fe54aa7
commit
c77fbb010b
172
T4_TRAINING_README.md
Normal file
172
T4_TRAINING_README.md
Normal file
|
|
@ -0,0 +1,172 @@
|
|||
# T4 GPU 训练指南
|
||||
|
||||
本文档说明如何在4个T4 GPU的服务器上运行nanochat训练。
|
||||
|
||||
## 硬件要求
|
||||
|
||||
- 4个NVIDIA T4 GPU (每个16GB显存)
|
||||
- 至少64GB系统内存
|
||||
- 足够的存储空间用于数据集和模型检查点
|
||||
|
||||
## 文件说明
|
||||
|
||||
### 训练脚本
|
||||
|
||||
1. **`scripts/t4_train.py`** - 针对T4优化的基础模型训练脚本
|
||||
2. **`scripts/t4_mid_train.py`** - 针对T4优化的中期训练脚本
|
||||
3. **`scripts/t4_chat_sft.py`** - 针对T4优化的SFT训练脚本
|
||||
|
||||
### 启动脚本
|
||||
|
||||
1. **`run_t4_training.sh`** - 完整的T4训练流程
|
||||
2. **`run_t4_quick_test.sh`** - 快速测试脚本,用于验证配置
|
||||
|
||||
## T4优化配置
|
||||
|
||||
### 基础模型训练 (t4_train.py)
|
||||
|
||||
- **模型深度**: 12层 (原来20-32层)
|
||||
- **序列长度**: 1024 (原来2048)
|
||||
- **设备批次大小**: 4 (原来32)
|
||||
- **总批次大小**: 131,072 tokens (原来524,288)
|
||||
- **评估频率**: 每100步 (原来250步)
|
||||
|
||||
### 中期训练 (t4_mid_train.py)
|
||||
|
||||
- **设备批次大小**: 2 (原来32)
|
||||
- **总批次大小**: 65,536 tokens (原来524,288)
|
||||
- **评估频率**: 每75步 (原来150步)
|
||||
|
||||
### SFT训练 (t4_chat_sft.py)
|
||||
|
||||
- **设备批次大小**: 1 (原来4)
|
||||
- **目标样本数**: 8 (原来32)
|
||||
- **评估频率**: 每50步 (原来100步)
|
||||
|
||||
## 使用方法
|
||||
|
||||
### 1. 快速测试
|
||||
|
||||
首先运行快速测试来验证配置:
|
||||
|
||||
```bash
|
||||
./run_t4_quick_test.sh
|
||||
```
|
||||
|
||||
这将运行:
|
||||
- 深度8的模型
|
||||
- 100步基础训练
|
||||
- 50步中期训练
|
||||
- 20步SFT训练
|
||||
|
||||
### 2. 完整训练
|
||||
|
||||
如果快速测试成功,运行完整训练:
|
||||
|
||||
```bash
|
||||
./run_t4_training.sh
|
||||
```
|
||||
|
||||
这将运行:
|
||||
- 深度12的模型
|
||||
- 完整的基础训练
|
||||
- 完整的中期训练
|
||||
- 完整的SFT训练
|
||||
|
||||
### 3. 单独运行训练步骤
|
||||
|
||||
你也可以单独运行每个训练步骤:
|
||||
|
||||
```bash
|
||||
# 基础训练
|
||||
torchrun --standalone --nproc_per_node=4 -m scripts.t4_train
|
||||
|
||||
# 中期训练
|
||||
torchrun --standalone --nproc_per_node=4 -m scripts.t4_mid_train
|
||||
|
||||
# SFT训练
|
||||
torchrun --standalone --nproc_per_node=4 -m scripts.t4_chat_sft
|
||||
```
|
||||
|
||||
## 参数调整
|
||||
|
||||
如果遇到显存不足的问题,可以进一步调整参数:
|
||||
|
||||
### 减少批次大小
|
||||
```bash
|
||||
torchrun --standalone --nproc_per_node=4 -m scripts.t4_train -- --device_batch_size=2
|
||||
```
|
||||
|
||||
### 减少模型深度
|
||||
```bash
|
||||
torchrun --standalone --nproc_per_node=4 -m scripts.t4_train -- --depth=8
|
||||
```
|
||||
|
||||
### 减少序列长度
|
||||
```bash
|
||||
torchrun --standalone --nproc_per_node=4 -m scripts.t4_train -- --max_seq_len=512
|
||||
```
|
||||
|
||||
## 监控训练
|
||||
|
||||
### 查看GPU状态
|
||||
```bash
|
||||
nvidia-smi
|
||||
```
|
||||
|
||||
### 查看训练日志
|
||||
训练过程中会显示:
|
||||
- 损失值
|
||||
- 学习率
|
||||
- 每步时间
|
||||
- 吞吐量 (tokens/sec)
|
||||
- 模型利用率 (MFU)
|
||||
|
||||
### 查看Wandb日志
|
||||
如果设置了Wandb,可以在Wandb界面查看详细的训练指标。
|
||||
|
||||
## 故障排除
|
||||
|
||||
### 显存不足 (OOM)
|
||||
1. 减少 `device_batch_size`
|
||||
2. 减少 `max_seq_len`
|
||||
3. 减少 `depth`
|
||||
|
||||
### 训练速度慢
|
||||
1. 检查GPU利用率 (`nvidia-smi`)
|
||||
2. 确保数据加载不是瓶颈
|
||||
3. 考虑减少 `eval_every` 频率
|
||||
|
||||
### 分布式训练问题
|
||||
1. 确保所有4个GPU都可用
|
||||
2. 检查网络连接
|
||||
3. 确保端口没有被占用
|
||||
|
||||
## 预期性能
|
||||
|
||||
在4个T4 GPU上:
|
||||
- **基础训练**: 约2-4小时 (取决于模型大小)
|
||||
- **中期训练**: 约1-2小时
|
||||
- **SFT训练**: 约30分钟-1小时
|
||||
|
||||
## 输出文件
|
||||
|
||||
训练完成后,模型检查点将保存在:
|
||||
- 基础模型: `~/.cache/nanochat/t4_checkpoints/`
|
||||
- 中期模型: `~/.cache/nanochat/t4_mid_checkpoints/`
|
||||
- SFT模型: `~/.cache/nanochat/t4_chatsft_checkpoints/`
|
||||
|
||||
## 注意事项
|
||||
|
||||
1. T4 GPU的显存限制意味着需要使用较小的批次大小
|
||||
2. 模型深度和序列长度需要相应调整
|
||||
3. 训练时间会比H100等高端GPU更长
|
||||
4. 建议先运行快速测试验证配置
|
||||
|
||||
## 支持
|
||||
|
||||
如果遇到问题,请检查:
|
||||
1. GPU驱动和CUDA版本
|
||||
2. PyTorch版本兼容性
|
||||
3. 显存使用情况
|
||||
4. 训练日志中的错误信息
|
||||
|
|
@ -20,7 +20,7 @@ from nanochat.common import get_base_dir
|
|||
# The specifics of the current pretraining dataset
|
||||
|
||||
# The URL on the internet where the data is hosted and downloaded from on demand
|
||||
BASE_URL = "https://huggingface.co/datasets/karpathy/fineweb-edu-100b-shuffle/resolve/main"
|
||||
BASE_URL = "https://www.modelscope.cn/datasets/Thackeray/karpathy-fineweb-edu-100b-shuffle-240shard/resolve/master"
|
||||
MAX_SHARD = 1822 # the last datashard is shard_01822.parquet
|
||||
index_to_filename = lambda index: f"shard_{index:05d}.parquet" # format of the filenames
|
||||
base_dir = get_base_dir()
|
||||
|
|
|
|||
99
run_t4_quick_test.sh
Normal file
99
run_t4_quick_test.sh
Normal file
|
|
@ -0,0 +1,99 @@
|
|||
#!/bin/bash
|
||||
|
||||
# T4 GPU快速测试脚本
|
||||
# 用于验证T4配置是否正常工作,运行较少的训练步数
|
||||
|
||||
set -e
|
||||
|
||||
echo "🧪 开始T4 GPU快速测试..."
|
||||
|
||||
# 环境设置
|
||||
export OMP_NUM_THREADS=1
|
||||
export NANOCHAT_BASE_DIR="$HOME/.cache/nanochat"
|
||||
mkdir -p $NANOCHAT_BASE_DIR
|
||||
|
||||
# 检查并安装uv
|
||||
command -v uv &> /dev/null || curl -LsSf https://astral.sh/uv/install.sh | sh
|
||||
|
||||
# 设置虚拟环境
|
||||
[ -d ".venv" ] || uv venv
|
||||
uv sync --extra gpu
|
||||
source .venv/bin/activate
|
||||
|
||||
# 设置wandb运行名称
|
||||
if [ -z "$WANDB_RUN" ]; then
|
||||
WANDB_RUN="t4_quick_test_$(date +%Y%m%d_%H%M%S)"
|
||||
fi
|
||||
echo "📊 Wandb运行名称: $WANDB_RUN"
|
||||
|
||||
# 重置报告
|
||||
python -m nanochat.report reset
|
||||
|
||||
# 安装Rust和编译tokenizer
|
||||
curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y
|
||||
source "$HOME/.cargo/env"
|
||||
uv run maturin develop --release --manifest-path rustbpe/Cargo.toml
|
||||
|
||||
# 下载评估数据
|
||||
EVAL_BUNDLE_URL=https://karpathy-public.s3.us-west-2.amazonaws.com/eval_bundle.zip
|
||||
if [ ! -d "$NANOCHAT_BASE_DIR/eval_bundle" ]; then
|
||||
echo "📥 下载评估数据包..."
|
||||
curl -L -o eval_bundle.zip $EVAL_BUNDLE_URL
|
||||
unzip -q eval_bundle.zip
|
||||
rm eval_bundle.zip
|
||||
mv eval_bundle $NANOCHAT_BASE_DIR
|
||||
fi
|
||||
|
||||
# 下载身份对话数据
|
||||
curl -L -o $NANOCHAT_BASE_DIR/identity_conversations.jsonl https://karpathy-public.s3.us-west-2.amazonaws.com/identity_conversations.jsonl
|
||||
|
||||
echo "📊 开始数据准备..."
|
||||
|
||||
# 训练tokenizer - 使用最少的数据
|
||||
echo "🔤 训练tokenizer..."
|
||||
python -m nanochat.dataset -n 2 # 最少数据量
|
||||
python -m scripts.tok_train --max_chars=100000000 # 最少字符数
|
||||
python -m scripts.tok_eval
|
||||
|
||||
echo "🏋️ 开始基础模型训练 (快速测试)..."
|
||||
|
||||
# 基础模型训练 - 快速测试版本
|
||||
echo "📈 运行基础训练 (深度8, 批次大小2, 100步)..."
|
||||
torchrun --standalone --nproc_per_node=4 -m scripts.t4_train -- --run=$WANDB_RUN --depth=8 --device_batch_size=2 --num_iterations=100
|
||||
|
||||
echo "📊 运行基础损失评估..."
|
||||
torchrun --standalone --nproc_per_node=4 -m scripts.base_loss
|
||||
|
||||
echo "📊 运行基础模型评估..."
|
||||
torchrun --standalone --nproc_per_node=4 -m scripts.base_eval
|
||||
|
||||
echo "🎯 开始中期训练 (快速测试)..."
|
||||
|
||||
# 中期训练 - 快速测试版本
|
||||
echo "📈 运行中期训练 (批次大小1, 50步)..."
|
||||
torchrun --standalone --nproc_per_node=4 -m scripts.t4_mid_train -- --run=$WANDB_RUN --device_batch_size=1 --num_iterations=50
|
||||
|
||||
echo "📊 运行中期训练评估..."
|
||||
torchrun --standalone --nproc_per_node=4 -m scripts.chat_eval -- -i mid
|
||||
|
||||
echo "💬 开始SFT训练 (快速测试)..."
|
||||
|
||||
# SFT训练 - 快速测试版本
|
||||
echo "📈 运行SFT训练 (批次大小1, 20步)..."
|
||||
torchrun --standalone --nproc_per_node=4 -m scripts.t4_chat_sft -- --run=$WANDB_RUN --device_batch_size=1 --num_iterations=20
|
||||
|
||||
echo "📊 运行SFT评估..."
|
||||
torchrun --standalone --nproc_per_node=4 -m scripts.chat_eval -- -i sft
|
||||
|
||||
echo "📋 生成最终报告..."
|
||||
python -m nanochat.report generate
|
||||
|
||||
echo "🎉 T4快速测试完成!"
|
||||
echo "📊 查看报告: python -m nanochat.report show"
|
||||
echo "💬 启动聊天界面: python -m scripts.chat_web"
|
||||
|
||||
# 显示GPU使用情况
|
||||
echo "🔍 当前GPU状态:"
|
||||
nvidia-smi --query-gpu=index,name,memory.used,memory.total,utilization.gpu --format=csv,noheader,nounits
|
||||
|
||||
echo "✅ 快速测试已完成!"
|
||||
99
run_t4_training.sh
Normal file
99
run_t4_training.sh
Normal file
|
|
@ -0,0 +1,99 @@
|
|||
#!/bin/bash
|
||||
|
||||
# 针对4个T4 GPU的完整训练流程脚本
|
||||
# 基于run1000.sh修改,专门为T4 GPU的16GB显存限制进行优化
|
||||
|
||||
set -e # 遇到错误时退出
|
||||
|
||||
echo "🚀 开始T4 GPU训练流程..."
|
||||
|
||||
# 环境设置
|
||||
export OMP_NUM_THREADS=1
|
||||
export NANOCHAT_BASE_DIR=".cache/nanochat"
|
||||
mkdir -p $NANOCHAT_BASE_DIR
|
||||
|
||||
# 检查并安装uv
|
||||
command -v uv &> /dev/null || curl -LsSf https://astral.sh/uv/install.sh | sh
|
||||
|
||||
# 设置虚拟环境
|
||||
[ -d ".venv" ] || uv venv
|
||||
uv sync --extra gpu
|
||||
source .venv/bin/activate
|
||||
|
||||
# 设置wandb运行名称
|
||||
if [ -z "$WANDB_RUN" ]; then
|
||||
WANDB_RUN="t4_training_$(date +%Y%m%d_%H%M%S)"
|
||||
fi
|
||||
echo "📊 Wandb运行名称: $WANDB_RUN"
|
||||
|
||||
# 重置报告
|
||||
python -m nanochat.report reset
|
||||
|
||||
# 安装Rust和编译tokenizer
|
||||
curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y
|
||||
source "$HOME/.cargo/env"
|
||||
uv run maturin develop --release --manifest-path rustbpe/Cargo.toml
|
||||
|
||||
# 下载评估数据
|
||||
EVAL_BUNDLE_URL=https://karpathy-public.s3.us-west-2.amazonaws.com/eval_bundle.zip
|
||||
if [ ! -d "$NANOCHAT_BASE_DIR/eval_bundle" ]; then
|
||||
echo "📥 下载评估数据包..."
|
||||
curl -L -o eval_bundle.zip $EVAL_BUNDLE_URL
|
||||
unzip -q eval_bundle.zip
|
||||
rm eval_bundle.zip
|
||||
mv eval_bundle $NANOCHAT_BASE_DIR
|
||||
fi
|
||||
|
||||
# 下载身份对话数据
|
||||
curl -L -o $NANOCHAT_BASE_DIR/identity_conversations.jsonl https://karpathy-public.s3.us-west-2.amazonaws.com/identity_conversations.jsonl
|
||||
|
||||
echo "📊 开始数据准备..."
|
||||
|
||||
# 训练tokenizer - 使用较少的数据以适应T4
|
||||
echo "🔤 训练tokenizer..."
|
||||
python -m nanochat.dataset -n 8 # 减少数据量
|
||||
python -m scripts.tok_train --max_chars=2000000000 # 减少字符数
|
||||
python -m scripts.tok_eval
|
||||
|
||||
echo "🏋️ 开始基础模型训练..."
|
||||
|
||||
# 基础模型训练 - 针对T4优化
|
||||
echo "📈 运行基础训练 (深度12, 批次大小4)..."
|
||||
torchrun --standalone --nproc_per_node=4 -m scripts.t4_train -- --run=$WANDB_RUN
|
||||
|
||||
echo "📊 运行基础损失评估..."
|
||||
torchrun --standalone --nproc_per_node=4 -m scripts.base_loss
|
||||
|
||||
echo "📊 运行基础模型评估..."
|
||||
torchrun --standalone --nproc_per_node=4 -m scripts.base_eval
|
||||
|
||||
echo "🎯 开始中期训练..."
|
||||
|
||||
# 中期训练 - 针对T4优化
|
||||
echo "📈 运行中期训练 (批次大小2)..."
|
||||
torchrun --standalone --nproc_per_node=4 -m scripts.t4_mid_train -- --run=$WANDB_RUN
|
||||
|
||||
echo "📊 运行中期训练评估..."
|
||||
torchrun --standalone --nproc_per_node=4 -m scripts.chat_eval -- -i mid
|
||||
|
||||
echo "💬 开始SFT训练..."
|
||||
|
||||
# SFT训练 - 针对T4优化
|
||||
echo "📈 运行SFT训练 (批次大小1)..."
|
||||
torchrun --standalone --nproc_per_node=4 -m scripts.t4_chat_sft -- --run=$WANDB_RUN
|
||||
|
||||
echo "📊 运行SFT评估..."
|
||||
torchrun --standalone --nproc_per_node=4 -m scripts.chat_eval -- -i sft
|
||||
|
||||
echo "📋 生成最终报告..."
|
||||
python -m nanochat.report generate
|
||||
|
||||
echo "🎉 T4训练流程完成!"
|
||||
echo "📊 查看报告: python -m nanochat.report show"
|
||||
echo "💬 启动聊天界面: python -m scripts.chat_web"
|
||||
|
||||
# 显示GPU使用情况
|
||||
echo "🔍 当前GPU状态:"
|
||||
nvidia-smi --query-gpu=index,name,memory.used,memory.total,utilization.gpu --format=csv,noheader,nounits
|
||||
|
||||
echo "✅ 所有训练步骤已完成!"
|
||||
290
scripts/t4_chat_sft.py
Normal file
290
scripts/t4_chat_sft.py
Normal file
|
|
@ -0,0 +1,290 @@
|
|||
"""
|
||||
针对4个T4 GPU优化的SFT脚本
|
||||
基于chat_sft.py修改,专门为T4 GPU的16GB显存限制进行优化
|
||||
|
||||
运行方式:
|
||||
torchrun --standalone --nproc_per_node=4 -m scripts.t4_chat_sft
|
||||
|
||||
或者单GPU调试:
|
||||
python -m scripts.t4_chat_sft
|
||||
"""
|
||||
|
||||
import os
|
||||
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
|
||||
|
||||
import wandb
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from contextlib import nullcontext
|
||||
|
||||
from nanochat.common import compute_init, compute_cleanup, get_base_dir, print0, DummyWandb, autodetect_device_type
|
||||
from nanochat.checkpoint_manager import load_model
|
||||
from nanochat.checkpoint_manager import save_checkpoint
|
||||
from nanochat.engine import Engine
|
||||
from scripts.chat_eval import run_chat_eval
|
||||
|
||||
from tasks.common import TaskMixture
|
||||
from tasks.arc import ARC
|
||||
from tasks.gsm8k import GSM8K
|
||||
from tasks.smoltalk import SmolTalk
|
||||
from tasks.customjson import CustomJSON
|
||||
from tasks.spellingbee import SimpleSpelling, SpellingBee
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# T4 GPU优化的SFT配置
|
||||
run = "t4_sft" # wandb run name
|
||||
# input model options
|
||||
source = "mid" # base|mid , which checkpoint to load the model from
|
||||
model_tag = None # model tag to load the model from
|
||||
step = None # step to load the model from
|
||||
# compute/precision
|
||||
device_type = "" # cuda|cpu|mps (empty => autodetect)
|
||||
dtype = "bfloat16"
|
||||
device_batch_size = 1 # 进一步减少批次大小以适应SFT的更大模型
|
||||
# optimization
|
||||
num_epochs = 1
|
||||
num_iterations = -1 # override number of iterations (-1 = disable, use num_epochs to derive it)
|
||||
target_examples_per_step = 8 # 减少目标样本数 (原来32)
|
||||
unembedding_lr = 0.004
|
||||
embedding_lr = 0.2
|
||||
matrix_lr = 0.02
|
||||
weight_decay = 0.0
|
||||
init_lr_frac = 0.02
|
||||
# evaluation and logging
|
||||
eval_every = 50 # 更频繁的评估 (原来100)
|
||||
eval_steps = 50 # 减少评估步数 (原来100)
|
||||
eval_metrics_every = 100 # 更频繁的指标评估 (原来200)
|
||||
eval_metrics_max_problems = 512 # 减少最大问题数 (原来1024)
|
||||
# now allow CLI to override the settings via the configurator
|
||||
config_keys = [k for k,v in globals().items() if not k.startswith('_') and isinstance(v, (int, float, bool, str))]
|
||||
exec(open(os.path.join('nanochat', 'configurator.py')).read()) # overrides from command line or config file
|
||||
user_config = {k: globals()[k] for k in config_keys} # possibly useful for logging
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
# Compute init
|
||||
device_type = autodetect_device_type() if device_type == "" else device_type
|
||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
|
||||
master_process = ddp_rank == 0
|
||||
ptdtype = torch.float32 if dtype == 'float32' else torch.bfloat16
|
||||
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype) if device_type == "cuda" else nullcontext()
|
||||
|
||||
# wandb logging init
|
||||
use_dummy_wandb = run == "dummy" or not master_process
|
||||
wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat-t4-sft", name=run, config=user_config, save_code=True)
|
||||
|
||||
# Load the model and tokenizer
|
||||
model, tokenizer, meta = load_model(source, device, phase="train", model_tag=model_tag, step=step)
|
||||
orig_model = model # original, uncompiled model
|
||||
# model = torch.compile(model, dynamic=True) # doesn't work super well because of variable lengths of inputs
|
||||
engine = Engine(model, tokenizer) # will be used for inline model evaluation only
|
||||
|
||||
print0(f"T4 SFT配置:")
|
||||
print0(f" device_batch_size: {device_batch_size}")
|
||||
print0(f" target_examples_per_step: {target_examples_per_step}")
|
||||
print0(f" DDP world size: {ddp_world_size}")
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Task data mixture we'll train on
|
||||
identity_conversations_filepath = os.path.join(get_base_dir(), "identity_conversations.jsonl")
|
||||
train_ds = TaskMixture([
|
||||
ARC(subset="ARC-Easy", split="train"), # 2.3K rows
|
||||
ARC(subset="ARC-Challenge", split="train"), # 1.1K rows
|
||||
GSM8K(subset="main", split="train"), # 8K rows
|
||||
SmolTalk(split="train", stop=10_000), # 10K rows of smoltalk
|
||||
CustomJSON(filepath=identity_conversations_filepath), # 1K rows of synthetic identity conversations
|
||||
SimpleSpelling(size=300, split="train"), # 300 rows of Simple Spelling (e.g. spell the word 'apple')
|
||||
SpellingBee(size=300, split="train"), # 300 rows of Spelling Bee (e.g. how many 'r' are in 'strawberry'?)
|
||||
]) # 2.3K + 1.1K + 8K + 10K + 1K + 0.3K + 0.3K = 23K rows
|
||||
val_ds = SmolTalk(split="test") # general conversations, 24K rows (though we don't actually use all of it)
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# DataLoader
|
||||
|
||||
def sft_data_generator(dataset, batch_size):
|
||||
pad_token_id = tokenizer.encode_special("<|assistant_end|>") # use <|assistant_end|> as the pad token is ok, these positions are masked in the loss
|
||||
# prepares a list of tokenized conversations into a batch and yields
|
||||
def collate_and_yield(batch):
|
||||
nrows = len(batch)
|
||||
ncols = max(len(ids) for ids, mask in batch) - 1 # seq of n creates inputs/targets of n-1
|
||||
inputs = torch.full((nrows, ncols), pad_token_id, dtype=torch.long)
|
||||
targets = torch.full((nrows, ncols), -1, dtype=torch.long) # -1 is ignore index
|
||||
for i, (ids, mask) in enumerate(batch):
|
||||
n = len(ids)
|
||||
ids_tensor = torch.tensor(ids, dtype=torch.long)
|
||||
inputs[i, :n-1] = ids_tensor[:-1]
|
||||
# recall -1 is the ignore index, so mask out targets where mask is 0
|
||||
row_targets = ids_tensor[1:]
|
||||
# mask[1:] omits the mask for the BOS token, which is never a target atm so it's ok
|
||||
mask_tensor = torch.tensor(mask[1:], dtype=torch.long)
|
||||
row_targets[mask_tensor == 0] = -1 # mask out targets where mask is 0
|
||||
targets[i, :n-1] = row_targets
|
||||
inputs = inputs.to(device) # move to device
|
||||
targets = targets.to(device)
|
||||
return inputs, targets
|
||||
# iterates over the dataset in epochs, tokenizes
|
||||
batch = []
|
||||
while True:
|
||||
for i in range(ddp_rank, len(dataset), ddp_world_size):
|
||||
doc = dataset[i]
|
||||
ids, mask = tokenizer.render_conversation(doc)
|
||||
batch.append((ids, mask))
|
||||
if len(batch) == batch_size:
|
||||
yield collate_and_yield(batch)
|
||||
batch = []
|
||||
|
||||
examples_per_step = device_batch_size * ddp_world_size
|
||||
print0(f"Target examples per step: {target_examples_per_step}")
|
||||
print0(f"Device batch size: {device_batch_size}")
|
||||
print0(f"Examples per step is device_batch_size * ddp_world_size: {examples_per_step}")
|
||||
assert target_examples_per_step % examples_per_step == 0, "Target examples per step must be divisible by examples per step"
|
||||
grad_accum_steps = target_examples_per_step // examples_per_step
|
||||
print0(f"=> Setting grad accum steps: {grad_accum_steps}")
|
||||
|
||||
if num_iterations == -1:
|
||||
# derive num_iterations from num_epochs and the size of the dataset
|
||||
assert num_epochs > 0, "num_epochs must be positive if num_iterations is -1"
|
||||
num_iterations = (len(train_ds) // target_examples_per_step) * num_epochs
|
||||
train_loader = sft_data_generator(train_ds, batch_size=device_batch_size)
|
||||
build_val_loader = lambda: sft_data_generator(val_ds, batch_size=device_batch_size)
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Initialize the Optimizer
|
||||
|
||||
optimizers = model.setup_optimizers(
|
||||
unembedding_lr=unembedding_lr,
|
||||
embedding_lr=embedding_lr,
|
||||
matrix_lr=matrix_lr,
|
||||
weight_decay=weight_decay,
|
||||
)
|
||||
# Set the initial learning rate as a fraction of the base learning rate
|
||||
for opt in optimizers:
|
||||
for group in opt.param_groups:
|
||||
group["lr"] = group["lr"] * init_lr_frac
|
||||
group["initial_lr"] = group["lr"] # save the initial learning so we can decay easily later
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Training loop
|
||||
|
||||
# Learning rate scheduler
|
||||
def get_lr_multiplier(it):
|
||||
lrm = 1.0 - it / num_iterations
|
||||
return lrm
|
||||
|
||||
# Go!
|
||||
step = 0
|
||||
train_iter = iter(train_loader)
|
||||
for step in range(num_iterations):
|
||||
last_step = step == num_iterations - 1
|
||||
|
||||
# evaluate the validation loss
|
||||
if last_step or step % eval_every == 0:
|
||||
model.eval()
|
||||
val_iter = iter(build_val_loader())
|
||||
losses = []
|
||||
for _ in range(eval_steps):
|
||||
val_inputs, val_targets = next(val_iter)
|
||||
with torch.no_grad(), autocast_ctx:
|
||||
loss = model(val_inputs, val_targets)
|
||||
losses.append(loss)
|
||||
val_loss = torch.stack(losses).mean() # average over eval_steps
|
||||
if ddp:
|
||||
dist.all_reduce(val_loss, op=dist.ReduceOp.AVG) # average over ranks
|
||||
val_loss = val_loss.item()
|
||||
print0(f"Step {step:05d} | Validation loss: {val_loss:.6f}")
|
||||
wandb_run.log({
|
||||
"step": step,
|
||||
"val_loss": val_loss,
|
||||
})
|
||||
model.train()
|
||||
|
||||
# evaluate accuracy of the multiple choice tasks (which are quick to run)
|
||||
if last_step or (step > 0 and step % eval_metrics_every == 0):
|
||||
model.eval()
|
||||
metrics = {}
|
||||
with torch.no_grad(), autocast_ctx:
|
||||
# note that because these are inside no_grad, we can usually afford to at least ~2X the batch size
|
||||
metrics["mmlu_acc"] = run_chat_eval("MMLU", model, tokenizer, engine, batch_size=device_batch_size*2, max_problems=eval_metrics_max_problems)
|
||||
metrics["arc_easy_acc"] = run_chat_eval("ARC-Easy", model, tokenizer, engine, batch_size=device_batch_size*2, max_problems=eval_metrics_max_problems)
|
||||
metrics_str = ', '.join(f'{k}: {v:.6f}' for k, v in metrics.items())
|
||||
print0(f"Step {step:05d} | {metrics_str}")
|
||||
wandb_run.log({
|
||||
"step": step,
|
||||
**metrics,
|
||||
})
|
||||
model.train()
|
||||
|
||||
if last_step:
|
||||
break
|
||||
|
||||
# evaluate the gradient
|
||||
num_tokens = torch.tensor(0, device=device) # the number of "active" tokens of supervision seen
|
||||
for micro_step in range(grad_accum_steps):
|
||||
train_inputs, train_targets = next(train_iter)
|
||||
with autocast_ctx:
|
||||
loss = model(train_inputs, train_targets)
|
||||
train_loss = loss.detach() # for logging
|
||||
loss = loss / grad_accum_steps # each .backward() is a grad sum => normalize loss here
|
||||
loss.backward() # accumulate the gradient
|
||||
num_tokens += (train_targets >= 0).sum()
|
||||
if ddp:
|
||||
dist.all_reduce(num_tokens, op=dist.ReduceOp.SUM) # sum over ranks
|
||||
|
||||
# learning rate scheduler
|
||||
lrm = get_lr_multiplier(step)
|
||||
for opt in optimizers:
|
||||
for group in opt.param_groups:
|
||||
group["lr"] = group["initial_lr"] * lrm
|
||||
|
||||
# step the optimizers
|
||||
for opt in optimizers:
|
||||
opt.step()
|
||||
model.zero_grad(set_to_none=True)
|
||||
|
||||
# logging
|
||||
train_loss_item = train_loss.item()
|
||||
num_tokens_item = num_tokens.item()
|
||||
print0(f"Step {step:05d}/{num_iterations:05d} | Training loss: {train_loss_item:.6f}| lrm: {lrm:.6f}| num_tokens: {num_tokens_item:,}")
|
||||
wandb_run.log({
|
||||
"step": step,
|
||||
"lrm": lrm,
|
||||
"train_loss": train_loss_item,
|
||||
"num_tokens": num_tokens_item,
|
||||
})
|
||||
step += 1
|
||||
|
||||
# Save the model at the end of the run
|
||||
if master_process:
|
||||
base_dir = get_base_dir()
|
||||
depth = model.config.n_layer
|
||||
model_tag = f"t4_d{depth}" # base the model tag on the depth of the base model
|
||||
checkpoint_dir = os.path.join(base_dir, "t4_chatsft_checkpoints", model_tag)
|
||||
model_config_kwargs = model.config.__dict__ # slightly naughty, abusing the simplicity of GPTConfig, TODO nicer
|
||||
save_checkpoint(
|
||||
checkpoint_dir,
|
||||
step,
|
||||
model.state_dict(),
|
||||
None, # note: we don't bother to save the optimizer state
|
||||
{
|
||||
"step": step,
|
||||
"val_loss": val_loss,
|
||||
**metrics,
|
||||
"model_config": model_config_kwargs,
|
||||
}
|
||||
)
|
||||
print0(f"✅ T4 SFT完成,模型保存到: {checkpoint_dir}")
|
||||
|
||||
# Log to report
|
||||
from nanochat.report import get_report
|
||||
get_report().log(section="T4 Chat SFT", data=[
|
||||
user_config,
|
||||
{
|
||||
"Training rows": len(train_ds),
|
||||
"Number of iterations": num_iterations,
|
||||
"Training loss": train_loss_item,
|
||||
"Validation loss": val_loss,
|
||||
},
|
||||
])
|
||||
|
||||
# Cleanup
|
||||
wandb_run.finish()
|
||||
compute_cleanup()
|
||||
314
scripts/t4_mid_train.py
Normal file
314
scripts/t4_mid_train.py
Normal file
|
|
@ -0,0 +1,314 @@
|
|||
"""
|
||||
针对4个T4 GPU优化的midtrain脚本
|
||||
基于mid_train.py修改,专门为T4 GPU的16GB显存限制进行优化
|
||||
|
||||
运行方式:
|
||||
torchrun --standalone --nproc_per_node=4 -m scripts.t4_mid_train
|
||||
|
||||
或者单GPU调试:
|
||||
python -m scripts.t4_mid_train
|
||||
"""
|
||||
|
||||
from collections import deque
|
||||
import os
|
||||
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
|
||||
import time
|
||||
import wandb
|
||||
import torch
|
||||
from contextlib import nullcontext
|
||||
from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, get_base_dir, autodetect_device_type
|
||||
from nanochat.tokenizer import get_token_bytes
|
||||
from nanochat.checkpoint_manager import save_checkpoint
|
||||
from nanochat.loss_eval import evaluate_bpb
|
||||
from nanochat.checkpoint_manager import load_model
|
||||
import torch.distributed as dist
|
||||
|
||||
from tasks.common import TaskMixture
|
||||
from tasks.gsm8k import GSM8K
|
||||
from tasks.mmlu import MMLU
|
||||
from tasks.smoltalk import SmolTalk
|
||||
from tasks.customjson import CustomJSON
|
||||
from tasks.spellingbee import SimpleSpelling, SpellingBee
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# T4 GPU优化配置
|
||||
run = "t4_midtraining" # wandb run name
|
||||
device_type = "" # cuda|cpu|mps (empty => autodetect)
|
||||
model_tag = None # model tag to load the model from (base model or midtrained model)
|
||||
step = None # step to load the model from (base model or midtrained model)
|
||||
dtype = "bfloat16"
|
||||
num_iterations = -1 # explicit number of steps of the optimization (-1 = disable)
|
||||
max_seq_len = 1024 # 减少序列长度以节省显存
|
||||
device_batch_size = 2 # 进一步减少批次大小以适应midtrain的更大模型
|
||||
unembedding_lr = 0.004
|
||||
embedding_lr = 0.2
|
||||
matrix_lr = 0.02
|
||||
init_lr_frac = 1.0 # initial learning rate is this fraction of the base learning rate
|
||||
weight_decay = 0.0
|
||||
eval_every = 75 # 更频繁的评估 (原来150)
|
||||
eval_tokens = 5*131072 # 减少评估token数量
|
||||
total_batch_size = 65536 # 减少总批次大小
|
||||
dry_run = 0 # dry_run=1 is for experiments: we will log to wandb but we won't write checkpoints or report
|
||||
config_keys = [k for k,v in globals().items() if not k.startswith('_') and isinstance(v, (int, float, bool, str))]
|
||||
exec(open(os.path.join('nanochat', 'configurator.py')).read()) # overrides from command line or config file
|
||||
user_config = {k: globals()[k] for k in config_keys} # possibly useful for logging
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
# Compute init
|
||||
device_type = autodetect_device_type() if device_type == "" else device_type
|
||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
|
||||
master_process = ddp_rank == 0
|
||||
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type == "cuda" else nullcontext()
|
||||
synchronize = torch.cuda.synchronize if device_type == "cuda" else lambda: None
|
||||
get_max_memory = torch.cuda.max_memory_allocated if device_type == "cuda" else lambda: 0
|
||||
|
||||
# wandb logging init
|
||||
use_dummy_wandb = run == "dummy" or not master_process
|
||||
wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat-t4-mid", name=run, config=user_config)
|
||||
|
||||
# Load the model and tokenizer
|
||||
model, tokenizer, meta = load_model("base", device, phase="train", model_tag=model_tag, step=step)
|
||||
pretrain_batch_size = meta.get("device_batch_size", None)
|
||||
if pretrain_batch_size is not None and device_batch_size > pretrain_batch_size:
|
||||
print0(f"FOOTGUN WARNING: base model training used device_batch_size {pretrain_batch_size}, did you pass in a good --device_batch_size to this script?")
|
||||
orig_model = model
|
||||
model = torch.compile(model, dynamic=False)
|
||||
depth = model.config.n_layer
|
||||
num_flops_per_token = model.estimate_flops()
|
||||
tokens_per_fwdbwd = device_batch_size * max_seq_len # tokens per iteration for a single rank
|
||||
world_tokens_per_fwdbwd = tokens_per_fwdbwd * ddp_world_size # total tokens per iteration for all ranks
|
||||
assert total_batch_size % world_tokens_per_fwdbwd == 0
|
||||
grad_accum_steps = total_batch_size // world_tokens_per_fwdbwd
|
||||
print0(f"T4 Midtrain配置:")
|
||||
print0(f" Tokens / micro-batch / rank: {device_batch_size} x {max_seq_len} = {tokens_per_fwdbwd:,}")
|
||||
print0(f" Tokens / micro-batch: {world_tokens_per_fwdbwd:,}")
|
||||
print0(f" Total batch size {total_batch_size:,} => gradient accumulation steps: {grad_accum_steps}")
|
||||
print0(f" DDP world size: {ddp_world_size}")
|
||||
token_bytes = get_token_bytes(device=device)
|
||||
|
||||
# Initialize the Optimizer (Muon for Linear layers, AdamW for embedding and lm_head)
|
||||
optimizers = model.setup_optimizers(unembedding_lr=unembedding_lr, embedding_lr=embedding_lr, matrix_lr=matrix_lr, weight_decay=weight_decay)
|
||||
adamw_optimizer, muon_optimizer = optimizers
|
||||
# Override the initial learning rate as a fraction of the base learning rate
|
||||
for opt in optimizers:
|
||||
for group in opt.param_groups:
|
||||
group["lr"] = group["lr"] * init_lr_frac
|
||||
group["initial_lr"] = group["lr"] # save the initial learning so we can decay easily later
|
||||
|
||||
# Midtraining data mixture and DataLoader
|
||||
base_dir = get_base_dir()
|
||||
identity_conversations_filepath = os.path.join(base_dir, "identity_conversations.jsonl")
|
||||
train_dataset = TaskMixture([
|
||||
SmolTalk(split="train"), # 460K rows of general conversations
|
||||
MMLU(subset="auxiliary_train", split="train"), # 100K rows of multiple choice problems drawn from ARC, MC_TEST, OBQA, RACE
|
||||
GSM8K(subset="main", split="train"), # 8K rows teaching simple math and (calculator) tool use
|
||||
CustomJSON(filepath=identity_conversations_filepath), # 1000 rows of synthetic identity conversations
|
||||
CustomJSON(filepath=identity_conversations_filepath), # let's do 2 epochs of these
|
||||
SimpleSpelling(size=200000, split="train"), # 200K rows of Simple Spelling (e.g. spell the word 'apple')
|
||||
SpellingBee(size=80000, split="train"), # 80K rows of Spelling Bee (e.g. how many 'r' are in 'strawberry'?)
|
||||
]) # total: 460K + 100K + 8K + 200K + 80K = 848K rows
|
||||
val_dataset = TaskMixture([
|
||||
SmolTalk(split="test"), # 24K rows in test set
|
||||
MMLU(subset="all", split="test", stop=5200), # 14K rows in test set, use only 5.2K to match the train ratios
|
||||
GSM8K(subset="main", split="test", stop=420), # 1.32K rows in test set, use only 420 to match the train ratios
|
||||
]) # total: 24K + 14K + 1.32K ~= 39K rows
|
||||
|
||||
# DataLoader is defined here, it emits inputs, targets : 2D tensors of shape (device_batch_size, max_seq_len)
|
||||
last_step = False # we will toggle this to True when we reach the end of the dataset
|
||||
approx_progress = 0.0 # will go from 0 to 1 over the course of the epoch
|
||||
def mid_data_generator(split):
|
||||
global last_step, approx_progress
|
||||
assert split in {"train", "val"}, "split must be 'train' or 'val'"
|
||||
dataset = train_dataset if split == "train" else val_dataset
|
||||
dataset_size = len(dataset)
|
||||
assert dataset_size > 0
|
||||
needed_tokens = device_batch_size * max_seq_len + 1 # to form one training batch of inputs,targets
|
||||
token_buffer = deque()
|
||||
# CUDA supports memory pinning for faster transfers between CPU and GPU:
|
||||
scratch = torch.empty(needed_tokens, dtype=torch.int64, pin_memory=(device_type == "cuda"))
|
||||
cursor = ddp_rank # increments by ddp_world_size each time, so each rank processes unique documents
|
||||
it = 0 # iteration counter
|
||||
while True:
|
||||
# Accumulate enough tokens for one iteration before yielding
|
||||
while len(token_buffer) < needed_tokens:
|
||||
conversation = dataset[cursor]
|
||||
ids, _ = tokenizer.render_conversation(conversation)
|
||||
token_buffer.extend(ids)
|
||||
cursor += ddp_world_size
|
||||
if cursor >= dataset_size:
|
||||
cursor -= dataset_size # wrap around for another epoch
|
||||
if split == "train":
|
||||
last_step = True # toggle last_step to True, which will terminate the training loop
|
||||
# Stopping condition to respect num_iterations, if given
|
||||
it += 1
|
||||
if num_iterations > 0 and it >= num_iterations:
|
||||
last_step = True # toggle last_step to True, which will terminate the training loop
|
||||
# Build up inputs/targets and yield
|
||||
for i in range(needed_tokens):
|
||||
scratch[i] = token_buffer.popleft()
|
||||
inputs_cpu = scratch[:-1].to(dtype=torch.int32)
|
||||
targets_cpu = scratch[1:]
|
||||
inputs = inputs_cpu.view(device_batch_size, max_seq_len).to(device=device, dtype=torch.int32, non_blocking=True)
|
||||
targets = targets_cpu.view(device_batch_size, max_seq_len).to(device=device, dtype=torch.int64, non_blocking=True)
|
||||
if split == "train":
|
||||
if num_iterations > 0:
|
||||
approx_progress = it / num_iterations # calculate progress from the max number of iterations
|
||||
else:
|
||||
approx_progress = cursor / dataset_size # approximate progress as a fraction of the dataset
|
||||
yield inputs, targets
|
||||
|
||||
train_loader = mid_data_generator("train")
|
||||
build_val_loader = lambda: mid_data_generator("val")
|
||||
progress = 0 # will go from 0 to 1 over the course of the epoch
|
||||
|
||||
# Learning rate scheduler
|
||||
def get_lr_multiplier(progress):
|
||||
# first 80% of training: no decay, then linearly ramp down to 0.
|
||||
return 1 if progress < 0.8 else 1 - (progress - 0.8) / 0.2
|
||||
|
||||
# Momentum scheduler for Muon optimizer
|
||||
def get_muon_momentum(it):
|
||||
frac = min(it / 300, 1)
|
||||
momentum = (1 - frac) * 0.85 + frac * 0.95
|
||||
return momentum
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Training loop
|
||||
x, y = next(train_loader) # prefetch the very first batch of data
|
||||
min_val_bpb = float("inf")
|
||||
smooth_train_loss = 0 # EMA of training loss
|
||||
ema_beta = 0.9 # EMA decay factor
|
||||
total_training_time = 0 # total wall-clock time of training
|
||||
step = 0
|
||||
while True:
|
||||
flops_so_far = num_flops_per_token * total_batch_size * step
|
||||
|
||||
# Synchronize last_step across all ranks to avoid hangs in the distributed setting
|
||||
if ddp:
|
||||
last_step_tensor = torch.tensor(last_step, dtype=torch.int32, device=device)
|
||||
dist.all_reduce(last_step_tensor, op=dist.ReduceOp.MAX)
|
||||
last_step = bool(last_step_tensor.item())
|
||||
|
||||
# once in a while: evaluate the val bpb (all ranks participate)
|
||||
if eval_every > 0 and (last_step or step % eval_every == 0):
|
||||
model.eval()
|
||||
val_loader = build_val_loader()
|
||||
eval_steps = eval_tokens // (device_batch_size * max_seq_len * ddp_world_size)
|
||||
with autocast_ctx:
|
||||
val_bpb = evaluate_bpb(model, val_loader, eval_steps, token_bytes)
|
||||
print0(f"Step {step:05d} | Validation bpb: {val_bpb:.4f}")
|
||||
if val_bpb < min_val_bpb:
|
||||
min_val_bpb = val_bpb
|
||||
wandb_run.log({
|
||||
"step": step,
|
||||
"total_training_flops": flops_so_far,
|
||||
"total_training_time": total_training_time,
|
||||
"val/bpb": val_bpb,
|
||||
})
|
||||
model.train()
|
||||
|
||||
# save checkpoint at the end of the run (only on master process)
|
||||
if master_process and last_step and not dry_run:
|
||||
output_dirname = f"t4_d{depth}"
|
||||
checkpoint_dir = os.path.join(base_dir, "t4_mid_checkpoints", output_dirname)
|
||||
save_checkpoint(
|
||||
checkpoint_dir,
|
||||
step,
|
||||
orig_model.state_dict(),
|
||||
[opt.state_dict() for opt in optimizers],
|
||||
{
|
||||
"step": step,
|
||||
"val_bpb": val_bpb,
|
||||
"model_config": {
|
||||
"sequence_len": max_seq_len,
|
||||
"vocab_size": tokenizer.get_vocab_size(),
|
||||
"n_layer": depth,
|
||||
"n_head": model.config.n_head,
|
||||
"n_kv_head": model.config.n_kv_head,
|
||||
"n_embd": model.config.n_embd,
|
||||
},
|
||||
"user_config": user_config,
|
||||
}
|
||||
)
|
||||
print0(f"✅ T4 Midtrain完成,模型保存到: {checkpoint_dir}")
|
||||
|
||||
if last_step:
|
||||
break
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# single training step
|
||||
synchronize()
|
||||
t0 = time.time()
|
||||
for micro_step in range(grad_accum_steps):
|
||||
with autocast_ctx:
|
||||
loss = model(x, y)
|
||||
train_loss = loss.detach()
|
||||
loss = loss / grad_accum_steps
|
||||
loss.backward()
|
||||
x, y = next(train_loader)
|
||||
progress = max(progress, approx_progress)
|
||||
|
||||
# step the optimizers
|
||||
lrm = get_lr_multiplier(progress)
|
||||
for opt in optimizers:
|
||||
for group in opt.param_groups:
|
||||
group["lr"] = group["initial_lr"] * lrm
|
||||
muon_momentum = get_muon_momentum(step)
|
||||
for group in muon_optimizer.param_groups:
|
||||
group["momentum"] = muon_momentum
|
||||
for opt in optimizers:
|
||||
opt.step()
|
||||
model.zero_grad(set_to_none=True)
|
||||
synchronize()
|
||||
t1 = time.time()
|
||||
dt = t1 - t0
|
||||
|
||||
# State
|
||||
step += 1
|
||||
|
||||
# logging
|
||||
smooth_train_loss = ema_beta * smooth_train_loss + (1 - ema_beta) * train_loss.item()
|
||||
debiased_smooth_loss = smooth_train_loss / (1 - ema_beta**(step + 1))
|
||||
pct_done = 100 * progress
|
||||
tok_per_sec = int(world_tokens_per_fwdbwd / dt)
|
||||
flops_per_sec = num_flops_per_token * total_batch_size / dt
|
||||
# T4的峰值性能约为65 TFLOPS (bfloat16)
|
||||
promised_flops_per_sec_t4 = 65e12 * ddp_world_size
|
||||
mfu = 100 * flops_per_sec / promised_flops_per_sec_t4
|
||||
if step > 10:
|
||||
total_training_time += dt
|
||||
print0(f"step {step:05d} ({pct_done:.2f}%) | loss: {debiased_smooth_loss:.6f} | lrm: {lrm:.2f} | dt: {dt * 1000:.2f}ms | tok/sec: {tok_per_sec:,} | mfu: {mfu:.2f} | total time: {total_training_time/60:.2f}m")
|
||||
if step % 10 == 0:
|
||||
wandb_run.log({
|
||||
"step": step,
|
||||
"total_training_flops": flops_so_far,
|
||||
"total_training_time": total_training_time,
|
||||
"train/loss": debiased_smooth_loss,
|
||||
"train/lrm": lrm,
|
||||
"train/dt": dt,
|
||||
"train/tok_per_sec": tok_per_sec,
|
||||
"train/mfu": mfu,
|
||||
})
|
||||
|
||||
# print final stats
|
||||
print0(f"Peak memory usage: {get_max_memory() / 1024 / 1024:.2f}MiB")
|
||||
print0(f"Total training time: {total_training_time/60:.2f}m")
|
||||
print0(f"Minimum validation bpb: {min_val_bpb:.4f}")
|
||||
|
||||
# Log to report
|
||||
if not dry_run:
|
||||
from nanochat.report import get_report
|
||||
get_report().log(section="T4 Midtraining", data=[
|
||||
user_config,
|
||||
{
|
||||
"Number of iterations": step,
|
||||
"DDP world size": ddp_world_size,
|
||||
},
|
||||
{
|
||||
"Minimum validation bpb": min_val_bpb,
|
||||
}
|
||||
])
|
||||
|
||||
# cleanup
|
||||
wandb_run.finish()
|
||||
compute_cleanup()
|
||||
350
scripts/t4_train.py
Normal file
350
scripts/t4_train.py
Normal file
|
|
@ -0,0 +1,350 @@
|
|||
"""
|
||||
针对4个T4 GPU优化的训练脚本
|
||||
基于base_train.py修改,专门为T4 GPU的16GB显存限制进行优化
|
||||
|
||||
运行方式:
|
||||
torchrun --standalone --nproc_per_node=4 -m scripts.t4_train
|
||||
|
||||
或者单GPU调试:
|
||||
python -m scripts.t4_train
|
||||
"""
|
||||
|
||||
import os
|
||||
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
|
||||
import time
|
||||
from contextlib import nullcontext
|
||||
|
||||
import wandb
|
||||
import torch
|
||||
|
||||
from nanochat.gpt import GPT, GPTConfig
|
||||
from nanochat.dataloader import tokenizing_distributed_data_loader
|
||||
from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, print_banner, get_base_dir, autodetect_device_type
|
||||
from nanochat.tokenizer import get_tokenizer, get_token_bytes
|
||||
from nanochat.checkpoint_manager import save_checkpoint
|
||||
from nanochat.loss_eval import evaluate_bpb
|
||||
from nanochat.engine import Engine
|
||||
from scripts.base_eval import evaluate_model
|
||||
print_banner()
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# T4 GPU优化配置
|
||||
run = "t4_training" # wandb run name
|
||||
# Runtime
|
||||
device_type = "" # cuda|cpu|mps (empty => autodetect)
|
||||
# Model architecture - 针对T4优化
|
||||
depth = 12 # 减少深度以适应T4显存限制 (原来20-32)
|
||||
max_seq_len = 1024 # 减少序列长度以节省显存 (原来2048)
|
||||
# Training horizon
|
||||
num_iterations = -1 # explicit number of steps of the optimization (-1 = disable)
|
||||
target_flops = -1.0 # calculate num_iterations to reach target_flops
|
||||
target_param_data_ratio = 20 # calculate num_iterations to maintain fixed data:param ratio (Chinchilla=20)
|
||||
# Optimization - 针对T4优化
|
||||
device_batch_size = 4 # 大幅减少批次大小以适应T4的16GB显存 (原来32)
|
||||
total_batch_size = 131072 # 减少总批次大小 (原来524288)
|
||||
embedding_lr = 0.2
|
||||
unembedding_lr = 0.004
|
||||
weight_decay = 0.0
|
||||
matrix_lr = 0.02
|
||||
grad_clip = 1.0
|
||||
warmup_ratio = 0.0
|
||||
warmdown_ratio = 0.2
|
||||
final_lr_frac = 0.0
|
||||
# Evaluation - 针对T4优化
|
||||
eval_every = 100 # 更频繁的评估 (原来250)
|
||||
eval_tokens = 5*131072 # 减少评估token数量 (原来20*524288)
|
||||
core_metric_every = 1000 # 更频繁的核心指标评估 (原来2000)
|
||||
core_metric_max_per_task = 250 # 减少每个任务的最大样本数 (原来500)
|
||||
sample_every = 1000 # 更频繁的采样 (原来2000)
|
||||
# Output
|
||||
model_tag = "t4_d12" # T4训练的模型标签
|
||||
# now allow CLI to override the settings via the configurator
|
||||
config_keys = [k for k,v in globals().items() if not k.startswith('_') and isinstance(v, (int, float, bool, str))]
|
||||
exec(open(os.path.join('nanochat', 'configurator.py')).read()) # overrides from command line or config file
|
||||
user_config = {k: globals()[k] for k in config_keys} # will be useful for logging
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
# Compute init
|
||||
device_type = autodetect_device_type() if device_type == "" else device_type
|
||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
|
||||
master_process = ddp_rank == 0 # this process will do logging, checkpointing etc.
|
||||
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type == "cuda" else nullcontext()
|
||||
synchronize = torch.cuda.synchronize if device_type == "cuda" else lambda: None
|
||||
get_max_memory = torch.cuda.max_memory_allocated if device_type == "cuda" else lambda: 0
|
||||
|
||||
# wandb logging init
|
||||
use_dummy_wandb = run == "dummy" or not master_process
|
||||
wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat-t4", name=run, config=user_config)
|
||||
|
||||
# Tokenizer will be useful for evaluation, also we need the vocab size
|
||||
tokenizer = get_tokenizer()
|
||||
token_bytes = get_token_bytes(device=device)
|
||||
vocab_size = tokenizer.get_vocab_size()
|
||||
print0(f"Vocab size: {vocab_size:,}")
|
||||
|
||||
# Model kwargs are derived from the desired depth of the model
|
||||
num_layers = depth
|
||||
model_dim = depth * 64 # aspect ratio 64 (保持与原始配置一致)
|
||||
num_heads = max(1, (model_dim + 127) // 128) # head dim 128
|
||||
num_kv_heads = num_heads # default is 1:1 GQA ratio
|
||||
print0(f"T4优化配置:")
|
||||
print0(f" num_layers: {num_layers}")
|
||||
print0(f" model_dim: {model_dim}")
|
||||
print0(f" num_heads: {num_heads}")
|
||||
print0(f" num_kv_heads: {num_kv_heads}")
|
||||
print0(f" max_seq_len: {max_seq_len}")
|
||||
print0(f" device_batch_size: {device_batch_size}")
|
||||
|
||||
# Optimizer / data / training length related hyperparameters
|
||||
# figure out the needed gradient accumulation to reach the desired total batch size
|
||||
tokens_per_fwdbwd = device_batch_size * max_seq_len # tokens per iteration for a single rank
|
||||
world_tokens_per_fwdbwd = tokens_per_fwdbwd * ddp_world_size # total tokens per iteration for all ranks
|
||||
assert total_batch_size % world_tokens_per_fwdbwd == 0
|
||||
grad_accum_steps = total_batch_size // world_tokens_per_fwdbwd
|
||||
print0(f"Tokens / micro-batch / rank: {device_batch_size} x {max_seq_len} = {tokens_per_fwdbwd:,}")
|
||||
print0(f"Tokens / micro-batch: {world_tokens_per_fwdbwd:,}")
|
||||
print0(f"Total batch size {total_batch_size:,} => gradient accumulation steps: {grad_accum_steps}")
|
||||
print0(f"DDP world size: {ddp_world_size}")
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Initialize the Model
|
||||
model_config_kwargs = dict(sequence_len=max_seq_len, vocab_size=vocab_size, n_layer=num_layers, n_head=num_heads, n_kv_head=num_kv_heads, n_embd=model_dim)
|
||||
with torch.device("meta"):
|
||||
model_config = GPTConfig(**model_config_kwargs)
|
||||
model = GPT(model_config)
|
||||
model.to_empty(device=device)
|
||||
model.init_weights()
|
||||
orig_model = model # original, uncompiled model, for saving raw model state_dict
|
||||
model = torch.compile(model, dynamic=False)
|
||||
num_params = sum(p.numel() for p in model.parameters())
|
||||
print0(f"Number of parameters: {num_params:,}")
|
||||
num_flops_per_token = model.estimate_flops()
|
||||
print0(f"Estimated FLOPs per token: {num_flops_per_token:e}")
|
||||
|
||||
# Calculate number of iterations
|
||||
assert num_iterations > 0 or target_param_data_ratio > 0 or target_flops > 0
|
||||
if num_iterations > 0:
|
||||
print0(f"Using user-provided number of iterations: {num_iterations:,}")
|
||||
elif target_flops > 0:
|
||||
num_iterations = round(target_flops / (num_flops_per_token * total_batch_size))
|
||||
print0(f"Calculated number of iterations from target FLOPs: {num_iterations:,}")
|
||||
elif target_param_data_ratio > 0:
|
||||
target_tokens = target_param_data_ratio * num_params
|
||||
num_iterations = target_tokens // total_batch_size
|
||||
print0(f"Calculated number of iterations from target data:param ratio: {num_iterations:,}")
|
||||
else:
|
||||
raise ValueError("No training horizon specified")
|
||||
total_tokens = total_batch_size * num_iterations
|
||||
print0(f"Total number of training tokens: {total_tokens:,}")
|
||||
print0(f"Tokens : Params ratio: {total_batch_size * num_iterations / num_params:.2f}")
|
||||
print0(f"Total training FLOPs estimate: {num_flops_per_token * total_tokens:e}")
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Initialize the Optimizer (Muon for Linear layers, AdamW for embedding and lm_head)
|
||||
optimizers = model.setup_optimizers(unembedding_lr=unembedding_lr, embedding_lr=embedding_lr, matrix_lr=matrix_lr, weight_decay=weight_decay)
|
||||
adamw_optimizer, muon_optimizer = optimizers
|
||||
|
||||
# Initialize the DataLoaders for train/val
|
||||
base_dir = get_base_dir()
|
||||
tokens_dir = os.path.join(base_dir, "tokenized_data")
|
||||
train_loader = tokenizing_distributed_data_loader(device_batch_size, max_seq_len, split="train", device=device)
|
||||
build_val_loader = lambda: tokenizing_distributed_data_loader(device_batch_size, max_seq_len, split="val", device=device)
|
||||
x, y = next(train_loader) # kick off load of the very first batch of data
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Set up hyperparameter schedulers
|
||||
|
||||
# Learning rate scheduler
|
||||
def get_lr_multiplier(it):
|
||||
warmup_iters = round(warmup_ratio * num_iterations)
|
||||
warmdown_iters = round(warmdown_ratio * num_iterations)
|
||||
if it < warmup_iters:
|
||||
return (it + 1) / warmup_iters
|
||||
elif it <= num_iterations - warmdown_iters:
|
||||
return 1.0
|
||||
else:
|
||||
progress = (num_iterations - it) / warmdown_iters
|
||||
return progress * 1.0 + (1 - progress) * final_lr_frac
|
||||
|
||||
# Momentum scheduler for Muon optimizer
|
||||
def get_muon_momentum(it):
|
||||
frac = min(it / 300, 1)
|
||||
momentum = (1 - frac) * 0.85 + frac * 0.95
|
||||
return momentum
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Training loop
|
||||
min_val_bpb = float("inf")
|
||||
smooth_train_loss = 0 # EMA of training loss
|
||||
ema_beta = 0.9 # EMA decay factor
|
||||
total_training_time = 0 # total wall-clock time of training
|
||||
# note that we run +1 steps only so that we can eval and save at the end
|
||||
for step in range(num_iterations + 1):
|
||||
last_step = step == num_iterations
|
||||
flops_so_far = num_flops_per_token * total_batch_size * step
|
||||
|
||||
# once in a while: evaluate the val bpb (all ranks participate)
|
||||
if last_step or step % eval_every == 0:
|
||||
model.eval()
|
||||
val_loader = build_val_loader()
|
||||
eval_steps = eval_tokens // (device_batch_size * max_seq_len * ddp_world_size)
|
||||
with autocast_ctx:
|
||||
val_bpb = evaluate_bpb(model, val_loader, eval_steps, token_bytes)
|
||||
print0(f"Step {step:05d} | Validation bpb: {val_bpb:.4f}")
|
||||
if val_bpb < min_val_bpb:
|
||||
min_val_bpb = val_bpb
|
||||
wandb_run.log({
|
||||
"step": step,
|
||||
"total_training_flops": flops_so_far,
|
||||
"total_training_time": total_training_time,
|
||||
"val/bpb": val_bpb,
|
||||
})
|
||||
model.train()
|
||||
|
||||
# once in a while: estimate the CORE metric (all ranks participate)
|
||||
results = {}
|
||||
if core_metric_every > 0 and (last_step or (step > 0 and step % core_metric_every == 0)):
|
||||
model.eval()
|
||||
with autocast_ctx:
|
||||
results = evaluate_model(orig_model, tokenizer, device, max_per_task=core_metric_max_per_task)
|
||||
print0(f"Step {step:05d} | CORE metric: {results['core_metric']:.4f}")
|
||||
wandb_run.log({
|
||||
"step": step,
|
||||
"total_training_flops": flops_so_far,
|
||||
"core_metric": results["core_metric"],
|
||||
"centered_results": results["centered_results"],
|
||||
})
|
||||
model.train()
|
||||
|
||||
# once in a while: sample from the model (only on master process)
|
||||
if master_process and (last_step or (step > 0 and step % sample_every == 0)):
|
||||
model.eval()
|
||||
prompts = [
|
||||
"The capital of France is",
|
||||
"The chemical symbol of gold is",
|
||||
"If yesterday was Friday, then tomorrow will be",
|
||||
"The opposite of hot is",
|
||||
"The planets of the solar system are:",
|
||||
"My favorite color is",
|
||||
"If 5*x + 3 = 13, then x is",
|
||||
]
|
||||
engine = Engine(orig_model, tokenizer)
|
||||
for prompt in prompts:
|
||||
tokens = tokenizer(prompt, prepend="<|bos|>")
|
||||
with autocast_ctx:
|
||||
sample, _ = engine.generate_batch(tokens, num_samples=1, max_tokens=16, temperature=0)
|
||||
print0(tokenizer.decode(sample[0]))
|
||||
model.train()
|
||||
|
||||
# save checkpoint at the end of the run (only on master process)
|
||||
if master_process and last_step:
|
||||
output_dirname = model_tag if model_tag else f"t4_d{depth}"
|
||||
checkpoint_dir = os.path.join(base_dir, "t4_checkpoints", output_dirname)
|
||||
save_checkpoint(
|
||||
checkpoint_dir,
|
||||
step,
|
||||
orig_model.state_dict(),
|
||||
[opt.state_dict() for opt in optimizers],
|
||||
{
|
||||
"step": step,
|
||||
"val_bpb": val_bpb,
|
||||
"model_config": model_config_kwargs,
|
||||
"user_config": user_config,
|
||||
"device_batch_size": device_batch_size,
|
||||
"max_seq_len": max_seq_len,
|
||||
}
|
||||
)
|
||||
print0(f"✅ T4训练完成,模型保存到: {checkpoint_dir}")
|
||||
|
||||
if last_step:
|
||||
break
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# single training step
|
||||
synchronize()
|
||||
t0 = time.time()
|
||||
for micro_step in range(grad_accum_steps):
|
||||
with autocast_ctx:
|
||||
loss = model(x, y)
|
||||
train_loss = loss.detach()
|
||||
loss = loss / grad_accum_steps
|
||||
loss.backward()
|
||||
x, y = next(train_loader)
|
||||
|
||||
# gradient clipping
|
||||
if grad_clip > 0.0:
|
||||
torch.nn.utils.clip_grad_norm_(orig_model.parameters(), grad_clip)
|
||||
|
||||
# step the optimizers
|
||||
lrm = get_lr_multiplier(step)
|
||||
for opt in optimizers:
|
||||
for group in opt.param_groups:
|
||||
group["lr"] = group["initial_lr"] * lrm
|
||||
muon_momentum = get_muon_momentum(step)
|
||||
for group in muon_optimizer.param_groups:
|
||||
group["momentum"] = muon_momentum
|
||||
for opt in optimizers:
|
||||
opt.step()
|
||||
model.zero_grad(set_to_none=True)
|
||||
synchronize()
|
||||
t1 = time.time()
|
||||
dt = t1 - t0
|
||||
|
||||
# logging
|
||||
smooth_train_loss = ema_beta * smooth_train_loss + (1 - ema_beta) * train_loss.item()
|
||||
debiased_smooth_loss = smooth_train_loss / (1 - ema_beta**(step + 1))
|
||||
pct_done = 100 * step / num_iterations
|
||||
tok_per_sec = int(world_tokens_per_fwdbwd / dt)
|
||||
flops_per_sec = num_flops_per_token * total_batch_size / dt
|
||||
# T4的峰值性能约为65 TFLOPS (bfloat16)
|
||||
promised_flops_per_sec_t4 = 65e12 * ddp_world_size
|
||||
mfu = 100 * flops_per_sec / promised_flops_per_sec_t4
|
||||
if step > 10:
|
||||
total_training_time += dt
|
||||
print0(f"step {step:05d}/{num_iterations:05d} ({pct_done:.2f}%) | loss: {debiased_smooth_loss:.6f} | lrm: {lrm:.2f} | dt: {dt * 1000:.2f}ms | tok/sec: {tok_per_sec:,} | mfu: {mfu:.2f} | total time: {total_training_time/60:.2f}m")
|
||||
if step % 50 == 0: # 更频繁的wandb日志记录
|
||||
wandb_run.log({
|
||||
"step": step,
|
||||
"total_training_flops": flops_so_far,
|
||||
"total_training_time": total_training_time,
|
||||
"train/loss": debiased_smooth_loss,
|
||||
"train/lrm": lrm,
|
||||
"train/dt": dt,
|
||||
"train/tok_per_sec": tok_per_sec,
|
||||
"train/mfu": mfu,
|
||||
})
|
||||
|
||||
# print final stats
|
||||
print0(f"Peak memory usage: {get_max_memory() / 1024 / 1024:.2f}MiB")
|
||||
print0(f"Total training time: {total_training_time/60:.2f}m")
|
||||
print0(f"Minimum validation bpb: {min_val_bpb:.4f}")
|
||||
|
||||
# Log to report
|
||||
from nanochat.report import get_report
|
||||
get_report().log(section="T4 Base model training", data=[
|
||||
user_config,
|
||||
{
|
||||
"Number of parameters": num_params,
|
||||
"Number of FLOPs per token": f"{num_flops_per_token:e}",
|
||||
"Calculated number of iterations": num_iterations,
|
||||
"Number of training tokens": total_tokens,
|
||||
"Tokens : Params ratio": total_batch_size * num_iterations / num_params,
|
||||
"DDP world size": ddp_world_size,
|
||||
"warmup_ratio": warmup_ratio,
|
||||
"warmdown_ratio": warmdown_ratio,
|
||||
"final_lr_frac": final_lr_frac,
|
||||
},
|
||||
{
|
||||
"Minimum validation bpb": min_val_bpb,
|
||||
"Final validation bpb": val_bpb,
|
||||
"CORE metric estimate": results.get("core_metric", None),
|
||||
"MFU %": f"{mfu:.2f}%",
|
||||
"Total training flops": f"{flops_so_far:e}",
|
||||
"Total training time": f"{total_training_time/60:.2f}m",
|
||||
"Peak memory usage": f"{get_max_memory() / 1024 / 1024:.2f}MiB",
|
||||
}
|
||||
])
|
||||
|
||||
# cleanup
|
||||
wandb_run.finish()
|
||||
compute_cleanup()
|
||||
Loading…
Reference in New Issue
Block a user