update BASE_URL in dataset.py to new modelscope link for data access

This commit is contained in:
z 2025-10-26 17:18:39 +08:00
parent c75fe54aa7
commit c77fbb010b
7 changed files with 1325 additions and 1 deletions

172
T4_TRAINING_README.md Normal file
View File

@ -0,0 +1,172 @@
# T4 GPU 训练指南
本文档说明如何在4个T4 GPU的服务器上运行nanochat训练。
## 硬件要求
- 4个NVIDIA T4 GPU (每个16GB显存)
- 至少64GB系统内存
- 足够的存储空间用于数据集和模型检查点
## 文件说明
### 训练脚本
1. **`scripts/t4_train.py`** - 针对T4优化的基础模型训练脚本
2. **`scripts/t4_mid_train.py`** - 针对T4优化的中期训练脚本
3. **`scripts/t4_chat_sft.py`** - 针对T4优化的SFT训练脚本
### 启动脚本
1. **`run_t4_training.sh`** - 完整的T4训练流程
2. **`run_t4_quick_test.sh`** - 快速测试脚本,用于验证配置
## T4优化配置
### 基础模型训练 (t4_train.py)
- **模型深度**: 12层 (原来20-32层)
- **序列长度**: 1024 (原来2048)
- **设备批次大小**: 4 (原来32)
- **总批次大小**: 131,072 tokens (原来524,288)
- **评估频率**: 每100步 (原来250步)
### 中期训练 (t4_mid_train.py)
- **设备批次大小**: 2 (原来32)
- **总批次大小**: 65,536 tokens (原来524,288)
- **评估频率**: 每75步 (原来150步)
### SFT训练 (t4_chat_sft.py)
- **设备批次大小**: 1 (原来4)
- **目标样本数**: 8 (原来32)
- **评估频率**: 每50步 (原来100步)
## 使用方法
### 1. 快速测试
首先运行快速测试来验证配置:
```bash
./run_t4_quick_test.sh
```
这将运行:
- 深度8的模型
- 100步基础训练
- 50步中期训练
- 20步SFT训练
### 2. 完整训练
如果快速测试成功,运行完整训练:
```bash
./run_t4_training.sh
```
这将运行:
- 深度12的模型
- 完整的基础训练
- 完整的中期训练
- 完整的SFT训练
### 3. 单独运行训练步骤
你也可以单独运行每个训练步骤:
```bash
# 基础训练
torchrun --standalone --nproc_per_node=4 -m scripts.t4_train
# 中期训练
torchrun --standalone --nproc_per_node=4 -m scripts.t4_mid_train
# SFT训练
torchrun --standalone --nproc_per_node=4 -m scripts.t4_chat_sft
```
## 参数调整
如果遇到显存不足的问题,可以进一步调整参数:
### 减少批次大小
```bash
torchrun --standalone --nproc_per_node=4 -m scripts.t4_train -- --device_batch_size=2
```
### 减少模型深度
```bash
torchrun --standalone --nproc_per_node=4 -m scripts.t4_train -- --depth=8
```
### 减少序列长度
```bash
torchrun --standalone --nproc_per_node=4 -m scripts.t4_train -- --max_seq_len=512
```
## 监控训练
### 查看GPU状态
```bash
nvidia-smi
```
### 查看训练日志
训练过程中会显示:
- 损失值
- 学习率
- 每步时间
- 吞吐量 (tokens/sec)
- 模型利用率 (MFU)
### 查看Wandb日志
如果设置了Wandb可以在Wandb界面查看详细的训练指标。
## 故障排除
### 显存不足 (OOM)
1. 减少 `device_batch_size`
2. 减少 `max_seq_len`
3. 减少 `depth`
### 训练速度慢
1. 检查GPU利用率 (`nvidia-smi`)
2. 确保数据加载不是瓶颈
3. 考虑减少 `eval_every` 频率
### 分布式训练问题
1. 确保所有4个GPU都可用
2. 检查网络连接
3. 确保端口没有被占用
## 预期性能
在4个T4 GPU上
- **基础训练**: 约2-4小时 (取决于模型大小)
- **中期训练**: 约1-2小时
- **SFT训练**: 约30分钟-1小时
## 输出文件
训练完成后,模型检查点将保存在:
- 基础模型: `~/.cache/nanochat/t4_checkpoints/`
- 中期模型: `~/.cache/nanochat/t4_mid_checkpoints/`
- SFT模型: `~/.cache/nanochat/t4_chatsft_checkpoints/`
## 注意事项
1. T4 GPU的显存限制意味着需要使用较小的批次大小
2. 模型深度和序列长度需要相应调整
3. 训练时间会比H100等高端GPU更长
4. 建议先运行快速测试验证配置
## 支持
如果遇到问题,请检查:
1. GPU驱动和CUDA版本
2. PyTorch版本兼容性
3. 显存使用情况
4. 训练日志中的错误信息

View File

@ -20,7 +20,7 @@ from nanochat.common import get_base_dir
# The specifics of the current pretraining dataset
# The URL on the internet where the data is hosted and downloaded from on demand
BASE_URL = "https://huggingface.co/datasets/karpathy/fineweb-edu-100b-shuffle/resolve/main"
BASE_URL = "https://www.modelscope.cn/datasets/Thackeray/karpathy-fineweb-edu-100b-shuffle-240shard/resolve/master"
MAX_SHARD = 1822 # the last datashard is shard_01822.parquet
index_to_filename = lambda index: f"shard_{index:05d}.parquet" # format of the filenames
base_dir = get_base_dir()

99
run_t4_quick_test.sh Normal file
View File

