diff --git a/scripts/base_train.py b/scripts/base_train.py index 4e02556..e28c9c7 100644 --- a/scripts/base_train.py +++ b/scripts/base_train.py @@ -274,6 +274,7 @@ for step in range(num_iterations + 1): # gradient clipping (TODO possibly experiment with) if grad_clip > 0.0: grad_norm = torch.nn.utils.clip_grad_norm_(orig_model.parameters(), grad_clip) + grad_norm = grad_norm.item() # step the optimizers lrm = get_lr_multiplier(step) for opt in optimizers: @@ -300,19 +301,28 @@ for step in range(num_iterations + 1): mfu = 100 * flops_per_sec / promised_flops_per_sec_h100 # in % if step > 10: total_training_time += dt # only count the time after the first 10 steps - print0(f"step {step:05d}/{num_iterations:05d} ({pct_done:.2f}%) | loss: {debiased_smooth_loss:.6f} | grad_norm: {grad_norm.item():.5f} | lrm: {lrm:.2f} | dt: {dt * 1000:.2f}ms | tok/sec: {tok_per_sec:,} | mfu: {mfu:.2f} | total time: {total_training_time/60:.2f}m") + print0( + f"step {step:05d}/{num_iterations:05d} ({pct_done:.2f}%) | " + f"loss: {debiased_smooth_loss:.6f} | " + + (f"grad_norm: {grad_norm.item():.5f} | " if grad_clip > 0.0 else "") + + f"lrm: {lrm:.2f} | dt: {dt * 1000:.2f}ms | " + f"tok/sec: {tok_per_sec:,} | mfu: {mfu:.2f} | " + f"total time: {total_training_time/60:.2f}m" + ) if step % 100 == 0: - wandb_run.log({ + log_data = { "step": step, "total_training_flops": flops_so_far, "total_training_time": total_training_time, "train/loss": debiased_smooth_loss, - "train/grad_norm": grad_norm.item(), "train/lrm": lrm, "train/dt": dt, "train/tok_per_sec": tok_per_sec, "train/mfu": mfu, - }) + } + if grad_clip > 0.0: + log_data["train/grad_norm"] = grad_norm + wandb_run.log(log_data) # print a few more stats print0(f"Peak memory usage: {get_max_memory() / 1024 / 1024:.2f}MiB")