mirror of
https://github.com/karpathy/nanochat.git
synced 2025-12-06 04:12:13 +00:00
Compare commits
3 Commits
77658f75f2
...
35049a63b6
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
35049a63b6 | ||
|
|
bc1fca39f3 | ||
|
|
f5d35391db |
|
|
@ -11,6 +11,7 @@ import os
|
|||
import argparse
|
||||
import time
|
||||
import requests
|
||||
import pyarrow.fs as fs
|
||||
import pyarrow.parquet as pq
|
||||
from multiprocessing import Pool
|
||||
|
||||
|
|
@ -20,7 +21,7 @@ from nanochat.common import get_base_dir
|
|||
# The specifics of the current pretraining dataset
|
||||
|
||||
# The URL on the internet where the data is hosted and downloaded from on demand
|
||||
BASE_URL = "https://huggingface.co/datasets/karpathy/fineweb-edu-100b-shuffle/resolve/main"
|
||||
BASE_URI = "hf://datasets/karpathy/fineweb-edu-100b-shuffle"
|
||||
MAX_SHARD = 1822 # the last datashard is shard_01822.parquet
|
||||
index_to_filename = lambda index: f"shard_{index:05d}.parquet" # format of the filenames
|
||||
base_dir = get_base_dir()
|
||||
|
|
@ -68,45 +69,17 @@ def download_single_file(index):
|
|||
return True
|
||||
|
||||
# Construct the remote URL for this file
|
||||
url = f"{BASE_URL}/{filename}"
|
||||
uri = f"{BASE_URI}/{filename}"
|
||||
print(f"Downloading {filename}...")
|
||||
|
||||
# Download with retries
|
||||
max_attempts = 5
|
||||
for attempt in range(1, max_attempts + 1):
|
||||
try:
|
||||
response = requests.get(url, stream=True, timeout=30)
|
||||
response.raise_for_status()
|
||||
# Write to temporary file first
|
||||
temp_path = filepath + f".tmp"
|
||||
with open(temp_path, 'wb') as f:
|
||||
for chunk in response.iter_content(chunk_size=1024 * 1024): # 1MB chunks
|
||||
if chunk:
|
||||
f.write(chunk)
|
||||
# Move temp file to final location
|
||||
os.rename(temp_path, filepath)
|
||||
print(f"Successfully downloaded {filename}")
|
||||
return True
|
||||
|
||||
except (requests.RequestException, IOError) as e:
|
||||
print(f"Attempt {attempt}/{max_attempts} failed for {filename}: {e}")
|
||||
# Clean up any partial files
|
||||
for path in [filepath + f".tmp", filepath]:
|
||||
if os.path.exists(path):
|
||||
try:
|
||||
os.remove(path)
|
||||
except:
|
||||
pass
|
||||
# Try a few times with exponential backoff: 2^attempt seconds
|
||||
if attempt < max_attempts:
|
||||
wait_time = 2 ** attempt
|
||||
print(f"Waiting {wait_time} seconds before retry...")
|
||||
time.sleep(wait_time)
|
||||
else:
|
||||
print(f"Failed to download {filename} after {max_attempts} attempts")
|
||||
return False
|
||||
|
||||
return False
|
||||
try:
|
||||
# pyarrow.fs uses huggingface_hub with builtin exponential backoff
|
||||
fs.copy_files(uri, filepath)
|
||||
except (requests.RequestException, IOError) as e:
|
||||
print(f"Failed to download {filename}: {e}")
|
||||
return False
|
||||
else:
|
||||
print(f"Successfully downloaded {filename}")
|
||||
return True
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ Notable features:
|
|||
- norm after token embedding
|
||||
- no learnable params in rmsnorm
|
||||
- no bias in linear layers
|
||||
- Multi-Query Attention (MQA) support for more efficient inference
|
||||
- Group-Query Attention (GQA) support for more efficient inference
|
||||
"""
|
||||
|
||||
import math
|
||||
|
|
@ -29,7 +29,7 @@ class GPTConfig:
|
|||
vocab_size: int = 50304
|
||||
n_layer: int = 12
|
||||
n_head: int = 6 # number of query heads
|
||||
n_kv_head: int = 6 # number of key/value heads (MQA)
|
||||
n_kv_head: int = 6 # number of key/value heads (GQA)
|
||||
n_embd: int = 768
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ dependencies = [
|
|||
"fastapi>=0.117.1",
|
||||
"files-to-prompt>=0.6",
|
||||
"psutil>=7.1.0",
|
||||
"pyarrow>=21.0.0",
|
||||
"regex>=2025.9.1",
|
||||
"setuptools>=80.9.0",
|
||||
"tiktoken>=0.11.0",
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user