diff --git a/nanochat/dataset.py b/nanochat/dataset.py index 602daed..1282cd0 100644 --- a/nanochat/dataset.py +++ b/nanochat/dataset.py @@ -11,6 +11,7 @@ import os import argparse import time import requests +import pyarrow.fs as fs import pyarrow.parquet as pq from multiprocessing import Pool @@ -20,7 +21,7 @@ from nanochat.common import get_base_dir # The specifics of the current pretraining dataset # The URL on the internet where the data is hosted and downloaded from on demand -BASE_URL = "https://huggingface.co/datasets/karpathy/fineweb-edu-100b-shuffle/resolve/main" +BASE_URI = "hf://datasets/karpathy/fineweb-edu-100b-shuffle" MAX_SHARD = 1822 # the last datashard is shard_01822.parquet index_to_filename = lambda index: f"shard_{index:05d}.parquet" # format of the filenames base_dir = get_base_dir() @@ -68,45 +69,17 @@ def download_single_file(index): return True # Construct the remote URL for this file - url = f"{BASE_URL}/{filename}" + uri = f"{BASE_URI}/{filename}" print(f"Downloading {filename}...") - - # Download with retries - max_attempts = 5 - for attempt in range(1, max_attempts + 1): - try: - response = requests.get(url, stream=True, timeout=30) - response.raise_for_status() - # Write to temporary file first - temp_path = filepath + f".tmp" - with open(temp_path, 'wb') as f: - for chunk in response.iter_content(chunk_size=1024 * 1024): # 1MB chunks - if chunk: - f.write(chunk) - # Move temp file to final location - os.rename(temp_path, filepath) - print(f"Successfully downloaded {filename}") - return True - - except (requests.RequestException, IOError) as e: - print(f"Attempt {attempt}/{max_attempts} failed for {filename}: {e}") - # Clean up any partial files - for path in [filepath + f".tmp", filepath]: - if os.path.exists(path): - try: - os.remove(path) - except: - pass - # Try a few times with exponential backoff: 2^attempt seconds - if attempt < max_attempts: - wait_time = 2 ** attempt - print(f"Waiting {wait_time} seconds before retry...") - time.sleep(wait_time) - else: - print(f"Failed to download {filename} after {max_attempts} attempts") - return False - - return False + try: + # pyarrow.fs uses huggingface_hub with builtin exponential backoff + fs.copy_files(uri, filepath) + except (requests.RequestException, IOError) as e: + print(f"Failed to download {filename}: {e}") + return False + else: + print(f"Successfully downloaded {filename}") + return True if __name__ == "__main__": diff --git a/pyproject.toml b/pyproject.toml index 3d03c4b..862d949 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -9,6 +9,7 @@ dependencies = [ "fastapi>=0.117.1", "files-to-prompt>=0.6", "psutil>=7.1.0", + "pyarrow>=21.0.0", "regex>=2025.9.1", "setuptools>=80.9.0", "tiktoken>=0.11.0",