diff --git a/nanochat/dataloader.py b/nanochat/dataloader.py index 4136802..20dd88f 100644 --- a/nanochat/dataloader.py +++ b/nanochat/dataloader.py @@ -5,9 +5,8 @@ import pyarrow.parquet as pq from nanochat.common import get_dist_info from nanochat.dataset import list_parquet_files -from nanochat.tokenizer import get_tokenizer -def tokenizing_distributed_data_loader_with_state(B, T, split, tokenizer_threads=4, tokenizer_batch_size=128, device="cuda", resume_state_dict=None): +def tokenizing_distributed_data_loader_with_state(tokenizer, B, T, split, tokenizer_threads=4, tokenizer_batch_size=128, device="cuda", resume_state_dict=None): """ Stream pretraining text from parquet files, tokenize, yield training batches. @@ -62,8 +61,6 @@ def tokenizing_distributed_data_loader_with_state(B, T, split, tokenizer_threads # Now emit batches of tokens. needed_tokens = B * T + 1 # +1 is because we also need the target at the last token - # get the tokenizer and the bos token - tokenizer = get_tokenizer() bos_token = tokenizer.get_bos_token_id() # scratch buffer holds the tokens for one iteration token_buffer = deque() # we stream tokens on the right and pop from the left diff --git a/nanochat/tokenizer.py b/nanochat/tokenizer.py index 726fb2f..e8ccafa 100644 --- a/nanochat/tokenizer.py +++ b/nanochat/tokenizer.py @@ -103,9 +103,10 @@ class HuggingFaceTokenizer: def id_to_token(self, id): return self.tokenizer.id_to_token(id) - def _encode_one(self, text, prepend=None, append=None): + def _encode_one(self, text, prepend=None, append=None, num_threads=None): # encode a single string # prepend/append can be either a string of a special token or a token id directly. + # num_threads is ignored (only used by the nanochat Tokenizer for parallel encoding) assert isinstance(text, str) ids = [] if prepend is not None: diff --git a/scripts/base_loss.py b/scripts/base_loss.py index 3dbe68f..094299a 100644 --- a/scripts/base_loss.py +++ b/scripts/base_loss.py @@ -5,6 +5,9 @@ Loads a checkpoint, and: Example run as: torchrun --standalone --nproc_per_node=8 -m scripts.base_loss + +To evaluate a HuggingFace model: +python -m scripts.base_loss --hf_path openai-community/gpt2 """ import argparse from contextlib import nullcontext @@ -12,42 +15,98 @@ import torch from nanochat.checkpoint_manager import load_model from nanochat.common import compute_init, print0, compute_cleanup, autodetect_device_type from nanochat.dataloader import tokenizing_distributed_data_loader -from nanochat.tokenizer import get_token_bytes +from nanochat.tokenizer import get_token_bytes, HuggingFaceTokenizer from nanochat.loss_eval import evaluate_bpb from nanochat.engine import Engine +# ----------------------------------------------------------------------------- +# HuggingFace loading utilities, making the APIs match up to those of nanochat + +class ModelWrapper: + """Lightweight wrapper for a HuggingFace model""" + def __init__(self, model, max_seq_len=None): + self.model = model + self.max_seq_len = max_seq_len + + def __call__(self, input_ids, targets=None, loss_reduction='mean'): + logits = self.model(input_ids).logits + if targets is None: + return logits + else: + loss = torch.nn.functional.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1, reduction=loss_reduction) + return loss + + def get_device(self): + return next(self.model.parameters()).device + +def load_hf_model(hf_path: str, device): + print0(f"Loading model from: {hf_path}") + from transformers import AutoModelForCausalLM + model = AutoModelForCausalLM.from_pretrained(hf_path) + model.to(device) + model.eval() + max_seq_len = 1024 if "openai-community/gpt2" in hf_path else None + model = ModelWrapper(model, max_seq_len=max_seq_len) + tokenizer = HuggingFaceTokenizer.from_pretrained(hf_path) + return model, tokenizer + +def get_hf_token_bytes(tokenizer, device="cpu"): + """Compute token_bytes tensor for a HuggingFace tokenizer.""" + vocab_size = tokenizer.tokenizer.get_vocab_size() + token_bytes = torch.zeros(vocab_size, dtype=torch.int64, device=device) + for token_id in range(vocab_size): + token_str = tokenizer.tokenizer.decode([token_id]) + token_bytes[token_id] = len(token_str.encode('utf-8')) # Count UTF-8 bytes + return token_bytes + # CLI arguments parser = argparse.ArgumentParser(description="Evaluate loss on train/val splits and sample from model") parser.add_argument("--device_batch_size", type=int, default=32, help="per-device batch size") -parser.add_argument("--split_tokens", type=int, default=20*524288, help="number of tokens to evaluate per split") +parser.add_argument("--split_tokens", type=int, default=40*524288, help="number of tokens to evaluate per split") parser.add_argument("--model_tag", type=str, default=None, help="model tag for checkpoint directory") parser.add_argument("--model_step", type=int, default=None, help="model step to load") parser.add_argument("--device_type", type=str, default="", help="cuda|cpu|mps (empty = autodetect)") +parser.add_argument("--hf_path", type=str, default=None, help="HuggingFace model path (e.g. openai-community/gpt2)") args = parser.parse_args() # Load the base model and the tokenizer device_type = autodetect_device_type() if args.device_type == "" else args.device_type ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type) -model, tokenizer, meta = load_model("base", device, phase="eval", model_tag=args.model_tag, step=args.model_step) -sequence_len = meta["model_config"]["sequence_len"] # could be arbitrary really +print0(f"Device: {device} | DDP rank: {ddp_rank} | DDP local rank: {ddp_local_rank} | DDP world size: {ddp_world_size}") + +if args.hf_path is not None: + # Load HuggingFace model + model, tokenizer = load_hf_model(args.hf_path, device) + sequence_len = model.max_seq_len if model.max_seq_len else 1024 + token_bytes = get_hf_token_bytes(tokenizer, device=device) + model_name = args.hf_path +else: + # Load local nanochat model + model, tokenizer, meta = load_model("base", device, phase="eval", model_tag=args.model_tag, step=args.model_step) + sequence_len = meta["model_config"]["sequence_len"] + token_bytes = get_token_bytes(device=device) + model_name = f"base_model (step {meta['step']})" + autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type == "cuda" else nullcontext() +print0(f"Evaluating model: {model_name}") + # Evaluate the loss on each split tokens_per_step = args.device_batch_size * sequence_len * ddp_world_size assert args.split_tokens % tokens_per_step == 0, "split_tokens must be divisible by tokens_per_step" steps = args.split_tokens // tokens_per_step -token_bytes = get_token_bytes(device=device) bpb_results = {} for split_name in ["train", "val"]: - loader = tokenizing_distributed_data_loader(args.device_batch_size, sequence_len, split_name, device=device) + loader = tokenizing_distributed_data_loader(tokenizer, args.device_batch_size, sequence_len, split_name, device=device) with autocast_ctx: bpb = evaluate_bpb(model, loader, steps, token_bytes) print0(f"{split_name} bpb: {bpb:.4f}") bpb_results[split_name] = bpb + print0(f"Model: {model_name}, {split_name} bpb: {bpb:.6f}") -# Master process also samples from the model +# Master process also samples from the model (only for nanochat models) samples = [] -if ddp_rank == 0: +if ddp_rank == 0 and args.hf_path is None: prompts = [ "The capital of France is", "The chemical symbol of gold is", @@ -70,6 +129,7 @@ if ddp_rank == 0: from nanochat.report import get_report get_report().log(section="Base model loss", data=[ { + "model": model_name, "train bpb": bpb_results["train"], "val bpb": bpb_results["val"], }, diff --git a/scripts/base_train.py b/scripts/base_train.py index 7a16276..c7c5bba 100644 --- a/scripts/base_train.py +++ b/scripts/base_train.py @@ -210,8 +210,8 @@ if resuming: # Initialize the DataLoaders for train/val tokens_dir = os.path.join(base_dir, "tokenized_data") dataloader_resume_state_dict = None if not resuming else meta_data["dataloader_state_dict"] -train_loader = tokenizing_distributed_data_loader_with_state(args.device_batch_size, args.max_seq_len, split="train", device=device, resume_state_dict=dataloader_resume_state_dict) -build_val_loader = lambda: tokenizing_distributed_data_loader(args.device_batch_size, args.max_seq_len, split="val", device=device) +train_loader = tokenizing_distributed_data_loader_with_state(tokenizer, args.device_batch_size, args.max_seq_len, split="train", device=device, resume_state_dict=dataloader_resume_state_dict) +build_val_loader = lambda: tokenizing_distributed_data_loader(tokenizer, args.device_batch_size, args.max_seq_len, split="val", device=device) x, y, dataloader_state_dict = next(train_loader) # kick off load of the very first batch of data # -----------------------------------------------------------------------------