typing: add Path type hints to function signatures and returns

This commit is contained in:
Tsvika Shapira 2025-12-26 12:36:35 +02:00
parent 6d6651e2df
commit 2dd3380dbc
7 changed files with 19 additions and 18 deletions

View File

@ -20,7 +20,7 @@ def log0(message):
if int(os.environ.get('RANK', 0)) == 0:
logger.info(message)
def save_checkpoint(checkpoint_dir, step, model_data, optimizer_data, meta_data, rank=0):
def save_checkpoint(checkpoint_dir: Path, step, model_data, optimizer_data, meta_data, rank=0):
if rank == 0:
checkpoint_dir.mkdir(parents=True, exist_ok=True)
# Save the model state parameters
@ -39,7 +39,7 @@ def save_checkpoint(checkpoint_dir, step, model_data, optimizer_data, meta_data,
torch.save(optimizer_data, optimizer_path)
logger.info(f"Saved optimizer state to: {optimizer_path}")
def load_checkpoint(checkpoint_dir, step, device, load_optimizer=False, rank=0):
def load_checkpoint(checkpoint_dir: Path, step, device, load_optimizer=False, rank=0):
# Load the model state
model_path = checkpoint_dir / f"model_{step:06d}.pt"
model_data = torch.load(model_path, map_location=device)
@ -55,7 +55,7 @@ def load_checkpoint(checkpoint_dir, step, device, load_optimizer=False, rank=0):
return model_data, optimizer_data, meta_data
def build_model(checkpoint_dir, step, device, phase):
def build_model(checkpoint_dir: Path, step, device, phase):
"""
A bunch of repetitive code to build a model from a given checkpoint.
Returns:
@ -94,7 +94,7 @@ def build_model(checkpoint_dir, step, device, phase):
return model, tokenizer, meta_data
def find_largest_model(checkpoints_dir):
def find_largest_model(checkpoints_dir: Path):
# attempt to guess the model tag: take the biggest model available
model_tags = [f.name for f in checkpoints_dir.iterdir() if f.is_dir()]
if not model_tags:
@ -114,7 +114,7 @@ def find_largest_model(checkpoints_dir):
return model_tags[0]
def find_last_step(checkpoint_dir):
def find_last_step(checkpoint_dir: Path):
# Look into checkpoint_dir and find model_<step>.pt with the highest step
checkpoint_files = list(checkpoint_dir.glob("model_*.pt"))
if not checkpoint_files:
@ -125,7 +125,7 @@ def find_last_step(checkpoint_dir):
# -----------------------------------------------------------------------------
# convenience functions that take into account nanochat's directory structure
def load_model_from_dir(checkpoints_dir, device, phase, model_tag=None, step=None):
def load_model_from_dir(checkpoints_dir: Path, device, phase, model_tag=None, step=None):
if model_tag is None:
# guess the model tag by defaulting to the largest model
model_tag = find_largest_model(checkpoints_dir)

View File

@ -48,7 +48,7 @@ def setup_default_logging():
setup_default_logging()
logger = logging.getLogger(__name__)
def get_base_dir():
def get_base_dir() -> Path:
# co-locate nanochat intermediates with other cached data in ~/.cache (by default)
if os.environ.get("NANOCHAT_BASE_DIR"):
nanochat_dir = Path(os.environ.get("NANOCHAT_BASE_DIR"))
@ -59,7 +59,7 @@ def get_base_dir():
nanochat_dir.mkdir(parents=True, exist_ok=True)
return nanochat_dir
def download_file_with_lock(url, filename, postprocess_fn=None):
def download_file_with_lock(url, filename, postprocess_fn=None) -> Path:
"""
Downloads a file from a URL to a local path in the base directory.
Uses a lock file to prevent concurrent downloads among multiple ranks.

View File

@ -29,7 +29,7 @@ DATA_DIR.mkdir(parents=True, exist_ok=True)
# -----------------------------------------------------------------------------
# These functions are useful utilities to other modules, can/should be imported
def list_parquet_files(data_dir = DATA_DIR):
def list_parquet_files(data_dir: Path = DATA_DIR) -> list[Path]:
""" Looks into a data dir and returns full paths to all parquet files. """
parquet_paths = sorted(data_dir.glob('*.parquet'))
return parquet_paths

View File

@ -234,11 +234,11 @@ def extract_timestamp(content, prefix):
class Report:
"""Maintains a bunch of logs, generates a final markdown report."""
def __init__(self, report_dir):
def __init__(self, report_dir: Path):
report_dir.mkdir(parents=True, exist_ok=True)
self.report_dir = report_dir
def log(self, section, data):
def log(self, section, data) -> Path:
"""Log a section of data to the report."""
slug = slugify(section)
file_name = f"{slug}.md"
@ -266,7 +266,7 @@ class Report:
f.write("\n")
return file_path
def generate(self):
def generate(self) -> Path:
"""Generate the final report."""
report_dir = self.report_dir
report_file = report_dir / "report.md"

View File

@ -48,7 +48,7 @@ class HuggingFaceTokenizer:
return cls(tokenizer)
@classmethod
def from_directory(cls, tokenizer_dir):
def from_directory(cls, tokenizer_dir: Path):
# init from a local directory on disk (e.g. "out/tokenizer")
tokenizer_path = tokenizer_dir / "tokenizer.json"
tokenizer = HFTokenizer.from_file(str(tokenizer_path))
@ -138,7 +138,7 @@ class HuggingFaceTokenizer:
def decode(self, ids):
return self.tokenizer.decode(ids, skip_special_tokens=False)
def save(self, tokenizer_dir):
def save(self, tokenizer_dir: Path):
# save the tokenizer to disk
tokenizer_dir.mkdir(parents=True, exist_ok=True)
tokenizer_path = tokenizer_dir / "tokenizer.json"
@ -181,7 +181,7 @@ class RustBPETokenizer:
return cls(enc, "<|bos|>")
@classmethod
def from_directory(cls, tokenizer_dir):
def from_directory(cls, tokenizer_dir: Path):
pickle_path = tokenizer_dir / "tokenizer.pkl"
with pickle_path.open("rb") as f:
enc = pickle.load(f)
@ -246,7 +246,7 @@ class RustBPETokenizer:
def decode(self, ids):
return self.enc.decode(ids)
def save(self, tokenizer_dir):
def save(self, tokenizer_dir: Path):
# save the encoding object to disk
tokenizer_dir.mkdir(parents=True, exist_ok=True)
pickle_path = tokenizer_dir / "tokenizer.pkl"

View File

@ -33,7 +33,7 @@ from nanochat.core_eval import evaluate_task
# ~162MB of data needed to evaluate the CORE metric
EVAL_BUNDLE_URL = "https://karpathy-public.s3.us-west-2.amazonaws.com/eval_bundle.zip"
def place_eval_bundle(file_path):
def place_eval_bundle(file_path: Path):
# here file_path is the path to the eval_bundle.zip file
# we need to unzip it and place it in the base directory
base_dir = get_base_dir()

View File

@ -20,6 +20,7 @@ python -m pytest tests/test_rustbpe.py -v -s
import regex as re
from collections import Counter, defaultdict
from pathlib import Path
import time
import rustbpe
import tiktoken
@ -426,7 +427,7 @@ class HuggingFaceTokenizer:
# Test all of the above
@pytest.fixture(scope="module")
def enwik8_path():
def enwik8_path() -> Path:
"""Fixture to download and cache enwik8 dataset."""
import zipfile
from nanochat.common import get_base_dir