diff --git a/T4_TRAINING_README.md b/T4_TRAINING_README.md new file mode 100644 index 0000000..e6b6942 --- /dev/null +++ b/T4_TRAINING_README.md @@ -0,0 +1,172 @@ +# T4 GPU 训练指南 + +本文档说明如何在4个T4 GPU的服务器上运行nanochat训练。 + +## 硬件要求 + +- 4个NVIDIA T4 GPU (每个16GB显存) +- 至少64GB系统内存 +- 足够的存储空间用于数据集和模型检查点 + +## 文件说明 + +### 训练脚本 + +1. **`scripts/t4_train.py`** - 针对T4优化的基础模型训练脚本 +2. **`scripts/t4_mid_train.py`** - 针对T4优化的中期训练脚本 +3. **`scripts/t4_chat_sft.py`** - 针对T4优化的SFT训练脚本 + +### 启动脚本 + +1. **`run_t4_training.sh`** - 完整的T4训练流程 +2. **`run_t4_quick_test.sh`** - 快速测试脚本,用于验证配置 + +## T4优化配置 + +### 基础模型训练 (t4_train.py) + +- **模型深度**: 12层 (原来20-32层) +- **序列长度**: 1024 (原来2048) +- **设备批次大小**: 4 (原来32) +- **总批次大小**: 131,072 tokens (原来524,288) +- **评估频率**: 每100步 (原来250步) + +### 中期训练 (t4_mid_train.py) + +- **设备批次大小**: 2 (原来32) +- **总批次大小**: 65,536 tokens (原来524,288) +- **评估频率**: 每75步 (原来150步) + +### SFT训练 (t4_chat_sft.py) + +- **设备批次大小**: 1 (原来4) +- **目标样本数**: 8 (原来32) +- **评估频率**: 每50步 (原来100步) + +## 使用方法 + +### 1. 快速测试 + +首先运行快速测试来验证配置: + +```bash +./run_t4_quick_test.sh +``` + +这将运行: +- 深度8的模型 +- 100步基础训练 +- 50步中期训练 +- 20步SFT训练 + +### 2. 完整训练 + +如果快速测试成功,运行完整训练: + +```bash +./run_t4_training.sh +``` + +这将运行: +- 深度12的模型 +- 完整的基础训练 +- 完整的中期训练 +- 完整的SFT训练 + +### 3. 单独运行训练步骤 + +你也可以单独运行每个训练步骤: + +```bash +# 基础训练 +torchrun --standalone --nproc_per_node=4 -m scripts.t4_train + +# 中期训练 +torchrun --standalone --nproc_per_node=4 -m scripts.t4_mid_train + +# SFT训练 +torchrun --standalone --nproc_per_node=4 -m scripts.t4_chat_sft +``` + +## 参数调整 + +如果遇到显存不足的问题,可以进一步调整参数: + +### 减少批次大小 +```bash +torchrun --standalone --nproc_per_node=4 -m scripts.t4_train -- --device_batch_size=2 +``` + +### 减少模型深度 +```bash +torchrun --standalone --nproc_per_node=4 -m scripts.t4_train -- --depth=8 +``` + +### 减少序列长度 +```bash +torchrun --standalone --nproc_per_node=4 -m scripts.t4_train -- --max_seq_len=512 +``` + +## 监控训练 + +### 查看GPU状态 +```bash +nvidia-smi +``` + +### 查看训练日志 +训练过程中会显示: +- 损失值 +- 学习率 +- 每步时间 +- 吞吐量 (tokens/sec) +- 模型利用率 (MFU) + +### 查看Wandb日志 +如果设置了Wandb,可以在Wandb界面查看详细的训练指标。 + +## 故障排除 + +### 显存不足 (OOM) +1. 减少 `device_batch_size` +2. 减少 `max_seq_len` +3. 减少 `depth` + +### 训练速度慢 +1. 检查GPU利用率 (`nvidia-smi`) +2. 确保数据加载不是瓶颈 +3. 考虑减少 `eval_every` 频率 + +### 分布式训练问题 +1. 确保所有4个GPU都可用 +2. 检查网络连接 +3. 确保端口没有被占用 + +## 预期性能 + +在4个T4 GPU上: +- **基础训练**: 约2-4小时 (取决于模型大小) +- **中期训练**: 约1-2小时 +- **SFT训练**: 约30分钟-1小时 + +## 输出文件 + +训练完成后,模型检查点将保存在: +- 基础模型: `~/.cache/nanochat/t4_checkpoints/` +- 中期模型: `~/.cache/nanochat/t4_mid_checkpoints/` +- SFT模型: `~/.cache/nanochat/t4_chatsft_checkpoints/` + +## 注意事项 + +1. T4 GPU的显存限制意味着需要使用较小的批次大小 +2. 模型深度和序列长度需要相应调整 +3. 训练时间会比H100等高端GPU更长 +4. 建议先运行快速测试验证配置 + +## 支持 + +如果遇到问题,请检查: +1. GPU驱动和CUDA版本 +2. PyTorch版本兼容性 +3. 显存使用情况 +4. 训练日志中的错误信息 diff --git a/nanochat/dataset.py b/nanochat/dataset.py index 602daed..21608b4 100644 --- a/nanochat/dataset.py +++ b/nanochat/dataset.py @@ -20,7 +20,7 @@ from nanochat.common import get_base_dir # The specifics of the current pretraining dataset # The URL on the internet where the data is hosted and downloaded from on demand -BASE_URL = "https://huggingface.co/datasets/karpathy/fineweb-edu-100b-shuffle/resolve/main" +BASE_URL = "https://www.modelscope.cn/datasets/Thackeray/karpathy-fineweb-edu-100b-shuffle-240shard/resolve/master" MAX_SHARD = 1822 # the last datashard is shard_01822.parquet index_to_filename = lambda index: f"shard_{index:05d}.parquet" # format of the filenames base_dir = get_base_dir() diff --git a/run_t4_quick_test.sh b/run_t4_quick_test.sh new file mode 100644 index 0000000..6fabfb3 --- /dev/null +++ b/run_t4_quick_test.sh @@ -0,0 +1,99 @@ +#!/bin/bash + +# T4 GPU快速测试脚本 +# 用于验证T4配置是否正常工作,运行较少的训练步数 + +set -e + +echo "🧪 开始T4 GPU快速测试..." + +# 环境设置 +export OMP_NUM_THREADS=1 +export NANOCHAT_BASE_DIR="$HOME/.cache/nanochat" +mkdir -p $NANOCHAT_BASE_DIR + +# 检查并安装uv +command -v uv &> /dev/null || curl -LsSf https://astral.sh/uv/install.sh | sh + +# 设置虚拟环境 +[ -d ".venv" ] || uv venv +uv sync --extra gpu +source .venv/bin/activate + +# 设置wandb运行名称 +if [ -z "$WANDB_RUN" ]; then + WANDB_RUN="t4_quick_test_$(date +%Y%m%d_%H%M%S)" +fi +echo "📊 Wandb运行名称: $WANDB_RUN" + +# 重置报告 +python -m nanochat.report reset + +# 安装Rust和编译tokenizer +curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y +source "$HOME/.cargo/env" +uv run maturin develop --release --manifest-path rustbpe/Cargo.toml + +# 下载评估数据 +EVAL_BUNDLE_URL=https://karpathy-public.s3.us-west-2.amazonaws.com/eval_bundle.zip +if [ ! -d "$NANOCHAT_BASE_DIR/eval_bundle" ]; then + echo "📥 下载评估数据包..." + curl -L -o eval_bundle.zip $EVAL_BUNDLE_URL + unzip -q eval_bundle.zip + rm eval_bundle.zip + mv eval_bundle $NANOCHAT_BASE_DIR +fi + +# 下载身份对话数据 +curl -L -o $NANOCHAT_BASE_DIR/identity_conversations.jsonl https://karpathy-public.s3.us-west-2.amazonaws.com/identity_conversations.jsonl + +echo "📊 开始数据准备..." + +# 训练tokenizer - 使用最少的数据 +echo "🔤 训练tokenizer..." +python -m nanochat.dataset -n 2 # 最少数据量 +python -m scripts.tok_train --max_chars=100000000 # 最少字符数 +python -m scripts.tok_eval + +echo "🏋️ 开始基础模型训练 (快速测试)..." + +# 基础模型训练 - 快速测试版本 +echo "📈 运行基础训练 (深度8, 批次大小2, 100步)..." +torchrun --standalone --nproc_per_node=4 -m scripts.t4_train -- --run=$WANDB_RUN --depth=8 --device_batch_size=2 --num_iterations=100 + +echo "📊 运行基础损失评估..." +torchrun --standalone --nproc_per_node=4 -m scripts.base_loss + +echo "📊 运行基础模型评估..." +torchrun --standalone --nproc_per_node=4 -m scripts.base_eval + +echo "🎯 开始中期训练 (快速测试)..." + +# 中期训练 - 快速测试版本 +echo "📈 运行中期训练 (批次大小1, 50步)..." +torchrun --standalone --nproc_per_node=4 -m scripts.t4_mid_train -- --run=$WANDB_RUN --device_batch_size=1 --num_iterations=50 + +echo "📊 运行中期训练评估..." +torchrun --standalone --nproc_per_node=4 -m scripts.chat_eval -- -i mid + +echo "💬 开始SFT训练 (快速测试)..." + +# SFT训练 - 快速测试版本 +echo "📈 运行SFT训练 (批次大小1, 20步)..." +torchrun --standalone --nproc_per_node=4 -m scripts.t4_chat_sft -- --run=$WANDB_RUN --device_batch_size=1 --num_iterations=20 + +echo "📊 运行SFT评估..." +torchrun --standalone --nproc_per_node=4 -m scripts.chat_eval -- -i sft + +echo "📋 生成最终报告..." +python -m nanochat.report generate + +echo "🎉 T4快速测试完成!" +echo "📊 查看报告: python -m nanochat.report show" +echo "💬 启动聊天界面: python -m scripts.chat_web" + +# 显示GPU使用情况 +echo "🔍 当前GPU状态:" +nvidia-smi --query-gpu=index,name,memory.used,memory.total,utilization.gpu --format=csv,noheader,nounits + +echo "✅ 快速测试已完成!" diff --git a/run_t4_training.sh b/run_t4_training.sh new file mode 100644 index 0000000..28f985c --- /dev/null +++ b/run_t4_training.sh @@ -0,0 +1,99 @@ +#!/bin/bash + +# 针对4个T4 GPU的完整训练流程脚本 +# 基于run1000.sh修改,专门为T4 GPU的16GB显存限制进行优化 + +set -e # 遇到错误时退出 + +echo "🚀 开始T4 GPU训练流程..." + +# 环境设置 +export OMP_NUM_THREADS=1 +export NANOCHAT_BASE_DIR=".cache/nanochat" +mkdir -p $NANOCHAT_BASE_DIR + +# 检查并安装uv +command -v uv &> /dev/null || curl -LsSf https://astral.sh/uv/install.sh | sh + +# 设置虚拟环境 +[ -d ".venv" ] || uv venv +uv sync --extra gpu +source .venv/bin/activate + +# 设置wandb运行名称 +if [ -z "$WANDB_RUN" ]; then + WANDB_RUN="t4_training_$(date +%Y%m%d_%H%M%S)" +fi +echo "📊 Wandb运行名称: $WANDB_RUN" + +# 重置报告 +python -m nanochat.report reset + +# 安装Rust和编译tokenizer +curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y +source "$HOME/.cargo/env" +uv run maturin develop --release --manifest-path rustbpe/Cargo.toml + +# 下载评估数据 +EVAL_BUNDLE_URL=https://karpathy-public.s3.us-west-2.amazonaws.com/eval_bundle.zip +if [ ! -d "$NANOCHAT_BASE_DIR/eval_bundle" ]; then + echo "📥 下载评估数据包..." + curl -L -o eval_bundle.zip $EVAL_BUNDLE_URL + unzip -q eval_bundle.zip + rm eval_bundle.zip + mv eval_bundle $NANOCHAT_BASE_DIR +fi + +# 下载身份对话数据 +curl -L -o $NANOCHAT_BASE_DIR/identity_conversations.jsonl https://karpathy-public.s3.us-west-2.amazonaws.com/identity_conversations.jsonl + +echo "📊 开始数据准备..." + +# 训练tokenizer - 使用较少的数据以适应T4 +echo "🔤 训练tokenizer..." +python -m nanochat.dataset -n 8 # 减少数据量 +python -m scripts.tok_train --max_chars=2000000000 # 减少字符数 +python -m scripts.tok_eval + +echo "🏋️ 开始基础模型训练..." + +# 基础模型训练 - 针对T4优化 +echo "📈 运行基础训练 (深度12, 批次大小4)..." +torchrun --standalone --nproc_per_node=4 -m scripts.t4_train -- --run=$WANDB_RUN + +echo "📊 运行基础损失评估..." +torchrun --standalone --nproc_per_node=4 -m scripts.base_loss + +echo "📊 运行基础模型评估..." +torchrun --standalone --nproc_per_node=4 -m scripts.base_eval + +echo "🎯 开始中期训练..." + +# 中期训练 - 针对T4优化 +echo "📈 运行中期训练 (批次大小2)..." +torchrun --standalone --nproc_per_node=4 -m scripts.t4_mid_train -- --run=$WANDB_RUN + +echo "📊 运行中期训练评估..." +torchrun --standalone --nproc_per_node=4 -m scripts.chat_eval -- -i mid + +echo "💬 开始SFT训练..." + +# SFT训练 - 针对T4优化 +echo "📈 运行SFT训练 (批次大小1)..." +torchrun --standalone --nproc_per_node=4 -m scripts.t4_chat_sft -- --run=$WANDB_RUN + +echo "📊 运行SFT评估..." +torchrun --standalone --nproc_per_node=4 -m scripts.chat_eval -- -i sft + +echo "📋 生成最终报告..." +python -m nanochat.report generate + +echo "🎉 T4训练流程完成!" +echo "📊 查看报告: python -m nanochat.report show" +echo "💬 启动聊天界面: python -m scripts.chat_web" + +# 显示GPU使用情况 +echo "🔍 当前GPU状态:" +nvidia-smi --query-gpu=index,name,memory.used,memory.total,utilization.gpu --format=csv,noheader,nounits + +echo "✅ 所有训练步骤已完成!" diff --git a/scripts/t4_chat_sft.py b/scripts/t4_chat_sft.py new file mode 100644 index 0000000..025dbbe --- /dev/null +++ b/scripts/t4_chat_sft.py @@ -0,0 +1,290 @@ +""" +针对4个T4 GPU优化的SFT脚本 +基于chat_sft.py修改,专门为T4 GPU的16GB显存限制进行优化 + +运行方式: +torchrun --standalone --nproc_per_node=4 -m scripts.t4_chat_sft + +或者单GPU调试: +python -m scripts.t4_chat_sft +""" + +import os +os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" + +import wandb +import torch +import torch.distributed as dist +from contextlib import nullcontext + +from nanochat.common import compute_init, compute_cleanup, get_base_dir, print0, DummyWandb, autodetect_device_type +from nanochat.checkpoint_manager import load_model +from nanochat.checkpoint_manager import save_checkpoint +from nanochat.engine import Engine +from scripts.chat_eval import run_chat_eval + +from tasks.common import TaskMixture +from tasks.arc import ARC +from tasks.gsm8k import GSM8K +from tasks.smoltalk import SmolTalk +from tasks.customjson import CustomJSON +from tasks.spellingbee import SimpleSpelling, SpellingBee + +# ----------------------------------------------------------------------------- +# T4 GPU优化的SFT配置 +run = "t4_sft" # wandb run name +# input model options +source = "mid" # base|mid , which checkpoint to load the model from +model_tag = None # model tag to load the model from +step = None # step to load the model from +# compute/precision +device_type = "" # cuda|cpu|mps (empty => autodetect) +dtype = "bfloat16" +device_batch_size = 1 # 进一步减少批次大小以适应SFT的更大模型 +# optimization +num_epochs = 1 +num_iterations = -1 # override number of iterations (-1 = disable, use num_epochs to derive it) +target_examples_per_step = 8 # 减少目标样本数 (原来32) +unembedding_lr = 0.004 +embedding_lr = 0.2 +matrix_lr = 0.02 +weight_decay = 0.0 +init_lr_frac = 0.02 +# evaluation and logging +eval_every = 50 # 更频繁的评估 (原来100) +eval_steps = 50 # 减少评估步数 (原来100) +eval_metrics_every = 100 # 更频繁的指标评估 (原来200) +eval_metrics_max_problems = 512 # 减少最大问题数 (原来1024) +# now allow CLI to override the settings via the configurator +config_keys = [k for k,v in globals().items() if not k.startswith('_') and isinstance(v, (int, float, bool, str))] +exec(open(os.path.join('nanochat', 'configurator.py')).read()) # overrides from command line or config file +user_config = {k: globals()[k] for k in config_keys} # possibly useful for logging +# ----------------------------------------------------------------------------- + +# Compute init +device_type = autodetect_device_type() if device_type == "" else device_type +ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type) +master_process = ddp_rank == 0 +ptdtype = torch.float32 if dtype == 'float32' else torch.bfloat16 +autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype) if device_type == "cuda" else nullcontext() + +# wandb logging init +use_dummy_wandb = run == "dummy" or not master_process +wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat-t4-sft", name=run, config=user_config, save_code=True) + +# Load the model and tokenizer +model, tokenizer, meta = load_model(source, device, phase="train", model_tag=model_tag, step=step) +orig_model = model # original, uncompiled model +# model = torch.compile(model, dynamic=True) # doesn't work super well because of variable lengths of inputs +engine = Engine(model, tokenizer) # will be used for inline model evaluation only + +print0(f"T4 SFT配置:") +print0(f" device_batch_size: {device_batch_size}") +print0(f" target_examples_per_step: {target_examples_per_step}") +print0(f" DDP world size: {ddp_world_size}") + +# ----------------------------------------------------------------------------- +# Task data mixture we'll train on +identity_conversations_filepath = os.path.join(get_base_dir(), "identity_conversations.jsonl") +train_ds = TaskMixture([ + ARC(subset="ARC-Easy", split="train"), # 2.3K rows + ARC(subset="ARC-Challenge", split="train"), # 1.1K rows + GSM8K(subset="main", split="train"), # 8K rows + SmolTalk(split="train", stop=10_000), # 10K rows of smoltalk + CustomJSON(filepath=identity_conversations_filepath), # 1K rows of synthetic identity conversations + SimpleSpelling(size=300, split="train"), # 300 rows of Simple Spelling (e.g. spell the word 'apple') + SpellingBee(size=300, split="train"), # 300 rows of Spelling Bee (e.g. how many 'r' are in 'strawberry'?) +]) # 2.3K + 1.1K + 8K + 10K + 1K + 0.3K + 0.3K = 23K rows +val_ds = SmolTalk(split="test") # general conversations, 24K rows (though we don't actually use all of it) + +# ----------------------------------------------------------------------------- +# DataLoader + +def sft_data_generator(dataset, batch_size): + pad_token_id = tokenizer.encode_special("<|assistant_end|>") # use <|assistant_end|> as the pad token is ok, these positions are masked in the loss + # prepares a list of tokenized conversations into a batch and yields + def collate_and_yield(batch): + nrows = len(batch) + ncols = max(len(ids) for ids, mask in batch) - 1 # seq of n creates inputs/targets of n-1 + inputs = torch.full((nrows, ncols), pad_token_id, dtype=torch.long) + targets = torch.full((nrows, ncols), -1, dtype=torch.long) # -1 is ignore index + for i, (ids, mask) in enumerate(batch): + n = len(ids) + ids_tensor = torch.tensor(ids, dtype=torch.long) + inputs[i, :n-1] = ids_tensor[:-1] + # recall -1 is the ignore index, so mask out targets where mask is 0 + row_targets = ids_tensor[1:] + # mask[1:] omits the mask for the BOS token, which is never a target atm so it's ok + mask_tensor = torch.tensor(mask[1:], dtype=torch.long) + row_targets[mask_tensor == 0] = -1 # mask out targets where mask is 0 + targets[i, :n-1] = row_targets + inputs = inputs.to(device) # move to device + targets = targets.to(device) + return inputs, targets + # iterates over the dataset in epochs, tokenizes + batch = [] + while True: + for i in range(ddp_rank, len(dataset), ddp_world_size): + doc = dataset[i] + ids, mask = tokenizer.render_conversation(doc) + batch.append((ids, mask)) + if len(batch) == batch_size: + yield collate_and_yield(batch) + batch = [] + +examples_per_step = device_batch_size * ddp_world_size +print0(f"Target examples per step: {target_examples_per_step}") +print0(f"Device batch size: {device_batch_size}") +print0(f"Examples per step is device_batch_size * ddp_world_size: {examples_per_step}") +assert target_examples_per_step % examples_per_step == 0, "Target examples per step must be divisible by examples per step" +grad_accum_steps = target_examples_per_step // examples_per_step +print0(f"=> Setting grad accum steps: {grad_accum_steps}") + +if num_iterations == -1: + # derive num_iterations from num_epochs and the size of the dataset + assert num_epochs > 0, "num_epochs must be positive if num_iterations is -1" + num_iterations = (len(train_ds) // target_examples_per_step) * num_epochs +train_loader = sft_data_generator(train_ds, batch_size=device_batch_size) +build_val_loader = lambda: sft_data_generator(val_ds, batch_size=device_batch_size) + +# ----------------------------------------------------------------------------- +# Initialize the Optimizer + +optimizers = model.setup_optimizers( + unembedding_lr=unembedding_lr, + embedding_lr=embedding_lr, + matrix_lr=matrix_lr, + weight_decay=weight_decay, +) +# Set the initial learning rate as a fraction of the base learning rate +for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["lr"] * init_lr_frac + group["initial_lr"] = group["lr"] # save the initial learning so we can decay easily later + +# ----------------------------------------------------------------------------- +# Training loop + +# Learning rate scheduler +def get_lr_multiplier(it): + lrm = 1.0 - it / num_iterations + return lrm + +# Go! +step = 0 +train_iter = iter(train_loader) +for step in range(num_iterations): + last_step = step == num_iterations - 1 + + # evaluate the validation loss + if last_step or step % eval_every == 0: + model.eval() + val_iter = iter(build_val_loader()) + losses = [] + for _ in range(eval_steps): + val_inputs, val_targets = next(val_iter) + with torch.no_grad(), autocast_ctx: + loss = model(val_inputs, val_targets) + losses.append(loss) + val_loss = torch.stack(losses).mean() # average over eval_steps + if ddp: + dist.all_reduce(val_loss, op=dist.ReduceOp.AVG) # average over ranks + val_loss = val_loss.item() + print0(f"Step {step:05d} | Validation loss: {val_loss:.6f}") + wandb_run.log({ + "step": step, + "val_loss": val_loss, + }) + model.train() + + # evaluate accuracy of the multiple choice tasks (which are quick to run) + if last_step or (step > 0 and step % eval_metrics_every == 0): + model.eval() + metrics = {} + with torch.no_grad(), autocast_ctx: + # note that because these are inside no_grad, we can usually afford to at least ~2X the batch size + metrics["mmlu_acc"] = run_chat_eval("MMLU", model, tokenizer, engine, batch_size=device_batch_size*2, max_problems=eval_metrics_max_problems) + metrics["arc_easy_acc"] = run_chat_eval("ARC-Easy", model, tokenizer, engine, batch_size=device_batch_size*2, max_problems=eval_metrics_max_problems) + metrics_str = ', '.join(f'{k}: {v:.6f}' for k, v in metrics.items()) + print0(f"Step {step:05d} | {metrics_str}") + wandb_run.log({ + "step": step, + **metrics, + }) + model.train() + + if last_step: + break + + # evaluate the gradient + num_tokens = torch.tensor(0, device=device) # the number of "active" tokens of supervision seen + for micro_step in range(grad_accum_steps): + train_inputs, train_targets = next(train_iter) + with autocast_ctx: + loss = model(train_inputs, train_targets) + train_loss = loss.detach() # for logging + loss = loss / grad_accum_steps # each .backward() is a grad sum => normalize loss here + loss.backward() # accumulate the gradient + num_tokens += (train_targets >= 0).sum() + if ddp: + dist.all_reduce(num_tokens, op=dist.ReduceOp.SUM) # sum over ranks + + # learning rate scheduler + lrm = get_lr_multiplier(step) + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["initial_lr"] * lrm + + # step the optimizers + for opt in optimizers: + opt.step() + model.zero_grad(set_to_none=True) + + # logging + train_loss_item = train_loss.item() + num_tokens_item = num_tokens.item() + print0(f"Step {step:05d}/{num_iterations:05d} | Training loss: {train_loss_item:.6f}| lrm: {lrm:.6f}| num_tokens: {num_tokens_item:,}") + wandb_run.log({ + "step": step, + "lrm": lrm, + "train_loss": train_loss_item, + "num_tokens": num_tokens_item, + }) + step += 1 + +# Save the model at the end of the run +if master_process: + base_dir = get_base_dir() + depth = model.config.n_layer + model_tag = f"t4_d{depth}" # base the model tag on the depth of the base model + checkpoint_dir = os.path.join(base_dir, "t4_chatsft_checkpoints", model_tag) + model_config_kwargs = model.config.__dict__ # slightly naughty, abusing the simplicity of GPTConfig, TODO nicer + save_checkpoint( + checkpoint_dir, + step, + model.state_dict(), + None, # note: we don't bother to save the optimizer state + { + "step": step, + "val_loss": val_loss, + **metrics, + "model_config": model_config_kwargs, + } + ) + print0(f"✅ T4 SFT完成,模型保存到: {checkpoint_dir}") + +# Log to report +from nanochat.report import get_report +get_report().log(section="T4 Chat SFT", data=[ + user_config, + { + "Training rows": len(train_ds), + "Number of iterations": num_iterations, + "Training loss": train_loss_item, + "Validation loss": val_loss, + }, +]) + +# Cleanup +wandb_run.finish() +compute_cleanup() diff --git a/scripts/t4_mid_train.py b/scripts/t4_mid_train.py new file mode 100644 index 0000000..bf36b36 --- /dev/null +++ b/scripts/t4_mid_train.py @@ -0,0 +1,314 @@ +""" +针对4个T4 GPU优化的midtrain脚本 +基于mid_train.py修改,专门为T4 GPU的16GB显存限制进行优化 + +运行方式: +torchrun --standalone --nproc_per_node=4 -m scripts.t4_mid_train + +或者单GPU调试: +python -m scripts.t4_mid_train +""" + +from collections import deque +import os +os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" +import time +import wandb +import torch +from contextlib import nullcontext +from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, get_base_dir, autodetect_device_type +from nanochat.tokenizer import get_token_bytes +from nanochat.checkpoint_manager import save_checkpoint +from nanochat.loss_eval import evaluate_bpb +from nanochat.checkpoint_manager import load_model +import torch.distributed as dist + +from tasks.common import TaskMixture +from tasks.gsm8k import GSM8K +from tasks.mmlu import MMLU +from tasks.smoltalk import SmolTalk +from tasks.customjson import CustomJSON +from tasks.spellingbee import SimpleSpelling, SpellingBee + +# ----------------------------------------------------------------------------- +# T4 GPU优化配置 +run = "t4_midtraining" # wandb run name +device_type = "" # cuda|cpu|mps (empty => autodetect) +model_tag = None # model tag to load the model from (base model or midtrained model) +step = None # step to load the model from (base model or midtrained model) +dtype = "bfloat16" +num_iterations = -1 # explicit number of steps of the optimization (-1 = disable) +max_seq_len = 1024 # 减少序列长度以节省显存 +device_batch_size = 2 # 进一步减少批次大小以适应midtrain的更大模型 +unembedding_lr = 0.004 +embedding_lr = 0.2 +matrix_lr = 0.02 +init_lr_frac = 1.0 # initial learning rate is this fraction of the base learning rate +weight_decay = 0.0 +eval_every = 75 # 更频繁的评估 (原来150) +eval_tokens = 5*131072 # 减少评估token数量 +total_batch_size = 65536 # 减少总批次大小 +dry_run = 0 # dry_run=1 is for experiments: we will log to wandb but we won't write checkpoints or report +config_keys = [k for k,v in globals().items() if not k.startswith('_') and isinstance(v, (int, float, bool, str))] +exec(open(os.path.join('nanochat', 'configurator.py')).read()) # overrides from command line or config file +user_config = {k: globals()[k] for k in config_keys} # possibly useful for logging +# ----------------------------------------------------------------------------- + +# Compute init +device_type = autodetect_device_type() if device_type == "" else device_type +ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type) +master_process = ddp_rank == 0 +autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type == "cuda" else nullcontext() +synchronize = torch.cuda.synchronize if device_type == "cuda" else lambda: None +get_max_memory = torch.cuda.max_memory_allocated if device_type == "cuda" else lambda: 0 + +# wandb logging init +use_dummy_wandb = run == "dummy" or not master_process +wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat-t4-mid", name=run, config=user_config) + +# Load the model and tokenizer +model, tokenizer, meta = load_model("base", device, phase="train", model_tag=model_tag, step=step) +pretrain_batch_size = meta.get("device_batch_size", None) +if pretrain_batch_size is not None and device_batch_size > pretrain_batch_size: + print0(f"FOOTGUN WARNING: base model training used device_batch_size {pretrain_batch_size}, did you pass in a good --device_batch_size to this script?") +orig_model = model +model = torch.compile(model, dynamic=False) +depth = model.config.n_layer +num_flops_per_token = model.estimate_flops() +tokens_per_fwdbwd = device_batch_size * max_seq_len # tokens per iteration for a single rank +world_tokens_per_fwdbwd = tokens_per_fwdbwd * ddp_world_size # total tokens per iteration for all ranks +assert total_batch_size % world_tokens_per_fwdbwd == 0 +grad_accum_steps = total_batch_size // world_tokens_per_fwdbwd +print0(f"T4 Midtrain配置:") +print0(f" Tokens / micro-batch / rank: {device_batch_size} x {max_seq_len} = {tokens_per_fwdbwd:,}") +print0(f" Tokens / micro-batch: {world_tokens_per_fwdbwd:,}") +print0(f" Total batch size {total_batch_size:,} => gradient accumulation steps: {grad_accum_steps}") +print0(f" DDP world size: {ddp_world_size}") +token_bytes = get_token_bytes(device=device) + +# Initialize the Optimizer (Muon for Linear layers, AdamW for embedding and lm_head) +optimizers = model.setup_optimizers(unembedding_lr=unembedding_lr, embedding_lr=embedding_lr, matrix_lr=matrix_lr, weight_decay=weight_decay) +adamw_optimizer, muon_optimizer = optimizers +# Override the initial learning rate as a fraction of the base learning rate +for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["lr"] * init_lr_frac + group["initial_lr"] = group["lr"] # save the initial learning so we can decay easily later + +# Midtraining data mixture and DataLoader +base_dir = get_base_dir() +identity_conversations_filepath = os.path.join(base_dir, "identity_conversations.jsonl") +train_dataset = TaskMixture([ + SmolTalk(split="train"), # 460K rows of general conversations + MMLU(subset="auxiliary_train", split="train"), # 100K rows of multiple choice problems drawn from ARC, MC_TEST, OBQA, RACE + GSM8K(subset="main", split="train"), # 8K rows teaching simple math and (calculator) tool use + CustomJSON(filepath=identity_conversations_filepath), # 1000 rows of synthetic identity conversations + CustomJSON(filepath=identity_conversations_filepath), # let's do 2 epochs of these + SimpleSpelling(size=200000, split="train"), # 200K rows of Simple Spelling (e.g. spell the word 'apple') + SpellingBee(size=80000, split="train"), # 80K rows of Spelling Bee (e.g. how many 'r' are in 'strawberry'?) +]) # total: 460K + 100K + 8K + 200K + 80K = 848K rows +val_dataset = TaskMixture([ + SmolTalk(split="test"), # 24K rows in test set + MMLU(subset="all", split="test", stop=5200), # 14K rows in test set, use only 5.2K to match the train ratios + GSM8K(subset="main", split="test", stop=420), # 1.32K rows in test set, use only 420 to match the train ratios +]) # total: 24K + 14K + 1.32K ~= 39K rows + +# DataLoader is defined here, it emits inputs, targets : 2D tensors of shape (device_batch_size, max_seq_len) +last_step = False # we will toggle this to True when we reach the end of the dataset +approx_progress = 0.0 # will go from 0 to 1 over the course of the epoch +def mid_data_generator(split): + global last_step, approx_progress + assert split in {"train", "val"}, "split must be 'train' or 'val'" + dataset = train_dataset if split == "train" else val_dataset + dataset_size = len(dataset) + assert dataset_size > 0 + needed_tokens = device_batch_size * max_seq_len + 1 # to form one training batch of inputs,targets + token_buffer = deque() + # CUDA supports memory pinning for faster transfers between CPU and GPU: + scratch = torch.empty(needed_tokens, dtype=torch.int64, pin_memory=(device_type == "cuda")) + cursor = ddp_rank # increments by ddp_world_size each time, so each rank processes unique documents + it = 0 # iteration counter + while True: + # Accumulate enough tokens for one iteration before yielding + while len(token_buffer) < needed_tokens: + conversation = dataset[cursor] + ids, _ = tokenizer.render_conversation(conversation) + token_buffer.extend(ids) + cursor += ddp_world_size + if cursor >= dataset_size: + cursor -= dataset_size # wrap around for another epoch + if split == "train": + last_step = True # toggle last_step to True, which will terminate the training loop + # Stopping condition to respect num_iterations, if given + it += 1 + if num_iterations > 0 and it >= num_iterations: + last_step = True # toggle last_step to True, which will terminate the training loop + # Build up inputs/targets and yield + for i in range(needed_tokens): + scratch[i] = token_buffer.popleft() + inputs_cpu = scratch[:-1].to(dtype=torch.int32) + targets_cpu = scratch[1:] + inputs = inputs_cpu.view(device_batch_size, max_seq_len).to(device=device, dtype=torch.int32, non_blocking=True) + targets = targets_cpu.view(device_batch_size, max_seq_len).to(device=device, dtype=torch.int64, non_blocking=True) + if split == "train": + if num_iterations > 0: + approx_progress = it / num_iterations # calculate progress from the max number of iterations + else: + approx_progress = cursor / dataset_size # approximate progress as a fraction of the dataset + yield inputs, targets + +train_loader = mid_data_generator("train") +build_val_loader = lambda: mid_data_generator("val") +progress = 0 # will go from 0 to 1 over the course of the epoch + +# Learning rate scheduler +def get_lr_multiplier(progress): + # first 80% of training: no decay, then linearly ramp down to 0. + return 1 if progress < 0.8 else 1 - (progress - 0.8) / 0.2 + +# Momentum scheduler for Muon optimizer +def get_muon_momentum(it): + frac = min(it / 300, 1) + momentum = (1 - frac) * 0.85 + frac * 0.95 + return momentum + +# ----------------------------------------------------------------------------- +# Training loop +x, y = next(train_loader) # prefetch the very first batch of data +min_val_bpb = float("inf") +smooth_train_loss = 0 # EMA of training loss +ema_beta = 0.9 # EMA decay factor +total_training_time = 0 # total wall-clock time of training +step = 0 +while True: + flops_so_far = num_flops_per_token * total_batch_size * step + + # Synchronize last_step across all ranks to avoid hangs in the distributed setting + if ddp: + last_step_tensor = torch.tensor(last_step, dtype=torch.int32, device=device) + dist.all_reduce(last_step_tensor, op=dist.ReduceOp.MAX) + last_step = bool(last_step_tensor.item()) + + # once in a while: evaluate the val bpb (all ranks participate) + if eval_every > 0 and (last_step or step % eval_every == 0): + model.eval() + val_loader = build_val_loader() + eval_steps = eval_tokens // (device_batch_size * max_seq_len * ddp_world_size) + with autocast_ctx: + val_bpb = evaluate_bpb(model, val_loader, eval_steps, token_bytes) + print0(f"Step {step:05d} | Validation bpb: {val_bpb:.4f}") + if val_bpb < min_val_bpb: + min_val_bpb = val_bpb + wandb_run.log({ + "step": step, + "total_training_flops": flops_so_far, + "total_training_time": total_training_time, + "val/bpb": val_bpb, + }) + model.train() + + # save checkpoint at the end of the run (only on master process) + if master_process and last_step and not dry_run: + output_dirname = f"t4_d{depth}" + checkpoint_dir = os.path.join(base_dir, "t4_mid_checkpoints", output_dirname) + save_checkpoint( + checkpoint_dir, + step, + orig_model.state_dict(), + [opt.state_dict() for opt in optimizers], + { + "step": step, + "val_bpb": val_bpb, + "model_config": { + "sequence_len": max_seq_len, + "vocab_size": tokenizer.get_vocab_size(), + "n_layer": depth, + "n_head": model.config.n_head, + "n_kv_head": model.config.n_kv_head, + "n_embd": model.config.n_embd, + }, + "user_config": user_config, + } + ) + print0(f"✅ T4 Midtrain完成,模型保存到: {checkpoint_dir}") + + if last_step: + break + + # ------------------------------------------------------------------------- + # single training step + synchronize() + t0 = time.time() + for micro_step in range(grad_accum_steps): + with autocast_ctx: + loss = model(x, y) + train_loss = loss.detach() + loss = loss / grad_accum_steps + loss.backward() + x, y = next(train_loader) + progress = max(progress, approx_progress) + + # step the optimizers + lrm = get_lr_multiplier(progress) + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["initial_lr"] * lrm + muon_momentum = get_muon_momentum(step) + for group in muon_optimizer.param_groups: + group["momentum"] = muon_momentum + for opt in optimizers: + opt.step() + model.zero_grad(set_to_none=True) + synchronize() + t1 = time.time() + dt = t1 - t0 + + # State + step += 1 + + # logging + smooth_train_loss = ema_beta * smooth_train_loss + (1 - ema_beta) * train_loss.item() + debiased_smooth_loss = smooth_train_loss / (1 - ema_beta**(step + 1)) + pct_done = 100 * progress + tok_per_sec = int(world_tokens_per_fwdbwd / dt) + flops_per_sec = num_flops_per_token * total_batch_size / dt + # T4的峰值性能约为65 TFLOPS (bfloat16) + promised_flops_per_sec_t4 = 65e12 * ddp_world_size + mfu = 100 * flops_per_sec / promised_flops_per_sec_t4 + if step > 10: + total_training_time += dt + print0(f"step {step:05d} ({pct_done:.2f}%) | loss: {debiased_smooth_loss:.6f} | lrm: {lrm:.2f} | dt: {dt * 1000:.2f}ms | tok/sec: {tok_per_sec:,} | mfu: {mfu:.2f} | total time: {total_training_time/60:.2f}m") + if step % 10 == 0: + wandb_run.log({ + "step": step, + "total_training_flops": flops_so_far, + "total_training_time": total_training_time, + "train/loss": debiased_smooth_loss, + "train/lrm": lrm, + "train/dt": dt, + "train/tok_per_sec": tok_per_sec, + "train/mfu": mfu, + }) + +# print final stats +print0(f"Peak memory usage: {get_max_memory() / 1024 / 1024:.2f}MiB") +print0(f"Total training time: {total_training_time/60:.2f}m") +print0(f"Minimum validation bpb: {min_val_bpb:.4f}") + +# Log to report +if not dry_run: + from nanochat.report import get_report + get_report().log(section="T4 Midtraining", data=[ + user_config, + { + "Number of iterations": step, + "DDP world size": ddp_world_size, + }, + { + "Minimum validation bpb": min_val_bpb, + } + ]) + +# cleanup +wandb_run.finish() +compute_cleanup() diff --git a/scripts/t4_train.py b/scripts/t4_train.py new file mode 100644 index 0000000..987202f --- /dev/null +++ b/scripts/t4_train.py @@ -0,0 +1,350 @@ +""" +针对4个T4 GPU优化的训练脚本 +基于base_train.py修改,专门为T4 GPU的16GB显存限制进行优化 + +运行方式: +torchrun --standalone --nproc_per_node=4 -m scripts.t4_train + +或者单GPU调试: +python -m scripts.t4_train +""" + +import os +os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" +import time +from contextlib import nullcontext + +import wandb +import torch + +from nanochat.gpt import GPT, GPTConfig +from nanochat.dataloader import tokenizing_distributed_data_loader +from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, print_banner, get_base_dir, autodetect_device_type +from nanochat.tokenizer import get_tokenizer, get_token_bytes +from nanochat.checkpoint_manager import save_checkpoint +from nanochat.loss_eval import evaluate_bpb +from nanochat.engine import Engine +from scripts.base_eval import evaluate_model +print_banner() + +# ----------------------------------------------------------------------------- +# T4 GPU优化配置 +run = "t4_training" # wandb run name +# Runtime +device_type = "" # cuda|cpu|mps (empty => autodetect) +# Model architecture - 针对T4优化 +depth = 12 # 减少深度以适应T4显存限制 (原来20-32) +max_seq_len = 1024 # 减少序列长度以节省显存 (原来2048) +# Training horizon +num_iterations = -1 # explicit number of steps of the optimization (-1 = disable) +target_flops = -1.0 # calculate num_iterations to reach target_flops +target_param_data_ratio = 20 # calculate num_iterations to maintain fixed data:param ratio (Chinchilla=20) +# Optimization - 针对T4优化 +device_batch_size = 4 # 大幅减少批次大小以适应T4的16GB显存 (原来32) +total_batch_size = 131072 # 减少总批次大小 (原来524288) +embedding_lr = 0.2 +unembedding_lr = 0.004 +weight_decay = 0.0 +matrix_lr = 0.02 +grad_clip = 1.0 +warmup_ratio = 0.0 +warmdown_ratio = 0.2 +final_lr_frac = 0.0 +# Evaluation - 针对T4优化 +eval_every = 100 # 更频繁的评估 (原来250) +eval_tokens = 5*131072 # 减少评估token数量 (原来20*524288) +core_metric_every = 1000 # 更频繁的核心指标评估 (原来2000) +core_metric_max_per_task = 250 # 减少每个任务的最大样本数 (原来500) +sample_every = 1000 # 更频繁的采样 (原来2000) +# Output +model_tag = "t4_d12" # T4训练的模型标签 +# now allow CLI to override the settings via the configurator +config_keys = [k for k,v in globals().items() if not k.startswith('_') and isinstance(v, (int, float, bool, str))] +exec(open(os.path.join('nanochat', 'configurator.py')).read()) # overrides from command line or config file +user_config = {k: globals()[k] for k in config_keys} # will be useful for logging +# ----------------------------------------------------------------------------- + +# Compute init +device_type = autodetect_device_type() if device_type == "" else device_type +ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type) +master_process = ddp_rank == 0 # this process will do logging, checkpointing etc. +autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type == "cuda" else nullcontext() +synchronize = torch.cuda.synchronize if device_type == "cuda" else lambda: None +get_max_memory = torch.cuda.max_memory_allocated if device_type == "cuda" else lambda: 0 + +# wandb logging init +use_dummy_wandb = run == "dummy" or not master_process +wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat-t4", name=run, config=user_config) + +# Tokenizer will be useful for evaluation, also we need the vocab size +tokenizer = get_tokenizer() +token_bytes = get_token_bytes(device=device) +vocab_size = tokenizer.get_vocab_size() +print0(f"Vocab size: {vocab_size:,}") + +# Model kwargs are derived from the desired depth of the model +num_layers = depth +model_dim = depth * 64 # aspect ratio 64 (保持与原始配置一致) +num_heads = max(1, (model_dim + 127) // 128) # head dim 128 +num_kv_heads = num_heads # default is 1:1 GQA ratio +print0(f"T4优化配置:") +print0(f" num_layers: {num_layers}") +print0(f" model_dim: {model_dim}") +print0(f" num_heads: {num_heads}") +print0(f" num_kv_heads: {num_kv_heads}") +print0(f" max_seq_len: {max_seq_len}") +print0(f" device_batch_size: {device_batch_size}") + +# Optimizer / data / training length related hyperparameters +# figure out the needed gradient accumulation to reach the desired total batch size +tokens_per_fwdbwd = device_batch_size * max_seq_len # tokens per iteration for a single rank +world_tokens_per_fwdbwd = tokens_per_fwdbwd * ddp_world_size # total tokens per iteration for all ranks +assert total_batch_size % world_tokens_per_fwdbwd == 0 +grad_accum_steps = total_batch_size // world_tokens_per_fwdbwd +print0(f"Tokens / micro-batch / rank: {device_batch_size} x {max_seq_len} = {tokens_per_fwdbwd:,}") +print0(f"Tokens / micro-batch: {world_tokens_per_fwdbwd:,}") +print0(f"Total batch size {total_batch_size:,} => gradient accumulation steps: {grad_accum_steps}") +print0(f"DDP world size: {ddp_world_size}") + +# ----------------------------------------------------------------------------- +# Initialize the Model +model_config_kwargs = dict(sequence_len=max_seq_len, vocab_size=vocab_size, n_layer=num_layers, n_head=num_heads, n_kv_head=num_kv_heads, n_embd=model_dim) +with torch.device("meta"): + model_config = GPTConfig(**model_config_kwargs) + model = GPT(model_config) +model.to_empty(device=device) +model.init_weights() +orig_model = model # original, uncompiled model, for saving raw model state_dict +model = torch.compile(model, dynamic=False) +num_params = sum(p.numel() for p in model.parameters()) +print0(f"Number of parameters: {num_params:,}") +num_flops_per_token = model.estimate_flops() +print0(f"Estimated FLOPs per token: {num_flops_per_token:e}") + +# Calculate number of iterations +assert num_iterations > 0 or target_param_data_ratio > 0 or target_flops > 0 +if num_iterations > 0: + print0(f"Using user-provided number of iterations: {num_iterations:,}") +elif target_flops > 0: + num_iterations = round(target_flops / (num_flops_per_token * total_batch_size)) + print0(f"Calculated number of iterations from target FLOPs: {num_iterations:,}") +elif target_param_data_ratio > 0: + target_tokens = target_param_data_ratio * num_params + num_iterations = target_tokens // total_batch_size + print0(f"Calculated number of iterations from target data:param ratio: {num_iterations:,}") +else: + raise ValueError("No training horizon specified") +total_tokens = total_batch_size * num_iterations +print0(f"Total number of training tokens: {total_tokens:,}") +print0(f"Tokens : Params ratio: {total_batch_size * num_iterations / num_params:.2f}") +print0(f"Total training FLOPs estimate: {num_flops_per_token * total_tokens:e}") + +# ----------------------------------------------------------------------------- +# Initialize the Optimizer (Muon for Linear layers, AdamW for embedding and lm_head) +optimizers = model.setup_optimizers(unembedding_lr=unembedding_lr, embedding_lr=embedding_lr, matrix_lr=matrix_lr, weight_decay=weight_decay) +adamw_optimizer, muon_optimizer = optimizers + +# Initialize the DataLoaders for train/val +base_dir = get_base_dir() +tokens_dir = os.path.join(base_dir, "tokenized_data") +train_loader = tokenizing_distributed_data_loader(device_batch_size, max_seq_len, split="train", device=device) +build_val_loader = lambda: tokenizing_distributed_data_loader(device_batch_size, max_seq_len, split="val", device=device) +x, y = next(train_loader) # kick off load of the very first batch of data + +# ----------------------------------------------------------------------------- +# Set up hyperparameter schedulers + +# Learning rate scheduler +def get_lr_multiplier(it): + warmup_iters = round(warmup_ratio * num_iterations) + warmdown_iters = round(warmdown_ratio * num_iterations) + if it < warmup_iters: + return (it + 1) / warmup_iters + elif it <= num_iterations - warmdown_iters: + return 1.0 + else: + progress = (num_iterations - it) / warmdown_iters + return progress * 1.0 + (1 - progress) * final_lr_frac + +# Momentum scheduler for Muon optimizer +def get_muon_momentum(it): + frac = min(it / 300, 1) + momentum = (1 - frac) * 0.85 + frac * 0.95 + return momentum + +# ----------------------------------------------------------------------------- +# Training loop +min_val_bpb = float("inf") +smooth_train_loss = 0 # EMA of training loss +ema_beta = 0.9 # EMA decay factor +total_training_time = 0 # total wall-clock time of training +# note that we run +1 steps only so that we can eval and save at the end +for step in range(num_iterations + 1): + last_step = step == num_iterations + flops_so_far = num_flops_per_token * total_batch_size * step + + # once in a while: evaluate the val bpb (all ranks participate) + if last_step or step % eval_every == 0: + model.eval() + val_loader = build_val_loader() + eval_steps = eval_tokens // (device_batch_size * max_seq_len * ddp_world_size) + with autocast_ctx: + val_bpb = evaluate_bpb(model, val_loader, eval_steps, token_bytes) + print0(f"Step {step:05d} | Validation bpb: {val_bpb:.4f}") + if val_bpb < min_val_bpb: + min_val_bpb = val_bpb + wandb_run.log({ + "step": step, + "total_training_flops": flops_so_far, + "total_training_time": total_training_time, + "val/bpb": val_bpb, + }) + model.train() + + # once in a while: estimate the CORE metric (all ranks participate) + results = {} + if core_metric_every > 0 and (last_step or (step > 0 and step % core_metric_every == 0)): + model.eval() + with autocast_ctx: + results = evaluate_model(orig_model, tokenizer, device, max_per_task=core_metric_max_per_task) + print0(f"Step {step:05d} | CORE metric: {results['core_metric']:.4f}") + wandb_run.log({ + "step": step, + "total_training_flops": flops_so_far, + "core_metric": results["core_metric"], + "centered_results": results["centered_results"], + }) + model.train() + + # once in a while: sample from the model (only on master process) + if master_process and (last_step or (step > 0 and step % sample_every == 0)): + model.eval() + prompts = [ + "The capital of France is", + "The chemical symbol of gold is", + "If yesterday was Friday, then tomorrow will be", + "The opposite of hot is", + "The planets of the solar system are:", + "My favorite color is", + "If 5*x + 3 = 13, then x is", + ] + engine = Engine(orig_model, tokenizer) + for prompt in prompts: + tokens = tokenizer(prompt, prepend="<|bos|>") + with autocast_ctx: + sample, _ = engine.generate_batch(tokens, num_samples=1, max_tokens=16, temperature=0) + print0(tokenizer.decode(sample[0])) + model.train() + + # save checkpoint at the end of the run (only on master process) + if master_process and last_step: + output_dirname = model_tag if model_tag else f"t4_d{depth}" + checkpoint_dir = os.path.join(base_dir, "t4_checkpoints", output_dirname) + save_checkpoint( + checkpoint_dir, + step, + orig_model.state_dict(), + [opt.state_dict() for opt in optimizers], + { + "step": step, + "val_bpb": val_bpb, + "model_config": model_config_kwargs, + "user_config": user_config, + "device_batch_size": device_batch_size, + "max_seq_len": max_seq_len, + } + ) + print0(f"✅ T4训练完成,模型保存到: {checkpoint_dir}") + + if last_step: + break + + # ------------------------------------------------------------------------- + # single training step + synchronize() + t0 = time.time() + for micro_step in range(grad_accum_steps): + with autocast_ctx: + loss = model(x, y) + train_loss = loss.detach() + loss = loss / grad_accum_steps + loss.backward() + x, y = next(train_loader) + + # gradient clipping + if grad_clip > 0.0: + torch.nn.utils.clip_grad_norm_(orig_model.parameters(), grad_clip) + + # step the optimizers + lrm = get_lr_multiplier(step) + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["initial_lr"] * lrm + muon_momentum = get_muon_momentum(step) + for group in muon_optimizer.param_groups: + group["momentum"] = muon_momentum + for opt in optimizers: + opt.step() + model.zero_grad(set_to_none=True) + synchronize() + t1 = time.time() + dt = t1 - t0 + + # logging + smooth_train_loss = ema_beta * smooth_train_loss + (1 - ema_beta) * train_loss.item() + debiased_smooth_loss = smooth_train_loss / (1 - ema_beta**(step + 1)) + pct_done = 100 * step / num_iterations + tok_per_sec = int(world_tokens_per_fwdbwd / dt) + flops_per_sec = num_flops_per_token * total_batch_size / dt + # T4的峰值性能约为65 TFLOPS (bfloat16) + promised_flops_per_sec_t4 = 65e12 * ddp_world_size + mfu = 100 * flops_per_sec / promised_flops_per_sec_t4 + if step > 10: + total_training_time += dt + print0(f"step {step:05d}/{num_iterations:05d} ({pct_done:.2f}%) | loss: {debiased_smooth_loss:.6f} | lrm: {lrm:.2f} | dt: {dt * 1000:.2f}ms | tok/sec: {tok_per_sec:,} | mfu: {mfu:.2f} | total time: {total_training_time/60:.2f}m") + if step % 50 == 0: # 更频繁的wandb日志记录 + wandb_run.log({ + "step": step, + "total_training_flops": flops_so_far, + "total_training_time": total_training_time, + "train/loss": debiased_smooth_loss, + "train/lrm": lrm, + "train/dt": dt, + "train/tok_per_sec": tok_per_sec, + "train/mfu": mfu, + }) + +# print final stats +print0(f"Peak memory usage: {get_max_memory() / 1024 / 1024:.2f}MiB") +print0(f"Total training time: {total_training_time/60:.2f}m") +print0(f"Minimum validation bpb: {min_val_bpb:.4f}") + +# Log to report +from nanochat.report import get_report +get_report().log(section="T4 Base model training", data=[ + user_config, + { + "Number of parameters": num_params, + "Number of FLOPs per token": f"{num_flops_per_token:e}", + "Calculated number of iterations": num_iterations, + "Number of training tokens": total_tokens, + "Tokens : Params ratio": total_batch_size * num_iterations / num_params, + "DDP world size": ddp_world_size, + "warmup_ratio": warmup_ratio, + "warmdown_ratio": warmdown_ratio, + "final_lr_frac": final_lr_frac, + }, + { + "Minimum validation bpb": min_val_bpb, + "Final validation bpb": val_bpb, + "CORE metric estimate": results.get("core_metric", None), + "MFU %": f"{mfu:.2f}%", + "Total training flops": f"{flops_so_far:e}", + "Total training time": f"{total_training_time/60:.2f}m", + "Peak memory usage": f"{get_max_memory() / 1024 / 1024:.2f}MiB", + } +]) + +# cleanup +wandb_run.finish() +compute_cleanup()