multi-node torch.compile warmup, fixes NCCL watchdog timeouts

trigger compilation on every rank with a dummy fwd+bwd after torch.compile, then a barrier, before the training loop begins. guarded by ddp_world_size > 1. without this, if one node has cached kernels from a prior run and another does not, DDP's async all-reduce lets the fast rank race ahead and the slow rank's NCCL ops lose their peer and the watchdog kills the job (see Edward Yang, "State of torch.compile for training", Aug 2025).

the warmup also pre-compiles the BF16 eval graph (FP8 disabled) so the recompile triggered by disable_fp8 does not happen lazily under full training-memory pressure at the first eval step (can crash UMA systems like DGX Spark; relates to #446).

small drive-bys: GB10 added to the peak-FLOPS table for MFU reporting, and mfu=0 initialized before the loop to avoid NameError on the edge case where --resume-from-step == num_iterations.

context: https://github.com/karpathy/nanochat/discussions/710 (the writeup was produced from my dgx-spark branch at https://github.com/matt-langston/nanochat/tree/dgx-spark, which carries these two PRs plus a DGX-Spark-Bundle-specific speedrun script I kept separate)

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
Matt Langston 2026-04-17 19:06:14 -07:00
parent 0aaca56805
commit ab89d04dca
No known key found for this signature in database
GPG Key ID: 181CABA5854FEEC2
3 changed files with 39 additions and 0 deletions

View File

@ -234,6 +234,7 @@ def get_peak_flops(device_name: str) -> float:
(["grace blackwell"], 2.5e15),
(["b200"], 2.25e15),
(["b100"], 1.8e15),
(["gb10"], 209e12),
# NVIDIA Hopper
(["h200", "nvl"], 836e12),
(["h200", "pcie"], 836e12),

View File

@ -245,6 +245,29 @@ def disable_fp8(model):
orig_model = model # original, uncompiled model, for saving raw model state_dict and for inference/evaluation (because the shapes may change shape)
model = torch.compile(model, dynamic=False) # the inputs to model will never change shape so dynamic=False is safe
# Multi-node compile warmup. torch.compile is lazy, so if one node has cached kernels from a previous run and the other doesn't, the fast rank's first forward takes ms while the slow rank compiles for 30-120+s.
# DDP's async all-reduce lets the fast rank race ahead, finish the loop, and the slow rank's NCCL ops then lose their peer and the watchdog kills the job (see Edward Yang "State of torch.compile for training" Aug 2025).
# Fix: trigger compilation on all ranks here with a dummy fwd+bwd, then barrier. Only needed for multi-node (ddp_world_size > 1).
if ddp and ddp_world_size > 1:
warmup_x = torch.randint(0, model_config.vocab_size, (1, model_config.sequence_len), device=device)
warmup_y = torch.randint(0, model_config.vocab_size, (1, model_config.sequence_len), device=device)
# 1) Compile the FP8 training graph (forward + backward)
warmup_loss = model(warmup_x, warmup_y)
warmup_loss.backward()
model.zero_grad(set_to_none=True)
del warmup_loss
# 2) Compile the BF16 eval graph (forward only, FP8 disabled). disable_fp8 swaps Float8Linear -> Linear, which invalidates torch.compile's guards and forces recompilation.
# Without this warmup the recompile happens at the first eval step on top of full training memory - on UMA systems (DGX Spark) the spike can exceed physical DRAM and crash.
model.eval()
with torch.no_grad():
with disable_fp8(model):
model(warmup_x)
model.train()
del warmup_x, warmup_y
torch.cuda.empty_cache()
dist.barrier()
print0("Multi-node compile warmup complete")
# -----------------------------------------------------------------------------
# Scaling laws and muP extrapolations to determine the optimal training horizon, batch size, learning rates, weight decay.
@ -402,6 +425,8 @@ else:
min_val_bpb = loop_state["min_val_bpb"]
smooth_train_loss = loop_state["smooth_train_loss"]
total_training_time = loop_state["total_training_time"]
# mfu is only assigned in the training loop. Initialize here so the report below doesn't NameError when --resume-from-step == num_iterations and the loop body never executes (can happen on a homelab run resuming from its final checkpoint).
mfu = 0
# Figure out the needed gradient accumulation micro-steps to reach the desired total batch size per step
tokens_per_fwdbwd = args.device_batch_size * args.max_seq_len # tokens per iteration for a single rank

View File

@ -118,6 +118,19 @@ for name, fallback, source in [
orig_model = model
model = torch.compile(model, dynamic=False)
# Multi-node compile warmup: same fix as in base_train.py, trigger kernel compilation on all ranks then barrier so nobody races ahead and trips the NCCL watchdog.
if ddp and ddp_world_size > 1:
warmup_x = torch.randint(0, model.config.vocab_size, (1, model.config.sequence_len), device=device)
warmup_y = torch.randint(0, model.config.vocab_size, (1, model.config.sequence_len), device=device)
warmup_loss = model(warmup_x, warmup_y)
warmup_loss.backward()
model.zero_grad(set_to_none=True)
del warmup_x, warmup_y, warmup_loss
torch.cuda.empty_cache()
dist.barrier()
print0("Multi-node compile warmup complete")
depth = model.config.n_layer
num_flops_per_token = model.estimate_flops()
tokens_per_fwdbwd = args.device_batch_size * args.max_seq_len # tokens per iteration for a single rank