@ -0,0 +1,99 @@
#!/bin/bash
# T4 GPU快速测试脚本
# 用于验证T4配置是否正常工作运行较少的训练步数
set -e
echo "🧪 开始T4 GPU快速测试..."
# 环境设置
export OMP_NUM_THREADS=1
export NANOCHAT_BASE_DIR="$HOME/.cache/nanochat"
mkdir -p $NANOCHAT_BASE_DIR
# 检查并安装uv
command -v uv &> /dev/null || curl -LsSf https://astral.sh/uv/install.sh | sh
# 设置虚拟环境
[ -d ".venv" ] || uv venv
uv sync --extra gpu
source .venv/bin/activate
# 设置wandb运行名称
if [ -z "$WANDB_RUN" ]; then
WANDB_RUN="t4_quick_test_$(date +%Y%m%d_%H%M%S)"
fi
echo "📊 Wandb运行名称: $WANDB_RUN"
# 重置报告
python -m nanochat.report reset
# 安装Rust和编译tokenizer
curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y
source "$HOME/.cargo/env"
uv run maturin develop --release --manifest-path rustbpe/Cargo.toml
# 下载评估数据
EVAL_BUNDLE_URL=https://karpathy-public.s3.us-west-2.amazonaws.com/eval_bundle.zip
if [ ! -d "$NANOCHAT_BASE_DIR/eval_bundle" ]; then
echo "📥 下载评估数据包..."
curl -L -o eval_bundle.zip $EVAL_BUNDLE_URL
unzip -q eval_bundle.zip
rm eval_bundle.zip
mv eval_bundle $NANOCHAT_BASE_DIR
fi
# 下载身份对话数据
curl -L -o $NANOCHAT_BASE_DIR/identity_conversations.jsonl https://karpathy-public.s3.us-west-2.amazonaws.com/identity_conversations.jsonl
echo "📊 开始数据准备..."
# 训练tokenizer - 使用最少的数据
echo "🔤 训练tokenizer..."
python -m nanochat.dataset -n 2 # 最少数据量
python -m scripts.tok_train --max_chars=100000000 # 最少字符数
python -m scripts.tok_eval
echo "🏋️ 开始基础模型训练 (快速测试)..."
# 基础模型训练 - 快速测试版本
echo "📈 运行基础训练 (深度8, 批次大小2, 100步)..."
torchrun --standalone --nproc_per_node=4 -m scripts.t4_train -- --run=$WANDB_RUN --depth=8 --device_batch_size=2 --num_iterations=100
echo "📊 运行基础损失评估..."
torchrun --standalone --nproc_per_node=4 -m scripts.base_loss
echo "📊 运行基础模型评估..."
torchrun --standalone --nproc_per_node=4 -m scripts.base_eval
echo "🎯 开始中期训练 (快速测试)..."
# 中期训练 - 快速测试版本
echo "📈 运行中期训练 (批次大小1, 50步)..."
torchrun --standalone --nproc_per_node=4 -m scripts.t4_mid_train -- --run=$WANDB_RUN --device_batch_size=1 --num_iterations=50
echo "📊 运行中期训练评估..."
torchrun --standalone --nproc_per_node=4 -m scripts.chat_eval -- -i mid
echo "💬 开始SFT训练 (快速测试)..."
# SFT训练 - 快速测试版本
echo "📈 运行SFT训练 (批次大小1, 20步)..."
torchrun --standalone --nproc_per_node=4 -m scripts.t4_chat_sft -- --run=$WANDB_RUN --device_batch_size=1 --num_iterations=20
echo "📊 运行SFT评估..."
torchrun --standalone --nproc_per_node=4 -m scripts.chat_eval -- -i sft
echo "📋 生成最终报告..."
python -m nanochat.report generate
echo "🎉 T4快速测试完成"
echo "📊 查看报告: python -m nanochat.report show"
echo "💬 启动聊天界面: python -m scripts.chat_web"
# 显示GPU使用情况
echo "🔍 当前GPU状态:"
nvidia-smi --query-gpu=index,name,memory.used,memory.total,utilization.gpu --format=csv,noheader,nounits
echo "✅ 快速测试已完成!"

99
run_t4_training.sh Normal file
View File

@ -0,0 +1,99 @@
#!/bin/bash
# 针对4个T4 GPU的完整训练流程脚本
# 基于run1000.sh修改专门为T4 GPU的16GB显存限制进行优化
set -e # 遇到错误时退出
echo "🚀 开始T4 GPU训练流程..."
# 环境设置
export OMP_NUM_THREADS=1
export NANOCHAT_BASE_DIR=".cache/nanochat"
mkdir -p $NANOCHAT_BASE_DIR
# 检查并安装uv
command -v uv &> /dev/null || curl -LsSf https://astral.sh/uv/install.sh | sh
# 设置虚拟环境
[ -d ".venv" ] || uv venv
uv sync --extra gpu
source .venv/bin/activate
# 设置wandb运行名称
if [ -z "$WANDB_RUN" ]; then
WANDB_RUN="t4_training_$(date +%Y%m%d_%H%M%S)"
fi
echo "📊 Wandb运行名称: $WANDB_RUN"
# 重置报告
python -m nanochat.report reset
# 安装Rust和编译tokenizer
curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y
source "$HOME/.cargo/env"
uv run maturin develop --release --manifest-path rustbpe/Cargo.toml
# 下载评估数据
EVAL_BUNDLE_URL=https://karpathy-public.s3.us-west-2.amazonaws.com/eval_bundle.zip
if [ ! -d "$NANOCHAT_BASE_DIR/eval_bundle" ]; then
echo "📥 下载评估数据包..."
curl -L -o eval_bundle.zip $EVAL_BUNDLE_URL
unzip -q eval_bundle.zip
rm eval_bundle.zip
mv eval_bundle $NANOCHAT_BASE_DIR
fi
# 下载身份对话数据
curl -L -o $NANOCHAT_BASE_DIR/identity_conversations.jsonl https://karpathy-public.s3.us-west-2.amazonaws.com/identity_conversations.jsonl
echo "📊 开始数据准备..."
# 训练tokenizer - 使用较少的数据以适应T4
echo "🔤 训练tokenizer..."
python -m nanochat.dataset -n 8 # 减少数据量
python -m scripts.tok_train --max_chars=2000000000 # 减少字符数
python -m scripts.tok_eval
echo "🏋️ 开始基础模型训练..."
# 基础模型训练 - 针对T4优化
echo "📈 运行基础训练 (深度12, 批次大小4)..."
torchrun --standalone --nproc_per_node=4 -m scripts.t4_train -- --run=$WANDB_RUN
echo "📊 运行基础损失评估..."
torchrun --standalone --nproc_per_node=4 -m scripts.base_loss
echo "📊 运行基础模型评估..."
torchrun --standalone --nproc_per_node=4 -m scripts.base_eval
echo "🎯 开始中期训练..."
# 中期训练 - 针对T4优化
echo "📈 运行中期训练 (批次大小2)..."
torchrun --standalone --nproc_per_node=4 -m scripts.t4_mid_train -- --run=$WANDB_RUN
echo "📊 运行中期训练评估..."
torchrun --standalone --nproc_per_node=4 -m scripts.chat_eval -- -i mid
echo "💬 开始SFT训练..."
# SFT训练 - 针对T4优化
echo "📈 运行SFT训练 (批次大小1)..."
torchrun --standalone --nproc_per_node=4 -m scripts.t4_chat_sft -- --run=$WANDB_RUN
echo "📊 运行SFT评估..."
torchrun --standalone --nproc_per_node=4 -m scripts.chat_eval -- -i sft
echo "📋 生成最终报告..."
python -m nanochat.report generate
echo "🎉 T4训练流程完成"
echo "📊 查看报告: python -m nanochat.report show"
echo "💬 启动聊天界面: python -m scripts.chat_web"
# 显示GPU使用情况
echo "🔍 当前GPU状态:"
nvidia-smi --query-gpu=index,name,memory.used,memory.total,utilization.gpu --format=csv,noheader,nounits
echo "✅ 所有训练步骤已完成!"

