diff --git a/scripts/chat_sft.py b/scripts/chat_sft.py index 9277cf9..1272b01 100644 --- a/scripts/chat_sft.py +++ b/scripts/chat_sft.py @@ -216,18 +216,20 @@ for step in range(num_iterations): break # evaluate the gradient + steps = [next(train_loader) for _ in range(grad_accum_steps)] num_tokens = torch.tensor(0, device=device) # the number of "active" tokens of supervision seen - for micro_step in range(grad_accum_steps): - train_inputs, train_targets = next(train_loader) - with autocast_ctx: - loss = model(train_inputs, train_targets) - train_loss = loss.detach() # for logging - loss = loss / grad_accum_steps # each .backward() is a grad sum => normalize loss here - loss.backward() # accumulate the gradient - num_tokens += (train_targets >= 0).sum() + num_tokens += sum((targets >= 0).sum() for _, targets in steps) + if ddp: dist.all_reduce(num_tokens, op=dist.ReduceOp.SUM) # sum over ranks + for micro_step, (train_inputs, train_targets) in enumerate(steps): + with autocast_ctx: + loss = model(train_inputs, train_targets, loss_reduction='sum') + loss = loss / num_tokens # normalize loss here + train_loss = loss.detach() # for logging + loss.backward() # accumulate the gradient + # learning rate scheduler lrm = get_lr_multiplier(step) for opt in optimizers: