log grad norm during training

This commit is contained in:
Nitish Pandey 2025-11-04 00:47:31 +05:30
parent a83646e098
commit 03939756bc

View File

@ -273,7 +273,7 @@ for step in range(num_iterations + 1):
x, y = next(train_loader) # prefetch the next batch while the GPU is busy with forward/backward
# gradient clipping (TODO possibly experiment with)
if grad_clip > 0.0:
torch.nn.utils.clip_grad_norm_(orig_model.parameters(), grad_clip)
grad_norm = torch.nn.utils.clip_grad_norm_(orig_model.parameters(), grad_clip)
# step the optimizers
lrm = get_lr_multiplier(step)
for opt in optimizers:
@ -300,13 +300,14 @@ for step in range(num_iterations + 1):
mfu = 100 * flops_per_sec / promised_flops_per_sec_h100 # in %
if step > 10:
total_training_time += dt # only count the time after the first 10 steps
print0(f"step {step:05d}/{num_iterations:05d} ({pct_done:.2f}%) | loss: {debiased_smooth_loss:.6f} | lrm: {lrm:.2f} | dt: {dt * 1000:.2f}ms | tok/sec: {tok_per_sec:,} | mfu: {mfu:.2f} | total time: {total_training_time/60:.2f}m")
print0(f"step {step:05d}/{num_iterations:05d} ({pct_done:.2f}%) | loss: {debiased_smooth_loss:.6f} | grad_norm: {grad_norm.item():.5f} | lrm: {lrm:.2f} | dt: {dt * 1000:.2f}ms | tok/sec: {tok_per_sec:,} | mfu: {mfu:.2f} | total time: {total_training_time/60:.2f}m")
if step % 100 == 0:
wandb_run.log({
"step": step,
"total_training_flops": flops_so_far,
"total_training_time": total_training_time,
"train/loss": debiased_smooth_loss,
"train/grad_norm": grad_norm.item(),
"train/lrm": lrm,
"train/dt": dt,
"train/tok_per_sec": tok_per_sec,