diff --git a/nanochat/dataloader.py b/nanochat/dataloader.py index 4cb2279..7f3c3d1 100644 --- a/nanochat/dataloader.py +++ b/nanochat/dataloader.py @@ -35,6 +35,17 @@ def _document_batches(split, resume_state_dict, tokenizer_batch_size): warn_on_legacy = ddp_rank == 0 and split == "train" # rank 0 on train split will warn on legacy parquet_paths = list_parquet_files(warn_on_legacy=warn_on_legacy) assert len(parquet_paths) != 0, "No dataset parquet files found, did you run dataset.py?" + + # Split parquet files: last file for validation, rest for training + # Handle edge case: single file scenario + if len(parquet_paths) == 1: + import warnings + warnings.warn( + "Only 1 parquet file found. " + "Both train and val will use the same data, which may cause overfitting. " + "Consider splitting your dataset into multiple parquet files.", + UserWarning + ) parquet_paths = parquet_paths[:-1] if split == "train" else parquet_paths[-1:] resume_pq_idx = resume_state_dict["pq_idx"] if resume_state_dict is not None else 0