mirror of
https://github.com/karpathy/nanochat.git
synced 2026-01-24 20:34:23 +00:00
allow base_loss to report the loss of any arbitrary huggingface model similar to base_eval. had to change dataloader to be a lot better and just take tokenizer, not load the nanochat one. much better this way anyway
This commit is contained in:
parent
aa95fb2e03
commit
21608ec51e
|
|
@ -5,9 +5,8 @@ import pyarrow.parquet as pq
|
|||
|
||||
from nanochat.common import get_dist_info
|
||||
from nanochat.dataset import list_parquet_files
|
||||
from nanochat.tokenizer import get_tokenizer
|
||||
|
||||
def tokenizing_distributed_data_loader_with_state(B, T, split, tokenizer_threads=4, tokenizer_batch_size=128, device="cuda", resume_state_dict=None):
|
||||
def tokenizing_distributed_data_loader_with_state(tokenizer, B, T, split, tokenizer_threads=4, tokenizer_batch_size=128, device="cuda", resume_state_dict=None):
|
||||
"""
|
||||
Stream pretraining text from parquet files, tokenize, yield training batches.
|
||||
|
||||
|
|
@ -62,8 +61,6 @@ def tokenizing_distributed_data_loader_with_state(B, T, split, tokenizer_threads
|
|||
|
||||
# Now emit batches of tokens.
|
||||
needed_tokens = B * T + 1 # +1 is because we also need the target at the last token
|
||||
# get the tokenizer and the bos token
|
||||
tokenizer = get_tokenizer()
|
||||
bos_token = tokenizer.get_bos_token_id()
|
||||
# scratch buffer holds the tokens for one iteration
|
||||
token_buffer = deque() # we stream tokens on the right and pop from the left
|
||||
|
|
|
|||
|
|
@ -103,9 +103,10 @@ class HuggingFaceTokenizer:
|
|||
def id_to_token(self, id):
|
||||
return self.tokenizer.id_to_token(id)
|
||||
|
||||
def _encode_one(self, text, prepend=None, append=None):
|
||||
def _encode_one(self, text, prepend=None, append=None, num_threads=None):
|
||||
# encode a single string
|
||||
# prepend/append can be either a string of a special token or a token id directly.
|
||||
# num_threads is ignored (only used by the nanochat Tokenizer for parallel encoding)
|
||||
assert isinstance(text, str)
|
||||
ids = []
|
||||
if prepend is not None:
|
||||
|
|
|
|||
|
|
@ -5,6 +5,9 @@ Loads a checkpoint, and:
|
|||
|
||||
Example run as:
|
||||
torchrun --standalone --nproc_per_node=8 -m scripts.base_loss
|
||||
|
||||
To evaluate a HuggingFace model:
|
||||
python -m scripts.base_loss --hf_path openai-community/gpt2
|
||||
"""
|
||||
import argparse
|
||||
from contextlib import nullcontext
|
||||
|
|
@ -12,42 +15,98 @@ import torch
|
|||
from nanochat.checkpoint_manager import load_model
|
||||
from nanochat.common import compute_init, print0, compute_cleanup, autodetect_device_type
|
||||
from nanochat.dataloader import tokenizing_distributed_data_loader
|
||||
from nanochat.tokenizer import get_token_bytes
|
||||
from nanochat.tokenizer import get_token_bytes, HuggingFaceTokenizer
|
||||
from nanochat.loss_eval import evaluate_bpb
|
||||
from nanochat.engine import Engine
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# HuggingFace loading utilities, making the APIs match up to those of nanochat
|
||||
|
||||
class ModelWrapper:
|
||||
"""Lightweight wrapper for a HuggingFace model"""
|
||||
def __init__(self, model, max_seq_len=None):
|
||||
self.model = model
|
||||
self.max_seq_len = max_seq_len
|
||||
|
||||
def __call__(self, input_ids, targets=None, loss_reduction='mean'):
|
||||
logits = self.model(input_ids).logits
|
||||
if targets is None:
|
||||
return logits
|
||||
else:
|
||||
loss = torch.nn.functional.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1, reduction=loss_reduction)
|
||||
return loss
|
||||
|
||||
def get_device(self):
|
||||
return next(self.model.parameters()).device
|
||||
|
||||
def load_hf_model(hf_path: str, device):
|
||||
print0(f"Loading model from: {hf_path}")
|
||||
from transformers import AutoModelForCausalLM
|
||||
model = AutoModelForCausalLM.from_pretrained(hf_path)
|
||||
model.to(device)
|
||||
model.eval()
|
||||
max_seq_len = 1024 if "openai-community/gpt2" in hf_path else None
|
||||
model = ModelWrapper(model, max_seq_len=max_seq_len)
|
||||
tokenizer = HuggingFaceTokenizer.from_pretrained(hf_path)
|
||||
return model, tokenizer
|
||||
|
||||
def get_hf_token_bytes(tokenizer, device="cpu"):
|
||||
"""Compute token_bytes tensor for a HuggingFace tokenizer."""
|
||||
vocab_size = tokenizer.tokenizer.get_vocab_size()
|
||||
token_bytes = torch.zeros(vocab_size, dtype=torch.int64, device=device)
|
||||
for token_id in range(vocab_size):
|
||||
token_str = tokenizer.tokenizer.decode([token_id])
|
||||
token_bytes[token_id] = len(token_str.encode('utf-8')) # Count UTF-8 bytes
|
||||
return token_bytes
|
||||
|
||||
# CLI arguments
|
||||
parser = argparse.ArgumentParser(description="Evaluate loss on train/val splits and sample from model")
|
||||
parser.add_argument("--device_batch_size", type=int, default=32, help="per-device batch size")
|
||||
parser.add_argument("--split_tokens", type=int, default=20*524288, help="number of tokens to evaluate per split")
|
||||
parser.add_argument("--split_tokens", type=int, default=40*524288, help="number of tokens to evaluate per split")
|
||||
parser.add_argument("--model_tag", type=str, default=None, help="model tag for checkpoint directory")
|
||||
parser.add_argument("--model_step", type=int, default=None, help="model step to load")
|
||||
parser.add_argument("--device_type", type=str, default="", help="cuda|cpu|mps (empty = autodetect)")
|
||||
parser.add_argument("--hf_path", type=str, default=None, help="HuggingFace model path (e.g. openai-community/gpt2)")
|
||||
args = parser.parse_args()
|
||||
|
||||
# Load the base model and the tokenizer
|
||||
device_type = autodetect_device_type() if args.device_type == "" else args.device_type
|
||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
|
||||
model, tokenizer, meta = load_model("base", device, phase="eval", model_tag=args.model_tag, step=args.model_step)
|
||||
sequence_len = meta["model_config"]["sequence_len"] # could be arbitrary really
|
||||
print0(f"Device: {device} | DDP rank: {ddp_rank} | DDP local rank: {ddp_local_rank} | DDP world size: {ddp_world_size}")
|
||||
|
||||
if args.hf_path is not None:
|
||||
# Load HuggingFace model
|
||||
model, tokenizer = load_hf_model(args.hf_path, device)
|
||||
sequence_len = model.max_seq_len if model.max_seq_len else 1024
|
||||
token_bytes = get_hf_token_bytes(tokenizer, device=device)
|
||||
model_name = args.hf_path
|
||||
else:
|
||||
# Load local nanochat model
|
||||
model, tokenizer, meta = load_model("base", device, phase="eval", model_tag=args.model_tag, step=args.model_step)
|
||||
sequence_len = meta["model_config"]["sequence_len"]
|
||||
token_bytes = get_token_bytes(device=device)
|
||||
model_name = f"base_model (step {meta['step']})"
|
||||
|
||||
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type == "cuda" else nullcontext()
|
||||
|
||||
print0(f"Evaluating model: {model_name}")
|
||||
|
||||
# Evaluate the loss on each split
|
||||
tokens_per_step = args.device_batch_size * sequence_len * ddp_world_size
|
||||
assert args.split_tokens % tokens_per_step == 0, "split_tokens must be divisible by tokens_per_step"
|
||||
steps = args.split_tokens // tokens_per_step
|
||||
token_bytes = get_token_bytes(device=device)
|
||||
bpb_results = {}
|
||||
for split_name in ["train", "val"]:
|
||||
loader = tokenizing_distributed_data_loader(args.device_batch_size, sequence_len, split_name, device=device)
|
||||
loader = tokenizing_distributed_data_loader(tokenizer, args.device_batch_size, sequence_len, split_name, device=device)
|
||||
with autocast_ctx:
|
||||
bpb = evaluate_bpb(model, loader, steps, token_bytes)
|
||||
print0(f"{split_name} bpb: {bpb:.4f}")
|
||||
bpb_results[split_name] = bpb
|
||||
print0(f"Model: {model_name}, {split_name} bpb: {bpb:.6f}")
|
||||
|
||||
# Master process also samples from the model
|
||||
# Master process also samples from the model (only for nanochat models)
|
||||
samples = []
|
||||
if ddp_rank == 0:
|
||||
if ddp_rank == 0 and args.hf_path is None:
|
||||
prompts = [
|
||||
"The capital of France is",
|
||||
"The chemical symbol of gold is",
|
||||
|
|
@ -70,6 +129,7 @@ if ddp_rank == 0:
|
|||
from nanochat.report import get_report
|
||||
get_report().log(section="Base model loss", data=[
|
||||
{
|
||||
"model": model_name,
|
||||
"train bpb": bpb_results["train"],
|
||||
"val bpb": bpb_results["val"],
|
||||
},
|
||||
|
|
|
|||
|
|
@ -210,8 +210,8 @@ if resuming:
|
|||
# Initialize the DataLoaders for train/val
|
||||
tokens_dir = os.path.join(base_dir, "tokenized_data")
|
||||
dataloader_resume_state_dict = None if not resuming else meta_data["dataloader_state_dict"]
|
||||
train_loader = tokenizing_distributed_data_loader_with_state(args.device_batch_size, args.max_seq_len, split="train", device=device, resume_state_dict=dataloader_resume_state_dict)
|
||||
build_val_loader = lambda: tokenizing_distributed_data_loader(args.device_batch_size, args.max_seq_len, split="val", device=device)
|
||||
train_loader = tokenizing_distributed_data_loader_with_state(tokenizer, args.device_batch_size, args.max_seq_len, split="train", device=device, resume_state_dict=dataloader_resume_state_dict)
|
||||
build_val_loader = lambda: tokenizing_distributed_data_loader(tokenizer, args.device_batch_size, args.max_seq_len, split="val", device=device)
|
||||
x, y, dataloader_state_dict = next(train_loader) # kick off load of the very first batch of data
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user