diff --git a/scripts/mid_train.py b/scripts/mid_train.py index ebb6c42..60c7bbc 100644 --- a/scripts/mid_train.py +++ b/scripts/mid_train.py @@ -112,7 +112,7 @@ val_dataset = TaskMixture([ # DataLoader is defined here, it emits inputs, targets : 2D tensors of shape (device_batch_size, max_seq_len) # A big problem is that we don't know the final num_iterations in advance. So we create # these two global variables and update them from within the data generator. -last_step = False # we will toggle this to True when we reach the end of the dataset +last_step = False # we will toggle this to True when we reach the end of the training dataset approx_progress = 0.0 # will go from 0 to 1 over the course of the epoch def mid_data_generator(split): global last_step, approx_progress