nanochat/scripts/t4_train.py

351 lines
15 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
针对4个T4 GPU优化的训练脚本
基于base_train.py修改专门为T4 GPU的16GB显存限制进行优化
运行方式:
torchrun --standalone --nproc_per_node=4 -m scripts.t4_train
或者单GPU调试:
python -m scripts.t4_train
"""
import os
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
import time
from contextlib import nullcontext
import wandb
import torch
from nanochat.gpt import GPT, GPTConfig
from nanochat.dataloader import tokenizing_distributed_data_loader
from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, print_banner, get_base_dir, autodetect_device_type
from nanochat.tokenizer import get_tokenizer, get_token_bytes
from nanochat.checkpoint_manager import save_checkpoint
from nanochat.loss_eval import evaluate_bpb
from nanochat.engine import Engine
from scripts.base_eval import evaluate_model
print_banner()
# -----------------------------------------------------------------------------
# T4 GPU优化配置
run = "t4_training" # wandb run name
# Runtime
device_type = "" # cuda|cpu|mps (empty => autodetect)
# Model architecture - 针对T4优化
depth = 12 # 减少深度以适应T4显存限制 (原来20-32)
max_seq_len = 1024 # 减少序列长度以节省显存 (原来2048)
# Training horizon
num_iterations = -1 # explicit number of steps of the optimization (-1 = disable)
target_flops = -1.0 # calculate num_iterations to reach target_flops
target_param_data_ratio = 20 # calculate num_iterations to maintain fixed data:param ratio (Chinchilla=20)
# Optimization - 针对T4优化
device_batch_size = 4 # 大幅减少批次大小以适应T4的16GB显存 (原来32)
total_batch_size = 131072 # 减少总批次大小 (原来524288)
embedding_lr = 0.2
unembedding_lr = 0.004
weight_decay = 0.0
matrix_lr = 0.02
grad_clip = 1.0
warmup_ratio = 0.0
warmdown_ratio = 0.2
final_lr_frac = 0.0
# Evaluation - 针对T4优化
eval_every = 100 # 更频繁的评估 (原来250)
eval_tokens = 5*131072 # 减少评估token数量 (原来20*524288)
core_metric_every = 1000 # 更频繁的核心指标评估 (原来2000)
core_metric_max_per_task = 250 # 减少每个任务的最大样本数 (原来500)
sample_every = 1000 # 更频繁的采样 (原来2000)
# Output
model_tag = "t4_d12" # T4训练的模型标签
# now allow CLI to override the settings via the configurator
config_keys = [k for k,v in globals().items() if not k.startswith('_') and isinstance(v, (int, float, bool, str))]
exec(open(os.path.join('nanochat', 'configurator.py')).read()) # overrides from command line or config file
user_config = {k: globals()[k] for k in config_keys} # will be useful for logging
# -----------------------------------------------------------------------------
# Compute init
device_type = autodetect_device_type() if device_type == "" else device_type
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
master_process = ddp_rank == 0 # this process will do logging, checkpointing etc.
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type == "cuda" else nullcontext()
synchronize = torch.cuda.synchronize if device_type == "cuda" else lambda: None
get_max_memory = torch.cuda.max_memory_allocated if device_type == "cuda" else lambda: 0
# wandb logging init
use_dummy_wandb = run == "dummy" or not master_process
wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat-t4", name=run, config=user_config)
# Tokenizer will be useful for evaluation, also we need the vocab size
tokenizer = get_tokenizer()
token_bytes = get_token_bytes(device=device)
vocab_size = tokenizer.get_vocab_size()
print0(f"Vocab size: {vocab_size:,}")
# Model kwargs are derived from the desired depth of the model
num_layers = depth
model_dim = depth * 64 # aspect ratio 64 (保持与原始配置一致)
num_heads = max(1, (model_dim + 127) // 128) # head dim 128
num_kv_heads = num_heads # default is 1:1 GQA ratio
print0(f"T4优化配置:")
print0(f" num_layers: {num_layers}")
print0(f" model_dim: {model_dim}")
print0(f" num_heads: {num_heads}")
print0(f" num_kv_heads: {num_kv_heads}")
print0(f" max_seq_len: {max_seq_len}")
print0(f" device_batch_size: {device_batch_size}")
# Optimizer / data / training length related hyperparameters
# figure out the needed gradient accumulation to reach the desired total batch size
tokens_per_fwdbwd = device_batch_size * max_seq_len # tokens per iteration for a single rank
world_tokens_per_fwdbwd = tokens_per_fwdbwd * ddp_world_size # total tokens per iteration for all ranks
assert total_batch_size % world_tokens_per_fwdbwd == 0
grad_accum_steps = total_batch_size // world_tokens_per_fwdbwd
print0(f"Tokens / micro-batch / rank: {device_batch_size} x {max_seq_len} = {tokens_per_fwdbwd:,}")
print0(f"Tokens / micro-batch: {world_tokens_per_fwdbwd:,}")
print0(f"Total batch size {total_batch_size:,} => gradient accumulation steps: {grad_accum_steps}")
print0(f"DDP world size: {ddp_world_size}")
# -----------------------------------------------------------------------------
# Initialize the Model
model_config_kwargs = dict(sequence_len=max_seq_len, vocab_size=vocab_size, n_layer=num_layers, n_head=num_heads, n_kv_head=num_kv_heads, n_embd=model_dim)
with torch.device("meta"):
model_config = GPTConfig(**model_config_kwargs)
model = GPT(model_config)
model.to_empty(device=device)
model.init_weights()
orig_model = model # original, uncompiled model, for saving raw model state_dict
model = torch.compile(model, dynamic=False)
num_params = sum(p.numel() for p in model.parameters())
print0(f"Number of parameters: {num_params:,}")
num_flops_per_token = model.estimate_flops()
print0(f"Estimated FLOPs per token: {num_flops_per_token:e}")
# Calculate number of iterations
assert num_iterations > 0 or target_param_data_ratio > 0 or target_flops > 0
if num_iterations > 0:
print0(f"Using user-provided number of iterations: {num_iterations:,}")
elif target_flops > 0:
num_iterations = round(target_flops / (num_flops_per_token * total_batch_size))
print0(f"Calculated number of iterations from target FLOPs: {num_iterations:,}")
elif target_param_data_ratio > 0:
target_tokens = target_param_data_ratio * num_params
num_iterations = target_tokens // total_batch_size
print0(f"Calculated number of iterations from target data:param ratio: {num_iterations:,}")
else:
raise ValueError("No training horizon specified")
total_tokens = total_batch_size * num_iterations
print0(f"Total number of training tokens: {total_tokens:,}")
print0(f"Tokens : Params ratio: {total_batch_size * num_iterations / num_params:.2f}")
print0(f"Total training FLOPs estimate: {num_flops_per_token * total_tokens:e}")
# -----------------------------------------------------------------------------
# Initialize the Optimizer (Muon for Linear layers, AdamW for embedding and lm_head)
optimizers = model.setup_optimizers(unembedding_lr=unembedding_lr, embedding_lr=embedding_lr, matrix_lr=matrix_lr, weight_decay=weight_decay)
adamw_optimizer, muon_optimizer = optimizers
# Initialize the DataLoaders for train/val
base_dir = get_base_dir()
tokens_dir = os.path.join(base_dir, "tokenized_data")
train_loader = tokenizing_distributed_data_loader(device_batch_size, max_seq_len, split="train", device=device)
build_val_loader = lambda: tokenizing_distributed_data_loader(device_batch_size, max_seq_len, split="val", device=device)
x, y = next(train_loader) # kick off load of the very first batch of data
# -----------------------------------------------------------------------------
# Set up hyperparameter schedulers
# Learning rate scheduler
def get_lr_multiplier(it):
warmup_iters = round(warmup_ratio * num_iterations)
warmdown_iters = round(warmdown_ratio * num_iterations)
if it < warmup_iters:
return (it + 1) / warmup_iters
elif it <= num_iterations - warmdown_iters:
return 1.0
else:
progress = (num_iterations - it) / warmdown_iters
return progress * 1.0 + (1 - progress) * final_lr_frac
# Momentum scheduler for Muon optimizer
def get_muon_momentum(it):
frac = min(it / 300, 1)
momentum = (1 - frac) * 0.85 + frac * 0.95
return momentum
# -----------------------------------------------------------------------------
# Training loop
min_val_bpb = float("inf")
smooth_train_loss = 0 # EMA of training loss
ema_beta = 0.9 # EMA decay factor
total_training_time = 0 # total wall-clock time of training
# note that we run +1 steps only so that we can eval and save at the end
for step in range(num_iterations + 1):
last_step = step == num_iterations
flops_so_far = num_flops_per_token * total_batch_size * step
# once in a while: evaluate the val bpb (all ranks participate)
if last_step or step % eval_every == 0:
model.eval()
val_loader = build_val_loader()
eval_steps = eval_tokens // (device_batch_size * max_seq_len * ddp_world_size)
with autocast_ctx:
val_bpb = evaluate_bpb(model, val_loader, eval_steps, token_bytes)
print0(f"Step {step:05d} | Validation bpb: {val_bpb:.4f}")
if val_bpb < min_val_bpb:
min_val_bpb = val_bpb
wandb_run.log({
"step": step,
"total_training_flops": flops_so_far,
"total_training_time": total_training_time,
"val/bpb": val_bpb,
})
model.train()
# once in a while: estimate the CORE metric (all ranks participate)
results = {}
if core_metric_every > 0 and (last_step or (step > 0 and step % core_metric_every == 0)):
model.eval()
with autocast_ctx:
results = evaluate_model(orig_model, tokenizer, device, max_per_task=core_metric_max_per_task)
print0(f"Step {step:05d} | CORE metric: {results['core_metric']:.4f}")
wandb_run.log({
"step": step,
"total_training_flops": flops_so_far,
"core_metric": results["core_metric"],
"centered_results": results["centered_results"],
})
model.train()
# once in a while: sample from the model (only on master process)
if master_process and (last_step or (step > 0 and step % sample_every == 0)):
model.eval()
prompts = [
"The capital of France is",
"The chemical symbol of gold is",
"If yesterday was Friday, then tomorrow will be",
"The opposite of hot is",
"The planets of the solar system are:",
"My favorite color is",
"If 5*x + 3 = 13, then x is",
]
engine = Engine(orig_model, tokenizer)
for prompt in prompts:
tokens = tokenizer(prompt, prepend="<|bos|>")
with autocast_ctx:
sample, _ = engine.generate_batch(tokens, num_samples=1, max_tokens=16, temperature=0)
print0(tokenizer.decode(sample[0]))
model.train()
# save checkpoint at the end of the run (only on master process)
if master_process and last_step:
output_dirname = model_tag if model_tag else f"t4_d{depth}"
checkpoint_dir = os.path.join(base_dir, "t4_checkpoints", output_dirname)
save_checkpoint(
checkpoint_dir,
step,
orig_model.state_dict(),
[opt.state_dict() for opt in optimizers],
{
"step": step,
"val_bpb": val_bpb,
"model_config": model_config_kwargs,
"user_config": user_config,
"device_batch_size": device_batch_size,
"max_seq_len": max_seq_len,
}
)
print0(f"✅ T4训练完成模型保存到: {checkpoint_dir}")
if last_step:
break
# -------------------------------------------------------------------------
# single training step
synchronize()
t0 = time.time()
for micro_step in range(grad_accum_steps):
with autocast_ctx:
loss = model(x, y)
train_loss = loss.detach()
loss = loss / grad_accum_steps
loss.backward()
x, y = next(train_loader)
# gradient clipping
if grad_clip > 0.0:
torch.nn.utils.clip_grad_norm_(orig_model.parameters(), grad_clip)
# step the optimizers
lrm = get_lr_multiplier(step)
for opt in optimizers:
for group in opt.param_groups:
group["lr"] = group["initial_lr"] * lrm
muon_momentum = get_muon_momentum(step)
for group in muon_optimizer.param_groups:
group["momentum"] = muon_momentum
for opt in optimizers:
opt.step()
model.zero_grad(set_to_none=True)
synchronize()
t1 = time.time()
dt = t1 - t0
# logging
smooth_train_loss = ema_beta * smooth_train_loss + (1 - ema_beta) * train_loss.item()
debiased_smooth_loss = smooth_train_loss / (1 - ema_beta**(step + 1))
pct_done = 100 * step / num_iterations
tok_per_sec = int(world_tokens_per_fwdbwd / dt)
flops_per_sec = num_flops_per_token * total_batch_size / dt
# T4的峰值性能约为65 TFLOPS (bfloat16)
promised_flops_per_sec_t4 = 65e12 * ddp_world_size
mfu = 100 * flops_per_sec / promised_flops_per_sec_t4
if step > 10:
total_training_time += dt
print0(f"step {step:05d}/{num_iterations:05d} ({pct_done:.2f}%) | loss: {debiased_smooth_loss:.6f} | lrm: {lrm:.2f} | dt: {dt * 1000:.2f}ms | tok/sec: {tok_per_sec:,} | mfu: {mfu:.2f} | total time: {total_training_time/60:.2f}m")
if step % 50 == 0: # 更频繁的wandb日志记录
wandb_run.log({
"step": step,
"total_training_flops": flops_so_far,
"total_training_time": total_training_time,
"train/loss": debiased_smooth_loss,
"train/lrm": lrm,
"train/dt": dt,
"train/tok_per_sec": tok_per_sec,
"train/mfu": mfu,
})
# print final stats
print0(f"Peak memory usage: {get_max_memory() / 1024 / 1024:.2f}MiB")
print0(f"Total training time: {total_training_time/60:.2f}m")
print0(f"Minimum validation bpb: {min_val_bpb:.4f}")
# Log to report
from nanochat.report import get_report
get_report().log(section="T4 Base model training", data=[
user_config,
{
"Number of parameters": num_params,
"Number of FLOPs per token": f"{num_flops_per_token:e}",
"Calculated number of iterations": num_iterations,
"Number of training tokens": total_tokens,
"Tokens : Params ratio": total_batch_size * num_iterations / num_params,
"DDP world size": ddp_world_size,
"warmup_ratio": warmup_ratio,
"warmdown_ratio": warmdown_ratio,
"final_lr_frac": final_lr_frac,
},
{
"Minimum validation bpb": min_val_bpb,
"Final validation bpb": val_bpb,
"CORE metric estimate": results.get("core_metric", None),
"MFU %": f"{mfu:.2f}%",
"Total training flops": f"{flops_so_far:e}",
"Total training time": f"{total_training_time/60:.2f}m",
"Peak memory usage": f"{get_max_memory() / 1024 / 1024:.2f}MiB",
}
])
# cleanup
wandb_run.finish()
compute_cleanup()