From 70319851fc960bc472ac7cfe9518c9478ada402e Mon Sep 17 00:00:00 2001 From: svlandeg Date: Wed, 29 Oct 2025 19:48:34 +0100 Subject: [PATCH 01/12] fix typo --- scripts/base_eval.py | 2 +- scripts/chat_sft.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/scripts/base_eval.py b/scripts/base_eval.py index 8efde4f..3d403cc 100644 --- a/scripts/base_eval.py +++ b/scripts/base_eval.py @@ -1,5 +1,5 @@ """ -Evlauate the CORE metric for a given model. +Evaluate the CORE metric for a given model. Run on a single GPU: python base_eval.py diff --git a/scripts/chat_sft.py b/scripts/chat_sft.py index e6e4565..bbeb1f9 100644 --- a/scripts/chat_sft.py +++ b/scripts/chat_sft.py @@ -192,7 +192,7 @@ for step in range(num_iterations): }) model.train() - # evlauate accuracy of the multiple choice tasks (which are quick to run) + # evaluate accuracy of the multiple choice tasks (which are quick to run) if last_step or (step > 0 and step % eval_metrics_every == 0): model.eval() metrics = {} From b399e431681d61dcced768c062b13a9089c0c21c Mon Sep 17 00:00:00 2001 From: "howardgao@outlook.com" Date: Thu, 6 Nov 2025 08:56:45 +0800 Subject: [PATCH 02/12] fix engine test bug --- nanochat/engine.py | 28 +++++++++++++++++----------- 1 file changed, 17 insertions(+), 11 deletions(-) diff --git a/nanochat/engine.py b/nanochat/engine.py index 916a9cf..da85085 100644 --- a/nanochat/engine.py +++ b/nanochat/engine.py @@ -17,8 +17,9 @@ import signal import warnings from contextlib import contextmanager from collections import deque -from nanochat.common import compute_init +from nanochat.common import compute_init, autodetect_device_type from nanochat.checkpoint_manager import load_model +from contextlib import nullcontext # ----------------------------------------------------------------------------- # Calculator tool helpers @@ -327,8 +328,11 @@ if __name__ == "__main__": import time # init compute ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init() + device_type = autodetect_device_type() + autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type == "cuda" else nullcontext() + # load the model and tokenizer - model, tokenizer, meta = load_model("base", device, phase="eval") + model, tokenizer, meta = load_model("sft", device, phase="eval") bos_token_id = tokenizer.get_bos_token_id() # common hyperparameters kwargs = dict(max_tokens=64, temperature=0.0) @@ -339,10 +343,11 @@ if __name__ == "__main__": torch.cuda.synchronize() t0 = time.time() stream = model.generate(prompt_tokens, **kwargs) - for token in stream: - generated_tokens.append(token) - chunk = tokenizer.decode([token]) - print(chunk, end="", flush=True) + with autocast_ctx: + for token in stream: + generated_tokens.append(token) + chunk = tokenizer.decode([token]) + print(chunk, end="", flush=True) print() torch.cuda.synchronize() t1 = time.time() @@ -354,11 +359,12 @@ if __name__ == "__main__": stream = engine.generate(prompt_tokens, num_samples=1, **kwargs) # note: runs in fp32 torch.cuda.synchronize() t0 = time.time() - for token_column, token_masks in stream: - token = token_column[0] # only print out the first row - generated_tokens.append(token) - chunk = tokenizer.decode([token]) - print(chunk, end="", flush=True) + with autocast_ctx: + for token_column, token_masks in stream: + token = token_column[0] # only print out the first row + generated_tokens.append(token) + chunk = tokenizer.decode([token]) + print(chunk, end="", flush=True) print() torch.cuda.synchronize() t1 = time.time() From adb5d4a16c0a8dd9d50e05176a2cac08931562bc Mon Sep 17 00:00:00 2001 From: Andrej Karpathy Date: Thu, 13 Nov 2025 15:16:27 +0000 Subject: [PATCH 03/12] uv lock has to change when we removed numpy the other commit --- uv.lock | 18 ++++++++---------- 1 file changed, 8 insertions(+), 10 deletions(-) diff --git a/uv.lock b/uv.lock index f01bba3..4e9b0bd 100644 --- a/uv.lock +++ b/uv.lock @@ -311,7 +311,7 @@ name = "exceptiongroup" version = "1.3.0" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "typing-extensions", marker = "python_full_version < '3.12' or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" }, + { name = "typing-extensions", marker = "python_full_version < '3.11' or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" }, ] sdist = { url = "https://files.pythonhosted.org/packages/0b/9f/a65090624ecf468cdca03533906e7c69ed7588582240cfe7cc9e770b50eb/exceptiongroup-1.3.0.tar.gz", hash = "sha256:b241f5885f560bc56a59ee63ca4c6a8bfa46ae4ad651af316d4e81817bb9fd88", size = 29749, upload-time = "2025-05-10T17:42:51.123Z" } wheels = [ @@ -777,7 +777,6 @@ dependencies = [ { name = "datasets" }, { name = "fastapi" }, { name = "files-to-prompt" }, - { name = "numpy" }, { name = "psutil" }, { name = "regex" }, { name = "setuptools" }, @@ -811,7 +810,6 @@ requires-dist = [ { name = "datasets", specifier = ">=4.0.0" }, { name = "fastapi", specifier = ">=0.117.1" }, { name = "files-to-prompt", specifier = ">=0.6" }, - { name = "numpy", specifier = "==1.26.4" }, { name = "psutil", specifier = ">=7.1.0" }, { name = "regex", specifier = ">=2025.9.1" }, { name = "setuptools", specifier = ">=80.9.0" }, @@ -951,7 +949,7 @@ name = "nvidia-cudnn-cu12" version = "9.10.2.21" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "nvidia-cublas-cu12", marker = "extra == 'extra-8-nanochat-gpu'" }, + { name = "nvidia-cublas-cu12", marker = "(sys_platform == 'linux' and extra == 'extra-8-nanochat-gpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/fa/41/e79269ce215c857c935fd86bcfe91a451a584dfc27f1e068f568b9ad1ab7/nvidia_cudnn_cu12-9.10.2.21-py3-none-manylinux_2_27_aarch64.whl", hash = "sha256:c9132cc3f8958447b4910a1720036d9eff5928cc3179b0a51fb6d167c6cc87d8", size = 705026878, upload-time = "2025-06-06T21:52:51.348Z" }, @@ -964,7 +962,7 @@ name = "nvidia-cufft-cu12" version = "11.3.3.83" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "nvidia-nvjitlink-cu12", marker = "extra == 'extra-8-nanochat-gpu'" }, + { name = "nvidia-nvjitlink-cu12", marker = "(sys_platform == 'linux' and extra == 'extra-8-nanochat-gpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/60/bc/7771846d3a0272026c416fbb7e5f4c1f146d6d80704534d0b187dd6f4800/nvidia_cufft_cu12-11.3.3.83-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:848ef7224d6305cdb2a4df928759dca7b1201874787083b6e7550dd6765ce69a", size = 193109211, upload-time = "2025-03-07T01:44:56.873Z" }, @@ -996,9 +994,9 @@ name = "nvidia-cusolver-cu12" version = "11.7.3.90" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "nvidia-cublas-cu12", marker = "extra == 'extra-8-nanochat-gpu'" }, - { name = "nvidia-cusparse-cu12", marker = "extra == 'extra-8-nanochat-gpu'" }, - { name = "nvidia-nvjitlink-cu12", marker = "extra == 'extra-8-nanochat-gpu'" }, + { name = "nvidia-cublas-cu12", marker = "(sys_platform == 'linux' and extra == 'extra-8-nanochat-gpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" }, + { name = "nvidia-cusparse-cu12", marker = "(sys_platform == 'linux' and extra == 'extra-8-nanochat-gpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" }, + { name = "nvidia-nvjitlink-cu12", marker = "(sys_platform == 'linux' and extra == 'extra-8-nanochat-gpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/c8/32/f7cd6ce8a7690544d084ea21c26e910a97e077c9b7f07bf5de623ee19981/nvidia_cusolver_cu12-11.7.3.90-py3-none-manylinux_2_27_aarch64.whl", hash = "sha256:db9ed69dbef9715071232caa9b69c52ac7de3a95773c2db65bdba85916e4e5c0", size = 267229841, upload-time = "2025-03-07T01:46:54.356Z" }, @@ -1011,7 +1009,7 @@ name = "nvidia-cusparse-cu12" version = "12.5.8.93" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "nvidia-nvjitlink-cu12", marker = "extra == 'extra-8-nanochat-gpu'" }, + { name = "nvidia-nvjitlink-cu12", marker = "(sys_platform == 'linux' and extra == 'extra-8-nanochat-gpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/bc/f7/cd777c4109681367721b00a106f491e0d0d15cfa1fd59672ce580ce42a97/nvidia_cusparse_cu12-12.5.8.93-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:9b6c161cb130be1a07a27ea6923df8141f3c295852f4b260c65f18f3e0a091dc", size = 288117129, upload-time = "2025-03-07T01:47:40.407Z" }, @@ -1955,7 +1953,7 @@ name = "triton" version = "3.4.0" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "setuptools", marker = "extra == 'extra-8-nanochat-gpu'" }, + { name = "setuptools", marker = "(sys_platform == 'linux' and extra == 'extra-8-nanochat-gpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/62/ee/0ee5f64a87eeda19bbad9bc54ae5ca5b98186ed00055281fd40fb4beb10e/triton-3.4.0-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:7ff2785de9bc02f500e085420273bb5cc9c9bb767584a4aa28d6e360cec70128", size = 155430069, upload-time = "2025-07-30T19:58:21.715Z" }, From 91f09ccd0d48daf89eee6ef7fcec05977fd87068 Mon Sep 17 00:00:00 2001 From: Andrej Karpathy Date: Thu, 13 Nov 2025 15:28:18 +0000 Subject: [PATCH 04/12] minor fix comment in engine --- nanochat/engine.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/nanochat/engine.py b/nanochat/engine.py index 916a9cf..1d541c7 100644 --- a/nanochat/engine.py +++ b/nanochat/engine.py @@ -107,8 +107,9 @@ class KVCache: assert self.kv_cache is None, "Cannot prefill a non-empty KV cache" assert other.kv_cache is not None, "Cannot prefill with a None KV cache" for ix, (dim1, dim2) in enumerate(zip(self.kv_shape, other.kv_shape)): + # ix 0: num_layers, 1: k/v, 2: batch_size, 3: num_heads, 4: seq_len, 5: head_dim if ix in [0, 1, 3, 5]: - # num_layers, batch_size, num_heads, head_dim must match + # num_layers, k/v, num_heads, head_dim must match assert dim1 == dim2, f"Dim {ix} mismatch: {dim1} != {dim2}" elif ix == 2: # batch_size can be expanded From c6abcdfe3a23f3cc3656e4132a606e8753415fca Mon Sep 17 00:00:00 2001 From: Andrej Karpathy Date: Thu, 13 Nov 2025 15:34:40 +0000 Subject: [PATCH 05/12] big change: add pretraining resumption logic so that checkpoints can now be approximately resumed and training can continue. this is useful for very long runs when you don't want the anxiety of your run crashing for some reason. alternatively, it's a way to recover training in the event of loss spikes. i mean, this should have been there in v0 but it's ok. the resumption is approximate to control complexity and bloat, but it's possible we want to change that in the future. to use, set --save_every to a step interval to write checkpoints with, and then use --resume_from_step to resume optimization from a given step. only base model training (pretraining) supports this atm, but it's ok because midtraining is comparably quite a bit faster. --- nanochat/checkpoint_manager.py | 35 +++++++------ nanochat/common.py | 2 + nanochat/dataloader.py | 82 ++++++++++++++++++++++--------- scripts/base_train.py | 89 +++++++++++++++++++++++++--------- 4 files changed, 145 insertions(+), 63 deletions(-) diff --git a/nanochat/checkpoint_manager.py b/nanochat/checkpoint_manager.py index b7d2191..63f257f 100644 --- a/nanochat/checkpoint_manager.py +++ b/nanochat/checkpoint_manager.py @@ -20,33 +20,32 @@ def log0(message): if int(os.environ.get('RANK', 0)) == 0: logger.info(message) -def save_checkpoint(checkpoint_dir, step, model_data, optimizer_data, meta_data): - assert int(os.environ.get('RANK', 0)) == 0 # prevent footguns for now - os.makedirs(checkpoint_dir, exist_ok=True) - # Save the model state (parameters) - model_path = os.path.join(checkpoint_dir, f"model_{step:06d}.pt") - torch.save(model_data, model_path) - log0(f"Saved model file to: {model_path}") - # Save the optimizer state (useful for SFT or any other fine-tuning) +def save_checkpoint(checkpoint_dir, step, model_data, optimizer_data, meta_data, rank=0): + if rank == 0: + os.makedirs(checkpoint_dir, exist_ok=True) + # Save the model state parameters + model_path = os.path.join(checkpoint_dir, f"model_{step:06d}.pt") + torch.save(model_data, model_path) + logger.info(f"Saved model parameters to: {model_path}") + # Save the metadata dict as json + meta_path = os.path.join(checkpoint_dir, f"meta_{step:06d}.json") + with open(meta_path, "w", encoding="utf-8") as f: + json.dump(meta_data, f, indent=2) + logger.info(f"Saved metadata to: {meta_path}") + # Note that optimizer state is sharded across ranks, so each rank must save its own. if optimizer_data is not None: - optimizer_path = os.path.join(checkpoint_dir, f"optim_{step:06d}.pt") + optimizer_path = os.path.join(checkpoint_dir, f"optim_{step:06d}_rank{rank:d}.pt") torch.save(optimizer_data, optimizer_path) - log0(f"Saved optimizer file to: {optimizer_path}") - # Save the metadata dict as json - meta_path = os.path.join(checkpoint_dir, f"meta_{step:06d}.json") - with open(meta_path, "w", encoding="utf-8") as f: - json.dump(meta_data, f, indent=2) - log0(f"Saved metadata file to: {meta_path}") + logger.info(f"Saved optimizer state to: {optimizer_path}") - -def load_checkpoint(checkpoint_dir, step, device, load_optimizer=False): +def load_checkpoint(checkpoint_dir, step, device, load_optimizer=False, rank=0): # Load the model state model_path = os.path.join(checkpoint_dir, f"model_{step:06d}.pt") model_data = torch.load(model_path, map_location=device) # Load the optimizer state if requested optimizer_data = None if load_optimizer: - optimizer_path = os.path.join(checkpoint_dir, f"optim_{step:06d}.pt") + optimizer_path = os.path.join(checkpoint_dir, f"optim_{step:06d}_rank{rank:d}.pt") optimizer_data = torch.load(optimizer_path, map_location=device) # Load the metadata meta_path = os.path.join(checkpoint_dir, f"meta_{step:06d}.json") diff --git a/nanochat/common.py b/nanochat/common.py index d4a9828..8f36f94 100644 --- a/nanochat/common.py +++ b/nanochat/common.py @@ -148,6 +148,8 @@ def compute_init(device_type="cuda"): # cuda|cpu|mps assert torch.backends.mps.is_available(), "Your PyTorch installation is not configured for MPS but device_type is 'mps'" # Reproducibility + # Note that we set the global seeds here, but most of the code uses explicit rng objects. + # The only place where global rng might be used is nn.Module initialization of the model weights. torch.manual_seed(42) if device_type == "cuda": torch.cuda.manual_seed(42) diff --git a/nanochat/dataloader.py b/nanochat/dataloader.py index 6c864d3..3271298 100644 --- a/nanochat/dataloader.py +++ b/nanochat/dataloader.py @@ -1,49 +1,87 @@ from collections import deque import torch +import pyarrow.parquet as pq from nanochat.common import get_dist_info -from nanochat.dataset import parquets_iter_batched +from nanochat.dataset import list_parquet_files from nanochat.tokenizer import get_tokenizer -def tokenizing_distributed_data_loader(B, T, split, tokenizer_threads=4, tokenizer_batch_size=128, device="cuda"): - """Stream pretraining text from parquet files, tokenize, yield training batches.""" +def tokenizing_distributed_data_loader_with_state(B, T, split, tokenizer_threads=4, tokenizer_batch_size=128, device="cuda", resume_state_dict=None): + """ + Stream pretraining text from parquet files, tokenize, yield training batches. + + This implementation became a bit more complex because we wish to support approximate resume training. + Instead of turning this into a Class, we opt to return the state_dict with every batch, + and then the caller can pass in a state_dict to resume training from a desired point. + Note that this resumption is atm only *approximate* for simplicity. + We won't repeat the same documents but we might skip a few. + The state_dict that is returned can be later passed into this function via `resume_state_dict` to approximately resume. + + Perfect state resumption is possible but would be a lot more bloated, probably not worth it atm. + """ assert split in ["train", "val"], "split must be 'train' or 'val'" + + # infinite iterator over document batches (list of text strings) ddp, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info() + def document_batches(): + parquet_paths = list_parquet_files() + parquet_paths = parquet_paths[:-1] if split == "train" else parquet_paths[-1:] + resume_pq_idx = resume_state_dict["pq_idx"] if resume_state_dict is not None else 0 + resume_rg_idx = resume_state_dict["rg_idx"] if resume_state_dict is not None else None + pq_idx = resume_pq_idx # we kick off parquet files at the resume index (or by default just 0) + while True: # iterate infinitely (multi-epoch) + while pq_idx < len(parquet_paths): # iterate over all parquet files + filepath = parquet_paths[pq_idx] + pf = pq.ParquetFile(filepath) + # Start from resume point if resuming on same file, otherwise from DDP rank + # I know this state resumption is a little bit tricky and a little bit hacky... sigh. + if resume_rg_idx is not None: + base_idx = resume_rg_idx // ddp_world_size # in units of ddp_world_size + base_idx += 1 # advance by 1 so that we definitely don't repeat data after resuming + rg_idx = base_idx * ddp_world_size + ddp_rank + resume_rg_idx = None # set to None as we only want to do this a single time + else: + rg_idx = ddp_rank + while rg_idx < pf.num_row_groups: + rg = pf.read_row_group(rg_idx) + batch = rg.column('text').to_pylist() # each batch is a parquet group, e.g. 1024 rows + # the tokenizer encode might want to go in even smaller batches, e.g. 128 rows + for i in range(0, len(batch), tokenizer_batch_size): + yield batch[i:i+tokenizer_batch_size], (pq_idx, rg_idx) + rg_idx += ddp_world_size # advance to the next row group (in DDP) + pq_idx += 1 # advance to the next parquet file + batches = document_batches() + + # Now emit batches of tokens. needed_tokens = B * T + 1 # +1 is because we also need the target at the last token # get the tokenizer and the bos token tokenizer = get_tokenizer() bos_token = tokenizer.get_bos_token_id() # scratch buffer holds the tokens for one iteration token_buffer = deque() # we stream tokens on the right and pop from the left - - # infinite iterator over document batches - def document_batches(): - while True: - # batch will iterate in group size of the parquet files, usually e.g. 1024 rows - for batch in parquets_iter_batched(split=split, start=ddp_rank, step=ddp_world_size): - # for the tokenizer we might want to go in usually smaller batches, e.g. 128 rows - for i in range(0, len(batch), tokenizer_batch_size): - yield batch[i:i+tokenizer_batch_size] - batches = document_batches() - - batch_index = 0 while True: # Accumulate enough tokens for one iteration before yielding. while len(token_buffer) < needed_tokens: - doc_batch = next(batches) + doc_batch, (pq_idx, rg_idx) = next(batches) token_lists = tokenizer.encode(doc_batch, prepend=bos_token, num_threads=tokenizer_threads) for tokens in token_lists: token_buffer.extend(tokens) - batch_index += 1 # Move tokens from the deque into the scratch buffer tokens = [token_buffer.popleft() for _ in range(needed_tokens)] - # CUDA supports memory pinning for faster transfers between CPU and GPU: - scratch = torch.tensor(tokens, dtype=torch.int64, pin_memory=(device == "cuda")) + # CUDA supports memory pinning for asynchronous transfers between CPU and GPU + use_cuda_optimizations = device == "cuda" + scratch = torch.tensor(tokens, dtype=torch.long, pin_memory=use_cuda_optimizations) # in PyTorch, long=int64 # Create the inputs/targets as 1D tensors - inputs_cpu = scratch[:-1].to(dtype=torch.int32) + inputs_cpu = scratch[:-1] targets_cpu = scratch[1:] # Reshape to 2D and move to GPU async - inputs = inputs_cpu.view(B, T).to(device=device, dtype=torch.int32, non_blocking=True) - targets = targets_cpu.view(B, T).to(device=device, dtype=torch.int64, non_blocking=True) + inputs = inputs_cpu.view(B, T).to(device=device, non_blocking=use_cuda_optimizations) + targets = targets_cpu.view(B, T).to(device=device, non_blocking=use_cuda_optimizations) + state_dict = {"pq_idx": pq_idx, "rg_idx": rg_idx} # we need this in case we wish to approximately resume training + yield inputs, targets, state_dict + +def tokenizing_distributed_data_loader(*args, **kwargs): + # helper function that only emits the inputs/targets and not the state_dict + for inputs, targets, state_dict in tokenizing_distributed_data_loader_with_state(*args, **kwargs): yield inputs, targets diff --git a/scripts/base_train.py b/scripts/base_train.py index 594c709..c9ea6c9 100644 --- a/scripts/base_train.py +++ b/scripts/base_train.py @@ -20,10 +20,10 @@ import wandb import torch from nanochat.gpt import GPT, GPTConfig -from nanochat.dataloader import tokenizing_distributed_data_loader +from nanochat.dataloader import tokenizing_distributed_data_loader, tokenizing_distributed_data_loader_with_state from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, print_banner, get_base_dir, autodetect_device_type from nanochat.tokenizer import get_tokenizer, get_token_bytes -from nanochat.checkpoint_manager import save_checkpoint +from nanochat.checkpoint_manager import save_checkpoint, load_checkpoint from nanochat.loss_eval import evaluate_bpb from nanochat.engine import Engine from scripts.base_eval import evaluate_model @@ -52,12 +52,14 @@ grad_clip = 1.0 # gradient clipping value (0.0 = disabled) warmup_ratio = 0.0 # ratio of iterations for LR warmup warmdown_ratio = 0.2 # ratio of iterations for LR warmdown final_lr_frac = 0.0 # final LR is this fraction of the initial LR +resume_from_step = -1 # resume training from this step of the optimization (-1 = disable) # Evaluation eval_every = 250 # every how many steps to evaluate the model for val bpb eval_tokens = 20*524288 # number of tokens to evaluate val loss on core_metric_every = 2000 # every how many steps to evaluate the core metric (-1 = disable) core_metric_max_per_task = 500 # examples per task in estimating the core metric sample_every = 2000 # every how many steps to sample from the model +save_every = -1 # every how many steps to save model checkpoints (-1 = disable, and save only at the end of the run) # Output model_tag = "" # optionally override the model tag for the output checkpoint directory name # now allow CLI to override the settings via the configurator lol @@ -103,16 +105,31 @@ grad_accum_steps = total_batch_size // world_tokens_per_fwdbwd print0(f"Tokens / micro-batch / rank: {device_batch_size} x {max_seq_len} = {tokens_per_fwdbwd:,}") print0(f"Tokens / micro-batch: {world_tokens_per_fwdbwd:,}") print0(f"Total batch size {total_batch_size:,} => gradient accumulation steps: {grad_accum_steps}") + # ----------------------------------------------------------------------------- # Initialize the Model + +# Create a new model with random weights model_config_kwargs = dict(sequence_len=max_seq_len, vocab_size=vocab_size, n_layer=num_layers, n_head=num_heads, n_kv_head=num_kv_heads, n_embd=model_dim) with torch.device("meta"): model_config = GPTConfig(**model_config_kwargs) model = GPT(model_config) model.to_empty(device=device) model.init_weights() -orig_model = model # original, uncompiled model, for saving raw model state_dict -model = torch.compile(model, dynamic=False) # TODO: dynamic True/False think through + +# If we are resuming, overwrite the model parameters with those of the checkpoint +base_dir = get_base_dir() +output_dirname = model_tag if model_tag else f"d{depth}" # e.g. d12 +checkpoint_dir = os.path.join(base_dir, "base_checkpoints", output_dirname) +resuming = resume_from_step != -1 +if resuming: + print0(f"Resuming optimization from step {resume_from_step}") + model_data, optimizer_data, meta_data = load_checkpoint(checkpoint_dir, resume_from_step, device, load_optimizer=True, rank=ddp_rank) + model.load_state_dict(model_data, strict=True, assign=True) + del model_data # free up this memory after the copy + +orig_model = model # original, uncompiled model, for saving raw model state_dict and for inference/evaluation (because the shapes may change shape) +model = torch.compile(model, dynamic=False) # the inputs to model will never change shape so dynamic=False is safe num_params = sum(p.numel() for p in model.parameters()) print0(f"Number of parameters: {num_params:,}") num_flops_per_token = model.estimate_flops() @@ -143,12 +160,18 @@ print0(f"Total training FLOPs estimate: {num_flops_per_token * total_tokens:e}") optimizers = model.setup_optimizers(unembedding_lr=unembedding_lr, embedding_lr=embedding_lr, matrix_lr=matrix_lr, weight_decay=weight_decay) adamw_optimizer, muon_optimizer = optimizers +if resuming: + for opt, dat in zip(optimizers, optimizer_data): + opt.load_state_dict(dat) + del optimizer_data # free up the memory + +# ----------------------------------------------------------------------------- # Initialize the DataLoaders for train/val -base_dir = get_base_dir() tokens_dir = os.path.join(base_dir, "tokenized_data") -train_loader = tokenizing_distributed_data_loader(device_batch_size, max_seq_len, split="train", device=device) +dataloader_resume_state_dict = None if not resuming else meta_data["dataloader_state_dict"] +train_loader = tokenizing_distributed_data_loader_with_state(device_batch_size, max_seq_len, split="train", device=device, resume_state_dict=dataloader_resume_state_dict) build_val_loader = lambda: tokenizing_distributed_data_loader(device_batch_size, max_seq_len, split="val", device=device) -x, y = next(train_loader) # kick off load of the very first batch of data +x, y, dataloader_state_dict = next(train_loader) # kick off load of the very first batch of data # ----------------------------------------------------------------------------- # Set up hyperparameter schedulers @@ -171,15 +194,25 @@ def get_muon_momentum(it): momentum = (1 - frac) * 0.85 + frac * 0.95 return momentum +# ----------------------------------------------------------------------------- +# Loop state (variables updated by the training loop) + +if not resuming: + step = 0 + min_val_bpb = float("inf") + smooth_train_loss = 0 # EMA of training loss + total_training_time = 0 # total wall-clock time of training +else: + step = meta_data["step"] + loop_state = meta_data["loop_state"] + min_val_bpb = loop_state["min_val_bpb"] + smooth_train_loss = loop_state["smooth_train_loss"] + total_training_time = loop_state["total_training_time"] + # ----------------------------------------------------------------------------- # Training loop -min_val_bpb = float("inf") -smooth_train_loss = 0 # EMA of training loss -ema_beta = 0.9 # EMA decay factor -total_training_time = 0 # total wall-clock time of training -# note that we run +1 steps only so that we can eval and save at the end -for step in range(num_iterations + 1): - last_step = step == num_iterations +while True: + last_step = step == num_iterations # loop runs num_iterations+1 times so that we can eval/save at the end flops_so_far = num_flops_per_token * total_batch_size * step # once in a while: evaluate the val bpb (all ranks participate) @@ -237,25 +270,31 @@ for step in range(num_iterations + 1): print0(tokenizer.decode(sample[0])) model.train() - # save checkpoint at the end of the run (only on master process) - if master_process and last_step: - output_dirname = model_tag if model_tag else f"d{depth}" # e.g. d12 - checkpoint_dir = os.path.join(base_dir, "base_checkpoints", output_dirname) + # save checkpoint: at the end of the run, or every save_every steps, except at the first step or the resume step + if last_step or (step > 0 and step != resume_from_step and save_every > 0 and step % save_every == 0): save_checkpoint( checkpoint_dir, step, - orig_model.state_dict(), - [opt.state_dict() for opt in optimizers], # TODO: make sure saving across ranks is done correctly - { + orig_model.state_dict(), # model parameters + [opt.state_dict() for opt in optimizers], # optimizer states + { # metadata saved as json "step": step, "val_bpb": val_bpb, # loss at last step "model_config": model_config_kwargs, "user_config": user_config, # inputs to the training script "device_batch_size": device_batch_size, "max_seq_len": max_seq_len, - } + "dataloader_state_dict": dataloader_state_dict, + "loop_state": { # all loop state (other than step) so that we can resume training + "min_val_bpb": min_val_bpb, + "smooth_train_loss": smooth_train_loss, + "total_training_time": total_training_time, + }, + }, + rank=ddp_rank, ) + # termination conditions (TODO: possibly also add loss explosions etc.) if last_step: break @@ -270,7 +309,7 @@ for step in range(num_iterations + 1): train_loss = loss.detach() # for logging loss = loss / grad_accum_steps # each .backward() is a grad sum => normalize loss here loss.backward() - x, y = next(train_loader) # prefetch the next batch while the GPU is busy with forward/backward + x, y, dataloader_state_dict = next(train_loader) # prefetch the next batch while the GPU is busy with forward/backward # gradient clipping grad_clip_enabled = grad_clip > 0.0 if grad_clip_enabled: @@ -293,6 +332,7 @@ for step in range(num_iterations + 1): # ------------------------------------------------------------------------- # logging + ema_beta = 0.9 # EMA decay factor for some smoothing just for nicer logging smooth_train_loss = ema_beta * smooth_train_loss + (1 - ema_beta) * train_loss.item() # EMA the training loss debiased_smooth_loss = smooth_train_loss / (1 - ema_beta**(step + 1)) # debias the EMA pct_done = 100 * step / num_iterations @@ -319,6 +359,9 @@ for step in range(num_iterations + 1): log_data["train/grad_norm"] = grad_norm wandb_run.log(log_data) + # state update + step += 1 + # print a few more stats print0(f"Peak memory usage: {get_max_memory() / 1024 / 1024:.2f}MiB") print0(f"Total training time: {total_training_time/60:.2f}m") From 7b7fd0fe71cf496304d0b8d4e3571c2fc412356b Mon Sep 17 00:00:00 2001 From: Andrej Karpathy Date: Thu, 13 Nov 2025 16:07:54 +0000 Subject: [PATCH 06/12] thank you Sophie for your help with nanochat --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index 18ea5ce..faee896 100644 --- a/README.md +++ b/README.md @@ -201,6 +201,7 @@ Current LLM policy: disclosure. When submitting a PR, please declare any parts t - Thank you to [HuggingFace](https://huggingface.co/) for fineweb and smoltalk. - Thank you [Lambda](https://lambda.ai/service/gpu-cloud) for the compute used in developing this project. - Thank you to chief LLM whisperer 🧙‍♂️ Alec Radford for advice/guidance. +- Thank you to the repo czar Sophie [@svlandeg](https://github.com/svlandeg) for help with managing issues, pull requests and discussions of nanochat. ## Cite From 9a71d1368899b7bfbb8e1fad966b683ec80a5760 Mon Sep 17 00:00:00 2001 From: Andrej Karpathy Date: Thu, 13 Nov 2025 16:08:30 +0000 Subject: [PATCH 07/12] typo oops --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index faee896..c96ac23 100644 --- a/README.md +++ b/README.md @@ -201,7 +201,7 @@ Current LLM policy: disclosure. When submitting a PR, please declare any parts t - Thank you to [HuggingFace](https://huggingface.co/) for fineweb and smoltalk. - Thank you [Lambda](https://lambda.ai/service/gpu-cloud) for the compute used in developing this project. - Thank you to chief LLM whisperer 🧙‍♂️ Alec Radford for advice/guidance. -- Thank you to the repo czar Sophie [@svlandeg](https://github.com/svlandeg) for help with managing issues, pull requests and discussions of nanochat. +- Thank you to the repo czar Sofie [@svlandeg](https://github.com/svlandeg) for help with managing issues, pull requests and discussions of nanochat. ## Cite From e5efb4b471cd708a5aa816462e8fce78cb2b4431 Mon Sep 17 00:00:00 2001 From: svlandeg Date: Fri, 14 Nov 2025 11:13:42 +0100 Subject: [PATCH 08/12] add test_engine.py to file structure --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index 18ea5ce..4b50d69 100644 --- a/README.md +++ b/README.md @@ -184,6 +184,7 @@ python -m pytest tests/test_rustbpe.py -v -s │ ├── smoltalk.py # Conglomerate dataset of SmolTalk from HF │ └── spellingbee.py # Task teaching model to spell/count letters ├── tests +│ └── test_engine.py │ └── test_rustbpe.py └── uv.lock ``` From a2fb3c83a66dd4199e7aa0fcaddda28e3fe85bbf Mon Sep 17 00:00:00 2001 From: svlandeg Date: Fri, 14 Nov 2025 11:20:25 +0100 Subject: [PATCH 09/12] fix typos --- nanochat/loss_eval.py | 4 ++-- scripts/chat_eval.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/nanochat/loss_eval.py b/nanochat/loss_eval.py index 6fcbea3..5a556e6 100644 --- a/nanochat/loss_eval.py +++ b/nanochat/loss_eval.py @@ -9,9 +9,9 @@ import torch.distributed as dist def evaluate_bpb(model, batches, steps, token_bytes): """ Instead of the naive 'mean loss', this function returns the bits per byte (bpb), - which is a tokenization vocab size-indepedent metric, meaning you are still comparing + which is a tokenization vocab size-independent metric, meaning you are still comparing apples:apples if you change the vocab size. The way this works is that instead of just - calculating the average loss as usual, you calculate the sum loss, and indepependently + calculating the average loss as usual, you calculate the sum loss, and independently also the sum bytes (of all the target tokens), and divide. This normalizes the loss by the number of bytes that the target tokens represent. diff --git a/scripts/chat_eval.py b/scripts/chat_eval.py index 616411d..cae2f0f 100644 --- a/scripts/chat_eval.py +++ b/scripts/chat_eval.py @@ -1,6 +1,6 @@ """ Evaluate the Chat model. -All the generic code lives here, and all the evlauation-specific +All the generic code lives here, and all the evaluation-specific code lives in nanochat directory and is imported from here. Example runs: From c6f5bd67db78982f02d19d86005524819aa410fc Mon Sep 17 00:00:00 2001 From: Sofie Van Landeghem Date: Fri, 14 Nov 2025 12:20:03 +0100 Subject: [PATCH 10/12] revert change of base to sft for quick inline test --- nanochat/engine.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nanochat/engine.py b/nanochat/engine.py index da85085..295d889 100644 --- a/nanochat/engine.py +++ b/nanochat/engine.py @@ -332,7 +332,7 @@ if __name__ == "__main__": autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type == "cuda" else nullcontext() # load the model and tokenizer - model, tokenizer, meta = load_model("sft", device, phase="eval") + model, tokenizer, meta = load_model("base", device, phase="eval") bos_token_id = tokenizer.get_bos_token_id() # common hyperparameters kwargs = dict(max_tokens=64, temperature=0.0) From bc1fca39f33074fec4319ef46d96e09b8024c824 Mon Sep 17 00:00:00 2001 From: Andrej Karpathy Date: Sat, 15 Nov 2025 15:43:37 +0000 Subject: [PATCH 11/12] mqa -> gqa to reduce confusion --- nanochat/gpt.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/nanochat/gpt.py b/nanochat/gpt.py index b640f1e..8b220c3 100644 --- a/nanochat/gpt.py +++ b/nanochat/gpt.py @@ -8,7 +8,7 @@ Notable features: - norm after token embedding - no learnable params in rmsnorm - no bias in linear layers -- Multi-Query Attention (MQA) support for more efficient inference +- Group-Query Attention (GQA) support for more efficient inference """ import math @@ -29,7 +29,7 @@ class GPTConfig: vocab_size: int = 50304 n_layer: int = 12 n_head: int = 6 # number of query heads - n_kv_head: int = 6 # number of key/value heads (MQA) + n_kv_head: int = 6 # number of key/value heads (GQA) n_embd: int = 768 From 11e68bf4427aef8748a8c0c3978c9c03838a9466 Mon Sep 17 00:00:00 2001 From: Sam Abrahams Date: Mon, 17 Nov 2025 11:32:56 -0500 Subject: [PATCH 12/12] Fix comment: rotary embeddings final dimension size --- nanochat/gpt.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nanochat/gpt.py b/nanochat/gpt.py index 8b220c3..216343c 100644 --- a/nanochat/gpt.py +++ b/nanochat/gpt.py @@ -244,7 +244,7 @@ class GPT(nn.Module): def forward(self, idx, targets=None, kv_cache=None, loss_reduction='mean'): B, T = idx.size() - # Grab the rotary embeddings for the current sequence length (they are of shape (1, seq_len, 1, head_dim)) + # Grab the rotary embeddings for the current sequence length (they are of shape (1, seq_len, 1, head_dim/2)) assert T <= self.cos.size(1), f"Sequence length grew beyond the rotary embeddings cache: {T} > {self.cos.size(1)}" assert idx.device == self.cos.device, f"Rotary embeddings and idx are on different devices: {idx.device} != {self.cos.device}" assert self.cos.dtype == torch.bfloat16, "Rotary embeddings must be in bfloat16"