This commit is contained in:
Maxkrvo 2026-02-10 14:41:19 -05:00 committed by GitHub
commit dbd1e7e70b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -487,13 +487,15 @@ while True:
# evaluate the gradient
synchronize()
t0 = time.time()
train_loss = 0.0
for micro_step in range(grad_accum_steps):
with autocast_ctx:
loss = model(x, y)
train_loss = loss.detach() # for logging
train_loss += loss.detach() # accumulate for logging
loss = loss / grad_accum_steps # each .backward() is a grad sum => normalize loss here
loss.backward()
x, y, dataloader_state_dict = next(train_loader) # prefetch the next batch while the GPU is busy with forward/backward
train_loss = train_loss / grad_accum_steps # average across micro steps
# step the optimizer
lrm = get_lr_multiplier(step)
muon_momentum = get_muon_momentum(step)