Merge branch 'karpathy:master' into master

This commit is contained in:
Hossein-Lakzaei 2025-12-28 20:09:01 +03:30 committed by GitHub
commit 7cdb05856d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 271 additions and 47 deletions

View File

@ -113,12 +113,24 @@ def print_banner():
"""
print0(banner)
def is_ddp():
# TODO is there a proper way
return int(os.environ.get('RANK', -1)) != -1
def is_ddp_requested() -> bool:
"""
True if launched by torchrun (env present), even before init.
Used to decide whether we *should* initialize a PG.
"""
return all(k in os.environ for k in ("RANK", "LOCAL_RANK", "WORLD_SIZE"))
def is_ddp_initialized() -> bool:
"""
True if torch.distributed is available and the process group is initialized.
Used at cleanup to avoid destroying a non-existent PG.
"""
return dist.is_available() and dist.is_initialized()
def get_dist_info():
if is_ddp():
if is_ddp_requested():
# We rely on torchrun's env to decide if we SHOULD init.
# (Initialization itself happens in compute init.)
assert all(var in os.environ for var in ['RANK', 'LOCAL_RANK', 'WORLD_SIZE'])
ddp_rank = int(os.environ['RANK'])
ddp_local_rank = int(os.environ['LOCAL_RANK'])
@ -158,11 +170,11 @@ def compute_init(device_type="cuda"): # cuda|cpu|mps
# Precision
if device_type == "cuda":
torch.set_float32_matmul_precision("high") # uses tf32 instead of fp32 for matmuls
torch.backends.cuda.matmul.fp32_precision = "tf32" # uses tf32 instead of fp32 for matmuls
# Distributed setup: Distributed Data Parallel (DDP), optional, and requires CUDA
ddp, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info()
if ddp and device_type == "cuda":
is_ddp_requested, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info()
if is_ddp_requested and device_type == "cuda":
device = torch.device("cuda", ddp_local_rank)
torch.cuda.set_device(device) # make "cuda" default to this device
dist.init_process_group(backend="nccl", device_id=device)
@ -173,11 +185,11 @@ def compute_init(device_type="cuda"): # cuda|cpu|mps
if ddp_rank == 0:
logger.info(f"Distributed world size: {ddp_world_size}")
return ddp, ddp_rank, ddp_local_rank, ddp_world_size, device
return is_ddp_requested, ddp_rank, ddp_local_rank, ddp_world_size, device
def compute_cleanup():
"""Companion function to compute_init, to clean things up before script exit"""
if is_ddp():
if is_ddp_initialized():
dist.destroy_process_group()
class DummyWandb:

View File

@ -19,7 +19,7 @@ from contextlib import contextmanager
from collections import deque
from nanochat.common import compute_init, autodetect_device_type
from nanochat.checkpoint_manager import load_model
from contextlib import nullcontext
from contextlib import nullcontext
# -----------------------------------------------------------------------------
# Calculator tool helpers
@ -107,23 +107,23 @@ class KVCache:
# 1) validate the shapes
assert self.kv_cache is None, "Cannot prefill a non-empty KV cache"
assert other.kv_cache is not None, "Cannot prefill with a None KV cache"
# Extract dimensions explicitly
self_layers, self_kv, self_batch, self_heads, self_seq, self_head_dim = self.kv_shape
other_layers, other_kv, other_batch, other_heads, other_seq, other_head_dim = other.kv_shape
# Validate dimensions
assert self_layers == other_layers, f"Layer count mismatch: {self_layers} != {other_layers}"
assert self_kv == other_kv, f"K/V dimension mismatch: {self_kv} != {other_kv}"
assert self_heads == other_heads, f"Head count mismatch: {self_heads} != {other_heads}"
assert self_head_dim == other_head_dim, f"Head dim mismatch: {self_head_dim} != {other_head_dim}"
# Batch size can be expanded (other can be 1, self can be larger)
assert self_batch == other_batch or other_batch == 1, f"Batch size mismatch: {self_batch} vs {other_batch} (other must be 1 or equal)"
# Sequence length: self must be longer than other
assert self_seq >= other_seq, f"Sequence length mismatch: {self_seq} < {other_seq}"
# 2) initialize the cache
dtype, device = other.kv_cache.dtype, other.kv_cache.device
self.kv_cache = torch.empty(self.kv_shape, dtype=dtype, device=device)
@ -223,9 +223,7 @@ class Engine:
)
ids = torch.tensor([tokens], dtype=torch.long, device=device)
logits = self.model.forward(ids, kv_cache=kv_cache_prefill)
logits = logits[:, -1, :]
next_ids = sample_next_token(logits, rng, temperature, top_k) # (B, 1)
sampled_tokens = next_ids[:, 0].tolist()
logits = logits[:, -1, :].expand(num_samples, -1) # (num_samples, vocab_size)
# 2) Replicate the KV cache for each sample/row
kv_length_hint = (len(tokens) + max_tokens) if max_tokens is not None else self.model.config.sequence_len
@ -242,7 +240,6 @@ class Engine:
# 4) Main generation loop
num_generated = 0
first_iteration = True
while True:
# Stop condition: we've reached max tokens
if max_tokens is not None and num_generated >= max_tokens:
@ -251,18 +248,9 @@ class Engine:
if all(state.completed for state in row_states):
break
# Get sampled tokens - either from prefill or from forward pass
if first_iteration:
# Use the tokens we already sampled from prefill
sampled_tokens = [sampled_tokens[0]] * num_samples # Broadcast first token to all rows
# TODO: we should sample a token for each row instead of broadcasting
first_iteration = False
else:
# Forward the model and get the next token for each row
logits = self.model.forward(ids, kv_cache=kv_cache_decode) # (B, T, vocab_size)
logits = logits[:, -1, :] # (B, vocab_size) at last time step
next_ids = sample_next_token(logits, rng, temperature, top_k) # (B, 1)
sampled_tokens = next_ids[:, 0].tolist()
# Sample the next token for each row
next_ids = sample_next_token(logits, rng, temperature, top_k) # (B, 1)
sampled_tokens = next_ids[:, 0].tolist()
# Process each row: choose the next token, update state, optional tool use
token_column = [] # contains the next token id along each row
@ -299,8 +287,10 @@ class Engine:
# Yield the token column
yield token_column, token_masks
num_generated += 1
# Prepare ids for next iteration
# Prepare logits for next iteration
ids = torch.tensor(token_column, dtype=torch.long, device=device).unsqueeze(1)
logits = self.model.forward(ids, kv_cache=kv_cache_decode)[:, -1, :] # (B, vocab_size)
def generate_batch(self, tokens, num_samples=1, **kwargs):
"""