290
scripts/t4_chat_sft.py Normal file
View File

@ -0,0 +1,290 @@
"""
针对4个T4 GPU优化的SFT脚本
基于chat_sft.py修改专门为T4 GPU的16GB显存限制进行优化
运行方式:
torchrun --standalone --nproc_per_node=4 -m scripts.t4_chat_sft
或者单GPU调试:
python -m scripts.t4_chat_sft
"""
import os
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
import wandb
import torch
import torch.distributed as dist
from contextlib import nullcontext
from nanochat.common import compute_init, compute_cleanup, get_base_dir, print0, DummyWandb, autodetect_device_type
from nanochat.checkpoint_manager import load_model
from nanochat.checkpoint_manager import save_checkpoint
from nanochat.engine import Engine
from scripts.chat_eval import run_chat_eval
from tasks.common import TaskMixture
from tasks.arc import ARC
from tasks.gsm8k import GSM8K
from tasks.smoltalk import SmolTalk
from tasks.customjson import CustomJSON
from tasks.spellingbee import SimpleSpelling, SpellingBee
# -----------------------------------------------------------------------------
# T4 GPU优化的SFT配置
run = "t4_sft" # wandb run name
# input model options
source = "mid" # base|mid , which checkpoint to load the model from
model_tag = None # model tag to load the model from
step = None # step to load the model from
# compute/precision
device_type = "" # cuda|cpu|mps (empty => autodetect)
dtype = "bfloat16"
device_batch_size = 1 # 进一步减少批次大小以适应SFT的更大模型
# optimization
num_epochs = 1
num_iterations = -1 # override number of iterations (-1 = disable, use num_epochs to derive it)
target_examples_per_step = 8 # 减少目标样本数 (原来32)
unembedding_lr = 0.004
embedding_lr = 0.2
matrix_lr = 0.02
weight_decay = 0.0
init_lr_frac = 0.02
# evaluation and logging
eval_every = 50 # 更频繁的评估 (原来100)
eval_steps = 50 # 减少评估步数 (原来100)
eval_metrics_every = 100 # 更频繁的指标评估 (原来200)
eval_metrics_max_problems = 512 # 减少最大问题数 (原来1024)
# now allow CLI to override the settings via the configurator
config_keys = [k for k,v in globals().items() if not k.startswith('_') and isinstance(v, (int, float, bool, str))]
exec(open(os.path.join('nanochat', 'configurator.py')).read()) # overrides from command line or config file
user_config = {k: globals()[k] for k in config_keys} # possibly useful for logging
# -----------------------------------------------------------------------------
# Compute init
device_type = autodetect_device_type() if device_type == "" else device_type
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
master_process = ddp_rank == 0
ptdtype = torch.float32 if dtype == 'float32' else torch.bfloat16
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype) if device_type == "cuda" else nullcontext()
# wandb logging init
use_dummy_wandb = run == "dummy" or not master_process
wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat-t4-sft", name=run, config=user_config, save_code=True)
# Load the model and tokenizer
model, tokenizer, meta = load_model(source, device, phase="train", model_tag=model_tag, step=step)
orig_model = model # original, uncompiled model
# model = torch.compile(model, dynamic=True) # doesn't work super well because of variable lengths of inputs
engine = Engine(model, tokenizer) # will be used for inline model evaluation only
print0(f"T4 SFT配置:")
print0(f" device_batch_size: {device_batch_size}")
print0(f" target_examples_per_step: {target_examples_per_step}")
print0(f" DDP world size: {ddp_world_size}")
# -----------------------------------------------------------------------------
# Task data mixture we'll train on
identity_conversations_filepath = os.path.join(get_base_dir(), "identity_conversations.jsonl")
train_ds = TaskMixture([
ARC(subset="ARC-Easy", split="train"), # 2.3K rows
ARC(subset="ARC-Challenge", split="train"), # 1.1K rows
GSM8K(subset="main", split="train"), # 8K rows
SmolTalk(split="train", stop=10_000), # 10K rows of smoltalk
CustomJSON(filepath=identity_conversations_filepath), # 1K rows of synthetic identity conversations
SimpleSpelling(size=300, split="train"), # 300 rows of Simple Spelling (e.g. spell the word 'apple')
SpellingBee(size=300, split="train"), # 300 rows of Spelling Bee (e.g. how many 'r' are in 'strawberry'?)
]) # 2.3K + 1.1K + 8K + 10K + 1K + 0.3K + 0.3K = 23K rows
val_ds = SmolTalk(split="test") # general conversations, 24K rows (though we don't actually use all of it)
# -----------------------------------------------------------------------------
# DataLoader
def sft_data_generator(dataset, batch_size):
pad_token_id = tokenizer.encode_special("<|assistant_end|>") # use <|assistant_end|> as the pad token is ok, these positions are masked in the loss
# prepares a list of tokenized conversations into a batch and yields
def collate_and_yield(batch):
nrows = len(batch)
ncols = max(len(ids) for ids, mask in batch) - 1 # seq of n creates inputs/targets of n-1
inputs = torch.full((nrows, ncols), pad_token_id, dtype=torch.long)
targets = torch.full((nrows, ncols), -1, dtype=torch.long) # -1 is ignore index
for i, (ids, mask) in enumerate(batch):
n = len(ids)
ids_tensor = torch.tensor(ids, dtype=torch.long)
inputs[i, :n-1] = ids_tensor[:-1]
# recall -1 is the ignore index, so mask out targets where mask is 0
row_targets = ids_tensor[1:]
# mask[1:] omits the mask for the BOS token, which is never a target atm so it's ok
mask_tensor = torch.tensor(mask[1:], dtype=torch.long)
row_targets[mask_tensor == 0] = -1 # mask out targets where mask is 0
targets[i, :n-1] = row_targets
inputs = inputs.to(device) # move to device
targets = targets.to(device)
return inputs, targets
# iterates over the dataset in epochs, tokenizes
batch = []
while True:
for i in range(ddp_rank, len(dataset), ddp_world_size):
doc = dataset[i]
ids, mask = tokenizer.render_conversation(doc)
batch.append((ids, mask))
if len(batch) == batch_size:
yield collate_and_yield(batch)
batch = []
examples_per_step = device_batch_size * ddp_world_size
print0(f"Target examples per step: {target_examples_per_step}")
print0(f"Device batch size: {device_batch_size}")
print0(f"Examples per step is device_batch_size * ddp_world_size: {examples_per_step}")
assert target_examples_per_step % examples_per_step == 0, "Target examples per step must be divisible by examples per step"
grad_accum_steps = target_examples_per_step // examples_per_step
print0(f"=> Setting grad accum steps: {grad_accum_steps}")
if num_iterations == -1:
# derive num_iterations from num_epochs and the size of the dataset
assert num_epochs > 0, "num_epochs must be positive if num_iterations is -1"
num_iterations = (len(train_ds) // target_examples_per_step) * num_epochs
train_loader = sft_data_generator(train_ds, batch_size=device_batch_size)
build_val_loader = lambda: sft_data_generator(val_ds, batch_size=device_batch_size)
# -----------------------------------------------------------------------------
# Initialize the Optimizer
optimizers = model.setup_optimizers(
unembedding_lr=unembedding_lr,
embedding_lr=embedding_lr,
matrix_lr=matrix_lr,
weight_decay=weight_decay,
)
# Set the initial learning rate as a fraction of the base learning rate
for opt in optimizers:
for group in opt.param_groups:
group["lr"] = group["lr"] * init_lr_frac
group["initial_lr"] = group["lr"] # save the initial learning so we can decay easily later
# -----------------------------------------------------------------------------
# Training loop
# Learning rate scheduler
def get_lr_multiplier(it):
lrm = 1.0 - it / num_iterations
return lrm
# Go!
step = 0
train_iter = iter(train_loader)
for step in range(num_iterations):
last_step = step == num_iterations - 1
# evaluate the validation loss
if last_step or step % eval_every == 0:
model.eval()
val_iter = iter(build_val_loader())
losses = []
for _ in range(eval_steps):
val_inputs, val_targets = next(val_iter)
with torch.no_grad(), autocast_ctx:
loss = model(val_inputs, val_targets)
losses.append(loss)
val_loss = torch.stack(losses).mean() # average over eval_steps
if ddp:
dist.all_reduce(val_loss, op=dist.ReduceOp.AVG) # average over ranks
val_loss = val_loss.item()
print0(f"Step {step:05d} | Validation loss: {val_loss:.6f}")
wandb_run.log({
"step": step,
"val_loss": val_loss,
})
model.train()
# evaluate accuracy of the multiple choice tasks (which are quick to run)
if last_step or (step > 0 and step % eval_metrics_every == 0):
model.eval()
metrics = {}
with torch.no_grad(), autocast_ctx:
# note that because these are inside no_grad, we can usually afford to at least ~2X the batch size
metrics["mmlu_acc"] = run_chat_eval("MMLU", model, tokenizer, engine, batch_size=device_batch_size*2, max_problems=eval_metrics_max_problems)
metrics["arc_easy_acc"] = run_chat_eval("ARC-Easy", model, tokenizer, engine, batch_size=device_batch_size*2, max_problems=eval_metrics_max_problems)
metrics_str = ', '.join(f'{k}: {v:.6f}' for k, v in metrics.items())
print0(f"Step {step:05d} | {metrics_str}")
wandb_run.log({
"step": step,
**metrics,
})
model.train()
if last_step:
break
# evaluate the gradient
num_tokens = torch.tensor(0, device=device) # the number of "active" tokens of supervision seen
for micro_step in range(grad_accum_steps):
train_inputs, train_targets = next(train_iter)
with autocast_ctx:
loss = model(train_inputs, train_targets)
train_loss = loss.detach() # for logging
loss = loss / grad_accum_steps # each .backward() is a grad sum => normalize loss here
loss.backward() # accumulate the gradient
num_tokens += (train_targets >= 0).sum()
if ddp:
dist.all_reduce(num_tokens, op=dist.ReduceOp.SUM) # sum over ranks
# learning rate scheduler
lrm = get_lr_multiplier(step)
for opt in optimizers:
for group in opt.param_groups:
group["lr"] = group["initial_lr"] * lrm
# step the optimizers
for opt in optimizers:
opt.step()
model.zero_grad(set_to_none=True)
# logging
train_loss_item = train_loss.item()
num_tokens_item = num_tokens.item()
print0(f"Step {step:05d}/{num_iterations:05d} | Training loss: {train_loss_item:.6f}| lrm: {lrm:.6f}| num_tokens: {num_tokens_item:,}")
wandb_run.log({
"step": step,
"lrm": lrm,
"train_loss": train_loss_item,
"num_tokens": num_tokens_item,
})
step += 1
# Save the model at the end of the run
if master_process:
base_dir = get_base_dir()
depth = model.config.n_layer
model_tag = f"t4_d{depth}" # base the model tag on the depth of the base model
checkpoint_dir = os.path.join(base_dir, "t4_chatsft_checkpoints", model_tag)
model_config_kwargs = model.config.__dict__ # slightly naughty, abusing the simplicity of GPTConfig, TODO nicer
save_checkpoint(
checkpoint_dir,
step,
model.state_dict(),
None, # note: we don't bother to save the optimizer state
{
"step": step,
"val_loss": val_loss,
**metrics,
"model_config": model_config_kwargs,
}
)
print0(f"✅ T4 SFT完成模型保存到: {checkpoint_dir}")
# Log to report
from nanochat.report import get_report
get_report().log(section="T4 Chat SFT", data=[
user_config,
{
"Training rows": len(train_ds),
"Number of iterations": num_iterations,
"Training loss": train_loss_item,
"Validation loss": val_loss,
},
])
# Cleanup
wandb_run.finish()
compute_cleanup()

314
scripts/t4_mid_train.py Normal file
View File

@ -0,0 +1,314 @@
"""
针对4个T4 GPU优化的midtrain脚本
基于mid_train.py修改专门为T4 GPU的16GB显存限制进行优化
运行方式:
torchrun --standalone --nproc_per_node=4 -m scripts.t4_mid_train
或者单GPU调试:
python -m scripts.t4_mid_train
"""
from collections import deque
import os
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
import time
import wandb
import torch
from contextlib import nullcontext
from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, get_base_dir, autodetect_device_type
from nanochat.tokenizer import get_token_bytes
from nanochat.checkpoint_manager import save_checkpoint
from nanochat.loss_eval import evaluate_bpb
from nanochat.checkpoint_manager import load_model
import torch.distributed as dist
from tasks.common import TaskMixture
from tasks.gsm8k import GSM8K
from tasks.mmlu import MMLU
from tasks.smoltalk import SmolTalk
from tasks.customjson import CustomJSON
from tasks.spellingbee import SimpleSpelling, SpellingBee
# -----------------------------------------------------------------------------
# T4 GPU优化配置
run = "t4_midtraining" # wandb run name
device_type = "" # cuda|cpu|mps (empty => autodetect)
model_tag = None # model tag to load the model from (base model or midtrained model)
step = None # step to load the model from (base model or midtrained model)
dtype = "bfloat16"
num_iterations = -1 # explicit number of steps of the optimization (-1 = disable)
max_seq_len = 1024 # 减少序列长度以节省显存
device_batch_size = 2 # 进一步减少批次大小以适应midtrain的更大模型
unembedding_lr = 0.004
embedding_lr = 0.2
matrix_lr = 0.02
init_lr_frac = 1.0 # initial learning rate is this fraction of the base learning rate
weight_decay = 0.0
eval_every = 75 # 更频繁的评估 (原来150)
eval_tokens = 5*131072 # 减少评估token数量
total_batch_size = 65536 # 减少总批次大小
dry_run = 0 # dry_run=1 is for experiments: we will log to wandb but we won't write checkpoints or report
config_keys = [k for k,v in globals().items() if not k.startswith('_') and isinstance(v, (int, float, bool, str))]
exec(open(os.path.join('nanochat', 'configurator.py')).read()) # overrides from command line or config file
user_config = {k: globals()[k] for k in config_keys} # possibly useful for logging
# -----------------------------------------------------------------------------
# Compute init
device_type = autodetect_device_type() if device_type == "" else device_type
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
master_process = ddp_rank == 0
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type == "cuda" else nullcontext()
synchronize = torch.cuda.synchronize if device_type == "cuda" else lambda: None
get_max_memory = torch.cuda.max_memory_allocated if device_type == "cuda" else lambda: 0
# wandb logging init
use_dummy_wandb = run == "dummy" or not master_process
wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat-t4-mid", name=run, config=user_config)
# Load the model and tokenizer
model, tokenizer, meta = load_model("base", device, phase="train", model_tag=model_tag, step=step)
pretrain_batch_size = meta.get("device_batch_size", None)
if pretrain_batch_size is not None and device_batch_size > pretrain_batch_size:
print0(f"FOOTGUN WARNING: base model training used device_batch_size {pretrain_batch_size}, did you pass in a good --device_batch_size to this script?")
orig_model = model
model = torch.compile(model, dynamic=False)
depth = model.config.n_layer
num_flops_per_token = model.estimate_flops()
tokens_per_fwdbwd = device_batch_size * max_seq_len # tokens per iteration for a single rank
world_tokens_per_fwdbwd = tokens_per_fwdbwd * ddp_world_size # total tokens per iteration for all ranks
assert total_batch_size % world_tokens_per_fwdbwd == 0
grad_accum_steps = total_batch_size // world_tokens_per_fwdbwd
print0(f"T4 Midtrain配置:")
print0(f" Tokens / micro-batch / rank: {device_batch_size} x {max_seq_len} = {tokens_per_fwdbwd:,}")
print0(f" Tokens / micro-batch: {world_tokens_per_fwdbwd:,}")
print0(f" Total batch size {total_batch_size:,} => gradient accumulation steps: {grad_accum_steps}")
print0(f" DDP world size: {ddp_world_size}")
token_bytes = get_token_bytes(device=device)
# Initialize the Optimizer (Muon for Linear layers, AdamW for embedding and lm_head)
optimizers = model.setup_optimizers(unembedding_lr=unembedding_lr, embedding_lr=embedding_lr, matrix_lr=matrix_lr, weight_decay=weight_decay)
adamw_optimizer, muon_optimizer = optimizers
# Override the initial learning rate as a fraction of the base learning rate
for opt in optimizers:
for group in opt.param_groups:
group["lr"] = group["lr"] * init_lr_frac
group["initial_lr"] = group["lr"] # save the initial learning so we can decay easily later
# Midtraining data mixture and DataLoader
base_dir = get_base_dir()
identity_conversations_filepath = os.path.join(base_dir, "identity_conversations.jsonl")
train_dataset = TaskMixture([
SmolTalk(split="train"), # 460K rows of general conversations
MMLU(subset="auxiliary_train", split="train"), # 100K rows of multiple choice problems drawn from ARC, MC_TEST, OBQA, RACE
GSM8K(subset="main", split="train"), # 8K rows teaching simple math and (calculator) tool use
CustomJSON(filepath=identity_conversations_filepath), # 1000 rows of synthetic identity conversations
CustomJSON(filepath=identity_conversations_filepath), # let's do 2 epochs of these
SimpleSpelling(size=200000, split="train"), # 200K rows of Simple Spelling (e.g. spell the word 'apple')
SpellingBee(size=80000, split="train"), # 80K rows of Spelling Bee (e.g. how many 'r' are in 'strawberry'?)
]) # total: 460K + 100K + 8K + 200K + 80K = 848K rows
val_dataset = TaskMixture([
SmolTalk(split="test"), # 24K rows in test set
MMLU(subset="all", split="test", stop=5200), # 14K rows in test set, use only 5.2K to match the train ratios
GSM8K(subset="main", split="test", stop=420), # 1.32K rows in test set, use only 420 to match the train ratios
]) # total: 24K + 14K + 1.32K ~= 39K rows
# DataLoader is defined here, it emits inputs, targets : 2D tensors of shape (device_batch_size, max_seq_len)
last_step = False # we will toggle this to True when we reach the end of the dataset
approx_progress = 0.0 # will go from 0 to 1 over the course of the epoch
def mid_data_generator(split):
global last_step, approx_progress
assert split in {"train", "val"}, "split must be 'train' or 'val'"
dataset = train_dataset if split == "train" else val_dataset
dataset_size = len(dataset)
assert dataset_size > 0
needed_tokens = device_batch_size * max_seq_len + 1 # to form one training batch of inputs,targets
token_buffer = deque()
# CUDA supports memory pinning for faster transfers between CPU and GPU:
scratch = torch.empty(needed_tokens, dtype=torch.int64, pin_memory=(device_type == "cuda"))
cursor = ddp_rank # increments by ddp_world_size each time, so each rank processes unique documents
it = 0 # iteration counter
while True:
# Accumulate enough tokens for one iteration before yielding
while len(token_buffer) < needed_tokens:
conversation = dataset[cursor]
ids, _ = tokenizer.render_conversation(conversation)
token_buffer.extend(ids)
cursor += ddp_world_size
if cursor >= dataset_size:
cursor -= dataset_size # wrap around for another epoch
if split == "train":
last_step = True # toggle last_step to True, which will terminate the training loop
# Stopping condition to respect num_iterations, if given
it += 1
if num_iterations > 0 and it >= num_iterations:
last_step = True # toggle last_step to True, which will terminate the training loop
# Build up inputs/targets and yield
for i in range(needed_tokens):
scratch[i] = token_buffer.popleft()
inputs_cpu = scratch[:-1].to(dtype=torch.int32)
targets_cpu = scratch[1:]
inputs = inputs_cpu.view(device_batch_size, max_seq_len).to(device=device, dtype=torch.int32, non_blocking=True)
targets = targets_cpu.view(device_batch_size, max_seq_len).to(device=device, dtype=torch.int64, non_blocking=True)
if split == "train":
if num_iterations > 0:
approx_progress = it / num_iterations # calculate progress from the max number of iterations
else:
approx_progress = cursor / dataset_size # approximate progress as a fraction of the dataset
yield inputs, targets
train_loader = mid_data_generator("train")
build_val_loader = lambda: mid_data_generator("val")
progress = 0 # will go from 0 to 1 over the course of the epoch
# Learning rate scheduler
def get_lr_multiplier(progress):
# first 80% of training: no decay, then linearly ramp down to 0.
return 1 if progress < 0.8 else 1 - (progress - 0.8) / 0.2
# Momentum scheduler for Muon optimizer
def get_muon_momentum(it):
frac = min(it / 300, 1)
momentum = (1 - frac) * 0.85 + frac * 0.95
return momentum
# -----------------------------------------------------------------------------
# Training loop
x, y = next(train_loader) # prefetch the very first batch of data
min_val_bpb = float("inf")
smooth_train_loss = 0 # EMA of training loss
ema_beta = 0.9 # EMA decay factor
total_training_time = 0 # total wall-clock time of training
step = 0
while True:
flops_so_far = num_flops_per_token * total_batch_size * step
# Synchronize last_step across all ranks to avoid hangs in the distributed setting
if ddp:
last_step_tensor = torch.tensor(last_step, dtype=torch.int32, device=device)
dist.all_reduce(last_step_tensor, op=dist.ReduceOp.MAX)
last_step = bool(last_step_tensor.item())
# once in a while: evaluate the val bpb (all ranks participate)
if eval_every > 0 and (last_step or step % eval_every == 0):
model.eval()
val_loader = build_val_loader()
eval_steps = eval_tokens // (device_batch_size * max_seq_len * ddp_world_size)
with autocast_ctx:
val_bpb = evaluate_bpb(model, val_loader, eval_steps, token_bytes)
print0(f"Step {step:05d} | Validation bpb: {val_bpb:.4f}")
if val_bpb < min_val_bpb:
min_val_bpb = val_bpb
wandb_run.log({
"step": step,
"total_training_flops": flops_so_far,
"total_training_time": total_training_time,
"val/bpb": val_bpb,
})
model.train()
# save checkpoint at the end of the run (only on master process)
if master_process and last_step and not dry_run:
output_dirname = f"t4_d{depth}"
checkpoint_dir = os.path.join(base_dir, "t4_mid_checkpoints", output_dirname)
save_checkpoint(
checkpoint_dir,
step,
orig_model.state_dict(),
[opt.state_dict() for opt in optimizers],
{
"step": step,
"val_bpb": val_bpb,
"model_config": {
"sequence_len": max_seq_len,
"vocab_size": tokenizer.get_vocab_size(),
"n_layer": depth,
"n_head": model.config.n_head,
"n_kv_head": model.config.n_kv_head,
"n_embd": model.config.n_embd,
},
"user_config": user_config,
}
)
print0(f"✅ T4 Midtrain完成模型保存到: {checkpoint_dir}")
if last_step:
break
# -------------------------------------------------------------------------
# single training step
synchronize()
t0 = time.time()
for micro_step in range(grad_accum_steps):
with autocast_ctx:
loss = model(x, y)
train_loss = loss.detach()
loss = loss / grad_accum_steps
loss.backward()
x, y = next(train_loader)
progress = max(progress, approx_progress)
# step the optimizers
lrm = get_lr_multiplier(progress)
for opt in optimizers:
for group in opt.param_groups:
group["lr"] = group["initial_lr"] * lrm
muon_momentum = get_muon_momentum(step)
for group in muon_optimizer.param_groups:
group["momentum"] = muon_momentum
for opt in optimizers:
opt.step()
model.zero_grad(set_to_none=True)
synchronize()
t1 = time.time()
dt = t1 - t0
# State
step += 1
# logging
smooth_train_loss = ema_beta * smooth_train_loss + (1 - ema_beta) * train_loss.item()
debiased_smooth_loss = smooth_train_loss / (1 - ema_beta**(step + 1))
pct_done = 100 * progress
tok_per_sec = int(world_tokens_per_fwdbwd / dt)
flops_per_sec = num_flops_per_token * total_batch_size / dt
# T4的峰值性能约为65 TFLOPS (bfloat16)
promised_flops_per_sec_t4 = 65e12 * ddp_world_size
mfu = 100 * flops_per_sec / promised_flops_per_sec_t4
if step > 10:
total_training_time += dt
print0(f"step {step:05d} ({pct_done:.2f}%) | loss: {debiased_smooth_loss:.6f} | lrm: {lrm:.2f} | dt: {dt * 1000:.2f}ms | tok/sec: {tok_per_sec:,} | mfu: {mfu:.2f} | total time: {total_training_time/60:.2f}m")
if step % 10 == 0:
wandb_run.log({
"step": step,
"total_training_flops": flops_so_far,
"total_training_time": total_training_time,
"train/loss": debiased_smooth_loss,
"train/lrm": lrm,
"train/dt": dt,
"train/tok_per_sec": tok_per_sec,
"train/mfu": mfu,
})
# print final stats
print0(f"Peak memory usage: {get_max_memory() / 1024 / 1024:.2f}MiB")
print0(f"Total training time: {total_training_time/60:.2f}m")
print0(f"Minimum validation bpb: {min_val_bpb:.4f}")
# Log to report
if not dry_run:
from nanochat.report import get_report
get_report().log(section="T4 Midtraining", data=[
user_config,
{
"Number of iterations": step,
"DDP world size": ddp_world_size,
},
{
"Minimum validation bpb": min_val_bpb,
}
])
# cleanup
wandb_run.finish()
compute_cleanup()

350
scripts/t4_train.py Normal file
View File

@ -0,0 +1,350 @@
"""
针对4个T4 GPU优化的训练脚本
基于base_train.py修改专门为T4 GPU的16GB显存限制进行优化
运行方式:
torchrun --standalone --nproc_per_node=4 -m scripts.t4_train
或者单GPU调试:
python -m scripts.t4_train
"""
import os
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
import time
from contextlib import nullcontext
import wandb
import torch
from nanochat.gpt import GPT, GPTConfig
from nanochat.dataloader import tokenizing_distributed_data_loader
from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, print_banner, get_base_dir, autodetect_device_type
from nanochat.tokenizer import get_tokenizer, get_token_bytes
from nanochat.checkpoint_manager import save_checkpoint
from nanochat.loss_eval import evaluate_bpb
from nanochat.engine import Engine
from scripts.base_eval import evaluate_model
print_banner()
# -----------------------------------------------------------------------------
# T4 GPU优化配置
run = "t4_training" # wandb run name
# Runtime
device_type = "" # cuda|cpu|mps (empty => autodetect)
# Model architecture - 针对T4优化
depth = 12 # 减少深度以适应T4显存限制 (原来20-32)
max_seq_len = 1024 # 减少序列长度以节省显存 (原来2048)
# Training horizon
num_iterations = -1 # explicit number of steps of the optimization (-1 = disable)
target_flops = -1.0 # calculate num_iterations to reach target_flops
target_param_data_ratio = 20 # calculate num_iterations to maintain fixed data:param ratio (Chinchilla=20)
# Optimization - 针对T4优化
device_batch_size = 4 # 大幅减少批次大小以适应T4的16GB显存 (原来32)
total_batch_size = 131072 # 减少总批次大小 (原来524288)
embedding_lr = 0.2
unembedding_lr = 0.004
weight_decay = 0.0
matrix_lr = 0.02
grad_clip = 1.0
warmup_ratio = 0.0
warmdown_ratio = 0.2
final_lr_frac = 0.0
# Evaluation - 针对T4优化
eval_every = 100 # 更频繁的评估 (原来250)
eval_tokens = 5*131072 # 减少评估token数量 (原来20*524288)
core_metric_every = 1000 # 更频繁的核心指标评估 (原来2000)
core_metric_max_per_task = 250 # 减少每个任务的最大样本数 (原来500)
sample_every = 1000 # 更频繁的采样 (原来2000)
# Output
model_tag = "t4_d12" # T4训练的模型标签
# now allow CLI to override the settings via the configurator
config_keys = [k for k,v in globals().items() if not k.startswith('_') and isinstance(v, (int, float, bool, str))]
exec(open(os.path.join('nanochat', 'configurator.py')).read()) # overrides from command line or config file
user_config = {k: globals()[k] for k in config_keys} # will be useful for logging
# -----------------------------------------------------------------------------
# Compute init
device_type = autodetect_device_type() if device_type == "" else device_type
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
master_process = ddp_rank == 0 # this process will do logging, checkpointing etc.
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type == "cuda" else nullcontext()
synchronize = torch.cuda.synchronize if device_type == "cuda" else lambda: None
get_max_memory = torch.cuda.max_memory_allocated if device_type == "cuda" else lambda: 0
# wandb logging init
use_dummy_wandb = run == "dummy" or not master_process
wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat-t4", name=run, config=user_config)
# Tokenizer will be useful for evaluation, also we need the vocab size
tokenizer = get_tokenizer()
token_bytes = get_token_bytes(device=device)
vocab_size = tokenizer.get_vocab_size()
print0(f"Vocab size: {vocab_size:,}")
# Model kwargs are derived from the desired depth of the model
num_layers = depth
model_dim = depth * 64 # aspect ratio 64 (保持与原始配置一致)
num_heads = max(1, (model_dim + 127) // 128) # head dim 128
num_kv_heads = num_heads # default is 1:1 GQA ratio
print0(f"T4优化配置:")
print0(f" num_layers: {num_layers}")
print0(f" model_dim: {model_dim}")
print0(f" num_heads: {num_heads}")
print0(f" num_kv_heads: {num_kv_heads}")
print0(f" max_seq_len: {max_seq_len}")
print0(f" device_batch_size: {device_batch_size}")
# Optimizer / data / training length related hyperparameters
# figure out the needed gradient accumulation to reach the desired total batch size
tokens_per_fwdbwd = device_batch_size * max_seq_len # tokens per iteration for a single rank
world_tokens_per_fwdbwd = tokens_per_fwdbwd * ddp_world_size # total tokens per iteration for all ranks
assert total_batch_size % world_tokens_per_fwdbwd == 0
grad_accum_steps = total_batch_size // world_tokens_per_fwdbwd
print0(f"Tokens / micro-batch / rank: {device_batch_size} x {max_seq_len} = {tokens_per_fwdbwd:,}")
print0(f"Tokens / micro-batch: {world_tokens_per_fwdbwd:,}")
print0(f"Total batch size {total_batch_size:,} => gradient accumulation steps: {grad_accum_steps}")
print0(f"DDP world size: {ddp_world_size}")
# -----------------------------------------------------------------------------
# Initialize the Model
model_config_kwargs = dict(sequence_len=max_seq_len, vocab_size=vocab_size, n_layer=num_layers, n_head=num_heads, n_kv_head=num_kv_heads, n_embd=model_dim)
with torch.device("meta"):
model_config = GPTConfig(**model_config_kwargs)
model = GPT(model_config)
model.to_empty(device=device)
model.init_weights()
orig_model = model # original, uncompiled model, for saving raw model state_dict
model = torch.compile(model, dynamic=False)
num_params = sum(p.numel() for p in model.parameters())
print0(f"Number of parameters: {num_params:,}")
num_flops_per_token = model.estimate_flops()
print0(f"Estimated FLOPs per token: {num_flops_per_token:e}")
# Calculate number of iterations
assert num_iterations > 0 or target_param_data_ratio > 0 or target_flops > 0
if num_iterations > 0:
print0(f"Using user-provided number of iterations: {num_iterations:,}")
elif target_flops > 0:
num_iterations = round(target_flops / (num_flops_per_token * total_batch_size))
print0(f"Calculated number of iterations from target FLOPs: {num_iterations:,}")
elif target_param_data_ratio > 0:
target_tokens = target_param_data_ratio * num_params
num_iterations = target_tokens // total_batch_size
print0(f"Calculated number of iterations from target data:param ratio: {num_iterations:,}")
else:
raise ValueError("No training horizon specified")
total_tokens = total_batch_size * num_iterations
print0(f"Total number of training tokens: {total_tokens:,}")
print0(f"Tokens : Params ratio: {total_batch_size * num_iterations / num_params:.2f}")
print0(f"Total training FLOPs estimate: {num_flops_per_token * total_tokens:e}")
# -----------------------------------------------------------------------------
# Initialize the Optimizer (Muon for Linear layers, AdamW for embedding and lm_head)
optimizers = model.setup_optimizers(unembedding_lr=unembedding_lr, embedding_lr=embedding_lr, matrix_lr=matrix_lr, weight_decay=weight_decay)
adamw_optimizer, muon_optimizer = optimizers
# Initialize the DataLoaders for train/val
base_dir = get_base_dir()
tokens_dir = os.path.join(base_dir, "tokenized_data")
train_loader = tokenizing_distributed_data_loader(device_batch_size, max_seq_len, split="train", device=device)
build_val_loader = lambda: tokenizing_distributed_data_loader(device_batch_size, max_seq_len, split="val", device=device)
x, y = next(train_loader) # kick off load of the very first batch of data
# -----------------------------------------------------------------------------
# Set up hyperparameter schedulers
# Learning rate scheduler
def get_lr_multiplier(it):
warmup_iters = round(warmup_ratio * num_iterations)
warmdown_iters = round(warmdown_ratio * num_iterations)
if it < warmup_iters:
return (it + 1) / warmup_iters
elif it <= num_iterations - warmdown_iters:
return 1.0
else:
progress = (num_iterations - it) / warmdown_iters
return progress * 1.0 + (1 - progress) * final_lr_frac
# Momentum scheduler for Muon optimizer
def get_muon_momentum(it):
frac = min(it / 300, 1)
momentum = (1 - frac) * 0.85 + frac * 0.95
return momentum
# -----------------------------------------------------------------------------
# Training loop
min_val_bpb = float("inf")
smooth_train_loss = 0 # EMA of training loss
ema_beta = 0.9 # EMA decay factor
total_training_time = 0 # total wall-clock time of training
# note that we run +1 steps only so that we can eval and save at the end
for step in range(num_iterations + 1):
last_step = step == num_iterations
flops_so_far = num_flops_per_token * total_batch_size * step
# once in a while: evaluate the val bpb (all ranks participate)
if last_step or step % eval_every == 0:
model.eval()
val_loader = build_val_loader()
eval_steps = eval_tokens // (device_batch_size * max_seq_len * ddp_world_size)
with autocast_ctx:
val_bpb = evaluate_bpb(model, val_loader, eval_steps, token_bytes)
print0(f"Step {step:05d} | Validation bpb: {val_bpb:.4f}")
if val_bpb < min_val_bpb:
min_val_bpb = val_bpb
wandb_run.log({
"step": step,
"total_training_flops": flops_so_far,
"total_training_time": total_training_time,
"val/bpb": val_bpb,
})
model.train()
# once in a while: estimate the CORE metric (all ranks participate)
results = {}
if core_metric_every > 0 and (last_step or (step > 0 and step % core_metric_every == 0)):
model.eval()
with autocast_ctx:
results = evaluate_model(orig_model, tokenizer, device, max_per_task=core_metric_max_per_task)
print0(f"Step {step:05d} | CORE metric: {results['core_metric']:.4f}")
wandb_run.log({
"step": step,
"total_training_flops": flops_so_far,
"core_metric": results["core_metric"],
"centered_results": results["centered_results"],
})
model.train()
# once in a while: sample from the model (only on master process)
if master_process and (last_step or (step > 0 and step % sample_every == 0)):
model.eval()
prompts = [
"The capital of France is",
"The chemical symbol of gold is",
"If yesterday was Friday, then tomorrow will be",
"The opposite of hot is",
"The planets of the solar system are:",
"My favorite color is",
"If 5*x + 3 = 13, then x is",
]
engine = Engine(orig_model, tokenizer)
for prompt in prompts:
tokens = tokenizer(prompt, prepend="<|bos|>")
with autocast_ctx:
sample, _ = engine.generate_batch(tokens, num_samples=1, max_tokens=16, temperature=0)
print0(tokenizer.decode(sample[0]))
model.train()
# save checkpoint at the end of the run (only on master process)
if master_process and last_step:
output_dirname = model_tag if model_tag else f"t4_d{depth}"
checkpoint_dir = os.path.join(base_dir, "t4_checkpoints", output_dirname)
save_checkpoint(
checkpoint_dir,
step,
orig_model.state_dict(),
[opt.state_dict() for opt in optimizers],
{
"step": step,
"val_bpb": val_bpb,
"model_config": model_config_kwargs,
"user_config": user_config,
"device_batch_size": device_batch_size,
"max_seq_len": max_seq_len,
}
)
print0(f"✅ T4训练完成模型保存到: {checkpoint_dir}")
if last_step:
break
# -------------------------------------------------------------------------
# single training step
synchronize()
t0 = time.time()
for micro_step in range(grad_accum_steps):
with autocast_ctx:
loss = model(x, y)
train_loss = loss.detach()
loss = loss / grad_accum_steps
loss.backward()
x, y = next(train_loader)
# gradient clipping
if grad_clip > 0.0:
torch.nn.utils.clip_grad_norm_(orig_model.parameters(), grad_clip)
# step the optimizers
lrm = get_lr_multiplier(step)
for opt in optimizers:
for group in opt.param_groups:
group["lr"] = group["initial_lr"] * lrm
muon_momentum = get_muon_momentum(step)
for group in muon_optimizer.param_groups:
group["momentum"] = muon_momentum
for opt in optimizers:
opt.step()
model.zero_grad(set_to_none=True)
synchronize()
t1 = time.time()
dt = t1 - t0
# logging
smooth_train_loss = ema_beta * smooth_train_loss + (1 - ema_beta) * train_loss.item()
debiased_smooth_loss = smooth_train_loss / (1 - ema_beta**(step + 1))
pct_done = 100 * step / num_iterations
tok_per_sec = int(world_tokens_per_fwdbwd / dt)
flops_per_sec = num_flops_per_token * total_batch_size / dt
# T4的峰值性能约为65 TFLOPS (bfloat16)
promised_flops_per_sec_t4 = 65e12 * ddp_world_size
mfu = 100 * flops_per_sec / promised_flops_per_sec_t4
if step > 10:
total_training_time += dt
print0(f"step {step:05d}/{num_iterations:05d} ({pct_done:.2f}%) | loss: {debiased_smooth_loss:.6f} | lrm: {lrm:.2f} | dt: {dt * 1000:.2f}ms | tok/sec: {tok_per_sec:,} | mfu: {mfu:.2f} | total time: {total_training_time/60:.2f}m")
if step % 50 == 0: # 更频繁的wandb日志记录
wandb_run.log({
"step": step,
"total_training_flops": flops_so_far,
"total_training_time": total_training_time,
"train/loss": debiased_smooth_loss,
"train/lrm": lrm,
"train/dt": dt,
"train/tok_per_sec": tok_per_sec,
"train/mfu": mfu,
})
# print final stats
print0(f"Peak memory usage: {get_max_memory() / 1024 / 1024:.2f}MiB")
print0(f"Total training time: {total_training_time/60:.2f}m")
print0(f"Minimum validation bpb: {min_val_bpb:.4f}")
# Log to report
from nanochat.report import get_report
get_report().log(section="T4 Base model training", data=[
user_config,
{
"Number of parameters": num_params,
"Number of FLOPs per token": f"{num_flops_per_token:e}",
"Calculated number of iterations": num_iterations,
"Number of training tokens": total_tokens,
"Tokens : Params ratio": total_batch_size * num_iterations / num_params,
"DDP world size": ddp_world_size,
"warmup_ratio": warmup_ratio,
"warmdown_ratio": warmdown_ratio,
"final_lr_frac": final_lr_frac,
},
{
"Minimum validation bpb": min_val_bpb,
"Final validation bpb": val_bpb,
"CORE metric estimate": results.get("core_metric", None),
"MFU %": f"{mfu:.2f}%",
"Total training flops": f"{flops_so_far:e}",
"Total training time": f"{total_training_time/60:.2f}m",
"Peak memory usage": f"{get_max_memory() / 1024 / 1024:.2f}MiB",
}
])
# cleanup
wandb_run.finish()
compute_cleanup()