optimisations fixed

This commit is contained in:
diana 2025-11-05 22:07:29 +03:30
parent 890d1af779
commit a6efa53b92
8 changed files with 600 additions and 146 deletions

View File

@ -0,0 +1,9 @@
{
"permissions": {
"allow": [
"Bash(python:*)"
],
"deny": [],
"ask": []
}
}

View File

@ -1,186 +1,339 @@
"""
Auto-discovery module for finding optimal batch sizes.
Automatic batch size discovery module for maximizing GPU utilization.
This is a minimal stub implementation to enable testing.
The full implementation should be added as part of Task 41 (Auto Batch Size Module).
This module implements an intelligent batch size search algorithm that:
1. Uses exponential search to quickly find an upper bound
2. Refines with binary search for optimal size
3. Applies safety margin to prevent edge-case OOMs
4. Supports DDP multi-GPU coordination
5. Caches results for faster subsequent runs
"""
import os
import json
import time
import hashlib
import torch
import torch.distributed as dist
from typing import Optional, Callable, Dict, Any
from nanochat.common import print0, get_base_dir
def discover_batch_size(
model: torch.nn.Module,
max_seq_len: int,
device: torch.device,
safety_margin: float = 0.85,
min_batch_size: int = 1,
max_batch_size: int = 128,
ddp_rank: int = 0,
ddp_world_size: int = 1,
use_cache: bool = False,
cache_key_components: Optional[Dict[str, Any]] = None,
) -> int:
def find_optimal_device_batch_size(
model,
max_seq_len,
total_batch_size,
ddp_world_size,
data_sample_fn,
override=None,
safety_margin=0.85,
enable_cache=True,
ddp_rank=0,
):
"""
Discover the optimal batch size for a model.
Main entry point for automatic batch size discovery.
Args:
model: The model to test
model: PyTorch model to test
max_seq_len: Maximum sequence length
device: Device to run on
safety_margin: Safety factor (e.g., 0.85 = use 85% of max)
min_batch_size: Minimum batch size to try
max_batch_size: Maximum batch size to try
ddp_rank: Rank in distributed setting
ddp_world_size: World size in distributed setting
use_cache: Whether to use cache
cache_key_components: Components for cache key
total_batch_size: Total batch size across all GPUs (for gradient accumulation calculation)
ddp_world_size: Number of GPUs in DDP
data_sample_fn: Callable(batch_size) -> (inputs, targets)
override: If set, skip discovery and return this value
safety_margin: Fraction of optimal batch size to use (default 0.85)
enable_cache: Whether to use caching
ddp_rank: Current rank in DDP
Returns:
Discovered batch size
optimal_batch_size: Optimal device batch size for this GPU
"""
# Only rank 0 performs discovery in DDP
# Handle manual override
if override is not None:
print0(f"Using manual batch_size override: {override}")
return override
optimal_batch_size = None
# Only rank 0 performs discovery
if ddp_rank == 0:
print0("Running auto-discovery on rank 0")
# Check cache first
if use_cache and cache_key_components:
cached_size = _load_from_cache(cache_key_components)
if cached_size is not None:
print0(f"Cache hit! Using batch_size={cached_size}")
discovered_size = cached_size
else:
print0("Cache miss, performing discovery")
discovered_size = _perform_discovery(
model, max_seq_len, device, safety_margin,
min_batch_size, max_batch_size
start_time = time.time()
print0(f"\n{'='*60}")
print0(f"Starting automatic batch size discovery...")
print0(f"Parameters: max_seq_len={max_seq_len}, ddp_world_size={ddp_world_size}")
print0(f"Safety margin: {safety_margin:.2%}")
print0(f"{'='*60}\n")
# Check cache
cache_key = None
if enable_cache:
cache_key = _get_cache_key(model, max_seq_len)
cached_batch_size = _load_from_cache(cache_key)
if cached_batch_size is not None:
print0(f"✓ Cache hit! Using cached batch_size: {cached_batch_size}")
optimal_batch_size = cached_batch_size
# Run discovery if no cache hit
if optimal_batch_size is None:
try:
# Warmup CUDA
_warmup_cuda()
# Run the search algorithm
optimal_batch_size = _find_batch_size_internal(
model=model,
max_seq_len=max_seq_len,
data_sample_fn=data_sample_fn,
safety_margin=safety_margin,
)
if cache_key_components:
_save_to_cache(cache_key_components, discovered_size)
else:
discovered_size = _perform_discovery(
model, max_seq_len, device, safety_margin,
min_batch_size, max_batch_size
)
print0(f"Auto-discovery found device_batch_size={discovered_size}")
# Save to cache
if enable_cache and cache_key is not None and optimal_batch_size is not None:
_save_to_cache(cache_key, optimal_batch_size)
elapsed = time.time() - start_time
print0(f"\n{'='*60}")
print0(f"✓ Found optimal batch_size={optimal_batch_size} in {elapsed:.1f} seconds")
print0(f"{'='*60}\n")
except Exception as e:
print0(f"⚠ Warning: Batch size discovery failed with error: {e}")
optimal_batch_size = None
# Fallback to conservative defaults if discovery failed
if optimal_batch_size is None:
print0(f"⚠ Warning: Using conservative fallback batch_size=8")
optimal_batch_size = 8
else:
discovered_size = 0 # Will be broadcast from rank 0
# Broadcast to all ranks in DDP
optimal_batch_size = 0 # Will be broadcast from rank 0
# DDP: Broadcast result from rank 0 to all ranks
if ddp_world_size > 1:
discovered_tensor = torch.tensor(discovered_size, dtype=torch.int32, device=device)
dist.broadcast(discovered_tensor, src=0)
discovered_size = discovered_tensor.item()
if ddp_rank != 0:
print0(f"Received batch size from rank 0: {discovered_size}")
return discovered_size
try:
import torch.distributed as dist
tensor = torch.tensor([optimal_batch_size], dtype=torch.long, device='cuda')
dist.broadcast(tensor, src=0)
optimal_batch_size = tensor.item()
if ddp_rank != 0:
print0(f"Received batch_size from rank 0: {optimal_batch_size}")
except Exception as e:
print0(f"⚠ Warning: DDP broadcast failed: {e}")
if optimal_batch_size == 0:
optimal_batch_size = 8
return optimal_batch_size
def _perform_discovery(
model: torch.nn.Module,
max_seq_len: int,
device: torch.device,
safety_margin: float,
min_batch_size: int,
max_batch_size: int,
) -> int:
def _find_batch_size_internal(model, max_seq_len, data_sample_fn, safety_margin):
"""
Perform the actual discovery using exponential + binary search.
This is a stub implementation that returns a fixed value.
The real implementation should:
1. Exponential search to find upper bound
2. Binary search to refine
3. Apply safety margin
"""
# Stub: return a fixed reasonable value
# Real implementation would perform exponential + binary search
batch_size = min(32, max_batch_size)
return max(int(batch_size * safety_margin), min_batch_size)
Core algorithm implementing exponential search followed by binary search.
def _test_batch_size(
model: torch.nn.Module,
batch_size: int,
max_seq_len: int,
device: torch.device,
) -> bool:
"""
Test if a given batch size fits in memory.
Returns:
True if batch size works, False if OOM
optimal_batch_size: The largest batch size that fits in memory (with safety margin)
"""
# Phase 1: Exponential search to find upper bound
print0("Phase 1: Exponential search to find upper bound...")
batch_size = 1
last_successful = None
while True:
print0(f" Testing batch_size={batch_size}...", end=" ")
success = _test_batch_size(
model=model,
batch_size=batch_size,
max_seq_len=max_seq_len,
data_sample_fn=data_sample_fn,
)
if success:
print0("✓ Success")
last_successful = batch_size
batch_size *= 2
else:
print0("✗ OOM")
break
# If even batch_size=1 failed, return None
if last_successful is None:
print0("✗ Even batch_size=1 caused OOM!")
return None
# Phase 2: Binary search refinement
print0(f"\nPhase 2: Binary search refinement between {last_successful} and {batch_size}...")
lower = last_successful
upper = batch_size
while upper - lower > 1:
mid = (lower + upper) // 2
print0(f" Testing batch_size={mid}...", end=" ")
success = _test_batch_size(
model=model,
batch_size=mid,
max_seq_len=max_seq_len,
data_sample_fn=data_sample_fn,
)
if success:
print0("✓ Success")
lower = mid
else:
print0("✗ OOM")
upper = mid
# Phase 3: Apply safety margin
optimal_batch_size = int(lower * safety_margin)
print0(f"\nApplying safety margin: {lower} × {safety_margin:.2%} = {optimal_batch_size}")
return optimal_batch_size
def _test_batch_size(model, batch_size, max_seq_len, data_sample_fn):
"""
Test if a specific batch size fits in memory by simulating training loop.
Returns:
bool: True if batch size fits, False if OOM
"""
try:
# Create dummy inputs
inputs = torch.randint(0, 50000, (batch_size, max_seq_len), device=device, dtype=torch.int32)
targets = torch.randint(0, 50000, (batch_size, max_seq_len), device=device, dtype=torch.int64)
# Forward + backward pass
# Clear CUDA cache before test
torch.cuda.empty_cache()
# Set model to training mode
model.train()
# Zero gradients
model.zero_grad(set_to_none=True)
# Generate test batch
inputs, targets = data_sample_fn(batch_size)
# Forward pass with bfloat16 autocast
with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
loss = model(inputs, targets)
# Backward pass
loss.backward()
model.zero_grad(set_to_none=True)
# Clean up
# Synchronize CUDA to ensure all operations complete
torch.cuda.synchronize()
# Clear cache after test
del inputs, targets, loss
torch.cuda.empty_cache()
return True
except torch.cuda.OutOfMemoryError:
# Clear cache and return False on OOM
torch.cuda.empty_cache()
return False
except Exception as e:
print0(f"Error testing batch size {batch_size}: {e}")
# Handle other exceptions
print0(f"\n⚠ Warning: Test failed with unexpected error: {e}")
torch.cuda.empty_cache()
return False
def _get_cache_key(components: Dict[str, Any]) -> str:
"""Generate cache key from components."""
key_str = json.dumps(components, sort_keys=True)
return hashlib.md5(key_str.encode()).hexdigest()
def _warmup_cuda():
"""Warmup CUDA by allocating and freeing a small tensor."""
try:
x = torch.zeros(1, device='cuda')
del x
torch.cuda.synchronize()
torch.cuda.empty_cache()
except Exception as e:
print0(f"⚠ Warning: CUDA warmup failed: {e}")
def _load_from_cache(components: Dict[str, Any]) -> Optional[int]:
"""Load batch size from cache if available."""
def _get_cache_key(model, max_seq_len):
"""
Generate cache key from model config hash, GPU model, and max_seq_len.
Returns:
str: Hash string to use as cache key
"""
try:
# Get model config attributes
config = model.config if hasattr(model, 'config') else None
if config is None:
# Try to get from original model (in case of compiled model)
config = model._orig_mod.config if hasattr(model, '_orig_mod') else None
if config is None:
return None
# Build config string
config_parts = [
f"vocab_size={config.vocab_size}",
f"n_layer={config.n_layer}",
f"n_embd={config.n_embd}",
f"n_head={config.n_head}",
f"n_kv_head={config.n_kv_head}",
]
config_str = "|".join(config_parts)
# Get GPU model name
gpu_name = torch.cuda.get_device_name(0)
# Combine all components
key_str = f"{config_str}|gpu={gpu_name}|seq_len={max_seq_len}"
# Hash to create a short key
cache_key = hashlib.md5(key_str.encode()).hexdigest()
return cache_key
except Exception as e:
print0(f"⚠ Warning: Failed to generate cache key: {e}")
return None
def _load_from_cache(cache_key):
"""
Load cached batch size from JSON file.
Returns:
int or None: Cached batch size, or None if not found
"""
if cache_key is None:
return None
try:
base_dir = get_base_dir()
cache_dir = os.path.join(base_dir, "auto_batch_cache")
cache_key = _get_cache_key(components)
cache_file = os.path.join(cache_dir, f"{cache_key}.json")
if os.path.exists(cache_file):
with open(cache_file, 'r') as f:
data = json.load(f)
if not os.path.exists(cache_file):
return None
with open(cache_file, 'r') as f:
data = json.load(f)
return data.get('batch_size')
except Exception as e:
print0(f"Cache load error: {e}")
return None
print0(f"⚠ Warning: Failed to load from cache: {e}")
return None
def _save_to_cache(components: Dict[str, Any], batch_size: int) -> None:
"""Save batch size to cache."""
def _save_to_cache(cache_key, batch_size):
"""Save batch size to JSON cache file."""
if cache_key is None or batch_size is None:
return
try:
base_dir = get_base_dir()
cache_dir = os.path.join(base_dir, "auto_batch_cache")
os.makedirs(cache_dir, exist_ok=True)
cache_key = _get_cache_key(components)
cache_file = os.path.join(cache_dir, f"{cache_key}.json")
data = {
'batch_size': batch_size,
'timestamp': time.time(),
}
with open(cache_file, 'w') as f:
json.dump({
'batch_size': batch_size,
'components': components,
}, f, indent=2)
json.dump(data, f, indent=2)
print0(f"✓ Saved batch_size={batch_size} to cache")
except Exception as e:
print0(f"Cache save error: {e}")
print0(f"⚠ Warning: Failed to save to cache: {e}")

View File

@ -219,9 +219,7 @@ class Engine:
# Get sampled tokens - either from prefill or from forward pass
if first_iteration:
# Use the tokens we already sampled from prefill
sampled_tokens = [sampled_tokens[0]] * num_samples # Broadcast first token to all rows
# TODO: we should sample a token for each row instead of broadcasting
# sampled_tokens already contains num_samples independently sampled tokens
first_iteration = False
else:
# Forward the model and get the next token for each row

View File

@ -22,6 +22,7 @@ import torch.nn.functional as F
from nanochat.common import get_dist_info, print0
from nanochat.muon import Muon, DistMuon
from nanochat.adamw import DistAdamW
from nanochat.engine import KVCache
@dataclass
class GPTConfig:
@ -304,10 +305,44 @@ class GPT(nn.Module):
if temperature > 0:
rng = torch.Generator(device=device)
rng.manual_seed(seed)
ids = torch.tensor([tokens], dtype=torch.long, device=device) # add batch dim
for _ in range(max_tokens):
logits = self.forward(ids) # (B, T, vocab_size)
logits = logits[:, -1, :] # (B, vocab_size)
# Initialize KV cache
kv_length_hint = len(tokens) + max_tokens
kv_cache = KVCache(
batch_size=1,
num_heads=self.config.n_kv_head,
seq_len=kv_length_hint,
head_dim=self.config.n_embd // self.config.n_head,
num_layers=self.config.n_layer
)
# Prefill phase: process the prompt
ids = torch.tensor([tokens], dtype=torch.long, device=device)
logits = self.forward(ids, kv_cache=kv_cache)
logits = logits[:, -1, :]
# Sample first token
if top_k is not None:
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
logits[logits < v[:, [-1]]] = -float('Inf')
if temperature > 0:
logits = logits / temperature
probs = F.softmax(logits, dim=-1)
next_ids = torch.multinomial(probs, num_samples=1, generator=rng)
else:
next_ids = torch.argmax(logits, dim=-1, keepdim=True)
# Yield first token
token = next_ids.item()
yield token
# Generation loop: process one token at a time with KV cache
for _ in range(max_tokens - 1):
# Forward pass with only the new token
logits = self.forward(next_ids, kv_cache=kv_cache)
logits = logits[:, -1, :]
# Sample next token
if top_k is not None:
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
logits[logits < v[:, [-1]]] = -float('Inf')
@ -317,6 +352,6 @@ class GPT(nn.Module):
next_ids = torch.multinomial(probs, num_samples=1, generator=rng)
else:
next_ids = torch.argmax(logits, dim=-1, keepdim=True)
ids = torch.cat((ids, next_ids), dim=1)
token = next_ids.item()
yield token

View File

@ -40,6 +40,7 @@ device_batch_size = None # per-device batch size (set to not OOM), None = auto-d
auto_batch_size = True # whether to auto-discover optimal batch size
batch_size_margin = 0.85 # safety margin for auto-discovered batch size
batch_size_cache = True # whether to cache auto-discovered batch size
max_seq_len = 2048 # maximum sequence length for fixed padding (enables torch.compile)
# optimization
num_epochs = 1
max_iterations = -1 # override number of iterations (-1 = use num_epochs * num_iterations)
@ -104,8 +105,8 @@ elif device_batch_size is None:
print0(f"Auto-discovery disabled, using default device_batch_size={device_batch_size}")
orig_model = model # original, uncompiled model
# model = torch.compile(model, dynamic=True) # doesn't work super well because of variable lengths of inputs
engine = Engine(model, tokenizer) # will be used for inline model evaluation only
model = torch.compile(model, dynamic=False) # enabled with fixed-length padding
engine = Engine(orig_model, tokenizer) # use uncompiled model for engine (variable-length eval)
# -----------------------------------------------------------------------------
# Task data mixture we'll train on
@ -126,7 +127,7 @@ def sft_data_generator(dataset, batch_size):
# prepares a list of tokenized conversations into a batch and yields
def collate_and_yield(batch):
nrows = len(batch)
ncols = max(len(ids) for ids, mask in batch) - 1 # seq of n creates inputs/targets of n-1
ncols = max_seq_len - 1 # fixed length for torch.compile (seq of n creates inputs/targets of n-1)
inputs = torch.full((nrows, ncols), pad_token_id, dtype=torch.long)
targets = torch.full((nrows, ncols), -1, dtype=torch.long) # -1 is ignore index
for i, (ids, mask) in enumerate(batch):
@ -199,13 +200,13 @@ for step in range(num_iterations):
# evaluate the validation loss
if last_step or step % eval_every == 0:
model.eval()
orig_model.eval()
val_iter = iter(build_val_loader())
losses = []
for _ in range(eval_steps):
val_inputs, val_targets = next(val_iter)
with torch.no_grad(), autocast_ctx:
loss = model(val_inputs, val_targets)
loss = orig_model(val_inputs, val_targets)
losses.append(loss)
val_loss = torch.stack(losses).mean() # average over eval_steps
if ddp:
@ -216,22 +217,24 @@ for step in range(num_iterations):
"step": step,
"val_loss": val_loss,
})
orig_model.train()
model.train()
# evlauate accuracy of the multiple choice tasks (which are quick to run)
if last_step or (step > 0 and step % eval_metrics_every == 0):
model.eval()
orig_model.eval()
metrics = {}
with torch.no_grad(), autocast_ctx:
# note that because these are inside no_grad, we can usually afford to at least ~2X the batch size
metrics["mmlu_acc"] = run_chat_eval("MMLU", model, tokenizer, engine, batch_size=device_batch_size*2, max_problems=1024)
metrics["arc_easy_acc"] = run_chat_eval("ARC-Easy", model, tokenizer, engine, batch_size=device_batch_size*2, max_problems=1024)
metrics["mmlu_acc"] = run_chat_eval("MMLU", orig_model, tokenizer, engine, batch_size=device_batch_size*2, max_problems=1024)
metrics["arc_easy_acc"] = run_chat_eval("ARC-Easy", orig_model, tokenizer, engine, batch_size=device_batch_size*2, max_problems=1024)
metrics_str = ', '.join(f'{k}: {v:.6f}' for k, v in metrics.items())
print0(f"Step {step:05d} | {metrics_str}")
wandb_run.log({
"step": step,
**metrics,
})
orig_model.train()
model.train()
if last_step:
@ -276,14 +279,14 @@ for step in range(num_iterations):
# Save the model at the end of the run
if master_process:
base_dir = get_base_dir()
depth = model.config.n_layer
depth = orig_model.config.n_layer
model_tag = f"d{depth}" # base the model tag on the depth of the base model
checkpoint_dir = os.path.join(base_dir, "chatsft_checkpoints", model_tag)
model_config_kwargs = model.config.__dict__ # slightly naughty, abusing the simplicity of GPTConfig, TODO nicer
model_config_kwargs = orig_model.config.__dict__ # slightly naughty, abusing the simplicity of GPTConfig, TODO nicer
save_checkpoint(
checkpoint_dir,
step,
model.state_dict(),
orig_model.state_dict(),
None, # note: we don't bother to save the optimizer state
{
"step": step,

View File

@ -0,0 +1,57 @@
"""Quick script to inspect checkpoint structure."""
import torch
import sys
if len(sys.argv) < 2:
print("Usage: python check_checkpoint.py <path_to_checkpoint>")
sys.exit(1)
checkpoint_path = sys.argv[1]
print(f"Loading checkpoint: {checkpoint_path}")
checkpoint = torch.load(checkpoint_path, map_location='cpu')
print("\n📦 Checkpoint keys:")
for key in checkpoint.keys():
value = checkpoint[key]
if isinstance(value, dict):
print(f" - {key}: dict with {len(value)} items")
elif isinstance(value, torch.Tensor):
print(f" - {key}: Tensor {value.shape}")
else:
print(f" - {key}: {type(value).__name__}")
# Check for model weights
if 'model' in checkpoint:
print("\n✅ Model weights found")
model_dict = checkpoint['model']
print(f" Number of parameters: {len(model_dict)}")
print(f" Sample keys: {list(model_dict.keys())[:5]}")
else:
print("\n⚠️ No 'model' key found!")
print(" Available keys:", list(checkpoint.keys()))
# Check for config
if 'config' in checkpoint:
print("\n✅ Config found")
config = checkpoint['config']
print(f" Type: {type(config)}")
if hasattr(config, '__dict__'):
print(f" Attributes: {list(vars(config).keys())[:10]}")
else:
print("\n⚠️ No 'config' key found!")
print("\n" + "="*60)
print("RECOMMENDATION:")
if 'model' not in checkpoint or 'config' not in checkpoint:
print("Your checkpoint is missing required keys.")
print("Please check how the model was saved during training.")
print("\nExpected checkpoint structure:")
print(" checkpoint = {")
print(" 'model': model.state_dict(),")
print(" 'config': model.config,")
print(" 'optimizer': optimizer.state_dict(), # optional")
print(" 'step': current_step, # optional")
print(" }")
else:
print("✅ Checkpoint looks good!")

66
scripts/quick_check.py Normal file
View File

@ -0,0 +1,66 @@
"""Quick checkpoint structure check."""
import torch
import sys
checkpoint_path = "/raid/diana/nanochat_cache/chatsft_checkpoints/d20/model_000650.pt"
print(f"Loading: {checkpoint_path}")
try:
checkpoint = torch.load(checkpoint_path, map_location='cpu')
print("\n" + "="*60)
print("CHECKPOINT STRUCTURE")
print("="*60)
print(f"\nTop-level keys: {list(checkpoint.keys())}\n")
for key in checkpoint.keys():
value = checkpoint[key]
if isinstance(value, dict):
print(f"'{key}': dict with {len(value)} items")
# Show a few sub-keys if it's a dict
sub_keys = list(value.keys())[:3]
print(f" Sample keys: {sub_keys}")
elif isinstance(value, torch.Tensor):
print(f"'{key}': Tensor {value.shape}, dtype={value.dtype}")
else:
print(f"'{key}': {type(value).__name__} = {value}")
print("\n" + "="*60)
print("DIAGNOSIS")
print("="*60)
# Check what we need
has_model = 'model' in checkpoint
has_config = 'config' in checkpoint
has_state_dict = 'state_dict' in checkpoint
has_model_state_dict = 'model_state_dict' in checkpoint
print(f"\n✓ Has 'model' key: {has_model}")
print(f"✓ Has 'config' key: {has_config}")
print(f"✓ Has 'state_dict' key: {has_state_dict}")
print(f"✓ Has 'model_state_dict' key: {has_model_state_dict}")
# Try to infer the structure
print("\n" + "="*60)
print("RECOMMENDATION")
print("="*60)
if has_model and has_config:
print("\n✅ Checkpoint has expected structure!")
print(" No changes needed to benchmark_optimizations.py")
elif has_state_dict:
print("\n⚠️ Checkpoint uses 'state_dict' instead of 'model'")
print(" Need to update benchmark to use checkpoint['state_dict']")
elif has_model_state_dict:
print("\n⚠️ Checkpoint uses 'model_state_dict' instead of 'model'")
print(" Need to update benchmark to use checkpoint['model_state_dict']")
else:
print("\n❌ Checkpoint has unexpected structure!")
print(" Available keys:", list(checkpoint.keys()))
print(" You may need to check how the model was saved during training")
except Exception as e:
print(f"\n❌ Error loading checkpoint: {e}")
import traceback
traceback.print_exc()

133
speedrun_4gpu.sh Normal file
View File

@ -0,0 +1,133 @@
#!/bin/bash
# This script is the "Best ChatGPT clone that $100 can buy",
# It is designed to run in ~4 hours on 8XH100 node at $3/GPU/hour.
# 1) Example launch (simplest):
# bash speedrun.sh
# 2) Example launch in a screen session (because the run takes ~4 hours):
# screen -L -Logfile speedrun.log -S speedrun bash speedrun.sh
# 3) Example launch with wandb logging, but see below for setting up wandb first:
# WANDB_RUN=speedrun screen -L -Logfile speedrun.log -S speedrun bash speedrun.sh
# Default intermediate artifacts directory is in ~/.cache/nanochat
export OMP_NUM_THREADS=1
export NANOCHAT_BASE_DIR="/raid/diana/nanochat_cache"
mkdir -p $NANOCHAT_BASE_DIR
# -----------------------------------------------------------------------------
# Python venv setup with uv
# install uv (if not already installed)
command -v uv &> /dev/null || curl -LsSf https://astral.sh/uv/install.sh | sh
# create a .venv local virtual environment (if it doesn't exist)
[ -d ".venv" ] || uv venv
# install the repo dependencies
uv sync
# activate venv so that `python` uses the project's venv instead of system python
source .venv/bin/activate
# -----------------------------------------------------------------------------
# wandb setup
# If you wish to use wandb for logging (it's nice!, recommended).
# 1) Make sure to first log in to wandb, e.g. run:
# `wandb login`
# 2) Set the WANDB_RUN environment variable when running this script, e.g.:
# `WANDB_RUN=d26 bash speedrun.sh`
if [ -z "$WANDB_RUN" ]; then
# by default use "dummy" : it's handled as a special case, skips logging to wandb
WANDB_RUN=dummy
fi
# -----------------------------------------------------------------------------
# During the course of the run, we will be writing markdown reports to the report/
# directory in the base dir. This command clears it out and writes a header section
# with a bunch of system info and a timestamp that marks the start of the run.
python -m nanochat.report reset
# -----------------------------------------------------------------------------
# Tokenizer
# Install Rust / Cargo
curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y
source "$HOME/.cargo/env"
# Build the rustbpe Tokenizer
uv run maturin develop --release --manifest-path rustbpe/Cargo.toml
# Download the first ~2B characters of pretraining dataset
# look at dev/repackage_data_reference.py for details on how this data was prepared
# each data shard is ~250M chars
# so we download 2e9 / 250e6 = 8 data shards at this point
# each shard is ~100MB of text (compressed), so this is about ~800MB of data on disk
python -m nanochat.dataset -n 8
# Immediately also kick off downloading more shards in the background while tokenizer trains
# See comment below for why 240 is the right number here
python -m nanochat.dataset -n 240 &
DATASET_DOWNLOAD_PID=$!
# train the tokenizer with vocab size 2**16 = 65536 on ~2B characters of data
python -m scripts.tok_train --max_chars=2000000000
# evaluate the tokenizer (report compression ratio etc.)
python -m scripts.tok_eval
# -----------------------------------------------------------------------------
# Base model (pretraining)
# Download the eval_bundle from s3 to evaluate CORE metric during training (~162MB)
EVAL_BUNDLE_URL=https://karpathy-public.s3.us-west-2.amazonaws.com/eval_bundle.zip
if [ ! -d "$NANOCHAT_BASE_DIR/eval_bundle" ]; then
curl -L -o eval_bundle.zip $EVAL_BUNDLE_URL
unzip -q eval_bundle.zip
rm eval_bundle.zip
mv eval_bundle $NANOCHAT_BASE_DIR
fi
# The d20 model is 561M parameters.
# Chinchilla says #tokens = 20X #params, so we need 561e6 * 20 = 11.2B tokens.
# Assume our tokenizer is 4.8 chars/token, this is 11.2B * 4.8 ~= 54B chars.
# At 250M chars/shard, this is 54B / 250M ~= 216 shards needed for pretraining.
# Round up to 240 for safety. At ~100MB/shard, this downloads ~24GB of data to disk.
# (The total number of shards available in the entire dataset is 1822.)
echo "Waiting for dataset download to complete..."
wait $DATASET_DOWNLOAD_PID
# pretrain the d20 model
torchrun --standalone --nproc_per_node=4 -m scripts.base_train -- --depth=20 --run=$WANDB_RUN --total_batch_size=262144
# evaluate the model on a larger chunk of train/val data and draw some samples
torchrun --standalone --nproc_per_node=4 -m scripts.base_loss
# evaluate the model on CORE tasks
torchrun --standalone --nproc_per_node=4 -m scripts.base_eval
# -----------------------------------------------------------------------------
# Midtraining (teach the model conversation special tokens, tool use, multiple choice)
# run midtraining and eval the model
torchrun --standalone --nproc_per_node=4 -m scripts.mid_train -- --run=$WANDB_RUN
torchrun --standalone --nproc_per_node=4 -m scripts.chat_eval -- -i mid
# -----------------------------------------------------------------------------
# Supervised Finetuning (domain adaptation to each sequence all by itself per row)
# train sft and re-eval right away (should see a small bump)
torchrun --standalone --nproc_per_node=4 -m scripts.chat_sft -- --run=$WANDB_RUN
torchrun --standalone --nproc_per_node=4 -m scripts.chat_eval -- -i sft
# chat with the model over CLI! Leave out the -p to chat interactively
# python -m scripts.chat_cli -p "Why is the sky blue?"
# even better, chat with your model over a pretty WebUI ChatGPT style
# python -m scripts.chat_web
# -----------------------------------------------------------------------------
# Reinforcement Learning. Optional, and currently only on GSM8K
# (optional)
# run reinforcement learning
# torchrun --standalone --nproc_per_node=4 -m scripts.chat_rl -- --run=$WANDB_RUN
# eval the RL model only on GSM8K
# torchrun --standalone --nproc_per_node=4 -m scripts.chat_eval -- -i rl -a GSM8K
# -----------------------------------------------------------------------------
# Generate the full report by putting together all the sections
# report.md is the output and will be copied to current directory for convenience
python -m nanochat.report generated