mirror of
https://github.com/karpathy/nanochat.git
synced 2026-03-13 08:23:12 +00:00
- Added `dev/runmps_evals.sh` for evaluating checkpoints and logging results to W&B. - Introduced `dev/runmps.sh` for orchestrating training stages with W&B support. - Updated `.gitignore` to include `wandb/` and `.runmps_wandb_ids`. - Changed permissions for `dev/runcpu.sh` and added executable flag. - Enhanced existing scripts to log metrics to W&B during training and evaluation processes.
103 lines
3.8 KiB
Python
103 lines
3.8 KiB
Python
"""
|
|
Loads a checkpoint, and:
|
|
- Evaluates the loss on a larger chunk of train/val splits
|
|
- Samples from the model
|
|
|
|
Example run as:
|
|
torchrun --standalone --nproc_per_node=8 -m scripts.base_loss
|
|
"""
|
|
import os
|
|
from contextlib import nullcontext
|
|
import torch
|
|
import wandb
|
|
|
|
from nanochat.common import DummyWandb
|
|
from nanochat.checkpoint_manager import load_model
|
|
from nanochat.common import compute_init, print0, compute_cleanup, autodetect_device_type
|
|
from nanochat.dataloader import tokenizing_distributed_data_loader
|
|
from nanochat.tokenizer import get_token_bytes
|
|
from nanochat.loss_eval import evaluate_bpb
|
|
from nanochat.engine import Engine
|
|
|
|
# Configuration
|
|
device_batch_size = 32
|
|
split_tokens = 20*524288 # number of tokens to evaluate per split
|
|
model_tag = None # optional model tag for the output directory name
|
|
model_step = None # optional model step for the output directory name
|
|
device_type = "" # cuda|cpu|mps (empty => autodetect)
|
|
exec(open(os.path.join('nanochat', 'configurator.py')).read()) # overrides from command line or config file
|
|
|
|
# Load the base model and the tokenizer
|
|
device_type = autodetect_device_type() if device_type == "" else device_type
|
|
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
|
|
model, tokenizer, meta = load_model("base", device, phase="eval", model_tag=model_tag, step=model_step)
|
|
sequence_len = meta["model_config"]["sequence_len"] # could be arbitrary really
|
|
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type == "cuda" else nullcontext()
|
|
|
|
# Evaluate the loss on each split
|
|
tokens_per_step = device_batch_size * sequence_len * ddp_world_size
|
|
assert split_tokens % tokens_per_step == 0, "split_tokens must be divisible by tokens_per_step"
|
|
steps = split_tokens // tokens_per_step
|
|
token_bytes = get_token_bytes(device=device)
|
|
use_wandb = bool(os.environ.get("WANDB_RUN_ID"))
|
|
wandb_run = DummyWandb()
|
|
if use_wandb:
|
|
wandb_kwargs = {
|
|
"project": os.environ.get("WANDB_PROJECT", "nanochat"),
|
|
"name": os.environ.get("WANDB_EVAL_RUN", "base-eval"),
|
|
"id": os.environ.get("WANDB_RUN_ID"),
|
|
"resume": "allow",
|
|
"reinit": True,
|
|
}
|
|
wandb_kwargs = {k: v for k, v in wandb_kwargs.items() if v is not None}
|
|
wandb_run = wandb.init(**wandb_kwargs)
|
|
|
|
bpb_results = {}
|
|
for split_name in ["train", "val"]:
|
|
loader = tokenizing_distributed_data_loader(device_batch_size, sequence_len, split_name, device=device)
|
|
with autocast_ctx:
|
|
bpb = evaluate_bpb(model, loader, steps, token_bytes)
|
|
print0(f"{split_name} bpb: {bpb:.4f}")
|
|
bpb_results[split_name] = bpb
|
|
|
|
# Master process also samples from the model
|
|
samples = []
|
|
if ddp_rank == 0:
|
|
prompts = [
|
|
"The capital of France is",
|
|
"The chemical symbol of gold is",
|
|
"If yesterday was Friday, then tomorrow will be",
|
|
"The opposite of hot is",
|
|
"The planets of the solar system are:",
|
|
"My favorite color is",
|
|
"If 5*x + 3 = 13, then x is",
|
|
]
|
|
engine = Engine(model, tokenizer)
|
|
for prompt in prompts:
|
|
tokens = tokenizer(prompt, prepend="<|bos|>")
|
|
with autocast_ctx:
|
|
sample, _ = engine.generate_batch(tokens, num_samples=1, max_tokens=16, temperature=0)
|
|
sample_str = tokenizer.decode(sample[0])
|
|
print0(sample_str)
|
|
samples.append(sample_str)
|
|
|
|
# Log to report
|
|
from nanochat.report import get_report
|
|
get_report().log(section="Base model loss", data=[
|
|
{
|
|
"train bpb": bpb_results["train"],
|
|
"val bpb": bpb_results["val"],
|
|
},
|
|
{f"sample {i}": sample for i, sample in enumerate(samples)},
|
|
])
|
|
|
|
if use_wandb:
|
|
wandb_run.log({
|
|
"base_loss/train_bpb": bpb_results["train"],
|
|
"base_loss/val_bpb": bpb_results["val"],
|
|
}, step=meta.get("step"))
|
|
wandb_run.finish()
|
|
|
|
# Cleanup
|
|
compute_cleanup()
|