diff --git a/scripts/chat_sft.py b/scripts/chat_sft.py index c0471c43..3559ada6 100644 --- a/scripts/chat_sft.py +++ b/scripts/chat_sft.py @@ -171,8 +171,8 @@ def get_lr_multiplier(it): # Go! step = 0 -for step in range(num_iterations): - last_step = step == num_iterations - 1 +for step in range(num_iterations + 1): + last_step = step == num_iterations # evaluate the validation loss if last_step or step % args.eval_every == 0: