fix: typo

This commit is contained in:
Kartik Vashishta 2026-02-20 02:29:19 +11:00
parent 2dffdc8cf6
commit 0b8fde9f4d
2 changed files with 46 additions and 13 deletions

View File

@ -20,7 +20,7 @@ import torch
import pyarrow.parquet as pq
from nanochat.common import get_dist_info
from nanochat.dataset import list_parquet_files
from nanochat.dataset import get_parquet_paths
def _document_batches(split, resume_state_dict, tokenizer_batch_size):
"""
@ -32,10 +32,9 @@ def _document_batches(split, resume_state_dict, tokenizer_batch_size):
"""
ddp, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info()
parquet_paths = list_parquet_files()
assert len(parquet_paths) != 0, "No dataset parquet files found, did you run dataset.py?"
parquet_paths = parquet_paths[:-1] if split == "train" else parquet_paths[-1:]
parquet_paths = get_parquet_paths(split)
assert len(parquet_paths) != 0, "No dataset parquet files found, did you run dataset.py?"
resume_pq_idx = resume_state_dict["pq_idx"] if resume_state_dict is not None else 0
resume_rg_idx = resume_state_dict["rg_idx"] if resume_state_dict is not None else None
resume_epoch = resume_state_dict.get("epoch", 1) if resume_state_dict is not None else 1

View File

@ -23,6 +23,13 @@ from nanochat.common import get_base_dir
BASE_URL = "https://huggingface.co/datasets/karpathy/fineweb-edu-100b-shuffle/resolve/main"
MAX_SHARD = 1822 # the last datashard is shard_01822.parquet
index_to_filename = lambda index: f"shard_{index:05d}.parquet" # format of the filenames
# Always use a fixed shard for val so that metrics don't depend on how many shards are downloaded
# Keeping pinned to shard_01822.
VAL_SHARD_INDEX = 1822
assert 0 <= VAL_SHARD_INDEX <= MAX_SHARD, "VAL_SHARD_INDEX must be within [0, MAX_SHARD]"
VAL_SHARD_FILENAME = index_to_filename(VAL_SHARD_INDEX)
base_dir = get_base_dir()
DATA_DIR = os.path.join(base_dir, "base_data")
os.makedirs(DATA_DIR, exist_ok=True)
@ -30,25 +37,44 @@ os.makedirs(DATA_DIR, exist_ok=True)
# -----------------------------------------------------------------------------
# These functions are useful utilities to other modules, can/should be imported
def list_parquet_files(data_dir=None):
""" Looks into a data dir and returns full paths to all parquet files. """
def list_parquet_files(data_dir=None, exclude_filenames=()):
"""Looks into a data dir and returns full paths to parquet files."""
data_dir = DATA_DIR if data_dir is None else data_dir
exclude = set(exclude_filenames)
parquet_files = sorted([
f for f in os.listdir(data_dir)
if f.endswith('.parquet') and not f.endswith('.tmp')
if f.endswith(".parquet") and not f.endswith(".tmp") and f not in exclude
])
parquet_paths = [os.path.join(data_dir, f) for f in parquet_files]
return parquet_paths
def get_parquet_paths(split, data_dir=None):
"""
Returns the parquet paths for a split.
Validation is always a fixed shard so that metrics are stable across partial downloads.
"""
assert split in ["train", "val"], "split must be 'train' or 'val'"
data_dir = DATA_DIR if data_dir is None else data_dir
val_path = os.path.join(data_dir, VAL_SHARD_FILENAME)
if split == "val":
if not os.path.exists(val_path):
raise FileNotFoundError(
f"Validation shard {VAL_SHARD_FILENAME} not found in {data_dir}. "
f"Run: python -m nanochat.dataset -n <N> (downloads the val shard too)."
)
return [val_path]
else:
#train split: list files while excluding val (don't add then remove).
return list_parquet_files(data_dir, exclude_filenames=(VAL_SHARD_FILENAME,))
def parquets_iter_batched(split, start=0, step=1):
"""
Iterate through the dataset, in batches of underlying row_groups for efficiency.
- split can be "train" or "val". the last parquet file will be val.
- split can be "train" or "val". validation is always a fixed shard.
- start/step are useful for skipping rows in DDP. e.g. start=rank, step=world_size
"""
assert split in ["train", "val"], "split must be 'train' or 'val'"
parquet_paths = list_parquet_files()
parquet_paths = parquet_paths[:-1] if split == "train" else parquet_paths[-1:]
parquet_paths = get_parquet_paths(split)
for filepath in parquet_paths:
pf = pq.ParquetFile(filepath)
for rg_idx in range(start, pf.num_row_groups, step):
@ -111,13 +137,21 @@ def download_single_file(index):
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Download FineWeb-Edu 100BT dataset shards")
parser.add_argument("-n", "--num-files", type=int, default=-1, help="Number of shards to download (default: -1), -1 = disable")
parser.add_argument("-n", "--num-files", type=int, default=-1, help="Number of training shards to download (default: -1 = all). Validation shard is always included")
parser.add_argument("-w", "--num-workers", type=int, default=4, help="Number of parallel download workers (default: 4)")
args = parser.parse_args()
num = MAX_SHARD + 1 if args.num_files == -1 else min(args.num_files, MAX_SHARD + 1)
ids_to_download = list(range(num))
print(f"Downloading {len(ids_to_download)} shards using {args.num_workers} workers...")
if VAL_SHARD_INDEX not in ids_to_download:
ids_to_download.append(VAL_SHARD_INDEX)
ids_to_download = sorted(set(ids_to_download))
if args.num_files != -1 and args.num_files <= MAX_SHARD:
print(f"Downloading {len(ids_to_download)} shards ({num} train + 1 val) using {args.num_workers} workers...")
else:
print(f"Downloading {len(ids_to_download)} shards using {args.num_workers} workers...")
print(f"Target directory: {DATA_DIR}")
print()
with Pool(processes=args.num_workers) as pool: