diff --git a/scripts/mid_train.py b/scripts/mid_train.py index 6c2b82f..ebb6c42 100644 --- a/scripts/mid_train.py +++ b/scripts/mid_train.py @@ -139,7 +139,7 @@ def mid_data_generator(split): last_step = True # toggle last_step to True, which will terminate the training loop # Stopping condition to respect num_iterations, if given it += 1 - if num_iterations > 0 and it >= num_iterations: + if num_iterations > 0 and it >= num_iterations and split == "train": last_step = True # toggle last_step to True, which will terminate the training loop # Build up inputs/targets and yield for i in range(needed_tokens):