Merge karpathy/cpu-mps-dev , adding the ability to run on CPU, on MPS, or on CUDA, with autodetect. Gnarly PR, nonzero chance I broke something.

add cpu|mps support
This commit is contained in:
Andrej 2025-10-21 10:26:04 -07:00 committed by GitHub
commit 33e8a27f91
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
19 changed files with 266 additions and 93 deletions

View File

@ -95,7 +95,7 @@ And a bit more about computing environments that will run nanochat:
## Running on CPU / MPS
If you'd like to tinker with nanochat on your Macbook or a CPU machine, there is a work in progress [CPU|MPS PR](https://github.com/karpathy/nanochat/pull/88) up here. If you're on Macbook, use `--device_type=mps` when running `base_train.py`. See the PR and its diff for more. You're not going to get too far without GPU nodes, but at least you'll be able to run the code and maybe train a very tiny LLM with some patience.
nanochat cn be run on CPU or on MPS (if you're on Macbook), and will automatically try to detect what device is best to run on. You're not going to get too far without GPUs, but at least you'll be able to run the code paths and maybe train a tiny LLM with some patience. For an example of how to make all the run commands much smaller (feel free to tune!), you can refer to [dev/runcpu.sh](dev/runcpu.sh) file. You'll see that I'm essentially restricting all scripts to train smaller models, to run for shorter number of iterations, etc. This functionality is new, slightly gnarly (touched a lot of code), and was merged in this [CPU|MPS PR](https://github.com/karpathy/nanochat/pull/88) on Oct 21, 2025.
## Customization

84
dev/runcpu.sh Normal file
View File

@ -0,0 +1,84 @@
#!/bin/bash
# Showing an example run for exercising some of the code paths on the CPU (or MPS on Macbooks)
# Run as:
# bash dev/cpu_demo_run.sh
# NOTE: Training LLMs requires GPU compute and $$$. You will not get far on your Macbook.
# Think of this run as educational/fun demo, not something you should expect to work well.
# This is also why I hide this script away in dev/
# all the setup stuff
export OMP_NUM_THREADS=1
NANOCHAT_BASE_DIR="$HOME/.cache/nanochat"
mkdir -p $NANOCHAT_BASE_DIR
command -v uv &> /dev/null || curl -LsSf https://astral.sh/uv/install.sh | sh
[ -d ".venv" ] || uv venv
uv sync
source .venv/bin/activate
if [ -z "$WANDB_RUN" ]; then
WANDB_RUN=dummy
fi
curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y
source "$HOME/.cargo/env"
uv run maturin develop --release --manifest-path rustbpe/Cargo.toml
EVAL_BUNDLE_URL=https://karpathy-public.s3.us-west-2.amazonaws.com/eval_bundle.zip
if [ ! -d "$NANOCHAT_BASE_DIR/eval_bundle" ]; then
curl -L -o eval_bundle.zip $EVAL_BUNDLE_URL
unzip -q eval_bundle.zip
rm eval_bundle.zip
mv eval_bundle $NANOCHAT_BASE_DIR
fi
# wipe the report
python -m nanochat.report reset
# train tokenizer on ~1B characters
python -m nanochat.dataset -n 4
python -m scripts.tok_train --max_chars=1000000000
python -m scripts.tok_eval
# train a very small 4 layer model on the CPU
# each optimization step processes a single sequence of 1024 tokens
# we only run 50 steps of optimization (bump this to get better results)
python -m scripts.base_train \
--depth=4 \
--max_seq_len=1024 \
--device_batch_size=1 \
--total_batch_size=1024 \
--eval_every=50 \
--eval_tokens=4096 \
--core_metric_every=50 \
--core_metric_max_per_task=12 \
--sample_every=50 \
--num_iterations=50
python -m scripts.base_loss --device_batch_size=1 --split_tokens=4096
python -m scripts.base_eval --max-per-task=5
# midtraining
python -m scripts.mid_train \
--max_seq_len=1024 \
--device_batch_size=1 \
--eval_every=50 \
--eval_tokens=4096 \
--total_batch_size=1024 \
--num_iterations=100
# eval results will be terrible, this is just to execute the code paths.
# note that we lower the execution memory limit to 1MB to avoid warnings on smaller systems
python -m scripts.chat_eval --source=mid --max-new-tokens=128 --max-problems=20
# SFT
python -m scripts.chat_sft \
--device_batch_size=1 \
--target_examples_per_step=4 \
--num_iterations=100 \
--eval_steps=4 \
--eval_metrics_max_problems=16
# Chat CLI
# python -m scripts.chat_cli -p "Why is the sky blue?"
# Chat Web
# python -m scripts.chat_web
python -m nanochat.report generate

View File

@ -89,32 +89,46 @@ def get_dist_info():
else:
return False, 0, 0, 1
def compute_init():
def autodetect_device_type():
# prefer to use CUDA if available, otherwise use MPS, otherwise fallback on CPU
if torch.cuda.is_available():
device_type = "cuda"
elif torch.backends.mps.is_available():
device_type = "mps"
else:
device_type = "cpu"
print0(f"Autodetected device type: {device_type}")
return device_type
def compute_init(device_type="cuda"): # cuda|cpu|mps
"""Basic initialization that we keep doing over and over, so make common."""
# CUDA is currently required
assert torch.cuda.is_available(), "CUDA is needed for a distributed run atm"
assert device_type in ["cuda", "mps", "cpu"], "Invalid device type atm"
if device_type == "cuda":
assert torch.cuda.is_available(), "Your PyTorch installation is not configured for CUDA but device_type is 'cuda'"
if device_type == "mps":
assert torch.backends.mps.is_available(), "Your PyTorch installation is not configured for MPS but device_type is 'mps'"
# Reproducibility
torch.manual_seed(42)
if device_type == "cuda":
torch.cuda.manual_seed(42)
# skipping full reproducibility for now, possibly investigate slowdown later
# torch.use_deterministic_algorithms(True)
# torch.backends.cudnn.deterministic = True
# torch.backends.cudnn.benchmark = False
# Precision
if device_type == "cuda":
torch.set_float32_matmul_precision("high") # uses tf32 instead of fp32 for matmuls
# Distributed setup: Distributed Data Parallel (DDP), optional
# Distributed setup: Distributed Data Parallel (DDP), optional, and requires CUDA
ddp, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info()
if ddp:
if ddp and device_type == "cuda":
device = torch.device("cuda", ddp_local_rank)
torch.cuda.set_device(device) # make "cuda" default to this device
dist.init_process_group(backend="nccl", device_id=device)
dist.barrier()
else:
device = torch.device("cuda")
device = torch.device(device_type) # mps|cpu
if ddp_rank == 0:
logger.info(f"Distributed world size: {ddp_world_size}")

View File

@ -6,7 +6,7 @@ from nanochat.common import get_dist_info
from nanochat.dataset import parquets_iter_batched
from nanochat.tokenizer import get_tokenizer
def tokenizing_distributed_data_loader(B, T, split, tokenizer_threads=4, tokenizer_batch_size=128):
def tokenizing_distributed_data_loader(B, T, split, tokenizer_threads=4, tokenizer_batch_size=128, device="cuda"):
"""Stream pretraining text from parquet files, tokenize, yield training batches."""
assert split in ["train", "val"], "split must be 'train' or 'val'"
ddp, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info()
@ -43,6 +43,6 @@ def tokenizing_distributed_data_loader(B, T, split, tokenizer_threads=4, tokeniz
inputs_cpu = scratch[:-1].to(dtype=torch.int32)
targets_cpu = scratch[1:]
# Reshape to 2D and move to GPU async
inputs = inputs_cpu.view(B, T).to(device="cuda", dtype=torch.int32, non_blocking=True)
targets = targets_cpu.view(B, T).to(device="cuda", dtype=torch.int64, non_blocking=True)
inputs = inputs_cpu.view(B, T).to(device=device, dtype=torch.int32, non_blocking=True)
targets = targets_cpu.view(B, T).to(device=device, dtype=torch.int64, non_blocking=True)
yield inputs, targets

View File

@ -146,12 +146,11 @@ def reliability_guard(maximum_memory_bytes: Optional[int] = None):
with caution.
"""
if maximum_memory_bytes is not None:
if platform.uname().system != "Darwin":
# These resource limit calls seem to fail on macOS (Darwin), skip?
import resource
resource.setrlimit(resource.RLIMIT_AS, (maximum_memory_bytes, maximum_memory_bytes))
resource.setrlimit(resource.RLIMIT_DATA, (maximum_memory_bytes, maximum_memory_bytes))
if not platform.uname().system == "Darwin":
resource.setrlimit(resource.RLIMIT_STACK, (maximum_memory_bytes, maximum_memory_bytes))
faulthandler.disable()
@ -225,6 +224,7 @@ def _unsafe_execute(code: str, timeout: float, maximum_memory_bytes: Optional[in
rmtree = shutil.rmtree
rmdir = os.rmdir
chdir = os.chdir
unlink = os.unlink
# Disable functionalities that can make destructive changes to the test.
reliability_guard(maximum_memory_bytes=maximum_memory_bytes)
@ -282,6 +282,7 @@ def _unsafe_execute(code: str, timeout: float, maximum_memory_bytes: Optional[in
shutil.rmtree = rmtree
os.rmdir = rmdir
os.chdir = chdir
os.unlink = unlink
def execute_code(

View File

@ -169,8 +169,6 @@ class GPT(nn.Module):
cos, sin = self._precompute_rotary_embeddings(self.rotary_seq_len, head_dim)
self.register_buffer("cos", cos, persistent=False) # persistent=False means it's not saved to the checkpoint
self.register_buffer("sin", sin, persistent=False)
# Cast the embeddings from fp32 to bf16: optim can tolerate it and it saves memory: both in the model and the activations
self.transformer.wte.to(dtype=torch.bfloat16)
def init_weights(self):
self.apply(self._init_weights)
@ -184,6 +182,9 @@ class GPT(nn.Module):
head_dim = self.config.n_embd // self.config.n_head
cos, sin = self._precompute_rotary_embeddings(self.rotary_seq_len, head_dim)
self.cos, self.sin = cos, sin
# Cast the embeddings from fp32 to bf16: optim can tolerate it and it saves memory: both in the model and the activations
if self.transformer.wte.weight.device.type == "cuda":
self.transformer.wte.to(dtype=torch.bfloat16)
def _init_weights(self, module):
if isinstance(module, nn.Linear):

View File

@ -33,7 +33,7 @@ def evaluate_bpb(model, batches, steps, token_bytes):
loss2d = model(x, y, loss_reduction='none') # (B, T)
loss2d = loss2d.view(-1) # flatten
y = y.view(-1) # flatten
if (y < 0).any():
if (y.int() < 0).any(): # mps does not currently have kernel for < 0 for int64, only int32
# slightly more complex code path if some target tokens are ignore_index (e.g. -1)
# any target token < 0 is to be ignored: do NOT index token_bytes with negatives
valid = y >= 0

View File

@ -283,6 +283,10 @@ class Report:
# capture bloat data for summary later (the stuff after Bloat header and until \n\n)
bloat_data = re.search(r"### Bloat\n(.*?)\n\n", header_content, re.DOTALL)
bloat_data = bloat_data.group(1) if bloat_data else ""
else:
start_time = None # will cause us to not write the total wall clock time
bloat_data = "[bloat data missing]"
print(f"Warning: {header_file} does not exist. Did you forget to run `nanochat reset`?")
# process all the individual sections
for file_name in EXPECTED_FILES:
section_file = os.path.join(report_dir, file_name)

View File

@ -11,6 +11,7 @@ dependencies = [
"numpy==1.26.4",
"psutil>=7.1.0",
"regex>=2025.9.1",
"setuptools>=80.9.0",
"tiktoken>=0.11.0",
"tokenizers>=0.22.0",
"torch>=2.8.0",
@ -22,17 +23,6 @@ dependencies = [
requires = ["maturin>=1.7,<2.0"]
build-backend = "maturin"
# target torch to cuda 12.8
[tool.uv.sources]
torch = [
{ index = "pytorch-cu128" },
]
[[tool.uv.index]]
name = "pytorch-cu128"
url = "https://download.pytorch.org/whl/cu128"
explicit = true
[tool.maturin]
module-name = "rustbpe"
bindings = "pyo3"
@ -53,3 +43,20 @@ testpaths = ["tests"]
python_files = ["test_*.py"]
python_classes = ["Test*"]
python_functions = ["test_*"]
# target torch to cuda 12.8
[tool.uv.sources]
torch = [
{ index = "pytorch-cpu", marker = "sys_platform != 'linux'" },
{ index = "pytorch-cu128", marker = "sys_platform == 'linux'" },
]
[[tool.uv.index]]
name = "pytorch-cpu"
url = "https://download.pytorch.org/whl/cpu"
explicit = true
[[tool.uv.index]]
name = "pytorch-cu128"
url = "https://download.pytorch.org/whl/cu128"
explicit = true

View File

@ -1,3 +1,5 @@
#!/bin/bash
# The $1000 tier of nanochat
# Designed to run end-to-end for $1000/24 ~= 41.6 hours on an 8XH100 node
# A bit sparser on comments, see speedrun.sh for more detail

View File

@ -15,11 +15,12 @@ import time
import json
import random
import yaml
from contextlib import nullcontext
import pandas as pd
import torch
from nanochat.common import compute_init, compute_cleanup, print0, get_base_dir
from nanochat.common import compute_init, compute_cleanup, print0, get_base_dir, autodetect_device_type
from nanochat.tokenizer import HuggingFaceTokenizer
from nanochat.checkpoint_manager import load_model
from nanochat.core_eval import evaluate_task
@ -118,16 +119,21 @@ def load_hf_model(hf_path: str, device):
# -----------------------------------------------------------------------------
def main():
assert len(sys.argv) in [1, 2], "Usage: python base_eval.py [hf_path]"
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--hf-path', type=str, default=None, help='HuggingFace model path to evaluate')
parser.add_argument('--max-per-task', type=int, default=-1, help='Max examples per task to evaluate (-1 = disable)')
args = parser.parse_args()
# distributed / precision setup
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init()
autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16)
device_type = autodetect_device_type()
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type == "cuda" else nullcontext()
# Load model and tokenizer from command line or from file system
if len(sys.argv) >= 2:
if args.hf_path is not None:
# atm assume that if a path is given, it's a huggingface model path
hf_path = sys.argv[1]
hf_path = args.hf_path
print0(f"Loading huggingface model from: {hf_path}")
model, tokenizer = load_hf_model(hf_path, device)
model_name = hf_path # just for logging
@ -140,7 +146,7 @@ def main():
# Evaluate the model
with autocast_ctx:
out = evaluate_model(model, tokenizer, device)
out = evaluate_model(model, tokenizer, device, max_per_task=args.max_per_task)
# Write out the results to a csv file
core_metric = None

View File

@ -7,9 +7,10 @@ Example run as:
torchrun --standalone --nproc_per_node=8 -m scripts.base_loss
"""
import os
from contextlib import nullcontext
import torch
from nanochat.checkpoint_manager import load_model
from nanochat.common import compute_init, print0, compute_cleanup
from nanochat.common import compute_init, print0, compute_cleanup, autodetect_device_type
from nanochat.dataloader import tokenizing_distributed_data_loader
from nanochat.tokenizer import get_token_bytes
from nanochat.loss_eval import evaluate_bpb
@ -20,15 +21,15 @@ device_batch_size = 32
split_tokens = 20*524288 # number of tokens to evaluate per split
model_tag = None # optional model tag for the output directory name
model_step = None # optional model step for the output directory name
device_type = "" # cuda|cpu|mps (empty => autodetect)
exec(open(os.path.join('nanochat', 'configurator.py')).read()) # overrides from command line or config file
# Load the base model and the tokenizer
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init()
device_type = autodetect_device_type() if device_type == "" else device_type
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
model, tokenizer, meta = load_model("base", device, phase="eval", model_tag=model_tag, step=model_step)
sequence_len = meta["model_config"]["sequence_len"] # could be arbitrary really
# Set up the precision we'll run with
autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16)
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type == "cuda" else nullcontext()
# Evaluate the loss on each split
tokens_per_step = device_batch_size * sequence_len * ddp_world_size
@ -37,7 +38,7 @@ steps = split_tokens // tokens_per_step
token_bytes = get_token_bytes(device=device)
bpb_results = {}
for split_name in ["train", "val"]:
loader = tokenizing_distributed_data_loader(device_batch_size, sequence_len, split_name)
loader = tokenizing_distributed_data_loader(device_batch_size, sequence_len, split_name, device=device)
with autocast_ctx:
bpb = evaluate_bpb(model, loader, steps, token_bytes)
print0(f"{split_name} bpb: {bpb:.4f}")

View File

@ -6,17 +6,22 @@ python base_train.py
or distributed as:
torchrun --nproc_per_node=8 base_train.py
If you are only on CPU/Macbook, you'll want to train a much much smaller LLM. Example:
python -m scripts.base_train --depth=4 --max_seq_len=512 --device_batch_size=1 --eval_tokens=512 --core_metric_every=-1 --total_batch_size=512 --num_iterations=20
"""
import os
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
import time
from contextlib import nullcontext
import wandb
import torch
from nanochat.gpt import GPT, GPTConfig
from nanochat.dataloader import tokenizing_distributed_data_loader
from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, print_banner, get_base_dir
from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, print_banner, get_base_dir, autodetect_device_type
from nanochat.tokenizer import get_tokenizer, get_token_bytes
from nanochat.checkpoint_manager import save_checkpoint
from nanochat.loss_eval import evaluate_bpb
@ -27,6 +32,8 @@ print_banner()
# -----------------------------------------------------------------------------
# User settings
run = "dummy" # wandb run name default ("dummy" is special - we won't log to wandb)
# Runtime
device_type = "" # cuda|cpu|mps (empty => autodetect good device type default, in order: CUDA > MPS > CPU)
# Model architecture
depth = 20 # the depth of the Transformer model to train, rest of the kwargs are derived
max_seq_len = 2048 # max context length
@ -45,7 +52,7 @@ grad_clip = 1.0 # gradient clipping value (0.0 = disabled)
# Evaluation
eval_every = 250 # every how many steps to evaluate the model for val bpb
eval_tokens = 20*524288 # number of tokens to evaluate val loss on
core_metric_every = 2000 # every how many steps to evaluate the core metric
core_metric_every = 2000 # every how many steps to evaluate the core metric (-1 = disable)
core_metric_max_per_task = 500 # examples per task in estimating the core metric
sample_every = 2000 # every how many steps to sample from the model
# Output
@ -57,9 +64,12 @@ user_config = {k: globals()[k] for k in config_keys} # will be useful for loggin
# -----------------------------------------------------------------------------
# Compute init
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init()
device_type = autodetect_device_type() if device_type == "" else device_type
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
master_process = ddp_rank == 0 # this process will do logging, checkpointing etc.
autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16)
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type == "cuda" else nullcontext()
synchronize = torch.cuda.synchronize if device_type == "cuda" else lambda: None
get_max_memory = torch.cuda.max_memory_allocated if device_type == "cuda" else lambda: 0
# wandb logging init
use_dummy_wandb = run == "dummy" or not master_process
@ -96,7 +106,7 @@ model_config_kwargs = dict(sequence_len=max_seq_len, vocab_size=vocab_size, n_la
with torch.device("meta"):
model_config = GPTConfig(**model_config_kwargs)
model = GPT(model_config)
model.to_empty(device="cuda")
model.to_empty(device=device)
model.init_weights()
orig_model = model # original, uncompiled model, for saving raw model state_dict
model = torch.compile(model, dynamic=False) # TODO: dynamic True/False think through
@ -133,8 +143,8 @@ adamw_optimizer, muon_optimizer = optimizers
# Initialize the DataLoaders for train/val
base_dir = get_base_dir()
tokens_dir = os.path.join(base_dir, "tokenized_data")
train_loader = tokenizing_distributed_data_loader(device_batch_size, max_seq_len, split="train")
build_val_loader = lambda: tokenizing_distributed_data_loader(device_batch_size, max_seq_len, split="val")
train_loader = tokenizing_distributed_data_loader(device_batch_size, max_seq_len, split="train", device=device)
build_val_loader = lambda: tokenizing_distributed_data_loader(device_batch_size, max_seq_len, split="val", device=device)
x, y = next(train_loader) # kick off load of the very first batch of data
# -----------------------------------------------------------------------------
@ -193,7 +203,8 @@ for step in range(num_iterations + 1):
# once in a while: estimate the CORE metric (all ranks participate)
# use the original uncompiled model because the inputs keep changing shape
if last_step or (step > 0 and step % core_metric_every == 0):
results = {}
if core_metric_every > 0 and (last_step or (step > 0 and step % core_metric_every == 0)):
model.eval()
with autocast_ctx:
results = evaluate_model(orig_model, tokenizer, device, max_per_task=core_metric_max_per_task)
@ -252,7 +263,7 @@ for step in range(num_iterations + 1):
# -------------------------------------------------------------------------
# single training step
# evaluate the gradient
torch.cuda.synchronize()
synchronize()
t0 = time.time()
for micro_step in range(grad_accum_steps):
with autocast_ctx:
@ -275,7 +286,7 @@ for step in range(num_iterations + 1):
for opt in optimizers:
opt.step()
model.zero_grad(set_to_none=True)
torch.cuda.synchronize()
synchronize()
t1 = time.time()
dt = t1 - t0
# -------------------------------------------------------------------------
@ -304,7 +315,7 @@ for step in range(num_iterations + 1):
})
# print a few more stats
print0(f"Peak memory usage: {torch.cuda.max_memory_allocated() / 1024 / 1024:.2f}MiB")
print0(f"Peak memory usage: {get_max_memory() / 1024 / 1024:.2f}MiB")
print0(f"Total training time: {total_training_time/60:.2f}m")
print0(f"Minimum validation bpb: {min_val_bpb:.4f}")
@ -326,11 +337,11 @@ get_report().log(section="Base model training", data=[
{ # stats about training outcomes
"Minimum validation bpb": min_val_bpb,
"Final validation bpb": val_bpb,
"CORE metric estimate": results["core_metric"],
"CORE metric estimate": results.get("core_metric", None),
"MFU %": f"{mfu:.2f}%",
"Total training flops": f"{flops_so_far:e}",
"Total training time": f"{total_training_time/60:.2f}m",
"Peak memory usage": f"{torch.cuda.max_memory_allocated() / 1024 / 1024:.2f}MiB",
"Peak memory usage": f"{get_max_memory() / 1024 / 1024:.2f}MiB",
}
])

View File

@ -6,7 +6,8 @@ python -m scripts.chat_cli -i mid
"""
import argparse
import torch
from nanochat.common import compute_init
from nanochat.common import compute_init, autodetect_device_type
from contextlib import nullcontext
from nanochat.engine import Engine
from nanochat.checkpoint_manager import load_model
@ -17,11 +18,16 @@ parser.add_argument('-s', '--step', type=int, default=None, help='Step to load')
parser.add_argument('-p', '--prompt', type=str, default='', help='Prompt the model, get a single response back')
parser.add_argument('-t', '--temperature', type=float, default=0.6, help='Temperature for generation')
parser.add_argument('-k', '--top-k', type=int, default=50, help='Top-k sampling parameter')
parser.add_argument('--device-type', type=str, default='', choices=['cuda', 'cpu', 'mps'], help='Device type for evaluation: cuda|cpu|mps. empty => autodetect')
parser.add_argument('-d', '--dtype', type=str, default='bfloat16', choices=['float32', 'bfloat16'])
args = parser.parse_args()
# Init the model and tokenizer
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init()
autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16)
device_type = autodetect_device_type() if args.device_type == "" else args.device_type
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
ptdtype = torch.float32 if args.dtype == 'float32' else torch.bfloat16
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype) if device_type == "cuda" else nullcontext()
model, tokenizer, meta = load_model(args.source, device, phase="eval", model_tag=args.model_tag, step=args.step)
# Special tokens for the chat state machine