View File

@ -41,12 +41,10 @@ def norm(x):
def apply_rotary_emb(x, cos, sin):
assert x.ndim == 4 # multihead attention
d = x.shape[3] // 2
x1, x2 = x[..., :d], x[..., d:] # split up last time into two halves
x1, x2 = x[..., :d], x[..., d:] # split up last dim into two halves
y1 = x1 * cos + x2 * sin # rotate pairs of dims
y2 = x1 * (-sin) + x2 * cos
out = torch.cat([y1, y2], 3) # re-assemble
out = out.to(x.dtype) # ensure input/output dtypes match
return out
return torch.cat([y1, y2], 3)
class CausalSelfAttention(nn.Module):
def __init__(self, config, layer_idx):

View File

@ -465,6 +465,22 @@ impl Tokenizer {
all_ids
}
/// Encode multiple texts in parallel using rayon.
/// Returns a list of token ID vectors, one per input text.
#[pyo3(signature = (texts))]
#[pyo3(text_signature = "(self, texts)")]
pub fn batch_encode(&self, py: Python<'_>, texts: Vec<String>) -> PyResult<Vec<Vec<u32>>> {
// Release Python GIL and encode in parallel using rayon
let results = py.allow_threads(|| {
texts
.par_iter()
.map(|text| self.encode(text))
.collect::<Vec<Vec<u32>>>()
});
Ok(results)
}
}
#[pymodule]

View File

