mirror of
https://github.com/karpathy/nanochat.git
synced 2026-02-07 11:09:55 +00:00
handle case when grad_clip is 0.0, call .item() once only
This commit is contained in:
parent
03939756bc
commit
3c43ef370c
|
|
@ -274,6 +274,7 @@ for step in range(num_iterations + 1):
|
|||
# gradient clipping (TODO possibly experiment with)
|
||||
if grad_clip > 0.0:
|
||||
grad_norm = torch.nn.utils.clip_grad_norm_(orig_model.parameters(), grad_clip)
|
||||
grad_norm = grad_norm.item()
|
||||
# step the optimizers
|
||||
lrm = get_lr_multiplier(step)
|
||||
for opt in optimizers:
|
||||
|
|
@ -300,19 +301,28 @@ for step in range(num_iterations + 1):
|
|||
mfu = 100 * flops_per_sec / promised_flops_per_sec_h100 # in %
|
||||
if step > 10:
|
||||
total_training_time += dt # only count the time after the first 10 steps
|
||||
print0(f"step {step:05d}/{num_iterations:05d} ({pct_done:.2f}%) | loss: {debiased_smooth_loss:.6f} | grad_norm: {grad_norm.item():.5f} | lrm: {lrm:.2f} | dt: {dt * 1000:.2f}ms | tok/sec: {tok_per_sec:,} | mfu: {mfu:.2f} | total time: {total_training_time/60:.2f}m")
|
||||
print0(
|
||||
f"step {step:05d}/{num_iterations:05d} ({pct_done:.2f}%) | "
|
||||
f"loss: {debiased_smooth_loss:.6f} | "
|
||||
+ (f"grad_norm: {grad_norm.item():.5f} | " if grad_clip > 0.0 else "")
|
||||
+ f"lrm: {lrm:.2f} | dt: {dt * 1000:.2f}ms | "
|
||||
f"tok/sec: {tok_per_sec:,} | mfu: {mfu:.2f} | "
|
||||
f"total time: {total_training_time/60:.2f}m"
|
||||
)
|
||||
if step % 100 == 0:
|
||||
wandb_run.log({
|
||||
log_data = {
|
||||
"step": step,
|
||||
"total_training_flops": flops_so_far,
|
||||
"total_training_time": total_training_time,
|
||||
"train/loss": debiased_smooth_loss,
|
||||
"train/grad_norm": grad_norm.item(),
|
||||
"train/lrm": lrm,
|
||||
"train/dt": dt,
|
||||
"train/tok_per_sec": tok_per_sec,
|
||||
"train/mfu": mfu,
|
||||
})
|
||||
}
|
||||
if grad_clip > 0.0:
|
||||
log_data["train/grad_norm"] = grad_norm
|
||||
wandb_run.log(log_data)
|
||||
|
||||
# print a few more stats
|
||||
print0(f"Peak memory usage: {get_max_memory() / 1024 / 1024:.2f}MiB")
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user