mirror of
https://github.com/karpathy/nanochat.git
synced 2026-05-08 08:49:53 +00:00
Merge ab89d04dca into dc54a1a307
This commit is contained in:
commit
43113e563a
|
|
@ -234,6 +234,7 @@ def get_peak_flops(device_name: str) -> float:
|
|||
(["grace blackwell"], 2.5e15),
|
||||
(["b200"], 2.25e15),
|
||||
(["b100"], 1.8e15),
|
||||
(["gb10"], 209e12),
|
||||
# NVIDIA Hopper
|
||||
(["h200", "nvl"], 836e12),
|
||||
(["h200", "pcie"], 836e12),
|
||||
|
|
|
|||
|
|
@ -245,6 +245,29 @@ def disable_fp8(model):
|
|||
orig_model = model # original, uncompiled model, for saving raw model state_dict and for inference/evaluation (because the shapes may change shape)
|
||||
model = torch.compile(model, dynamic=False) # the inputs to model will never change shape so dynamic=False is safe
|
||||
|
||||
# Multi-node compile warmup. torch.compile is lazy, so if one node has cached kernels from a previous run and the other doesn't, the fast rank's first forward takes ms while the slow rank compiles for 30-120+s.
|
||||
# DDP's async all-reduce lets the fast rank race ahead, finish the loop, and the slow rank's NCCL ops then lose their peer and the watchdog kills the job (see Edward Yang "State of torch.compile for training" Aug 2025).
|
||||
# Fix: trigger compilation on all ranks here with a dummy fwd+bwd, then barrier. Only needed for multi-node (ddp_world_size > 1).
|
||||
if ddp and ddp_world_size > 1:
|
||||
warmup_x = torch.randint(0, model_config.vocab_size, (1, model_config.sequence_len), device=device)
|
||||
warmup_y = torch.randint(0, model_config.vocab_size, (1, model_config.sequence_len), device=device)
|
||||
# 1) Compile the FP8 training graph (forward + backward)
|
||||
warmup_loss = model(warmup_x, warmup_y)
|
||||
warmup_loss.backward()
|
||||
model.zero_grad(set_to_none=True)
|
||||
del warmup_loss
|
||||
# 2) Compile the BF16 eval graph (forward only, FP8 disabled). disable_fp8 swaps Float8Linear -> Linear, which invalidates torch.compile's guards and forces recompilation.
|
||||
# Without this warmup the recompile happens at the first eval step on top of full training memory - on UMA systems (DGX Spark) the spike can exceed physical DRAM and crash.
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
with disable_fp8(model):
|
||||
model(warmup_x)
|
||||
model.train()
|
||||
del warmup_x, warmup_y
|
||||
torch.cuda.empty_cache()
|
||||
dist.barrier()
|
||||
print0("Multi-node compile warmup complete")
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Scaling laws and muP extrapolations to determine the optimal training horizon, batch size, learning rates, weight decay.
|
||||
|
||||
|
|
@ -402,6 +425,8 @@ else:
|
|||
min_val_bpb = loop_state["min_val_bpb"]
|
||||
smooth_train_loss = loop_state["smooth_train_loss"]
|
||||
total_training_time = loop_state["total_training_time"]
|
||||
# mfu is only assigned in the training loop. Initialize here so the report below doesn't NameError when --resume-from-step == num_iterations and the loop body never executes (can happen on a homelab run resuming from its final checkpoint).
|
||||
mfu = 0
|
||||
|
||||
# Figure out the needed gradient accumulation micro-steps to reach the desired total batch size per step
|
||||
tokens_per_fwdbwd = args.device_batch_size * args.max_seq_len # tokens per iteration for a single rank
|
||||
|
|
|
|||
|
|
@ -118,6 +118,19 @@ for name, fallback, source in [
|
|||
|
||||
orig_model = model
|
||||
model = torch.compile(model, dynamic=False)
|
||||
|
||||
# Multi-node compile warmup: same fix as in base_train.py, trigger kernel compilation on all ranks then barrier so nobody races ahead and trips the NCCL watchdog.
|
||||
if ddp and ddp_world_size > 1:
|
||||
warmup_x = torch.randint(0, model.config.vocab_size, (1, model.config.sequence_len), device=device)
|
||||
warmup_y = torch.randint(0, model.config.vocab_size, (1, model.config.sequence_len), device=device)
|
||||
warmup_loss = model(warmup_x, warmup_y)
|
||||
warmup_loss.backward()
|
||||
model.zero_grad(set_to_none=True)
|
||||
del warmup_x, warmup_y, warmup_loss
|
||||
torch.cuda.empty_cache()
|
||||
dist.barrier()
|
||||
print0("Multi-node compile warmup complete")
|
||||
|
||||
depth = model.config.n_layer
|
||||
num_flops_per_token = model.estimate_flops()
|
||||
tokens_per_fwdbwd = args.device_batch_size * args.max_seq_len # tokens per iteration for a single rank
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user