From 35ec43822c3d0c2998ad34b1214832655cd67cb6 Mon Sep 17 00:00:00 2001 From: svlandeg Date: Tue, 30 Dec 2025 11:05:48 +0100 Subject: [PATCH 1/4] refactor out steps beforehand --- scripts/chat_sft.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/scripts/chat_sft.py b/scripts/chat_sft.py index bbb33e0..3bf38ae 100644 --- a/scripts/chat_sft.py +++ b/scripts/chat_sft.py @@ -211,9 +211,15 @@ for step in range(num_iterations): break # evaluate the gradient + steps = [next(train_loader) for _ in range(grad_accum_steps)] num_tokens = torch.tensor(0, device=device) # the number of "active" tokens of supervision seen - for micro_step in range(grad_accum_steps): - train_inputs, train_targets = next(train_loader) + num_tokens += sum((targets >= 0).sum() for _, targets in steps) + + if ddp: + dist.all_reduce(num_tokens, op=dist.ReduceOp.SUM) # sum over ranks + + for micro_step, (train_inputs, train_targets) in enumerate(steps): + batch_num_tokens = (train_targets >= 0).sum() with autocast_ctx: loss = model(train_inputs, train_targets) train_loss = loss.detach() # for logging From fc565d7294f804ef2b6d1947897b34659e017192 Mon Sep 17 00:00:00 2001 From: svlandeg Date: Tue, 30 Dec 2025 11:12:37 +0100 Subject: [PATCH 2/4] refactor part 2 --- scripts/chat_sft.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/scripts/chat_sft.py b/scripts/chat_sft.py index 3bf38ae..76ce822 100644 --- a/scripts/chat_sft.py +++ b/scripts/chat_sft.py @@ -225,9 +225,6 @@ for step in range(num_iterations): train_loss = loss.detach() # for logging loss = loss / grad_accum_steps # each .backward() is a grad sum => normalize loss here loss.backward() # accumulate the gradient - num_tokens += (train_targets >= 0).sum() - if ddp: - dist.all_reduce(num_tokens, op=dist.ReduceOp.SUM) # sum over ranks # learning rate scheduler lrm = get_lr_multiplier(step) From 1ca9280328b60369e99c9a93728f2a00be25ae96 Mon Sep 17 00:00:00 2001 From: svlandeg Date: Tue, 30 Dec 2025 11:13:05 +0100 Subject: [PATCH 3/4] fix normalization --- scripts/chat_sft.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/scripts/chat_sft.py b/scripts/chat_sft.py index 76ce822..65b75f2 100644 --- a/scripts/chat_sft.py +++ b/scripts/chat_sft.py @@ -221,9 +221,9 @@ for step in range(num_iterations): for micro_step, (train_inputs, train_targets) in enumerate(steps): batch_num_tokens = (train_targets >= 0).sum() with autocast_ctx: - loss = model(train_inputs, train_targets) + loss = model(train_inputs, train_targets, loss_reduction='sum') + loss = loss / num_tokens # normalize loss here train_loss = loss.detach() # for logging - loss = loss / grad_accum_steps # each .backward() is a grad sum => normalize loss here loss.backward() # accumulate the gradient # learning rate scheduler From 32ce342c8848eeaf09d298690923a078f40ac20e Mon Sep 17 00:00:00 2001 From: svlandeg Date: Tue, 30 Dec 2025 11:32:47 +0100 Subject: [PATCH 4/4] remove batch_num_tokens definition which was only used for experiment logging --- scripts/chat_sft.py | 1 - 1 file changed, 1 deletion(-) diff --git a/scripts/chat_sft.py b/scripts/chat_sft.py index 65b75f2..4afb96f 100644 --- a/scripts/chat_sft.py +++ b/scripts/chat_sft.py @@ -219,7 +219,6 @@ for step in range(num_iterations): dist.all_reduce(num_tokens, op=dist.ReduceOp.SUM) # sum over ranks for micro_step, (train_inputs, train_targets) in enumerate(steps): - batch_num_tokens = (train_targets >= 0).sum() with autocast_ctx: loss = model(train_inputs, train_targets, loss_reduction='sum') loss = loss / num_tokens # normalize loss here