View File

@ -10,11 +10,12 @@ torchrun --nproc_per_node=8 -m scripts.chat_eval -- -a ARC-Easy
import argparse
from functools import partial
from contextlib import nullcontext
import torch
import torch.distributed as dist
from nanochat.common import compute_init, compute_cleanup, get_dist_info, print0
from nanochat.common import compute_init, compute_cleanup, get_dist_info, print0, autodetect_device_type
from nanochat.checkpoint_manager import load_model
from nanochat.engine import Engine
@ -191,11 +192,13 @@ if __name__ == "__main__":
parser.add_argument('-g', '--model-tag', type=str, default=None, help='Model tag to load')
parser.add_argument('-s', '--step', type=int, default=None, help='Step to load')
parser.add_argument('-x', '--max-problems', type=int, default=None, help='Max problems to evaluate')
parser.add_argument('--device-type', type=str, default='', choices=['cuda', 'cpu', 'mps'], help='Device type for evaluation: cuda|cpu|mps. empty => autodetect')
args = parser.parse_args()
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init()
device_type = autodetect_device_type() if args.device_type == "" else args.device_type
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
ptdtype = torch.float32 if args.dtype == 'float32' else torch.bfloat16
autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=ptdtype)
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype) if device_type == "cuda" else nullcontext()
model, tokenizer, meta = load_model(args.source, device, phase="eval", model_tag=args.model_tag, step=args.step)
engine = Engine(model, tokenizer)

