refactor part 2

This commit is contained in:
svlandeg 2025-12-30 11:12:37 +01:00
parent 35ec43822c
commit fc565d7294

View File

@ -225,9 +225,6 @@ for step in range(num_iterations):
train_loss = loss.detach() # for logging
loss = loss / grad_accum_steps # each .backward() is a grad sum => normalize loss here
loss.backward() # accumulate the gradient
num_tokens += (train_targets >= 0).sum()
if ddp:
dist.all_reduce(num_tokens, op=dist.ReduceOp.SUM) # sum over ranks
# learning rate scheduler
lrm = get_lr_multiplier(step)