mirror of
https://github.com/karpathy/nanochat.git
synced 2026-03-23 05:13:22 +00:00
fix: typo
This commit is contained in:
parent
2dffdc8cf6
commit
0b8fde9f4d
|
|
@ -20,7 +20,7 @@ import torch
|
|||
import pyarrow.parquet as pq
|
||||
|
||||
from nanochat.common import get_dist_info
|
||||
from nanochat.dataset import list_parquet_files
|
||||
from nanochat.dataset import get_parquet_paths
|
||||
|
||||
def _document_batches(split, resume_state_dict, tokenizer_batch_size):
|
||||
"""
|
||||
|
|
@ -32,10 +32,9 @@ def _document_batches(split, resume_state_dict, tokenizer_batch_size):
|
|||
"""
|
||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info()
|
||||
|
||||
parquet_paths = list_parquet_files()
|
||||
assert len(parquet_paths) != 0, "No dataset parquet files found, did you run dataset.py?"
|
||||
parquet_paths = parquet_paths[:-1] if split == "train" else parquet_paths[-1:]
|
||||
|
||||
parquet_paths = get_parquet_paths(split)
|
||||
assert len(parquet_paths) != 0, "No dataset parquet files found, did you run dataset.py?"
|
||||
resume_pq_idx = resume_state_dict["pq_idx"] if resume_state_dict is not None else 0
|
||||
resume_rg_idx = resume_state_dict["rg_idx"] if resume_state_dict is not None else None
|
||||
resume_epoch = resume_state_dict.get("epoch", 1) if resume_state_dict is not None else 1
|
||||
|
|
|
|||
|
|
@ -23,6 +23,13 @@ from nanochat.common import get_base_dir
|
|||
BASE_URL = "https://huggingface.co/datasets/karpathy/fineweb-edu-100b-shuffle/resolve/main"
|
||||
MAX_SHARD = 1822 # the last datashard is shard_01822.parquet
|
||||
index_to_filename = lambda index: f"shard_{index:05d}.parquet" # format of the filenames
|
||||
|
||||
# Always use a fixed shard for val so that metrics don't depend on how many shards are downloaded
|
||||
# Keeping pinned to shard_01822.
|
||||
VAL_SHARD_INDEX = 1822
|
||||
assert 0 <= VAL_SHARD_INDEX <= MAX_SHARD, "VAL_SHARD_INDEX must be within [0, MAX_SHARD]"
|
||||
VAL_SHARD_FILENAME = index_to_filename(VAL_SHARD_INDEX)
|
||||
|
||||
base_dir = get_base_dir()
|
||||
DATA_DIR = os.path.join(base_dir, "base_data")
|
||||
os.makedirs(DATA_DIR, exist_ok=True)
|
||||
|
|
@ -30,25 +37,44 @@ os.makedirs(DATA_DIR, exist_ok=True)
|
|||
# -----------------------------------------------------------------------------
|
||||
# These functions are useful utilities to other modules, can/should be imported
|
||||
|
||||
def list_parquet_files(data_dir=None):
|
||||
""" Looks into a data dir and returns full paths to all parquet files. """
|
||||
def list_parquet_files(data_dir=None, exclude_filenames=()):
|
||||
"""Looks into a data dir and returns full paths to parquet files."""
|
||||
data_dir = DATA_DIR if data_dir is None else data_dir
|
||||
exclude = set(exclude_filenames)
|
||||
parquet_files = sorted([
|
||||
f for f in os.listdir(data_dir)
|
||||
if f.endswith('.parquet') and not f.endswith('.tmp')
|
||||
if f.endswith(".parquet") and not f.endswith(".tmp") and f not in exclude
|
||||
])
|
||||
parquet_paths = [os.path.join(data_dir, f) for f in parquet_files]
|
||||
return parquet_paths
|
||||
|
||||
def get_parquet_paths(split, data_dir=None):
|
||||
"""
|
||||
Returns the parquet paths for a split.
|
||||
|
||||
Validation is always a fixed shard so that metrics are stable across partial downloads.
|
||||
"""
|
||||
assert split in ["train", "val"], "split must be 'train' or 'val'"
|
||||
data_dir = DATA_DIR if data_dir is None else data_dir
|
||||
val_path = os.path.join(data_dir, VAL_SHARD_FILENAME)
|
||||
if split == "val":
|
||||
if not os.path.exists(val_path):
|
||||
raise FileNotFoundError(
|
||||
f"Validation shard {VAL_SHARD_FILENAME} not found in {data_dir}. "
|
||||
f"Run: python -m nanochat.dataset -n <N> (downloads the val shard too)."
|
||||
)
|
||||
return [val_path]
|
||||
else:
|
||||
#train split: list files while excluding val (don't add then remove).
|
||||
return list_parquet_files(data_dir, exclude_filenames=(VAL_SHARD_FILENAME,))
|
||||
|
||||
def parquets_iter_batched(split, start=0, step=1):
|
||||
"""
|
||||
Iterate through the dataset, in batches of underlying row_groups for efficiency.
|
||||
- split can be "train" or "val". the last parquet file will be val.
|
||||
- split can be "train" or "val". validation is always a fixed shard.
|
||||
- start/step are useful for skipping rows in DDP. e.g. start=rank, step=world_size
|
||||
"""
|
||||
assert split in ["train", "val"], "split must be 'train' or 'val'"
|
||||
parquet_paths = list_parquet_files()
|
||||
parquet_paths = parquet_paths[:-1] if split == "train" else parquet_paths[-1:]
|
||||
parquet_paths = get_parquet_paths(split)
|
||||
for filepath in parquet_paths:
|
||||
pf = pq.ParquetFile(filepath)
|
||||
for rg_idx in range(start, pf.num_row_groups, step):
|
||||
|
|
@ -111,13 +137,21 @@ def download_single_file(index):
|
|||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Download FineWeb-Edu 100BT dataset shards")
|
||||
parser.add_argument("-n", "--num-files", type=int, default=-1, help="Number of shards to download (default: -1), -1 = disable")
|
||||
parser.add_argument("-n", "--num-files", type=int, default=-1, help="Number of training shards to download (default: -1 = all). Validation shard is always included")
|
||||
parser.add_argument("-w", "--num-workers", type=int, default=4, help="Number of parallel download workers (default: 4)")
|
||||
args = parser.parse_args()
|
||||
|
||||
num = MAX_SHARD + 1 if args.num_files == -1 else min(args.num_files, MAX_SHARD + 1)
|
||||
ids_to_download = list(range(num))
|
||||
print(f"Downloading {len(ids_to_download)} shards using {args.num_workers} workers...")
|
||||
if VAL_SHARD_INDEX not in ids_to_download:
|
||||
ids_to_download.append(VAL_SHARD_INDEX)
|
||||
ids_to_download = sorted(set(ids_to_download))
|
||||
|
||||
if args.num_files != -1 and args.num_files <= MAX_SHARD:
|
||||
print(f"Downloading {len(ids_to_download)} shards ({num} train + 1 val) using {args.num_workers} workers...")
|
||||
else:
|
||||
print(f"Downloading {len(ids_to_download)} shards using {args.num_workers} workers...")
|
||||
|
||||
print(f"Target directory: {DATA_DIR}")
|
||||
print()
|
||||
with Pool(processes=args.num_workers) as pool:
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user