View File

@ -15,8 +15,9 @@ os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
import wandb
import torch
import torch.distributed as dist
from contextlib import nullcontext
from nanochat.common import compute_init, compute_cleanup, get_base_dir, print0, DummyWandb
from nanochat.common import compute_init, compute_cleanup, get_base_dir, print0, DummyWandb, autodetect_device_type
from nanochat.checkpoint_manager import load_model
from nanochat.checkpoint_manager import save_checkpoint
from nanochat.engine import Engine
@ -36,11 +37,12 @@ source = "mid" # base|mid , which checkpoint to load the model from (base model
model_tag = None # model tag to load the model from (base model or midtrained model)
step = None # step to load the model from (base model or midtrained model)
# compute/precision
device_type = "" # cuda|cpu|mps (empty => autodetect)
dtype = "bfloat16"
device_batch_size = 4 # max to avoid OOM
# optimization
num_epochs = 1
max_iterations = -1 # override number of iterations (-1 = use num_epochs * num_iterations)
num_iterations = -1 # override number of iterations (-1 = disable, use num_epochs to derive it)
target_examples_per_step = 32
unembedding_lr = 0.004
embedding_lr = 0.2
@ -51,6 +53,7 @@ init_lr_frac = 0.02
eval_every = 100
eval_steps = 100
eval_metrics_every = 200
eval_metrics_max_problems = 1024
# now allow CLI to override the settings via the configurator lol
config_keys = [k for k,v in globals().items() if not k.startswith('_') and isinstance(v, (int, float, bool, str))]
exec(open(os.path.join('nanochat', 'configurator.py')).read()) # overrides from command line or config file
@ -58,10 +61,11 @@ user_config = {k: globals()[k] for k in config_keys} # possibly useful for loggi
# -----------------------------------------------------------------------------
# Compute init
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init()
device_type = autodetect_device_type() if device_type == "" else device_type
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
master_process = ddp_rank == 0
dtype = torch.float32 if dtype == 'float32' else torch.bfloat16
autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=dtype)
ptdtype = torch.float32 if dtype == 'float32' else torch.bfloat16
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype) if device_type == "cuda" else nullcontext()
# wandb logging init
use_dummy_wandb = run == "dummy" or not master_process
@ -128,10 +132,10 @@ assert target_examples_per_step % examples_per_step == 0, "Target examples per s
grad_accum_steps = target_examples_per_step // examples_per_step
print0(f"=> Setting grad accum steps: {grad_accum_steps}")
num_iterations = (len(train_ds) // target_examples_per_step) * num_epochs
if max_iterations >= 0 and num_iterations > max_iterations:
print0(f"Number of iterations is too high: {num_iterations}, capping to {max_iterations}")
num_iterations = max_iterations
if num_iterations == -1:
# derive num_iterations from num_epochs and the size of the dataset
assert num_epochs > 0, "num_epochs must be positive if num_iterations is -1"
num_iterations = (len(train_ds) // target_examples_per_step) * num_epochs
train_loader = sft_data_generator(train_ds, batch_size=device_batch_size)
build_val_loader = lambda: sft_data_generator(val_ds, batch_size=device_batch_size)
@ -191,8 +195,8 @@ for step in range(num_iterations):
metrics = {}
with torch.no_grad(), autocast_ctx:
# note that because these are inside no_grad, we can usually afford to at least ~2X the batch size
metrics["mmlu_acc"] = run_chat_eval("MMLU", model, tokenizer, engine, batch_size=device_batch_size*2, max_problems=1024)
metrics["arc_easy_acc"] = run_chat_eval("ARC-Easy", model, tokenizer, engine, batch_size=device_batch_size*2, max_problems=1024)
metrics["mmlu_acc"] = run_chat_eval("MMLU", model, tokenizer, engine, batch_size=device_batch_size*2, max_problems=eval_metrics_max_problems)
metrics["arc_easy_acc"] = run_chat_eval("ARC-Easy", model, tokenizer, engine, batch_size=device_batch_size*2, max_problems=eval_metrics_max_problems)
metrics_str = ', '.join(f'{k}: {v:.6f}' for k, v in metrics.items())
print0(f"Step {step:05d} | {metrics_str}")
wandb_run.log({

View File

@ -44,8 +44,8 @@ from fastapi.responses import StreamingResponse, HTMLResponse, FileResponse
from pydantic import BaseModel
from typing import List, Optional, AsyncGenerator
from dataclasses import dataclass
from nanochat.common import compute_init
from contextlib import nullcontext
from nanochat.common import compute_init, autodetect_device_type
from nanochat.checkpoint_manager import load_model
from nanochat.engine import Engine
@ -69,6 +69,8 @@ parser.add_argument('-m', '--max-tokens', type=int, default=512, help='Default m
parser.add_argument('-g', '--model-tag', type=str, default=None, help='Model tag to load')
parser.add_argument('-s', '--step', type=int, default=None, help='Step to load')
parser.add_argument('-p', '--port', type=int, default=8000, help='Port to run the server on')
parser.add_argument('-d', '--dtype', type=str, default='bfloat16', choices=['float32', 'bfloat16'])
parser.add_argument('--device-type', type=str, default='', choices=['cuda', 'cpu', 'mps'], help='Device type for evaluation: cuda|cpu|mps. empty => autodetect')
parser.add_argument('--host', type=str, default='0.0.0.0', help='Host to bind the server to')
args = parser.parse_args()
@ -80,7 +82,9 @@ logging.basicConfig(
)
logger = logging.getLogger(__name__)
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init()
device_type = autodetect_device_type() if args.device_type == "" else args.device_type
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
ptdtype = torch.float32 if args.dtype == 'float32' else torch.bfloat16
@dataclass
class Worker:
@ -95,21 +99,33 @@ class WorkerPool:
"""Pool of workers, each with a model replica on a different GPU."""
def __init__(self, num_gpus: Optional[int] = None):
self.num_gpus = num_gpus if num_gpus is not None else torch.cuda.device_count()
if num_gpus is None:
if device_type == "cuda":
num_gpus = torch.cuda.device_count()
else:
num_gpus = 1 # e.g. cpu|mps
self.num_gpus = num_gpus
self.workers: List[Worker] = []
self.available_workers: asyncio.Queue = asyncio.Queue()
async def initialize(self, source: str, model_tag: Optional[str] = None, step: Optional[int] = None):
"""Load model on each GPU."""
print(f"Initializing worker pool with {self.num_gpus} GPUs...")
if self.num_gpus > 1:
assert device_type == "cuda", "Only CUDA supports multiple workers/GPUs. cpu|mps does not."
for gpu_id in range(self.num_gpus):
if device_type == "cuda":
device = torch.device(f"cuda:{gpu_id}")
print(f"Loading model on GPU {gpu_id}...")
else:
device = torch.device(device_type) # e.g. cpu|mps
print(f"Loading model on {device_type}...")
model, tokenizer, _ = load_model(source, device, phase="eval", model_tag=model_tag, step=step)
engine = Engine(model, tokenizer)
autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16)
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype) if device_type == "cuda" else nullcontext()
worker = Worker(
gpu_id=gpu_id,

View File

@ -15,8 +15,8 @@ os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
import time
import wandb
import torch
from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, get_base_dir
from contextlib import nullcontext
from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, get_base_dir, autodetect_device_type
from nanochat.tokenizer import get_token_bytes
from nanochat.checkpoint_manager import save_checkpoint
from nanochat.loss_eval import evaluate_bpb
@ -31,9 +31,11 @@ from tasks.customjson import CustomJSON
# -----------------------------------------------------------------------------
run = "dummy" # wandb run name default ("dummy" is special - we won't log to wandb)
device_type = "" # cuda|cpu|mps (empty => autodetect)
model_tag = None # model tag to load the model from (base model or midtrained model)
step = None # step to load the model from (base model or midtrained model)
dtype = "bfloat16"
num_iterations = -1 # explicit number of steps of the optimization (-1 = disable)
max_seq_len = 2048
device_batch_size = 32
unembedding_lr = 0.004
@ -41,7 +43,7 @@ embedding_lr = 0.2
matrix_lr = 0.02
init_lr_frac = 1.0 # initial learning rate is this fraction of the base learning rate
weight_decay = 0.0
eval_every = 150
eval_every = 150 # -1 = disable
eval_tokens = 20*524288
total_batch_size = 524288
dry_run = 0 # dry_run=1 is for experiments: we will log to wandb but we won't write checkpoints or report
@ -51,10 +53,12 @@ user_config = {k: globals()[k] for k in config_keys} # possibly useful for loggi
# -----------------------------------------------------------------------------
# Compute init
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init()
device_type = autodetect_device_type() if device_type == "" else device_type
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
master_process = ddp_rank == 0
dtype = torch.float32 if dtype == 'float32' else torch.bfloat16
autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=dtype)
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type == "cuda" else nullcontext()
synchronize = torch.cuda.synchronize if device_type == "cuda" else lambda: None
get_max_memory = torch.cuda.max_memory_allocated if device_type == "cuda" else lambda: 0
# wandb logging init
use_dummy_wandb = run == "dummy" or not master_process
@ -117,6 +121,7 @@ def mid_data_generator(split):
token_buffer = deque()
scratch = torch.empty(needed_tokens, dtype=torch.int64, pin_memory=True)
cursor = ddp_rank # increments by ddp_world_size each time, so each rank processes unique documents
it = 0 # iteration counter
while True:
# Accumulate enough tokens for one iteration before yielding
while len(token_buffer) < needed_tokens:
@ -128,6 +133,10 @@ def mid_data_generator(split):
cursor -= dataset_size # wrap around for another epoch
if split == "train":
last_step = True # toggle last_step to True, which will terminate the training loop
# Stopping condition to respect num_iterations, if given
it += 1
if num_iterations > 0 and it >= num_iterations:
last_step = True # toggle last_step to True, which will terminate the training loop
# Build up inputs/targets and yield
for i in range(needed_tokens):
scratch[i] = token_buffer.popleft()
@ -136,6 +145,9 @@ def mid_data_generator(split):
inputs = inputs_cpu.view(device_batch_size, max_seq_len).to(device=device, dtype=torch.int32, non_blocking=True)
targets = targets_cpu.view(device_batch_size, max_seq_len).to(device=device, dtype=torch.int64, non_blocking=True)
if split == "train":
if num_iterations > 0:
approx_progress = it / num_iterations # calculate progress from the max number of iterations
else:
approx_progress = cursor / dataset_size # approximate progress as a fraction of the dataset
yield inputs, targets
@ -172,7 +184,7 @@ while True:
last_step = bool(last_step_tensor.item())
# once in a while: evaluate the val bpb (all ranks participate)
if last_step or step % eval_every == 0:
if eval_every > 0 and (last_step or step % eval_every == 0):
model.eval()
val_loader = build_val_loader()
eval_steps = eval_tokens // (device_batch_size * max_seq_len * ddp_world_size)
@ -219,7 +231,7 @@ while True:
# -------------------------------------------------------------------------
# single training step
# evaluate the gradient
torch.cuda.synchronize()
synchronize()
t0 = time.time()
for micro_step in range(grad_accum_steps):
with autocast_ctx:
@ -240,7 +252,7 @@ while True:
for opt in optimizers:
opt.step()
model.zero_grad(set_to_none=True)
torch.cuda.synchronize()
synchronize()
t1 = time.time()
dt = t1 - t0
# -------------------------------------------------------------------------
@ -272,7 +284,7 @@ while True:
})
# print a few more stats
print0(f"Peak memory usage: {torch.cuda.max_memory_allocated() / 1024 / 1024:.2f}MiB")
print0(f"Peak memory usage: {get_max_memory() / 1024 / 1024:.2f}MiB")
print0(f"Total training time: {total_training_time/60:.2f}m")
print0(f"Minimum validation bpb: {min_val_bpb:.4f}")

View File

@ -2002,3 +2002,4 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/94/c3/b2e9f38bc3e11191981d57ea08cab2166e74ea770024a646617c9cddd9f6/yarl-1.20.1-cp313-cp313t-win_amd64.whl", hash = "sha256:541d050a355bbbc27e55d906bc91cb6fe42f96c01413dd0f4ed5a5240513874f", size = 93003, upload-time = "2025-06-10T00:45:27.752Z" },
{ url = "https://files.pythonhosted.org/packages/b4/2d/2345fce04cfd4bee161bf1e7d9cdc702e3e16109021035dbb24db654a622/yarl-1.20.1-py3-none-any.whl", hash = "sha256:83b8eb083fe4683c6115795d9fc1cfaf2cbbefb19b3a1cb68f6527460f483a77", size = 46542, upload-time = "2025-06-10T00:46:07.521Z" },
]