This commit is contained in:
Maxkrvo 2026-02-11 09:38:29 -04:00 committed by GitHub
commit d6ee33eda5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -487,13 +487,15 @@ while True:
# evaluate the gradient
synchronize()
t0 = time.time()
train_loss = 0.0
for micro_step in range(grad_accum_steps):
with autocast_ctx:
loss = model(x, y)
train_loss = loss.detach() # for logging
train_loss += loss.detach() # accumulate for logging
loss = loss / grad_accum_steps # each .backward() is a grad sum => normalize loss here
loss.backward()
x, y, dataloader_state_dict = next(train_loader) # prefetch the next batch while the GPU is busy with forward/backward
train_loss = train_loss / grad_accum_steps # average across micro steps
# step the optimizer
lrm = get_lr_multiplier(step)
muon_momentum = get_muon_momentum(step)