This commit is contained in:
Sofie Van Landeghem 2026-01-18 22:04:17 +05:00 committed by GitHub
commit d70e15083c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -216,18 +216,20 @@ for step in range(num_iterations):
break
# evaluate the gradient
steps = [next(train_loader) for _ in range(grad_accum_steps)]
num_tokens = torch.tensor(0, device=device) # the number of "active" tokens of supervision seen
for micro_step in range(grad_accum_steps):
train_inputs, train_targets = next(train_loader)
with autocast_ctx:
loss = model(train_inputs, train_targets)
train_loss = loss.detach() # for logging
loss = loss / grad_accum_steps # each .backward() is a grad sum => normalize loss here
loss.backward() # accumulate the gradient
num_tokens += (train_targets >= 0).sum()
num_tokens += sum((targets >= 0).sum() for _, targets in steps)
if ddp:
dist.all_reduce(num_tokens, op=dist.ReduceOp.SUM) # sum over ranks
for micro_step, (train_inputs, train_targets) in enumerate(steps):
with autocast_ctx:
loss = model(train_inputs, train_targets, loss_reduction='sum')
loss = loss / num_tokens # normalize loss here
train_loss = loss.detach() # for logging
loss.backward() # accumulate the gradient
# learning rate scheduler
lrm = get_lr_multiplier(step)
for opt in optimizers: