Compare commits

...

10 Commits

Author SHA1 Message Date
Krisztián Szűcs
77658f75f2
Merge f5d35391db into f66a780f68 2025-11-14 18:22:12 +01:00
Andrej
f66a780f68
Fix torch.dtype mismatching when running engine inline test. 2025-11-14 07:28:29 -08:00
Andrej
4763ce612a
Small fixes to typos 2025-11-14 07:25:59 -08:00
Sofie Van Landeghem
c6f5bd67db
revert change of base to sft for quick inline test 2025-11-14 12:20:03 +01:00
svlandeg
a2fb3c83a6 fix typos 2025-11-14 11:20:25 +01:00
svlandeg
e5efb4b471 add test_engine.py to file structure 2025-11-14 11:13:42 +01:00
howardgao@outlook.com
b399e43168 fix engine test bug 2025-11-06 08:56:45 +08:00
svlandeg
52e85aaf80 Merge branch 'master' into fix/typo 2025-11-02 13:41:13 +01:00
svlandeg
70319851fc fix typo 2025-10-29 19:48:34 +01:00
Krisztian Szucs
f5d35391db use pyarrow.fs to download parquet files from the huggingface hub 2025-10-14 13:28:36 +02:00
7 changed files with 34 additions and 53 deletions

View File

@ -184,6 +184,7 @@ python -m pytest tests/test_rustbpe.py -v -s
│ ├── smoltalk.py # Conglomerate dataset of SmolTalk from HF │ ├── smoltalk.py # Conglomerate dataset of SmolTalk from HF
│ └── spellingbee.py # Task teaching model to spell/count letters │ └── spellingbee.py # Task teaching model to spell/count letters
├── tests ├── tests
│ └── test_engine.py
│ └── test_rustbpe.py │ └── test_rustbpe.py
└── uv.lock └── uv.lock
``` ```

View File

@ -11,6 +11,7 @@ import os
import argparse import argparse
import time import time
import requests import requests
import pyarrow.fs as fs
import pyarrow.parquet as pq import pyarrow.parquet as pq
from multiprocessing import Pool from multiprocessing import Pool
@ -20,7 +21,7 @@ from nanochat.common import get_base_dir
# The specifics of the current pretraining dataset # The specifics of the current pretraining dataset
# The URL on the internet where the data is hosted and downloaded from on demand # The URL on the internet where the data is hosted and downloaded from on demand
BASE_URL = "https://huggingface.co/datasets/karpathy/fineweb-edu-100b-shuffle/resolve/main" BASE_URI = "hf://datasets/karpathy/fineweb-edu-100b-shuffle"
MAX_SHARD = 1822 # the last datashard is shard_01822.parquet MAX_SHARD = 1822 # the last datashard is shard_01822.parquet
index_to_filename = lambda index: f"shard_{index:05d}.parquet" # format of the filenames index_to_filename = lambda index: f"shard_{index:05d}.parquet" # format of the filenames
base_dir = get_base_dir() base_dir = get_base_dir()
@ -68,45 +69,17 @@ def download_single_file(index):
return True return True
# Construct the remote URL for this file # Construct the remote URL for this file
url = f"{BASE_URL}/{filename}" uri = f"{BASE_URI}/{filename}"
print(f"Downloading {filename}...") print(f"Downloading {filename}...")
try:
# Download with retries # pyarrow.fs uses huggingface_hub with builtin exponential backoff
max_attempts = 5 fs.copy_files(uri, filepath)
for attempt in range(1, max_attempts + 1): except (requests.RequestException, IOError) as e:
try: print(f"Failed to download {filename}: {e}")
response = requests.get(url, stream=True, timeout=30) return False
response.raise_for_status() else:
# Write to temporary file first print(f"Successfully downloaded {filename}")
temp_path = filepath + f".tmp" return True
with open(temp_path, 'wb') as f:
for chunk in response.iter_content(chunk_size=1024 * 1024): # 1MB chunks
if chunk:
f.write(chunk)
# Move temp file to final location
os.rename(temp_path, filepath)
print(f"Successfully downloaded {filename}")
return True
except (requests.RequestException, IOError) as e:
print(f"Attempt {attempt}/{max_attempts} failed for {filename}: {e}")
# Clean up any partial files
for path in [filepath + f".tmp", filepath]:
if os.path.exists(path):
try:
os.remove(path)
except:
pass
# Try a few times with exponential backoff: 2^attempt seconds
if attempt < max_attempts:
wait_time = 2 ** attempt
print(f"Waiting {wait_time} seconds before retry...")
time.sleep(wait_time)
else:
print(f"Failed to download {filename} after {max_attempts} attempts")
return False
return False
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -17,8 +17,9 @@ import signal
import warnings import warnings
from contextlib import contextmanager from contextlib import contextmanager
from collections import deque from collections import deque
from nanochat.common import compute_init from nanochat.common import compute_init, autodetect_device_type
from nanochat.checkpoint_manager import load_model from nanochat.checkpoint_manager import load_model
from contextlib import nullcontext
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Calculator tool helpers # Calculator tool helpers
@ -328,6 +329,9 @@ if __name__ == "__main__":
import time import time
# init compute # init compute
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init() ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init()
device_type = autodetect_device_type()
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type == "cuda" else nullcontext()
# load the model and tokenizer # load the model and tokenizer
model, tokenizer, meta = load_model("base", device, phase="eval") model, tokenizer, meta = load_model("base", device, phase="eval")
bos_token_id = tokenizer.get_bos_token_id() bos_token_id = tokenizer.get_bos_token_id()
@ -340,10 +344,11 @@ if __name__ == "__main__":
torch.cuda.synchronize() torch.cuda.synchronize()
t0 = time.time() t0 = time.time()
stream = model.generate(prompt_tokens, **kwargs) stream = model.generate(prompt_tokens, **kwargs)
for token in stream: with autocast_ctx:
generated_tokens.append(token) for token in stream:
chunk = tokenizer.decode([token]) generated_tokens.append(token)
print(chunk, end="", flush=True) chunk = tokenizer.decode([token])
print(chunk, end="", flush=True)
print() print()
torch.cuda.synchronize() torch.cuda.synchronize()
t1 = time.time() t1 = time.time()
@ -355,11 +360,12 @@ if __name__ == "__main__":
stream = engine.generate(prompt_tokens, num_samples=1, **kwargs) # note: runs in fp32 stream = engine.generate(prompt_tokens, num_samples=1, **kwargs) # note: runs in fp32
torch.cuda.synchronize() torch.cuda.synchronize()
t0 = time.time() t0 = time.time()
for token_column, token_masks in stream: with autocast_ctx:
token = token_column[0] # only print out the first row for token_column, token_masks in stream:
generated_tokens.append(token) token = token_column[0] # only print out the first row
chunk = tokenizer.decode([token]) generated_tokens.append(token)
print(chunk, end="", flush=True) chunk = tokenizer.decode([token])
print(chunk, end="", flush=True)
print() print()
torch.cuda.synchronize() torch.cuda.synchronize()
t1 = time.time() t1 = time.time()

