diff --git a/scripts/chat_sft.py b/scripts/chat_sft.py index 9277cf9..0b661e9 100644 --- a/scripts/chat_sft.py +++ b/scripts/chat_sft.py @@ -172,8 +172,8 @@ def get_lr_multiplier(it): # Go! step = 0 -for step in range(num_iterations): - last_step = step == num_iterations - 1 +for step in range(num_iterations + 1): + last_step = step == num_iterations # evaluate the validation loss if last_step or step % args.eval_every == 0: