Fix train_loss to average all steps instead of keeping only the last

This commit is contained in:
Max Kruijs Voorberge 2026-02-08 17:55:38 +01:00
parent aeff095e97
commit 2ae28292aa

View File

@ -485,13 +485,15 @@ while True:
# evaluate the gradient
synchronize()
t0 = time.time()
train_loss = 0.0
for micro_step in range(grad_accum_steps):
with autocast_ctx:
loss = model(x, y)
train_loss = loss.detach() # for logging
train_loss += loss.detach() # accumulate for logging
loss = loss / grad_accum_steps # each .backward() is a grad sum => normalize loss here
loss.backward()
x, y, dataloader_state_dict = next(train_loader) # prefetch the next batch while the GPU is busy with forward/backward
train_loss = train_loss / grad_accum_steps # average across micro steps
# step the optimizer
lrm = get_lr_multiplier(step)
muon_momentum = get_muon_momentum(step)