grad clip logging and printing and cosmetics

This commit is contained in:
Andrej Karpathy 2025-11-05 21:08:30 +00:00
parent 885a4f25e7
commit c6b7ab7440

View File

@ -271,9 +271,11 @@ for step in range(num_iterations + 1):
loss = loss / grad_accum_steps # each .backward() is a grad sum => normalize loss here loss = loss / grad_accum_steps # each .backward() is a grad sum => normalize loss here
loss.backward() loss.backward()
x, y = next(train_loader) # prefetch the next batch while the GPU is busy with forward/backward x, y = next(train_loader) # prefetch the next batch while the GPU is busy with forward/backward
# gradient clipping (TODO possibly experiment with) # gradient clipping
if grad_clip > 0.0: grad_clip_enabled = grad_clip > 0.0
torch.nn.utils.clip_grad_norm_(orig_model.parameters(), grad_clip) if grad_clip_enabled:
grad_norm_tensor = torch.nn.utils.clip_grad_norm_(orig_model.parameters(), grad_clip)
grad_norm = grad_norm_tensor.item() # GPU tensor -> CPU float (note: cpu-gpu sync point)
# step the optimizers # step the optimizers
lrm = get_lr_multiplier(step) lrm = get_lr_multiplier(step)
for opt in optimizers: for opt in optimizers:
@ -300,9 +302,10 @@ for step in range(num_iterations + 1):
mfu = 100 * flops_per_sec / promised_flops_per_sec_h100 # in % mfu = 100 * flops_per_sec / promised_flops_per_sec_h100 # in %
if step > 10: if step > 10:
total_training_time += dt # only count the time after the first 10 steps total_training_time += dt # only count the time after the first 10 steps
print0(f"step {step:05d}/{num_iterations:05d} ({pct_done:.2f}%) | loss: {debiased_smooth_loss:.6f} | lrm: {lrm:.2f} | dt: {dt * 1000:.2f}ms | tok/sec: {tok_per_sec:,} | mfu: {mfu:.2f} | total time: {total_training_time/60:.2f}m") print_grad_norm = f" grad norm: {grad_norm:.4f} |" if grad_clip_enabled else ""
print0(f"step {step:05d}/{num_iterations:05d} ({pct_done:.2f}%) | loss: {debiased_smooth_loss:.6f} |{print_grad_norm} lrm: {lrm:.2f} | dt: {dt * 1000:.2f}ms | tok/sec: {tok_per_sec:,} | mfu: {mfu:.2f} | total time: {total_training_time/60:.2f}m")
if step % 100 == 0: if step % 100 == 0:
wandb_run.log({ log_data = {
"step": step, "step": step,
"total_training_flops": flops_so_far, "total_training_flops": flops_so_far,
"total_training_time": total_training_time, "total_training_time": total_training_time,
@ -311,7 +314,10 @@ for step in range(num_iterations + 1):
"train/dt": dt, "train/dt": dt,
"train/tok_per_sec": tok_per_sec, "train/tok_per_sec": tok_per_sec,
"train/mfu": mfu, "train/mfu": mfu,
}) }
if grad_clip_enabled:
log_data["train/grad_norm"] = grad_norm
wandb_run.log(log_data)
# print a few more stats # print a few more stats
print0(f"Peak memory usage: {get_max_memory() / 1024 / 1024:.2f}MiB") print0(f"Peak memory usage: {get_max_memory() / 1024 / 1024:.2f}MiB")