nanochat/scripts/t4_chat_sft.py

291 lines
12 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
针对4个T4 GPU优化的SFT脚本
基于chat_sft.py修改专门为T4 GPU的16GB显存限制进行优化
运行方式:
torchrun --standalone --nproc_per_node=4 -m scripts.t4_chat_sft
或者单GPU调试:
python -m scripts.t4_chat_sft
"""
import os
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
import wandb
import torch
import torch.distributed as dist
from contextlib import nullcontext
from nanochat.common import compute_init, compute_cleanup, get_base_dir, print0, DummyWandb, autodetect_device_type
from nanochat.checkpoint_manager import load_model
from nanochat.checkpoint_manager import save_checkpoint
from nanochat.engine import Engine
from scripts.chat_eval import run_chat_eval
from tasks.common import TaskMixture
from tasks.arc import ARC
from tasks.gsm8k import GSM8K
from tasks.smoltalk import SmolTalk
from tasks.customjson import CustomJSON
from tasks.spellingbee import SimpleSpelling, SpellingBee
# -----------------------------------------------------------------------------
# T4 GPU优化的SFT配置
run = "t4_sft" # wandb run name
# input model options
source = "mid" # base|mid , which checkpoint to load the model from
model_tag = None # model tag to load the model from
step = None # step to load the model from
# compute/precision
device_type = "" # cuda|cpu|mps (empty => autodetect)
dtype = "bfloat16"
device_batch_size = 1 # 进一步减少批次大小以适应SFT的更大模型
# optimization
num_epochs = 1
num_iterations = -1 # override number of iterations (-1 = disable, use num_epochs to derive it)
target_examples_per_step = 8 # 减少目标样本数 (原来32)
unembedding_lr = 0.004
embedding_lr = 0.2
matrix_lr = 0.02
weight_decay = 0.0
init_lr_frac = 0.02
# evaluation and logging
eval_every = 50 # 更频繁的评估 (原来100)
eval_steps = 50 # 减少评估步数 (原来100)
eval_metrics_every = 100 # 更频繁的指标评估 (原来200)
eval_metrics_max_problems = 512 # 减少最大问题数 (原来1024)
# now allow CLI to override the settings via the configurator
config_keys = [k for k,v in globals().items() if not k.startswith('_') and isinstance(v, (int, float, bool, str))]
exec(open(os.path.join('nanochat', 'configurator.py')).read()) # overrides from command line or config file
user_config = {k: globals()[k] for k in config_keys} # possibly useful for logging
# -----------------------------------------------------------------------------
# Compute init
device_type = autodetect_device_type() if device_type == "" else device_type
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
master_process = ddp_rank == 0
ptdtype = torch.float32 if dtype == 'float32' else torch.bfloat16
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype) if device_type == "cuda" else nullcontext()
# wandb logging init
use_dummy_wandb = run == "dummy" or not master_process
wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat-t4-sft", name=run, config=user_config, save_code=True)
# Load the model and tokenizer
model, tokenizer, meta = load_model(source, device, phase="train", model_tag=model_tag, step=step)
orig_model = model # original, uncompiled model
# model = torch.compile(model, dynamic=True) # doesn't work super well because of variable lengths of inputs
engine = Engine(model, tokenizer) # will be used for inline model evaluation only
print0(f"T4 SFT配置:")
print0(f" device_batch_size: {device_batch_size}")
print0(f" target_examples_per_step: {target_examples_per_step}")
print0(f" DDP world size: {ddp_world_size}")
# -----------------------------------------------------------------------------
# Task data mixture we'll train on
identity_conversations_filepath = os.path.join(get_base_dir(), "identity_conversations.jsonl")
train_ds = TaskMixture([
ARC(subset="ARC-Easy", split="train"), # 2.3K rows
ARC(subset="ARC-Challenge", split="train"), # 1.1K rows
GSM8K(subset="main", split="train"), # 8K rows
SmolTalk(split="train", stop=10_000), # 10K rows of smoltalk
CustomJSON(filepath=identity_conversations_filepath), # 1K rows of synthetic identity conversations
SimpleSpelling(size=300, split="train"), # 300 rows of Simple Spelling (e.g. spell the word 'apple')
SpellingBee(size=300, split="train"), # 300 rows of Spelling Bee (e.g. how many 'r' are in 'strawberry'?)
]) # 2.3K + 1.1K + 8K + 10K + 1K + 0.3K + 0.3K = 23K rows
val_ds = SmolTalk(split="test") # general conversations, 24K rows (though we don't actually use all of it)
# -----------------------------------------------------------------------------
# DataLoader
def sft_data_generator(dataset, batch_size):
pad_token_id = tokenizer.encode_special("<|assistant_end|>") # use <|assistant_end|> as the pad token is ok, these positions are masked in the loss
# prepares a list of tokenized conversations into a batch and yields
def collate_and_yield(batch):
nrows = len(batch)
ncols = max(len(ids) for ids, mask in batch) - 1 # seq of n creates inputs/targets of n-1
inputs = torch.full((nrows, ncols), pad_token_id, dtype=torch.long)
targets = torch.full((nrows, ncols), -1, dtype=torch.long) # -1 is ignore index
for i, (ids, mask) in enumerate(batch):
n = len(ids)
ids_tensor = torch.tensor(ids, dtype=torch.long)
inputs[i, :n-1] = ids_tensor[:-1]
# recall -1 is the ignore index, so mask out targets where mask is 0
row_targets = ids_tensor[1:]
# mask[1:] omits the mask for the BOS token, which is never a target atm so it's ok
mask_tensor = torch.tensor(mask[1:], dtype=torch.long)
row_targets[mask_tensor == 0] = -1 # mask out targets where mask is 0
targets[i, :n-1] = row_targets
inputs = inputs.to(device) # move to device
targets = targets.to(device)
return inputs, targets
# iterates over the dataset in epochs, tokenizes
batch = []
while True:
for i in range(ddp_rank, len(dataset), ddp_world_size):
doc = dataset[i]
ids, mask = tokenizer.render_conversation(doc)
batch.append((ids, mask))
if len(batch) == batch_size:
yield collate_and_yield(batch)
batch = []
examples_per_step = device_batch_size * ddp_world_size
print0(f"Target examples per step: {target_examples_per_step}")
print0(f"Device batch size: {device_batch_size}")
print0(f"Examples per step is device_batch_size * ddp_world_size: {examples_per_step}")
assert target_examples_per_step % examples_per_step == 0, "Target examples per step must be divisible by examples per step"
grad_accum_steps = target_examples_per_step // examples_per_step
print0(f"=> Setting grad accum steps: {grad_accum_steps}")
if num_iterations == -1:
# derive num_iterations from num_epochs and the size of the dataset
assert num_epochs > 0, "num_epochs must be positive if num_iterations is -1"
num_iterations = (len(train_ds) // target_examples_per_step) * num_epochs
train_loader = sft_data_generator(train_ds, batch_size=device_batch_size)
build_val_loader = lambda: sft_data_generator(val_ds, batch_size=device_batch_size)
# -----------------------------------------------------------------------------
# Initialize the Optimizer
optimizers = model.setup_optimizers(
unembedding_lr=unembedding_lr,
embedding_lr=embedding_lr,
matrix_lr=matrix_lr,
weight_decay=weight_decay,
)
# Set the initial learning rate as a fraction of the base learning rate
for opt in optimizers:
for group in opt.param_groups:
group["lr"] = group["lr"] * init_lr_frac
group["initial_lr"] = group["lr"] # save the initial learning so we can decay easily later
# -----------------------------------------------------------------------------
# Training loop
# Learning rate scheduler
def get_lr_multiplier(it):
lrm = 1.0 - it / num_iterations
return lrm
# Go!
step = 0
train_iter = iter(train_loader)
for step in range(num_iterations):
last_step = step == num_iterations - 1
# evaluate the validation loss
if last_step or step % eval_every == 0:
model.eval()
val_iter = iter(build_val_loader())
losses = []
for _ in range(eval_steps):
val_inputs, val_targets = next(val_iter)
with torch.no_grad(), autocast_ctx:
loss = model(val_inputs, val_targets)
losses.append(loss)
val_loss = torch.stack(losses).mean() # average over eval_steps
if ddp:
dist.all_reduce(val_loss, op=dist.ReduceOp.AVG) # average over ranks
val_loss = val_loss.item()
print0(f"Step {step:05d} | Validation loss: {val_loss:.6f}")
wandb_run.log({
"step": step,
"val_loss": val_loss,
})
model.train()
# evaluate accuracy of the multiple choice tasks (which are quick to run)
if last_step or (step > 0 and step % eval_metrics_every == 0):
model.eval()
metrics = {}
with torch.no_grad(), autocast_ctx:
# note that because these are inside no_grad, we can usually afford to at least ~2X the batch size
metrics["mmlu_acc"] = run_chat_eval("MMLU", model, tokenizer, engine, batch_size=device_batch_size*2, max_problems=eval_metrics_max_problems)
metrics["arc_easy_acc"] = run_chat_eval("ARC-Easy", model, tokenizer, engine, batch_size=device_batch_size*2, max_problems=eval_metrics_max_problems)
metrics_str = ', '.join(f'{k}: {v:.6f}' for k, v in metrics.items())
print0(f"Step {step:05d} | {metrics_str}")
wandb_run.log({
"step": step,
**metrics,
})
model.train()
if last_step:
break
# evaluate the gradient
num_tokens = torch.tensor(0, device=device) # the number of "active" tokens of supervision seen
for micro_step in range(grad_accum_steps):
train_inputs, train_targets = next(train_iter)
with autocast_ctx:
loss = model(train_inputs, train_targets)
train_loss = loss.detach() # for logging
loss = loss / grad_accum_steps # each .backward() is a grad sum => normalize loss here
loss.backward() # accumulate the gradient
num_tokens += (train_targets >= 0).sum()
if ddp:
dist.all_reduce(num_tokens, op=dist.ReduceOp.SUM) # sum over ranks
# learning rate scheduler
lrm = get_lr_multiplier(step)
for opt in optimizers:
for group in opt.param_groups:
group["lr"] = group["initial_lr"] * lrm
# step the optimizers
for opt in optimizers:
opt.step()
model.zero_grad(set_to_none=True)
# logging
train_loss_item = train_loss.item()
num_tokens_item = num_tokens.item()
print0(f"Step {step:05d}/{num_iterations:05d} | Training loss: {train_loss_item:.6f}| lrm: {lrm:.6f}| num_tokens: {num_tokens_item:,}")
wandb_run.log({
"step": step,
"lrm": lrm,
"train_loss": train_loss_item,
"num_tokens": num_tokens_item,
})
step += 1
# Save the model at the end of the run
if master_process:
base_dir = get_base_dir()
depth = model.config.n_layer
model_tag = f"t4_d{depth}" # base the model tag on the depth of the base model
checkpoint_dir = os.path.join(base_dir, "t4_chatsft_checkpoints", model_tag)
model_config_kwargs = model.config.__dict__ # slightly naughty, abusing the simplicity of GPTConfig, TODO nicer
save_checkpoint(
checkpoint_dir,
step,
model.state_dict(),
None, # note: we don't bother to save the optimizer state
{
"step": step,
"val_loss": val_loss,
**metrics,
"model_config": model_config_kwargs,
}
)
print0(f"✅ T4 SFT完成模型保存到: {checkpoint_dir}")
# Log to report
from nanochat.report import get_report
get_report().log(section="T4 Chat SFT", data=[
user_config,
{
"Training rows": len(train_ds),
"Number of iterations": num_iterations,
"Training loss": train_loss_item,
"Validation loss": val_loss,
},
])
# Cleanup
wandb_run.finish()
compute_cleanup()