diff --git a/scripts/chat_sft.py b/scripts/chat_sft.py index bbeb1f9..a39fa0d 100644 --- a/scripts/chat_sft.py +++ b/scripts/chat_sft.py @@ -212,18 +212,30 @@ for step in range(num_iterations): break # evaluate the gradient + total_loss_sum = torch.tensor(0.0, device=device) # sum of losses num_tokens = torch.tensor(0, device=device) # the number of "active" tokens of supervision seen for micro_step in range(grad_accum_steps): train_inputs, train_targets = next(train_iter) with autocast_ctx: - loss = model(train_inputs, train_targets) - train_loss = loss.detach() # for logging - loss = loss / grad_accum_steps # each .backward() is a grad sum => normalize loss here + loss = model(train_inputs, train_targets, loss_reduction='sum') + total_loss_sum += loss.detach() # for logging loss.backward() # accumulate the gradient num_tokens += (train_targets >= 0).sum() if ddp: + dist.all_reduce(total_loss_sum, op=dist.ReduceOp.SUM) dist.all_reduce(num_tokens, op=dist.ReduceOp.SUM) # sum over ranks + # scale gradients by total number of tokens + num_tokens_item = num_tokens.item() + if num_tokens_item == 0: + print0(f"Warning: the number of valid tokens in train targets is 0 at step {step}, skipping model update") + model.zero_grad(set_to_none=True) + continue + + for param in model.parameters(): + if param.grad is not None: + param.grad.div_(num_tokens_item) + # learning rate scheduler lrm = get_lr_multiplier(step) for opt in optimizers: @@ -236,8 +248,7 @@ for step in range(num_iterations): model.zero_grad(set_to_none=True) # logging - train_loss_item = train_loss.item() - num_tokens_item = num_tokens.item() + train_loss_item = total_loss_sum.item() / num_tokens_item print0(f"Step {step:05d}/{num_iterations:05d} | Training loss: {train_loss_item:.6f}| lrm: {lrm:.6f}| num_tokens: {num_tokens_item:,}") wandb_run.log({ "step": step,