diff --git a/dev/repackage_data_reference.py b/dev/repackage_data_reference.py index 32980a8..d0baeec 100644 --- a/dev/repackage_data_reference.py +++ b/dev/repackage_data_reference.py @@ -46,13 +46,14 @@ shard_characters = 0 total_docs_processed = 0 total_time_spent = 0 t0 = time.time() -for doc in ds: +for doc_idx, doc in enumerate(ds): text = doc['text'] shard_docs.append(text) shard_characters += len(text) collected_enough_chars = shard_characters >= chars_per_shard docs_multiple_of_row_group_size = len(shard_docs) % row_group_size == 0 - if collected_enough_chars and docs_multiple_of_row_group_size: # leads to ~100MB of text (compressed) + last_doc = doc_idx >= ndocs - 1 + if last_doc or (collected_enough_chars and docs_multiple_of_row_group_size): # leads to ~100MB of text (compressed) shard_path = os.path.join(output_dir, f"shard_{shard_index:05d}.parquet") shard_table = pa.Table.from_pydict({"text": shard_docs}) pq.write_table(