diff --git a/scripts/base_train.py b/scripts/base_train.py index 996b2ba..dca7c76 100644 --- a/scripts/base_train.py +++ b/scripts/base_train.py @@ -487,13 +487,15 @@ while True: # evaluate the gradient synchronize() t0 = time.time() + train_loss = 0.0 for micro_step in range(grad_accum_steps): with autocast_ctx: loss = model(x, y) - train_loss = loss.detach() # for logging + train_loss += loss.detach() # accumulate for logging loss = loss / grad_accum_steps # each .backward() is a grad sum => normalize loss here loss.backward() x, y, dataloader_state_dict = next(train_loader) # prefetch the next batch while the GPU is busy with forward/backward + train_loss = train_loss / grad_accum_steps # average across micro steps # step the optimizer lrm = get_lr_multiplier(step) muon_momentum = get_muon_momentum(step)