This commit is contained in:
Matt Langston 2026-05-05 01:34:07 -04:00 committed by GitHub
commit 43113e563a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 39 additions and 0 deletions

View File

@ -234,6 +234,7 @@ def get_peak_flops(device_name: str) -> float:
(["grace blackwell"], 2.5e15),
(["b200"], 2.25e15),
(["b100"], 1.8e15),
(["gb10"], 209e12),
# NVIDIA Hopper
(["h200", "nvl"], 836e12),
(["h200", "pcie"], 836e12),

View File

@ -245,6 +245,29 @@ def disable_fp8(model):
orig_model = model # original, uncompiled model, for saving raw model state_dict and for inference/evaluation (because the shapes may change shape)
model = torch.compile(model, dynamic=False) # the inputs to model will never change shape so dynamic=False is safe
# Multi-node compile warmup. torch.compile is lazy, so if one node has cached kernels from a previous run and the other doesn't, the fast rank's first forward takes ms while the slow rank compiles for 30-120+s.
# DDP's async all-reduce lets the fast rank race ahead, finish the loop, and the slow rank's NCCL ops then lose their peer and the watchdog kills the job (see Edward Yang "State of torch.compile for training" Aug 2025).
# Fix: trigger compilation on all ranks here with a dummy fwd+bwd, then barrier. Only needed for multi-node (ddp_world_size > 1).
if ddp and ddp_world_size > 1:
warmup_x = torch.randint(0, model_config.vocab_size, (1, model_config.sequence_len), device=device)
warmup_y = torch.randint(0, model_config.vocab_size, (1, model_config.sequence_len), device=device)
# 1) Compile the FP8 training graph (forward + backward)
warmup_loss = model(warmup_x, warmup_y)
warmup_loss.backward()
model.zero_grad(set_to_none=True)
del warmup_loss
# 2) Compile the BF16 eval graph (forward only, FP8 disabled). disable_fp8 swaps Float8Linear -> Linear, which invalidates torch.compile's guards and forces recompilation.
# Without this warmup the recompile happens at the first eval step on top of full training memory - on UMA systems (DGX Spark) the spike can exceed physical DRAM and crash.
model.eval()
with torch.no_grad():
with disable_fp8(model):
model(warmup_x)
model.train()
del warmup_x, warmup_y
torch.cuda.empty_cache()
dist.barrier()
print0("Multi-node compile warmup complete")
# -----------------------------------------------------------------------------
# Scaling laws and muP extrapolations to determine the optimal training horizon, batch size, learning rates, weight decay.
@ -402,6 +425,8 @@ else:
min_val_bpb = loop_state["min_val_bpb"]
smooth_train_loss = loop_state["smooth_train_loss"]
total_training_time = loop_state["total_training_time"]
# mfu is only assigned in the training loop. Initialize here so the report below doesn't NameError when --resume-from-step == num_iterations and the loop body never executes (can happen on a homelab run resuming from its final checkpoint).
mfu = 0
# Figure out the needed gradient accumulation micro-steps to reach the desired total batch size per step
tokens_per_fwdbwd = args.device_batch_size * args.max_seq_len # tokens per iteration for a single rank

View File

@ -118,6 +118,19 @@ for name, fallback, source in [
orig_model = model
model = torch.compile(model, dynamic=False)
# Multi-node compile warmup: same fix as in base_train.py, trigger kernel compilation on all ranks then barrier so nobody races ahead and trips the NCCL watchdog.
if ddp and ddp_world_size > 1:
warmup_x = torch.randint(0, model.config.vocab_size, (1, model.config.sequence_len), device=device)
warmup_y = torch.randint(0, model.config.vocab_size, (1, model.config.sequence_len), device=device)
warmup_loss = model(warmup_x, warmup_y)
warmup_loss.backward()
model.zero_grad(set_to_none=True)
del warmup_x, warmup_y, warmup_loss
torch.cuda.empty_cache()
dist.barrier()
print0("Multi-node compile warmup complete")
depth = model.config.n_layer
num_flops_per_token = model.estimate_flops()
tokens_per_fwdbwd = args.device_batch_size * args.max_seq_len # tokens per iteration for a single rank