From 2dd3380dbc1340b8093db2f122775d423ef9cf1a Mon Sep 17 00:00:00 2001 From: Tsvika Shapira Date: Fri, 26 Dec 2025 12:36:35 +0200 Subject: [PATCH] typing: add Path type hints to function signatures and returns --- nanochat/checkpoint_manager.py | 12 ++++++------ nanochat/common.py | 4 ++-- nanochat/dataset.py | 2 +- nanochat/report.py | 6 +++--- nanochat/tokenizer.py | 8 ++++---- scripts/base_eval.py | 2 +- tests/test_rustbpe.py | 3 ++- 7 files changed, 19 insertions(+), 18 deletions(-) diff --git a/nanochat/checkpoint_manager.py b/nanochat/checkpoint_manager.py index 83256fe..2586c63 100644 --- a/nanochat/checkpoint_manager.py +++ b/nanochat/checkpoint_manager.py @@ -20,7 +20,7 @@ def log0(message): if int(os.environ.get('RANK', 0)) == 0: logger.info(message) -def save_checkpoint(checkpoint_dir, step, model_data, optimizer_data, meta_data, rank=0): +def save_checkpoint(checkpoint_dir: Path, step, model_data, optimizer_data, meta_data, rank=0): if rank == 0: checkpoint_dir.mkdir(parents=True, exist_ok=True) # Save the model state parameters @@ -39,7 +39,7 @@ def save_checkpoint(checkpoint_dir, step, model_data, optimizer_data, meta_data, torch.save(optimizer_data, optimizer_path) logger.info(f"Saved optimizer state to: {optimizer_path}") -def load_checkpoint(checkpoint_dir, step, device, load_optimizer=False, rank=0): +def load_checkpoint(checkpoint_dir: Path, step, device, load_optimizer=False, rank=0): # Load the model state model_path = checkpoint_dir / f"model_{step:06d}.pt" model_data = torch.load(model_path, map_location=device) @@ -55,7 +55,7 @@ def load_checkpoint(checkpoint_dir, step, device, load_optimizer=False, rank=0): return model_data, optimizer_data, meta_data -def build_model(checkpoint_dir, step, device, phase): +def build_model(checkpoint_dir: Path, step, device, phase): """ A bunch of repetitive code to build a model from a given checkpoint. Returns: @@ -94,7 +94,7 @@ def build_model(checkpoint_dir, step, device, phase): return model, tokenizer, meta_data -def find_largest_model(checkpoints_dir): +def find_largest_model(checkpoints_dir: Path): # attempt to guess the model tag: take the biggest model available model_tags = [f.name for f in checkpoints_dir.iterdir() if f.is_dir()] if not model_tags: @@ -114,7 +114,7 @@ def find_largest_model(checkpoints_dir): return model_tags[0] -def find_last_step(checkpoint_dir): +def find_last_step(checkpoint_dir: Path): # Look into checkpoint_dir and find model_.pt with the highest step checkpoint_files = list(checkpoint_dir.glob("model_*.pt")) if not checkpoint_files: @@ -125,7 +125,7 @@ def find_last_step(checkpoint_dir): # ----------------------------------------------------------------------------- # convenience functions that take into account nanochat's directory structure -def load_model_from_dir(checkpoints_dir, device, phase, model_tag=None, step=None): +def load_model_from_dir(checkpoints_dir: Path, device, phase, model_tag=None, step=None): if model_tag is None: # guess the model tag by defaulting to the largest model model_tag = find_largest_model(checkpoints_dir) diff --git a/nanochat/common.py b/nanochat/common.py index d49be7c..7ae00b9 100644 --- a/nanochat/common.py +++ b/nanochat/common.py @@ -48,7 +48,7 @@ def setup_default_logging(): setup_default_logging() logger = logging.getLogger(__name__) -def get_base_dir(): +def get_base_dir() -> Path: # co-locate nanochat intermediates with other cached data in ~/.cache (by default) if os.environ.get("NANOCHAT_BASE_DIR"): nanochat_dir = Path(os.environ.get("NANOCHAT_BASE_DIR")) @@ -59,7 +59,7 @@ def get_base_dir(): nanochat_dir.mkdir(parents=True, exist_ok=True) return nanochat_dir -def download_file_with_lock(url, filename, postprocess_fn=None): +def download_file_with_lock(url, filename, postprocess_fn=None) -> Path: """ Downloads a file from a URL to a local path in the base directory. Uses a lock file to prevent concurrent downloads among multiple ranks. diff --git a/nanochat/dataset.py b/nanochat/dataset.py index 4c73aa5..d582dc4 100644 --- a/nanochat/dataset.py +++ b/nanochat/dataset.py @@ -29,7 +29,7 @@ DATA_DIR.mkdir(parents=True, exist_ok=True) # ----------------------------------------------------------------------------- # These functions are useful utilities to other modules, can/should be imported -def list_parquet_files(data_dir = DATA_DIR): +def list_parquet_files(data_dir: Path = DATA_DIR) -> list[Path]: """ Looks into a data dir and returns full paths to all parquet files. """ parquet_paths = sorted(data_dir.glob('*.parquet')) return parquet_paths diff --git a/nanochat/report.py b/nanochat/report.py index ef19704..139d0ad 100644 --- a/nanochat/report.py +++ b/nanochat/report.py @@ -234,11 +234,11 @@ def extract_timestamp(content, prefix): class Report: """Maintains a bunch of logs, generates a final markdown report.""" - def __init__(self, report_dir): + def __init__(self, report_dir: Path): report_dir.mkdir(parents=True, exist_ok=True) self.report_dir = report_dir - def log(self, section, data): + def log(self, section, data) -> Path: """Log a section of data to the report.""" slug = slugify(section) file_name = f"{slug}.md" @@ -266,7 +266,7 @@ class Report: f.write("\n") return file_path - def generate(self): + def generate(self) -> Path: """Generate the final report.""" report_dir = self.report_dir report_file = report_dir / "report.md" diff --git a/nanochat/tokenizer.py b/nanochat/tokenizer.py index 80e1136..6f17bd7 100644 --- a/nanochat/tokenizer.py +++ b/nanochat/tokenizer.py @@ -48,7 +48,7 @@ class HuggingFaceTokenizer: return cls(tokenizer) @classmethod - def from_directory(cls, tokenizer_dir): + def from_directory(cls, tokenizer_dir: Path): # init from a local directory on disk (e.g. "out/tokenizer") tokenizer_path = tokenizer_dir / "tokenizer.json" tokenizer = HFTokenizer.from_file(str(tokenizer_path)) @@ -138,7 +138,7 @@ class HuggingFaceTokenizer: def decode(self, ids): return self.tokenizer.decode(ids, skip_special_tokens=False) - def save(self, tokenizer_dir): + def save(self, tokenizer_dir: Path): # save the tokenizer to disk tokenizer_dir.mkdir(parents=True, exist_ok=True) tokenizer_path = tokenizer_dir / "tokenizer.json" @@ -181,7 +181,7 @@ class RustBPETokenizer: return cls(enc, "<|bos|>") @classmethod - def from_directory(cls, tokenizer_dir): + def from_directory(cls, tokenizer_dir: Path): pickle_path = tokenizer_dir / "tokenizer.pkl" with pickle_path.open("rb") as f: enc = pickle.load(f) @@ -246,7 +246,7 @@ class RustBPETokenizer: def decode(self, ids): return self.enc.decode(ids) - def save(self, tokenizer_dir): + def save(self, tokenizer_dir: Path): # save the encoding object to disk tokenizer_dir.mkdir(parents=True, exist_ok=True) pickle_path = tokenizer_dir / "tokenizer.pkl" diff --git a/scripts/base_eval.py b/scripts/base_eval.py index 2790808..1dd5981 100644 --- a/scripts/base_eval.py +++ b/scripts/base_eval.py @@ -33,7 +33,7 @@ from nanochat.core_eval import evaluate_task # ~162MB of data needed to evaluate the CORE metric EVAL_BUNDLE_URL = "https://karpathy-public.s3.us-west-2.amazonaws.com/eval_bundle.zip" -def place_eval_bundle(file_path): +def place_eval_bundle(file_path: Path): # here file_path is the path to the eval_bundle.zip file # we need to unzip it and place it in the base directory base_dir = get_base_dir() diff --git a/tests/test_rustbpe.py b/tests/test_rustbpe.py index 5d84a02..fca17a6 100644 --- a/tests/test_rustbpe.py +++ b/tests/test_rustbpe.py @@ -20,6 +20,7 @@ python -m pytest tests/test_rustbpe.py -v -s import regex as re from collections import Counter, defaultdict +from pathlib import Path import time import rustbpe import tiktoken @@ -426,7 +427,7 @@ class HuggingFaceTokenizer: # Test all of the above @pytest.fixture(scope="module") -def enwik8_path(): +def enwik8_path() -> Path: """Fixture to download and cache enwik8 dataset.""" import zipfile from nanochat.common import get_base_dir