View File

@ -9,9 +9,9 @@ import torch.distributed as dist
def evaluate_bpb(model, batches, steps, token_bytes): def evaluate_bpb(model, batches, steps, token_bytes):
""" """
Instead of the naive 'mean loss', this function returns the bits per byte (bpb), Instead of the naive 'mean loss', this function returns the bits per byte (bpb),
which is a tokenization vocab size-indepedent metric, meaning you are still comparing which is a tokenization vocab size-independent metric, meaning you are still comparing
apples:apples if you change the vocab size. The way this works is that instead of just apples:apples if you change the vocab size. The way this works is that instead of just
calculating the average loss as usual, you calculate the sum loss, and indepependently calculating the average loss as usual, you calculate the sum loss, and independently
also the sum bytes (of all the target tokens), and divide. This normalizes the loss by also the sum bytes (of all the target tokens), and divide. This normalizes the loss by
the number of bytes that the target tokens represent. the number of bytes that the target tokens represent.

View File

@ -9,6 +9,7 @@ dependencies = [
"fastapi>=0.117.1", "fastapi>=0.117.1",
"files-to-prompt>=0.6", "files-to-prompt>=0.6",
"psutil>=7.1.0", "psutil>=7.1.0",
"pyarrow>=21.0.0",
"regex>=2025.9.1", "regex>=2025.9.1",
"setuptools>=80.9.0", "setuptools>=80.9.0",
"tiktoken>=0.11.0", "tiktoken>=0.11.0",

View File

@ -1,6 +1,6 @@
""" """
Evaluate the Chat model. Evaluate the Chat model.
All the generic code lives here, and all the evlauation-specific All the generic code lives here, and all the evaluation-specific
code lives in nanochat directory and is imported from here. code lives in nanochat directory and is imported from here.
Example runs: Example runs:

View File

@ -192,7 +192,7 @@ for step in range(num_iterations):
}) })
model.train() model.train()
# evlauate accuracy of the multiple choice tasks (which are quick to run) # evaluate accuracy of the multiple choice tasks (which are quick to run)
if last_step or (step > 0 and step % eval_metrics_every == 0): if last_step or (step > 0 and step % eval_metrics_every == 0):
model.eval() model.eval()
metrics = {} metrics = {}