diff --git a/scripts/chat_sft.py b/scripts/chat_sft.py index bbb33e06..3bf38ae3 100644 --- a/scripts/chat_sft.py +++ b/scripts/chat_sft.py @@ -211,9 +211,15 @@ for step in range(num_iterations): break # evaluate the gradient + steps = [next(train_loader) for _ in range(grad_accum_steps)] num_tokens = torch.tensor(0, device=device) # the number of "active" tokens of supervision seen - for micro_step in range(grad_accum_steps): - train_inputs, train_targets = next(train_loader) + num_tokens += sum((targets >= 0).sum() for _, targets in steps) + + if ddp: + dist.all_reduce(num_tokens, op=dist.ReduceOp.SUM) # sum over ranks + + for micro_step, (train_inputs, train_targets) in enumerate(steps): + batch_num_tokens = (train_targets >= 0).sum() with autocast_ctx: loss = model(train_inputs, train_targets) train_loss = loss.detach() # for logging