From 35ec43822c3d0c2998ad34b1214832655cd67cb6 Mon Sep 17 00:00:00 2001 From: svlandeg Date: Tue, 30 Dec 2025 11:05:48 +0100 Subject: [PATCH] refactor out steps beforehand --- scripts/chat_sft.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/scripts/chat_sft.py b/scripts/chat_sft.py index bbb33e06..3bf38ae3 100644 --- a/scripts/chat_sft.py +++ b/scripts/chat_sft.py @@ -211,9 +211,15 @@ for step in range(num_iterations): break # evaluate the gradient + steps = [next(train_loader) for _ in range(grad_accum_steps)] num_tokens = torch.tensor(0, device=device) # the number of "active" tokens of supervision seen - for micro_step in range(grad_accum_steps): - train_inputs, train_targets = next(train_loader) + num_tokens += sum((targets >= 0).sum() for _, targets in steps) + + if ddp: + dist.all_reduce(num_tokens, op=dist.ReduceOp.SUM) # sum over ranks + + for micro_step, (train_inputs, train_targets) in enumerate(steps): + batch_num_tokens = (train_targets >= 0).sum() with autocast_ctx: loss = model(train_inputs, train_targets) train_loss = loss.detach() # for logging