mirror of
https://github.com/karpathy/nanochat.git
synced 2026-02-10 04:29:50 +00:00
typing: add Path type hints to function signatures and returns
This commit is contained in:
parent
6d6651e2df
commit
2dd3380dbc
|
|
@ -20,7 +20,7 @@ def log0(message):
|
|||
if int(os.environ.get('RANK', 0)) == 0:
|
||||
logger.info(message)
|
||||
|
||||
def save_checkpoint(checkpoint_dir, step, model_data, optimizer_data, meta_data, rank=0):
|
||||
def save_checkpoint(checkpoint_dir: Path, step, model_data, optimizer_data, meta_data, rank=0):
|
||||
if rank == 0:
|
||||
checkpoint_dir.mkdir(parents=True, exist_ok=True)
|
||||
# Save the model state parameters
|
||||
|
|
@ -39,7 +39,7 @@ def save_checkpoint(checkpoint_dir, step, model_data, optimizer_data, meta_data,
|
|||
torch.save(optimizer_data, optimizer_path)
|
||||
logger.info(f"Saved optimizer state to: {optimizer_path}")
|
||||
|
||||
def load_checkpoint(checkpoint_dir, step, device, load_optimizer=False, rank=0):
|
||||
def load_checkpoint(checkpoint_dir: Path, step, device, load_optimizer=False, rank=0):
|
||||
# Load the model state
|
||||
model_path = checkpoint_dir / f"model_{step:06d}.pt"
|
||||
model_data = torch.load(model_path, map_location=device)
|
||||
|
|
@ -55,7 +55,7 @@ def load_checkpoint(checkpoint_dir, step, device, load_optimizer=False, rank=0):
|
|||
return model_data, optimizer_data, meta_data
|
||||
|
||||
|
||||
def build_model(checkpoint_dir, step, device, phase):
|
||||
def build_model(checkpoint_dir: Path, step, device, phase):
|
||||
"""
|
||||
A bunch of repetitive code to build a model from a given checkpoint.
|
||||
Returns:
|
||||
|
|
@ -94,7 +94,7 @@ def build_model(checkpoint_dir, step, device, phase):
|
|||
return model, tokenizer, meta_data
|
||||
|
||||
|
||||
def find_largest_model(checkpoints_dir):
|
||||
def find_largest_model(checkpoints_dir: Path):
|
||||
# attempt to guess the model tag: take the biggest model available
|
||||
model_tags = [f.name for f in checkpoints_dir.iterdir() if f.is_dir()]
|
||||
if not model_tags:
|
||||
|
|
@ -114,7 +114,7 @@ def find_largest_model(checkpoints_dir):
|
|||
return model_tags[0]
|
||||
|
||||
|
||||
def find_last_step(checkpoint_dir):
|
||||
def find_last_step(checkpoint_dir: Path):
|
||||
# Look into checkpoint_dir and find model_<step>.pt with the highest step
|
||||
checkpoint_files = list(checkpoint_dir.glob("model_*.pt"))
|
||||
if not checkpoint_files:
|
||||
|
|
@ -125,7 +125,7 @@ def find_last_step(checkpoint_dir):
|
|||
# -----------------------------------------------------------------------------
|
||||
# convenience functions that take into account nanochat's directory structure
|
||||
|
||||
def load_model_from_dir(checkpoints_dir, device, phase, model_tag=None, step=None):
|
||||
def load_model_from_dir(checkpoints_dir: Path, device, phase, model_tag=None, step=None):
|
||||
if model_tag is None:
|
||||
# guess the model tag by defaulting to the largest model
|
||||
model_tag = find_largest_model(checkpoints_dir)
|
||||
|
|
|
|||
|
|
@ -48,7 +48,7 @@ def setup_default_logging():
|
|||
setup_default_logging()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def get_base_dir():
|
||||
def get_base_dir() -> Path:
|
||||
# co-locate nanochat intermediates with other cached data in ~/.cache (by default)
|
||||
if os.environ.get("NANOCHAT_BASE_DIR"):
|
||||
nanochat_dir = Path(os.environ.get("NANOCHAT_BASE_DIR"))
|
||||
|
|
@ -59,7 +59,7 @@ def get_base_dir():
|
|||
nanochat_dir.mkdir(parents=True, exist_ok=True)
|
||||
return nanochat_dir
|
||||
|
||||
def download_file_with_lock(url, filename, postprocess_fn=None):
|
||||
def download_file_with_lock(url, filename, postprocess_fn=None) -> Path:
|
||||
"""
|
||||
Downloads a file from a URL to a local path in the base directory.
|
||||
Uses a lock file to prevent concurrent downloads among multiple ranks.
|
||||
|
|
|
|||
|
|
@ -29,7 +29,7 @@ DATA_DIR.mkdir(parents=True, exist_ok=True)
|
|||
# -----------------------------------------------------------------------------
|
||||
# These functions are useful utilities to other modules, can/should be imported
|
||||
|
||||
def list_parquet_files(data_dir = DATA_DIR):
|
||||
def list_parquet_files(data_dir: Path = DATA_DIR) -> list[Path]:
|
||||
""" Looks into a data dir and returns full paths to all parquet files. """
|
||||
parquet_paths = sorted(data_dir.glob('*.parquet'))
|
||||
return parquet_paths
|
||||
|
|
|
|||
|
|
@ -234,11 +234,11 @@ def extract_timestamp(content, prefix):
|
|||
class Report:
|
||||
"""Maintains a bunch of logs, generates a final markdown report."""
|
||||
|
||||
def __init__(self, report_dir):
|
||||
def __init__(self, report_dir: Path):
|
||||
report_dir.mkdir(parents=True, exist_ok=True)
|
||||
self.report_dir = report_dir
|
||||
|
||||
def log(self, section, data):
|
||||
def log(self, section, data) -> Path:
|
||||
"""Log a section of data to the report."""
|
||||
slug = slugify(section)
|
||||
file_name = f"{slug}.md"
|
||||
|
|
@ -266,7 +266,7 @@ class Report:
|
|||
f.write("\n")
|
||||
return file_path
|
||||
|
||||
def generate(self):
|
||||
def generate(self) -> Path:
|
||||
"""Generate the final report."""
|
||||
report_dir = self.report_dir
|
||||
report_file = report_dir / "report.md"
|
||||
|
|
|
|||
|
|
@ -48,7 +48,7 @@ class HuggingFaceTokenizer:
|
|||
return cls(tokenizer)
|
||||
|
||||
@classmethod
|
||||
def from_directory(cls, tokenizer_dir):
|
||||
def from_directory(cls, tokenizer_dir: Path):
|
||||
# init from a local directory on disk (e.g. "out/tokenizer")
|
||||
tokenizer_path = tokenizer_dir / "tokenizer.json"
|
||||
tokenizer = HFTokenizer.from_file(str(tokenizer_path))
|
||||
|
|
@ -138,7 +138,7 @@ class HuggingFaceTokenizer:
|
|||
def decode(self, ids):
|
||||
return self.tokenizer.decode(ids, skip_special_tokens=False)
|
||||
|
||||
def save(self, tokenizer_dir):
|
||||
def save(self, tokenizer_dir: Path):
|
||||
# save the tokenizer to disk
|
||||
tokenizer_dir.mkdir(parents=True, exist_ok=True)
|
||||
tokenizer_path = tokenizer_dir / "tokenizer.json"
|
||||
|
|
@ -181,7 +181,7 @@ class RustBPETokenizer:
|
|||
return cls(enc, "<|bos|>")
|
||||
|
||||
@classmethod
|
||||
def from_directory(cls, tokenizer_dir):
|
||||
def from_directory(cls, tokenizer_dir: Path):
|
||||
pickle_path = tokenizer_dir / "tokenizer.pkl"
|
||||
with pickle_path.open("rb") as f:
|
||||
enc = pickle.load(f)
|
||||
|
|
@ -246,7 +246,7 @@ class RustBPETokenizer:
|
|||
def decode(self, ids):
|
||||
return self.enc.decode(ids)
|
||||
|
||||
def save(self, tokenizer_dir):
|
||||
def save(self, tokenizer_dir: Path):
|
||||
# save the encoding object to disk
|
||||
tokenizer_dir.mkdir(parents=True, exist_ok=True)
|
||||
pickle_path = tokenizer_dir / "tokenizer.pkl"
|
||||
|
|
|
|||
|
|
@ -33,7 +33,7 @@ from nanochat.core_eval import evaluate_task
|
|||
# ~162MB of data needed to evaluate the CORE metric
|
||||
EVAL_BUNDLE_URL = "https://karpathy-public.s3.us-west-2.amazonaws.com/eval_bundle.zip"
|
||||
|
||||
def place_eval_bundle(file_path):
|
||||
def place_eval_bundle(file_path: Path):
|
||||
# here file_path is the path to the eval_bundle.zip file
|
||||
# we need to unzip it and place it in the base directory
|
||||
base_dir = get_base_dir()
|
||||
|
|
|
|||
|
|
@ -20,6 +20,7 @@ python -m pytest tests/test_rustbpe.py -v -s
|
|||
|
||||
import regex as re
|
||||
from collections import Counter, defaultdict
|
||||
from pathlib import Path
|
||||
import time
|
||||
import rustbpe
|
||||
import tiktoken
|
||||
|
|
@ -426,7 +427,7 @@ class HuggingFaceTokenizer:
|
|||
# Test all of the above
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def enwik8_path():
|
||||
def enwik8_path() -> Path:
|
||||
"""Fixture to download and cache enwik8 dataset."""
|
||||
import zipfile
|
||||
from nanochat.common import get_base_dir
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user