From b48d210795493a9e6c85dca33b31223cc154da7e Mon Sep 17 00:00:00 2001 From: kibitzing Date: Wed, 15 Oct 2025 08:56:58 +0000 Subject: [PATCH] Fix gradient accumulation for variable length sequences --- scripts/chat_sft.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/scripts/chat_sft.py b/scripts/chat_sft.py index 8389deb..6969b13 100644 --- a/scripts/chat_sft.py +++ b/scripts/chat_sft.py @@ -34,7 +34,7 @@ from tasks.smoltalk import SmolTalk # SFT Hyperparameters run = "dummy" # wandb run name default ("dummy" is special - we won't log to wandb) # input model options -source = "mid" # base|mid , which checkpoint to load the model from (base model or midtrained model) +source = "base" # base|mid , which checkpoint to load the model from (base model or midtrained model) model_tag = None # model tag to load the model from (base model or midtrained model) step = None # step to load the model from (base model or midtrained model) # compute/precision @@ -208,18 +208,24 @@ for step in range(num_iterations): break # evaluate the gradient + total_loss_sum = torch.tensor(0.0, device=device) # sum of losses num_tokens = torch.tensor(0, device=device) # the number of "active" tokens of supervision seen for micro_step in range(grad_accum_steps): train_inputs, train_targets = next(train_iter) with autocast_ctx: - loss = model(train_inputs, train_targets) - train_loss = loss.detach() # for logging - loss = loss / grad_accum_steps # each .backward() is a grad sum => normalize loss here + loss = model(train_inputs, train_targets, loss_reduction='sum') + total_loss_sum += loss.detach() # for logging loss.backward() # accumulate the gradient num_tokens += (train_targets >= 0).sum() if ddp: + dist.all_reduce(total_loss_sum, op=dist.ReduceOp.SUM) dist.all_reduce(num_tokens, op=dist.ReduceOp.SUM) # sum over ranks + # Scale gradients by total number of tokens + for param in model.parameters(): + if param.grad is not None: + param.grad.div_(num_tokens.item()) + # learning rate scheduler lrm = get_lr_multiplier(step) for opt in optimizers: @@ -232,8 +238,8 @@ for step in range(num_iterations): model.zero_grad(set_to_none=True) # logging - train_loss_item = train_loss.item() num_tokens_item = num_tokens.item() + train_loss_item = total_loss_sum.item() / num_tokens_item print0(f"Step {step:05d}/{num_iterations:05d} | Training loss: {train_loss_item:.6f}| lrm: {lrm:.6f}| num_tokens: {num_tokens_item:,}") wandb_run.log({ "step": step,