@ -149,6 +149,8 @@ def main():
parser = argparse.ArgumentParser()
parser.add_argument('--hf-path', type=str, default=None, help='HuggingFace model path to evaluate')
parser.add_argument('--max-per-task', type=int, default=-1, help='Max examples per task to evaluate (-1 = disable)')
parser.add_argument('--model-tag', type=str, default=None, help='optional model tag for the output directory name')
parser.add_argument('--step', type=str, default=None, help='optional model step for the output directory name')
args = parser.parse_args()
# distributed / precision setup
@ -166,7 +168,7 @@ def main():
model_slug = hf_path.replace("/", "-") # for the output csv file
else:
# load a local model from the file system
model, tokenizer, meta = load_model("base", device, phase="eval")
model, tokenizer, meta = load_model("base", device, phase="eval", model_tag=args.model_tag, step=args.step)
model_name = f"base_model (step {meta['step']})" # just for logging
model_slug = f"base_model_{meta['step']:06d}" # for the output csv file

View File

@ -12,7 +12,7 @@ python -m scripts.base_train --depth=4 --max_seq_len=512 --device_batch_size=1 -
"""
import os
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
os.environ["PYTORCH_ALLOC_CONF"] = "expandable_segments:True"
import time
from contextlib import nullcontext

View File

@ -31,6 +31,8 @@ from tasks.gsm8k import GSM8K
# RL hyperparameters
run = "dummy" # wandb run name
source = "sft" # mid|sft
model_tag = None # model tag to load the model from (base model or midtrained model)
step = None # step to load the model from (base model or midtrained model)
dtype = "bfloat16"
device_batch_size = 8 # no forward pass will go above this to not OOM
examples_per_step = 16 # in total and across all ranks (note: examples, not samples/completions!)
@ -64,7 +66,7 @@ use_dummy_wandb = run == "dummy" or not master_process
wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat-rl", name=run, config=user_config)
# Init model and tokenizer
model, tokenizer, meta = load_model(source, device, phase="eval")
model, tokenizer, meta = load_model(source, device, phase="eval", model_tag=model_tag, step=step)
engine = Engine(model, tokenizer) # for sampling rollouts
# -----------------------------------------------------------------------------
@ -307,8 +309,8 @@ for step in range(num_steps):
if master_process and ((step > 0 and step % save_every == 0) or step == num_steps - 1):
base_dir = get_base_dir()
depth = model.config.n_layer
model_tag = f"d{depth}" # base the model tag on the depth of the base model
checkpoint_dir = os.path.join(base_dir, "chatrl_checkpoints", model_tag)
output_dirname = model_tag if model_tag else f"d{depth}" # base the model tag on the depth of the base model
checkpoint_dir = os.path.join(base_dir, "chatrl_checkpoints", output_dirname)
model_config_kwargs = model.config.__dict__ # slightly naughty, abusing the simplicity of GPTConfig, TODO nicer
save_checkpoint(
checkpoint_dir,

View File

@ -10,7 +10,7 @@ torchrun --standalone --nproc_per_node=8 -m scripts.chat_sft
"""
import os
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
os.environ["PYTORCH_ALLOC_CONF"] = "expandable_segments:True"
import wandb
import torch
@ -250,8 +250,8 @@ for step in range(num_iterations):
if master_process:
base_dir = get_base_dir()
depth = model.config.n_layer
model_tag = f"d{depth}" # base the model tag on the depth of the base model
checkpoint_dir = os.path.join(base_dir, "chatsft_checkpoints", model_tag)
output_dirname = model_tag if model_tag else f"d{depth}" # e.g. d12
checkpoint_dir = os.path.join(base_dir, "chatsft_checkpoints", output_dirname)
model_config_kwargs = model.config.__dict__ # slightly naughty, abusing the simplicity of GPTConfig, TODO nicer
save_checkpoint(
checkpoint_dir,

View File

@ -11,7 +11,7 @@ torchrun --standalone --nproc_per_node=8 -m scripts.mid_train -- --device_batch_
from collections import deque
import os
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
os.environ["PYTORCH_ALLOC_CONF"] = "expandable_segments:True"
import time
import wandb
import torch
@ -207,7 +207,7 @@ while True:
# save checkpoint at the end of the run (only on master process)
if master_process and last_step and not dry_run:
output_dirname = f"d{depth}" # e.g. d12
output_dirname = model_tag if model_tag else f"d{depth}" # e.g. d12
checkpoint_dir = os.path.join(base_dir, "mid_checkpoints", output_dirname)
save_checkpoint(
checkpoint_dir,

View File

@ -5,7 +5,85 @@ python -m pytest tests/test_engine.py -v
"""
import torch
from nanochat.engine import KVCache
from nanochat.engine import KVCache, Engine
from dataclasses import dataclass
# -----------------------------------------------------------------------------
# Mock classes for testing Engine without loading a real model
@dataclass
class MockConfig:
"""Minimal config for Engine tests."""
n_kv_head: int = 4
n_head: int = 4
n_embd: int = 64
n_layer: int = 2
sequence_len: int = 128
class MockModel:
"""
Mock model that returns uniform logits over the vocab.
This ensures that with temperature > 0, different samples should
(with very high probability) produce different tokens.
"""
def __init__(self, vocab_size=262): # 256 bytes + 6 special tokens
self.vocab_size = vocab_size
self.config = MockConfig()
self._device = "cpu"
def get_device(self):
return self._device
def forward(self, ids, kv_cache=None):
"""Return uniform logits so sampling is spread across vocab."""
B, T = ids.shape
# Simulate what a real transformer does: insert k,v into the cache for each layer
if kv_cache is not None:
head_dim = self.config.n_embd // self.config.n_head
for layer_idx in range(self.config.n_layer):
k = torch.zeros(B, self.config.n_kv_head, T, head_dim)
v = torch.zeros(B, self.config.n_kv_head, T, head_dim)
kv_cache.insert_kv(layer_idx, k, v)
# Uniform logits -> equal probability for all tokens
logits = torch.zeros(B, T, self.vocab_size)
return logits
class ByteTokenizer:
"""
Simple byte-level tokenizer for testing.
Tokens 0-255 are raw bytes, 256+ are special tokens.
"""
def __init__(self):
# Special tokens start at 256
self._special_tokens = {
"<|python_start|>": 256,
"<|python_end|>": 257,
"<|output_start|>": 258,
"<|output_end|>": 259,
"<|assistant_end|>": 260,
"<|bos|>": 261,
}
self._bos = 261
def encode_special(self, s):
return self._special_tokens[s]
def get_bos_token_id(self):
return self._bos
def encode(self, s, prepend=None):
tokens = list(s.encode("utf-8")) # bytes 0-255
if prepend is not None:
tokens = [prepend] + tokens
return tokens
def decode(self, tokens):
# Filter out special tokens before decoding
byte_tokens = [t for t in tokens if t < 256]
return bytes(byte_tokens).decode("utf-8", errors="replace")
def test_kv_cache_resize():
"""
@ -64,3 +142,46 @@ def test_kv_cache_resize():
original_v = original_cache[layer_idx, 1, :, :, token_idx, :]
assert (actual_k == original_k).all(), f"Layer {layer_idx}, token {token_idx}: key doesn't match original"
assert (actual_v == original_v).all(), f"Layer {layer_idx}, token {token_idx}: value doesn't match original"
def test_multi_sample_first_token_diversity():
"""
Test that when generating multiple samples, each sample gets an independently
sampled first token (not a broadcast of the same token to all rows).
Previously, the first token after prefill was sampled once and broadcast to all
rows, causing all samples to start identically. The fix expands the prefill logits
to num_samples and samples independently for each row.
With uniform logits over 262 tokens and 16 samples, the probability that all
samples independently pick the same token is (1/262)^15 10^-36. So if they're
all identical, it indicates tokens are being broadcast instead of independently sampled.
"""
model = MockModel(vocab_size=262)
tokenizer = ByteTokenizer()
engine = Engine(model, tokenizer)
# Generate 16 samples with temperature=1.0 (stochastic sampling)
prompt_tokens = [261, 72, 101, 108, 108, 111] # <bos> + "Hello"
num_samples = 16
# Collect the first generated token from each sample
first_tokens = []
gen = engine.generate(
prompt_tokens,
num_samples=num_samples,
max_tokens=1, # We only need the first token
temperature=1.0,
seed=42,
)
for token_column, token_masks in gen:
first_tokens = token_column # This is the first (and only) yield
# With uniform distribution and 16 samples, they should NOT all be identical
# If they are all identical, the bug exists (broadcasting instead of sampling)
unique_tokens = set(first_tokens)
assert len(unique_tokens) > 1, (
f"All {num_samples} samples got the same first token ({first_tokens[0]}). "
f"With uniform logits, this is statistically impossible (~10^-36 probability) "
f"unless tokens are being broadcast instead of independently sampled."
)

View File

@ -21,6 +21,7 @@ python -m pytest tests/test_rustbpe.py -v -s
import regex as re
from collections import Counter, defaultdict
import time
import warnings
import rustbpe
import tiktoken
import pytest
@ -633,3 +634,85 @@ def test_interface(enwik8_small):
ids_reloaded = tok_reloaded.encode(encode_text)
assert ids_reloaded == ids, "Reloaded tokenizer should produce same results"
print("✅ Save/load through temporary directory OK")
def test_batch_encode_correctness(enwik8_small):
"""Quick correctness test for batch_encode()"""
text = enwik8_small
vocab_size = 512
tokenizer = rustbpe.Tokenizer()
tokenizer.train_from_iterator([text], vocab_size)
# Test with various batch sizes and edge cases
test_texts = [
"Hello world",
"The quick brown fox",
"jumps over the lazy dog",
"", # empty string
"a", # single char
]
# Compare batch vs individual encoding
individual = [tokenizer.encode(t) for t in test_texts]
batched = tokenizer.batch_encode(test_texts)
assert individual == batched, "Batch encoding should match individual encoding"
print("✅ batch_encode() correctness verified")
@pytest.mark.slow
def test_batch_encode_performance(enwik8_large):
"""
Benchmark batch_encode() vs sequential encode() loop.
Demonstrates parallelization speedup.
"""
# Setup
text = enwik8_large # 10MB dataset
vocab_size = 2048
# Train tokenizer
print("\nTraining tokenizer...")
tokenizer = rustbpe.Tokenizer()
tokenizer.train_from_iterator([text], vocab_size)
# Create test batch: split text into chunks
chunk_size = 50_000 # ~50KB per chunk
chunks = [text[i:i+chunk_size] for i in range(0, len(text), chunk_size)]
chunks = chunks[:20] # Use first 20 chunks (~1MB total)
print(f"\nBatch encoding benchmark:")
print(f" Number of texts: {len(chunks)}")
print(f" Avg text length: {sum(len(c) for c in chunks) / len(chunks):.0f} chars")
# Benchmark 1: Sequential encoding (baseline)
print("\n [1/3] Sequential encode() loop...")
sequential_results, sequential_time = time_function(
lambda: [tokenizer.encode(chunk) for chunk in chunks]
)
print(f" Time: {sequential_time:.4f}s")
# Benchmark 2: Parallel batch_encode()
print(" [2/3] Parallel batch_encode()...")
batch_results, batch_time = time_function(
tokenizer.batch_encode, chunks
)
print(f" Time: {batch_time:.4f}s")
# Verify correctness
print(" [3/3] Verifying correctness...")
assert len(batch_results) == len(sequential_results), "Result count mismatch"
for i, (seq, batch) in enumerate(zip(sequential_results, batch_results)):
assert seq == batch, f"Mismatch at index {i}"
print(" ✓ All results match")
# Report speedup
speedup = sequential_time / batch_time
print(f"\n Performance Results:")
print(f" Sequential: {sequential_time:.4f}s")
print(f" Batch: {batch_time:.4f}s")
print(f" Speedup: {speedup:.2f}x")
# Warn if speedup is low (can vary by machine/load)
if speedup < 1.5:
warnings.warn(f"batch_encode() speedup was only {speedup:.2f}x (expected >1.5x)")