small bugfix make mid_train script work even with a tiny number of iterations

This commit is contained in:
Andrej 2025-12-08 18:27:32 -08:00 committed by GitHub
commit 39cccc527f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -112,7 +112,7 @@ val_dataset = TaskMixture([
# DataLoader is defined here, it emits inputs, targets : 2D tensors of shape (device_batch_size, max_seq_len)
# A big problem is that we don't know the final num_iterations in advance. So we create
# these two global variables and update them from within the data generator.
last_step = False # we will toggle this to True when we reach the end of the dataset
last_step = False # we will toggle this to True when we reach the end of the training dataset
approx_progress = 0.0 # will go from 0 to 1 over the course of the epoch
def mid_data_generator(split):
global last_step, approx_progress
@ -139,7 +139,7 @@ def mid_data_generator(split):
last_step = True # toggle last_step to True, which will terminate the training loop
# Stopping condition to respect num_iterations, if given
it += 1
if num_iterations > 0 and it >= num_iterations:
if 0 < num_iterations <= it and split == "train":
last_step = True # toggle last_step to True, which will terminate the training loop
# Build up inputs/targets and yield
for i in range(needed_tokens):