From 01ea71be39437ec076e33e80873cf4fa28f31ff2 Mon Sep 17 00:00:00 2001 From: sunyujun03 Date: Mon, 8 Dec 2025 00:10:19 -0600 Subject: [PATCH] Fix distributed Parquet dataloader resume for multi-epoch training --- nanochat/dataloader.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/nanochat/dataloader.py b/nanochat/dataloader.py index 3271298..6be9820 100644 --- a/nanochat/dataloader.py +++ b/nanochat/dataloader.py @@ -29,17 +29,22 @@ def tokenizing_distributed_data_loader_with_state(B, T, split, tokenizer_threads parquet_paths = parquet_paths[:-1] if split == "train" else parquet_paths[-1:] resume_pq_idx = resume_state_dict["pq_idx"] if resume_state_dict is not None else 0 resume_rg_idx = resume_state_dict["rg_idx"] if resume_state_dict is not None else None + first_pass = True pq_idx = resume_pq_idx # we kick off parquet files at the resume index (or by default just 0) while True: # iterate infinitely (multi-epoch) + pq_idx = resume_pq_idx if first_pass else 0 while pq_idx < len(parquet_paths): # iterate over all parquet files filepath = parquet_paths[pq_idx] pf = pq.ParquetFile(filepath) # Start from resume point if resuming on same file, otherwise from DDP rank # I know this state resumption is a little bit tricky and a little bit hacky... sigh. - if resume_rg_idx is not None: + if first_pass and (resume_rg_idx is not None) and (pq_idx == resume_pq_idx): base_idx = resume_rg_idx // ddp_world_size # in units of ddp_world_size base_idx += 1 # advance by 1 so that we definitely don't repeat data after resuming rg_idx = base_idx * ddp_world_size + ddp_rank + if rg_idx >= pf.num_row_groups: + pq_idx += 1 + continue resume_rg_idx = None # set to None as we only want to do this a single time else: rg_idx = ddp_rank @@ -51,6 +56,7 @@ def tokenizing_distributed_data_loader_with_state(B, T, split, tokenizer_threads yield batch[i:i+tokenizer_batch_size], (pq_idx, rg_idx) rg_idx += ddp_world_size # advance to the next row group (in DDP) pq_idx += 1 # advance to the next parquet file + first_pass = False batches = document_batches() # Now emit batches of tokens.