Fix distributed Parquet dataloader resume for multi-epoch training

This commit is contained in:
Andrej 2025-12-08 18:15:02 -08:00 committed by GitHub
commit 72a7cf2bc4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -29,17 +29,22 @@ def tokenizing_distributed_data_loader_with_state(B, T, split, tokenizer_threads
parquet_paths = parquet_paths[:-1] if split == "train" else parquet_paths[-1:]
resume_pq_idx = resume_state_dict["pq_idx"] if resume_state_dict is not None else 0
resume_rg_idx = resume_state_dict["rg_idx"] if resume_state_dict is not None else None
first_pass = True
pq_idx = resume_pq_idx # we kick off parquet files at the resume index (or by default just 0)
while True: # iterate infinitely (multi-epoch)
pq_idx = resume_pq_idx if first_pass else 0
while pq_idx < len(parquet_paths): # iterate over all parquet files
filepath = parquet_paths[pq_idx]
pf = pq.ParquetFile(filepath)
# Start from resume point if resuming on same file, otherwise from DDP rank
# I know this state resumption is a little bit tricky and a little bit hacky... sigh.
if resume_rg_idx is not None:
if first_pass and (resume_rg_idx is not None) and (pq_idx == resume_pq_idx):
base_idx = resume_rg_idx // ddp_world_size # in units of ddp_world_size
base_idx += 1 # advance by 1 so that we definitely don't repeat data after resuming
rg_idx = base_idx * ddp_world_size + ddp_rank
if rg_idx >= pf.num_row_groups:
pq_idx += 1
continue
resume_rg_idx = None # set to None as we only want to do this a single time
else:
rg_idx = ddp_rank
@ -51,6 +56,7 @@ def tokenizing_distributed_data_loader_with_state(B, T, split, tokenizer_threads
yield batch[i:i+tokenizer_batch_size], (pq_idx, rg_idx)
rg_idx += ddp_world_size # advance to the next row group (in DDP)
pq_idx += 1 # advance to the next parquet file
first_pass = False
batches = document_batches()
# Now emit batches of tokens.