mirror of
https://github.com/karpathy/nanochat.git
synced 2026-03-31 00:55:18 +00:00
optimisations fixed
This commit is contained in:
parent
890d1af779
commit
a6efa53b92
9
.claude/settings.local.json
Normal file
9
.claude/settings.local.json
Normal file
|
|
@ -0,0 +1,9 @@
|
|||
{
|
||||
"permissions": {
|
||||
"allow": [
|
||||
"Bash(python:*)"
|
||||
],
|
||||
"deny": [],
|
||||
"ask": []
|
||||
}
|
||||
}
|
||||
|
|
@ -1,186 +1,339 @@
|
|||
"""
|
||||
Auto-discovery module for finding optimal batch sizes.
|
||||
Automatic batch size discovery module for maximizing GPU utilization.
|
||||
|
||||
This is a minimal stub implementation to enable testing.
|
||||
The full implementation should be added as part of Task 41 (Auto Batch Size Module).
|
||||
This module implements an intelligent batch size search algorithm that:
|
||||
1. Uses exponential search to quickly find an upper bound
|
||||
2. Refines with binary search for optimal size
|
||||
3. Applies safety margin to prevent edge-case OOMs
|
||||
4. Supports DDP multi-GPU coordination
|
||||
5. Caches results for faster subsequent runs
|
||||
"""
|
||||
|
||||
import os
|
||||
import json
|
||||
import time
|
||||
import hashlib
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from typing import Optional, Callable, Dict, Any
|
||||
|
||||
from nanochat.common import print0, get_base_dir
|
||||
|
||||
|
||||
def discover_batch_size(
|
||||
model: torch.nn.Module,
|
||||
max_seq_len: int,
|
||||
device: torch.device,
|
||||
safety_margin: float = 0.85,
|
||||
min_batch_size: int = 1,
|
||||
max_batch_size: int = 128,
|
||||
ddp_rank: int = 0,
|
||||
ddp_world_size: int = 1,
|
||||
use_cache: bool = False,
|
||||
cache_key_components: Optional[Dict[str, Any]] = None,
|
||||
) -> int:
|
||||
def find_optimal_device_batch_size(
|
||||
model,
|
||||
max_seq_len,
|
||||
total_batch_size,
|
||||
ddp_world_size,
|
||||
data_sample_fn,
|
||||
override=None,
|
||||
safety_margin=0.85,
|
||||
enable_cache=True,
|
||||
ddp_rank=0,
|
||||
):
|
||||
"""
|
||||
Discover the optimal batch size for a model.
|
||||
|
||||
Main entry point for automatic batch size discovery.
|
||||
|
||||
Args:
|
||||
model: The model to test
|
||||
model: PyTorch model to test
|
||||
max_seq_len: Maximum sequence length
|
||||
device: Device to run on
|
||||
safety_margin: Safety factor (e.g., 0.85 = use 85% of max)
|
||||
min_batch_size: Minimum batch size to try
|
||||
max_batch_size: Maximum batch size to try
|
||||
ddp_rank: Rank in distributed setting
|
||||
ddp_world_size: World size in distributed setting
|
||||
use_cache: Whether to use cache
|
||||
cache_key_components: Components for cache key
|
||||
|
||||
total_batch_size: Total batch size across all GPUs (for gradient accumulation calculation)
|
||||
ddp_world_size: Number of GPUs in DDP
|
||||
data_sample_fn: Callable(batch_size) -> (inputs, targets)
|
||||
override: If set, skip discovery and return this value
|
||||
safety_margin: Fraction of optimal batch size to use (default 0.85)
|
||||
enable_cache: Whether to use caching
|
||||
ddp_rank: Current rank in DDP
|
||||
|
||||
Returns:
|
||||
Discovered batch size
|
||||
optimal_batch_size: Optimal device batch size for this GPU
|
||||
"""
|
||||
# Only rank 0 performs discovery in DDP
|
||||
# Handle manual override
|
||||
if override is not None:
|
||||
print0(f"Using manual batch_size override: {override}")
|
||||
return override
|
||||
|
||||
optimal_batch_size = None
|
||||
|
||||
# Only rank 0 performs discovery
|
||||
if ddp_rank == 0:
|
||||
print0("Running auto-discovery on rank 0")
|
||||
|
||||
# Check cache first
|
||||
if use_cache and cache_key_components:
|
||||
cached_size = _load_from_cache(cache_key_components)
|
||||
if cached_size is not None:
|
||||
print0(f"Cache hit! Using batch_size={cached_size}")
|
||||
discovered_size = cached_size
|
||||
else:
|
||||
print0("Cache miss, performing discovery")
|
||||
discovered_size = _perform_discovery(
|
||||
model, max_seq_len, device, safety_margin,
|
||||
min_batch_size, max_batch_size
|
||||
start_time = time.time()
|
||||
print0(f"\n{'='*60}")
|
||||
print0(f"Starting automatic batch size discovery...")
|
||||
print0(f"Parameters: max_seq_len={max_seq_len}, ddp_world_size={ddp_world_size}")
|
||||
print0(f"Safety margin: {safety_margin:.2%}")
|
||||
print0(f"{'='*60}\n")
|
||||
|
||||
# Check cache
|
||||
cache_key = None
|
||||
if enable_cache:
|
||||
cache_key = _get_cache_key(model, max_seq_len)
|
||||
cached_batch_size = _load_from_cache(cache_key)
|
||||
if cached_batch_size is not None:
|
||||
print0(f"✓ Cache hit! Using cached batch_size: {cached_batch_size}")
|
||||
optimal_batch_size = cached_batch_size
|
||||
|
||||
# Run discovery if no cache hit
|
||||
if optimal_batch_size is None:
|
||||
try:
|
||||
# Warmup CUDA
|
||||
_warmup_cuda()
|
||||
|
||||
# Run the search algorithm
|
||||
optimal_batch_size = _find_batch_size_internal(
|
||||
model=model,
|
||||
max_seq_len=max_seq_len,
|
||||
data_sample_fn=data_sample_fn,
|
||||
safety_margin=safety_margin,
|
||||
)
|
||||
if cache_key_components:
|
||||
_save_to_cache(cache_key_components, discovered_size)
|
||||
else:
|
||||
discovered_size = _perform_discovery(
|
||||
model, max_seq_len, device, safety_margin,
|
||||
min_batch_size, max_batch_size
|
||||
)
|
||||
|
||||
print0(f"Auto-discovery found device_batch_size={discovered_size}")
|
||||
|
||||
# Save to cache
|
||||
if enable_cache and cache_key is not None and optimal_batch_size is not None:
|
||||
_save_to_cache(cache_key, optimal_batch_size)
|
||||
|
||||
elapsed = time.time() - start_time
|
||||
print0(f"\n{'='*60}")
|
||||
print0(f"✓ Found optimal batch_size={optimal_batch_size} in {elapsed:.1f} seconds")
|
||||
print0(f"{'='*60}\n")
|
||||
|
||||
except Exception as e:
|
||||
print0(f"⚠ Warning: Batch size discovery failed with error: {e}")
|
||||
optimal_batch_size = None
|
||||
|
||||
# Fallback to conservative defaults if discovery failed
|
||||
if optimal_batch_size is None:
|
||||
print0(f"⚠ Warning: Using conservative fallback batch_size=8")
|
||||
optimal_batch_size = 8
|
||||
else:
|
||||
discovered_size = 0 # Will be broadcast from rank 0
|
||||
|
||||
# Broadcast to all ranks in DDP
|
||||
optimal_batch_size = 0 # Will be broadcast from rank 0
|
||||
|
||||
# DDP: Broadcast result from rank 0 to all ranks
|
||||
if ddp_world_size > 1:
|
||||
discovered_tensor = torch.tensor(discovered_size, dtype=torch.int32, device=device)
|
||||
dist.broadcast(discovered_tensor, src=0)
|
||||
discovered_size = discovered_tensor.item()
|
||||
if ddp_rank != 0:
|
||||
print0(f"Received batch size from rank 0: {discovered_size}")
|
||||
|
||||
return discovered_size
|
||||
try:
|
||||
import torch.distributed as dist
|
||||
tensor = torch.tensor([optimal_batch_size], dtype=torch.long, device='cuda')
|
||||
dist.broadcast(tensor, src=0)
|
||||
optimal_batch_size = tensor.item()
|
||||
if ddp_rank != 0:
|
||||
print0(f"Received batch_size from rank 0: {optimal_batch_size}")
|
||||
except Exception as e:
|
||||
print0(f"⚠ Warning: DDP broadcast failed: {e}")
|
||||
if optimal_batch_size == 0:
|
||||
optimal_batch_size = 8
|
||||
|
||||
return optimal_batch_size
|
||||
|
||||
|
||||
def _perform_discovery(
|
||||
model: torch.nn.Module,
|
||||
max_seq_len: int,
|
||||
device: torch.device,
|
||||
safety_margin: float,
|
||||
min_batch_size: int,
|
||||
max_batch_size: int,
|
||||
) -> int:
|
||||
def _find_batch_size_internal(model, max_seq_len, data_sample_fn, safety_margin):
|
||||
"""
|
||||
Perform the actual discovery using exponential + binary search.
|
||||
|
||||
This is a stub implementation that returns a fixed value.
|
||||
The real implementation should:
|
||||
1. Exponential search to find upper bound
|
||||
2. Binary search to refine
|
||||
3. Apply safety margin
|
||||
"""
|
||||
# Stub: return a fixed reasonable value
|
||||
# Real implementation would perform exponential + binary search
|
||||
batch_size = min(32, max_batch_size)
|
||||
return max(int(batch_size * safety_margin), min_batch_size)
|
||||
Core algorithm implementing exponential search followed by binary search.
|
||||
|
||||
|
||||
def _test_batch_size(
|
||||
model: torch.nn.Module,
|
||||
batch_size: int,
|
||||
max_seq_len: int,
|
||||
device: torch.device,
|
||||
) -> bool:
|
||||
"""
|
||||
Test if a given batch size fits in memory.
|
||||
|
||||
Returns:
|
||||
True if batch size works, False if OOM
|
||||
optimal_batch_size: The largest batch size that fits in memory (with safety margin)
|
||||
"""
|
||||
# Phase 1: Exponential search to find upper bound
|
||||
print0("Phase 1: Exponential search to find upper bound...")
|
||||
batch_size = 1
|
||||
last_successful = None
|
||||
|
||||
while True:
|
||||
print0(f" Testing batch_size={batch_size}...", end=" ")
|
||||
success = _test_batch_size(
|
||||
model=model,
|
||||
batch_size=batch_size,
|
||||
max_seq_len=max_seq_len,
|
||||
data_sample_fn=data_sample_fn,
|
||||
)
|
||||
|
||||
if success:
|
||||
print0("✓ Success")
|
||||
last_successful = batch_size
|
||||
batch_size *= 2
|
||||
else:
|
||||
print0("✗ OOM")
|
||||
break
|
||||
|
||||
# If even batch_size=1 failed, return None
|
||||
if last_successful is None:
|
||||
print0("✗ Even batch_size=1 caused OOM!")
|
||||
return None
|
||||
|
||||
# Phase 2: Binary search refinement
|
||||
print0(f"\nPhase 2: Binary search refinement between {last_successful} and {batch_size}...")
|
||||
lower = last_successful
|
||||
upper = batch_size
|
||||
|
||||
while upper - lower > 1:
|
||||
mid = (lower + upper) // 2
|
||||
print0(f" Testing batch_size={mid}...", end=" ")
|
||||
success = _test_batch_size(
|
||||
model=model,
|
||||
batch_size=mid,
|
||||
max_seq_len=max_seq_len,
|
||||
data_sample_fn=data_sample_fn,
|
||||
)
|
||||
|
||||
if success:
|
||||
print0("✓ Success")
|
||||
lower = mid
|
||||
else:
|
||||
print0("✗ OOM")
|
||||
upper = mid
|
||||
|
||||
# Phase 3: Apply safety margin
|
||||
optimal_batch_size = int(lower * safety_margin)
|
||||
print0(f"\nApplying safety margin: {lower} × {safety_margin:.2%} = {optimal_batch_size}")
|
||||
|
||||
return optimal_batch_size
|
||||
|
||||
|
||||
def _test_batch_size(model, batch_size, max_seq_len, data_sample_fn):
|
||||
"""
|
||||
Test if a specific batch size fits in memory by simulating training loop.
|
||||
|
||||
Returns:
|
||||
bool: True if batch size fits, False if OOM
|
||||
"""
|
||||
try:
|
||||
# Create dummy inputs
|
||||
inputs = torch.randint(0, 50000, (batch_size, max_seq_len), device=device, dtype=torch.int32)
|
||||
targets = torch.randint(0, 50000, (batch_size, max_seq_len), device=device, dtype=torch.int64)
|
||||
|
||||
# Forward + backward pass
|
||||
# Clear CUDA cache before test
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# Set model to training mode
|
||||
model.train()
|
||||
|
||||
# Zero gradients
|
||||
model.zero_grad(set_to_none=True)
|
||||
|
||||
# Generate test batch
|
||||
inputs, targets = data_sample_fn(batch_size)
|
||||
|
||||
# Forward pass with bfloat16 autocast
|
||||
with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
|
||||
loss = model(inputs, targets)
|
||||
|
||||
# Backward pass
|
||||
loss.backward()
|
||||
model.zero_grad(set_to_none=True)
|
||||
|
||||
# Clean up
|
||||
|
||||
# Synchronize CUDA to ensure all operations complete
|
||||
torch.cuda.synchronize()
|
||||
|
||||
# Clear cache after test
|
||||
del inputs, targets, loss
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
return True
|
||||
|
||||
except torch.cuda.OutOfMemoryError:
|
||||
# Clear cache and return False on OOM
|
||||
torch.cuda.empty_cache()
|
||||
return False
|
||||
except Exception as e:
|
||||
print0(f"Error testing batch size {batch_size}: {e}")
|
||||
# Handle other exceptions
|
||||
print0(f"\n⚠ Warning: Test failed with unexpected error: {e}")
|
||||
torch.cuda.empty_cache()
|
||||
return False
|
||||
|
||||
|
||||
def _get_cache_key(components: Dict[str, Any]) -> str:
|
||||
"""Generate cache key from components."""
|
||||
key_str = json.dumps(components, sort_keys=True)
|
||||
return hashlib.md5(key_str.encode()).hexdigest()
|
||||
def _warmup_cuda():
|
||||
"""Warmup CUDA by allocating and freeing a small tensor."""
|
||||
try:
|
||||
x = torch.zeros(1, device='cuda')
|
||||
del x
|
||||
torch.cuda.synchronize()
|
||||
torch.cuda.empty_cache()
|
||||
except Exception as e:
|
||||
print0(f"⚠ Warning: CUDA warmup failed: {e}")
|
||||
|
||||
|
||||
def _load_from_cache(components: Dict[str, Any]) -> Optional[int]:
|
||||
"""Load batch size from cache if available."""
|
||||
def _get_cache_key(model, max_seq_len):
|
||||
"""
|
||||
Generate cache key from model config hash, GPU model, and max_seq_len.
|
||||
|
||||
Returns:
|
||||
str: Hash string to use as cache key
|
||||
"""
|
||||
try:
|
||||
# Get model config attributes
|
||||
config = model.config if hasattr(model, 'config') else None
|
||||
if config is None:
|
||||
# Try to get from original model (in case of compiled model)
|
||||
config = model._orig_mod.config if hasattr(model, '_orig_mod') else None
|
||||
|
||||
if config is None:
|
||||
return None
|
||||
|
||||
# Build config string
|
||||
config_parts = [
|
||||
f"vocab_size={config.vocab_size}",
|
||||
f"n_layer={config.n_layer}",
|
||||
f"n_embd={config.n_embd}",
|
||||
f"n_head={config.n_head}",
|
||||
f"n_kv_head={config.n_kv_head}",
|
||||
]
|
||||
config_str = "|".join(config_parts)
|
||||
|
||||
# Get GPU model name
|
||||
gpu_name = torch.cuda.get_device_name(0)
|
||||
|
||||
# Combine all components
|
||||
key_str = f"{config_str}|gpu={gpu_name}|seq_len={max_seq_len}"
|
||||
|
||||
# Hash to create a short key
|
||||
cache_key = hashlib.md5(key_str.encode()).hexdigest()
|
||||
|
||||
return cache_key
|
||||
|
||||
except Exception as e:
|
||||
print0(f"⚠ Warning: Failed to generate cache key: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def _load_from_cache(cache_key):
|
||||
"""
|
||||
Load cached batch size from JSON file.
|
||||
|
||||
Returns:
|
||||
int or None: Cached batch size, or None if not found
|
||||
"""
|
||||
if cache_key is None:
|
||||
return None
|
||||
|
||||
try:
|
||||
base_dir = get_base_dir()
|
||||
cache_dir = os.path.join(base_dir, "auto_batch_cache")
|
||||
cache_key = _get_cache_key(components)
|
||||
cache_file = os.path.join(cache_dir, f"{cache_key}.json")
|
||||
|
||||
if os.path.exists(cache_file):
|
||||
with open(cache_file, 'r') as f:
|
||||
data = json.load(f)
|
||||
|
||||
if not os.path.exists(cache_file):
|
||||
return None
|
||||
|
||||
with open(cache_file, 'r') as f:
|
||||
data = json.load(f)
|
||||
return data.get('batch_size')
|
||||
|
||||
except Exception as e:
|
||||
print0(f"Cache load error: {e}")
|
||||
return None
|
||||
print0(f"⚠ Warning: Failed to load from cache: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def _save_to_cache(components: Dict[str, Any], batch_size: int) -> None:
|
||||
"""Save batch size to cache."""
|
||||
def _save_to_cache(cache_key, batch_size):
|
||||
"""Save batch size to JSON cache file."""
|
||||
if cache_key is None or batch_size is None:
|
||||
return
|
||||
|
||||
try:
|
||||
base_dir = get_base_dir()
|
||||
cache_dir = os.path.join(base_dir, "auto_batch_cache")
|
||||
os.makedirs(cache_dir, exist_ok=True)
|
||||
|
||||
cache_key = _get_cache_key(components)
|
||||
|
||||
cache_file = os.path.join(cache_dir, f"{cache_key}.json")
|
||||
|
||||
|
||||
data = {
|
||||
'batch_size': batch_size,
|
||||
'timestamp': time.time(),
|
||||
}
|
||||
|
||||
with open(cache_file, 'w') as f:
|
||||
json.dump({
|
||||
'batch_size': batch_size,
|
||||
'components': components,
|
||||
}, f, indent=2)
|
||||
json.dump(data, f, indent=2)
|
||||
|
||||
print0(f"✓ Saved batch_size={batch_size} to cache")
|
||||
|
||||
except Exception as e:
|
||||
print0(f"Cache save error: {e}")
|
||||
print0(f"⚠ Warning: Failed to save to cache: {e}")
|
||||
|
|
|
|||
|
|
@ -219,9 +219,7 @@ class Engine:
|
|||
|
||||
# Get sampled tokens - either from prefill or from forward pass
|
||||
if first_iteration:
|
||||
# Use the tokens we already sampled from prefill
|
||||
sampled_tokens = [sampled_tokens[0]] * num_samples # Broadcast first token to all rows
|
||||
# TODO: we should sample a token for each row instead of broadcasting
|
||||
# sampled_tokens already contains num_samples independently sampled tokens
|
||||
first_iteration = False
|
||||
else:
|
||||
# Forward the model and get the next token for each row
|
||||
|
|
|
|||
|
|
@ -22,6 +22,7 @@ import torch.nn.functional as F
|
|||
from nanochat.common import get_dist_info, print0
|
||||
from nanochat.muon import Muon, DistMuon
|
||||
from nanochat.adamw import DistAdamW
|
||||
from nanochat.engine import KVCache
|
||||
|
||||
@dataclass
|
||||
class GPTConfig:
|
||||
|
|
@ -304,10 +305,44 @@ class GPT(nn.Module):
|
|||
if temperature > 0:
|
||||
rng = torch.Generator(device=device)
|
||||
rng.manual_seed(seed)
|
||||
ids = torch.tensor([tokens], dtype=torch.long, device=device) # add batch dim
|
||||
for _ in range(max_tokens):
|
||||
logits = self.forward(ids) # (B, T, vocab_size)
|
||||
logits = logits[:, -1, :] # (B, vocab_size)
|
||||
|
||||
# Initialize KV cache
|
||||
kv_length_hint = len(tokens) + max_tokens
|
||||
kv_cache = KVCache(
|
||||
batch_size=1,
|
||||
num_heads=self.config.n_kv_head,
|
||||
seq_len=kv_length_hint,
|
||||
head_dim=self.config.n_embd // self.config.n_head,
|
||||
num_layers=self.config.n_layer
|
||||
)
|
||||
|
||||
# Prefill phase: process the prompt
|
||||
ids = torch.tensor([tokens], dtype=torch.long, device=device)
|
||||
logits = self.forward(ids, kv_cache=kv_cache)
|
||||
logits = logits[:, -1, :]
|
||||
|
||||
# Sample first token
|
||||
if top_k is not None:
|
||||
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
|
||||
logits[logits < v[:, [-1]]] = -float('Inf')
|
||||
if temperature > 0:
|
||||
logits = logits / temperature
|
||||
probs = F.softmax(logits, dim=-1)
|
||||
next_ids = torch.multinomial(probs, num_samples=1, generator=rng)
|
||||
else:
|
||||
next_ids = torch.argmax(logits, dim=-1, keepdim=True)
|
||||
|
||||
# Yield first token
|
||||
token = next_ids.item()
|
||||
yield token
|
||||
|
||||
# Generation loop: process one token at a time with KV cache
|
||||
for _ in range(max_tokens - 1):
|
||||
# Forward pass with only the new token
|
||||
logits = self.forward(next_ids, kv_cache=kv_cache)
|
||||
logits = logits[:, -1, :]
|
||||
|
||||
# Sample next token
|
||||
if top_k is not None:
|
||||
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
|
||||
logits[logits < v[:, [-1]]] = -float('Inf')
|
||||
|
|
@ -317,6 +352,6 @@ class GPT(nn.Module):
|
|||
next_ids = torch.multinomial(probs, num_samples=1, generator=rng)
|
||||
else:
|
||||
next_ids = torch.argmax(logits, dim=-1, keepdim=True)
|
||||
ids = torch.cat((ids, next_ids), dim=1)
|
||||
|
||||
token = next_ids.item()
|
||||
yield token
|
||||
|
|
|
|||
|
|
@ -40,6 +40,7 @@ device_batch_size = None # per-device batch size (set to not OOM), None = auto-d
|
|||
auto_batch_size = True # whether to auto-discover optimal batch size
|
||||
batch_size_margin = 0.85 # safety margin for auto-discovered batch size
|
||||
batch_size_cache = True # whether to cache auto-discovered batch size
|
||||
max_seq_len = 2048 # maximum sequence length for fixed padding (enables torch.compile)
|
||||
# optimization
|
||||
num_epochs = 1
|
||||
max_iterations = -1 # override number of iterations (-1 = use num_epochs * num_iterations)
|
||||
|
|
@ -104,8 +105,8 @@ elif device_batch_size is None:
|
|||
print0(f"Auto-discovery disabled, using default device_batch_size={device_batch_size}")
|
||||
|
||||
orig_model = model # original, uncompiled model
|
||||
# model = torch.compile(model, dynamic=True) # doesn't work super well because of variable lengths of inputs
|
||||
engine = Engine(model, tokenizer) # will be used for inline model evaluation only
|
||||
model = torch.compile(model, dynamic=False) # enabled with fixed-length padding
|
||||
engine = Engine(orig_model, tokenizer) # use uncompiled model for engine (variable-length eval)
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Task data mixture we'll train on
|
||||
|
|
@ -126,7 +127,7 @@ def sft_data_generator(dataset, batch_size):
|
|||
# prepares a list of tokenized conversations into a batch and yields
|
||||
def collate_and_yield(batch):
|
||||
nrows = len(batch)
|
||||
ncols = max(len(ids) for ids, mask in batch) - 1 # seq of n creates inputs/targets of n-1
|
||||
ncols = max_seq_len - 1 # fixed length for torch.compile (seq of n creates inputs/targets of n-1)
|
||||
inputs = torch.full((nrows, ncols), pad_token_id, dtype=torch.long)
|
||||
targets = torch.full((nrows, ncols), -1, dtype=torch.long) # -1 is ignore index
|
||||
for i, (ids, mask) in enumerate(batch):
|
||||
|
|
@ -199,13 +200,13 @@ for step in range(num_iterations):
|
|||
|
||||
# evaluate the validation loss
|
||||
if last_step or step % eval_every == 0:
|
||||
model.eval()
|
||||
orig_model.eval()
|
||||
val_iter = iter(build_val_loader())
|
||||
losses = []
|
||||
for _ in range(eval_steps):
|
||||
val_inputs, val_targets = next(val_iter)
|
||||
with torch.no_grad(), autocast_ctx:
|
||||
loss = model(val_inputs, val_targets)
|
||||
loss = orig_model(val_inputs, val_targets)
|
||||
losses.append(loss)
|
||||
val_loss = torch.stack(losses).mean() # average over eval_steps
|
||||
if ddp:
|
||||
|
|
@ -216,22 +217,24 @@ for step in range(num_iterations):
|
|||
"step": step,
|
||||
"val_loss": val_loss,
|
||||
})
|
||||
orig_model.train()
|
||||
model.train()
|
||||
|
||||
# evlauate accuracy of the multiple choice tasks (which are quick to run)
|
||||
if last_step or (step > 0 and step % eval_metrics_every == 0):
|
||||
model.eval()
|
||||
orig_model.eval()
|
||||
metrics = {}
|
||||
with torch.no_grad(), autocast_ctx:
|
||||
# note that because these are inside no_grad, we can usually afford to at least ~2X the batch size
|
||||
metrics["mmlu_acc"] = run_chat_eval("MMLU", model, tokenizer, engine, batch_size=device_batch_size*2, max_problems=1024)
|
||||
metrics["arc_easy_acc"] = run_chat_eval("ARC-Easy", model, tokenizer, engine, batch_size=device_batch_size*2, max_problems=1024)
|
||||
metrics["mmlu_acc"] = run_chat_eval("MMLU", orig_model, tokenizer, engine, batch_size=device_batch_size*2, max_problems=1024)
|
||||
metrics["arc_easy_acc"] = run_chat_eval("ARC-Easy", orig_model, tokenizer, engine, batch_size=device_batch_size*2, max_problems=1024)
|
||||
metrics_str = ', '.join(f'{k}: {v:.6f}' for k, v in metrics.items())
|
||||
print0(f"Step {step:05d} | {metrics_str}")
|
||||
wandb_run.log({
|
||||
"step": step,
|
||||
**metrics,
|
||||
})
|
||||
orig_model.train()
|
||||
model.train()
|
||||
|
||||
if last_step:
|
||||
|
|
@ -276,14 +279,14 @@ for step in range(num_iterations):
|
|||
# Save the model at the end of the run
|
||||
if master_process:
|
||||
base_dir = get_base_dir()
|
||||
depth = model.config.n_layer
|
||||
depth = orig_model.config.n_layer
|
||||
model_tag = f"d{depth}" # base the model tag on the depth of the base model
|
||||
checkpoint_dir = os.path.join(base_dir, "chatsft_checkpoints", model_tag)
|
||||
model_config_kwargs = model.config.__dict__ # slightly naughty, abusing the simplicity of GPTConfig, TODO nicer
|
||||
model_config_kwargs = orig_model.config.__dict__ # slightly naughty, abusing the simplicity of GPTConfig, TODO nicer
|
||||
save_checkpoint(
|
||||
checkpoint_dir,
|
||||
step,
|
||||
model.state_dict(),
|
||||
orig_model.state_dict(),
|
||||
None, # note: we don't bother to save the optimizer state
|
||||
{
|
||||
"step": step,
|
||||
|
|
|
|||
57
scripts/check_checkpoint.py
Normal file
57
scripts/check_checkpoint.py
Normal file
|
|
@ -0,0 +1,57 @@
|
|||
"""Quick script to inspect checkpoint structure."""
|
||||
import torch
|
||||
import sys
|
||||
|
||||
if len(sys.argv) < 2:
|
||||
print("Usage: python check_checkpoint.py <path_to_checkpoint>")
|
||||
sys.exit(1)
|
||||
|
||||
checkpoint_path = sys.argv[1]
|
||||
print(f"Loading checkpoint: {checkpoint_path}")
|
||||
|
||||
checkpoint = torch.load(checkpoint_path, map_location='cpu')
|
||||
|
||||
print("\n📦 Checkpoint keys:")
|
||||
for key in checkpoint.keys():
|
||||
value = checkpoint[key]
|
||||
if isinstance(value, dict):
|
||||
print(f" - {key}: dict with {len(value)} items")
|
||||
elif isinstance(value, torch.Tensor):
|
||||
print(f" - {key}: Tensor {value.shape}")
|
||||
else:
|
||||
print(f" - {key}: {type(value).__name__}")
|
||||
|
||||
# Check for model weights
|
||||
if 'model' in checkpoint:
|
||||
print("\n✅ Model weights found")
|
||||
model_dict = checkpoint['model']
|
||||
print(f" Number of parameters: {len(model_dict)}")
|
||||
print(f" Sample keys: {list(model_dict.keys())[:5]}")
|
||||
else:
|
||||
print("\n⚠️ No 'model' key found!")
|
||||
print(" Available keys:", list(checkpoint.keys()))
|
||||
|
||||
# Check for config
|
||||
if 'config' in checkpoint:
|
||||
print("\n✅ Config found")
|
||||
config = checkpoint['config']
|
||||
print(f" Type: {type(config)}")
|
||||
if hasattr(config, '__dict__'):
|
||||
print(f" Attributes: {list(vars(config).keys())[:10]}")
|
||||
else:
|
||||
print("\n⚠️ No 'config' key found!")
|
||||
|
||||
print("\n" + "="*60)
|
||||
print("RECOMMENDATION:")
|
||||
if 'model' not in checkpoint or 'config' not in checkpoint:
|
||||
print("Your checkpoint is missing required keys.")
|
||||
print("Please check how the model was saved during training.")
|
||||
print("\nExpected checkpoint structure:")
|
||||
print(" checkpoint = {")
|
||||
print(" 'model': model.state_dict(),")
|
||||
print(" 'config': model.config,")
|
||||
print(" 'optimizer': optimizer.state_dict(), # optional")
|
||||
print(" 'step': current_step, # optional")
|
||||
print(" }")
|
||||
else:
|
||||
print("✅ Checkpoint looks good!")
|
||||
66
scripts/quick_check.py
Normal file
66
scripts/quick_check.py
Normal file
|
|
@ -0,0 +1,66 @@
|
|||
"""Quick checkpoint structure check."""
|
||||
import torch
|
||||
import sys
|
||||
|
||||
checkpoint_path = "/raid/diana/nanochat_cache/chatsft_checkpoints/d20/model_000650.pt"
|
||||
print(f"Loading: {checkpoint_path}")
|
||||
|
||||
try:
|
||||
checkpoint = torch.load(checkpoint_path, map_location='cpu')
|
||||
|
||||
print("\n" + "="*60)
|
||||
print("CHECKPOINT STRUCTURE")
|
||||
print("="*60)
|
||||
|
||||
print(f"\nTop-level keys: {list(checkpoint.keys())}\n")
|
||||
|
||||
for key in checkpoint.keys():
|
||||
value = checkpoint[key]
|
||||
if isinstance(value, dict):
|
||||
print(f"'{key}': dict with {len(value)} items")
|
||||
# Show a few sub-keys if it's a dict
|
||||
sub_keys = list(value.keys())[:3]
|
||||
print(f" Sample keys: {sub_keys}")
|
||||
elif isinstance(value, torch.Tensor):
|
||||
print(f"'{key}': Tensor {value.shape}, dtype={value.dtype}")
|
||||
else:
|
||||
print(f"'{key}': {type(value).__name__} = {value}")
|
||||
|
||||
print("\n" + "="*60)
|
||||
print("DIAGNOSIS")
|
||||
print("="*60)
|
||||
|
||||
# Check what we need
|
||||
has_model = 'model' in checkpoint
|
||||
has_config = 'config' in checkpoint
|
||||
has_state_dict = 'state_dict' in checkpoint
|
||||
has_model_state_dict = 'model_state_dict' in checkpoint
|
||||
|
||||
print(f"\n✓ Has 'model' key: {has_model}")
|
||||
print(f"✓ Has 'config' key: {has_config}")
|
||||
print(f"✓ Has 'state_dict' key: {has_state_dict}")
|
||||
print(f"✓ Has 'model_state_dict' key: {has_model_state_dict}")
|
||||
|
||||
# Try to infer the structure
|
||||
print("\n" + "="*60)
|
||||
print("RECOMMENDATION")
|
||||
print("="*60)
|
||||
|
||||
if has_model and has_config:
|
||||
print("\n✅ Checkpoint has expected structure!")
|
||||
print(" No changes needed to benchmark_optimizations.py")
|
||||
elif has_state_dict:
|
||||
print("\n⚠️ Checkpoint uses 'state_dict' instead of 'model'")
|
||||
print(" Need to update benchmark to use checkpoint['state_dict']")
|
||||
elif has_model_state_dict:
|
||||
print("\n⚠️ Checkpoint uses 'model_state_dict' instead of 'model'")
|
||||
print(" Need to update benchmark to use checkpoint['model_state_dict']")
|
||||
else:
|
||||
print("\n❌ Checkpoint has unexpected structure!")
|
||||
print(" Available keys:", list(checkpoint.keys()))
|
||||
print(" You may need to check how the model was saved during training")
|
||||
|
||||
except Exception as e:
|
||||
print(f"\n❌ Error loading checkpoint: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
133
speedrun_4gpu.sh
Normal file
133
speedrun_4gpu.sh
Normal file
|
|
@ -0,0 +1,133 @@
|
|||
#!/bin/bash
|
||||
|
||||
# This script is the "Best ChatGPT clone that $100 can buy",
|
||||
# It is designed to run in ~4 hours on 8XH100 node at $3/GPU/hour.
|
||||
|
||||
# 1) Example launch (simplest):
|
||||
# bash speedrun.sh
|
||||
# 2) Example launch in a screen session (because the run takes ~4 hours):
|
||||
# screen -L -Logfile speedrun.log -S speedrun bash speedrun.sh
|
||||
# 3) Example launch with wandb logging, but see below for setting up wandb first:
|
||||
# WANDB_RUN=speedrun screen -L -Logfile speedrun.log -S speedrun bash speedrun.sh
|
||||
|
||||
# Default intermediate artifacts directory is in ~/.cache/nanochat
|
||||
export OMP_NUM_THREADS=1
|
||||
export NANOCHAT_BASE_DIR="/raid/diana/nanochat_cache"
|
||||
mkdir -p $NANOCHAT_BASE_DIR
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Python venv setup with uv
|
||||
|
||||
# install uv (if not already installed)
|
||||
command -v uv &> /dev/null || curl -LsSf https://astral.sh/uv/install.sh | sh
|
||||
# create a .venv local virtual environment (if it doesn't exist)
|
||||
[ -d ".venv" ] || uv venv
|
||||
# install the repo dependencies
|
||||
uv sync
|
||||
# activate venv so that `python` uses the project's venv instead of system python
|
||||
source .venv/bin/activate
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# wandb setup
|
||||
# If you wish to use wandb for logging (it's nice!, recommended).
|
||||
# 1) Make sure to first log in to wandb, e.g. run:
|
||||
# `wandb login`
|
||||
# 2) Set the WANDB_RUN environment variable when running this script, e.g.:
|
||||
# `WANDB_RUN=d26 bash speedrun.sh`
|
||||
if [ -z "$WANDB_RUN" ]; then
|
||||
# by default use "dummy" : it's handled as a special case, skips logging to wandb
|
||||
WANDB_RUN=dummy
|
||||
fi
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# During the course of the run, we will be writing markdown reports to the report/
|
||||
# directory in the base dir. This command clears it out and writes a header section
|
||||
# with a bunch of system info and a timestamp that marks the start of the run.
|
||||
python -m nanochat.report reset
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Tokenizer
|
||||
|
||||
# Install Rust / Cargo
|
||||
curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y
|
||||
source "$HOME/.cargo/env"
|
||||
|
||||
# Build the rustbpe Tokenizer
|
||||
uv run maturin develop --release --manifest-path rustbpe/Cargo.toml
|
||||
|
||||
# Download the first ~2B characters of pretraining dataset
|
||||
# look at dev/repackage_data_reference.py for details on how this data was prepared
|
||||
# each data shard is ~250M chars
|
||||
# so we download 2e9 / 250e6 = 8 data shards at this point
|
||||
# each shard is ~100MB of text (compressed), so this is about ~800MB of data on disk
|
||||
python -m nanochat.dataset -n 8
|
||||
# Immediately also kick off downloading more shards in the background while tokenizer trains
|
||||
# See comment below for why 240 is the right number here
|
||||
python -m nanochat.dataset -n 240 &
|
||||
DATASET_DOWNLOAD_PID=$!
|
||||
# train the tokenizer with vocab size 2**16 = 65536 on ~2B characters of data
|
||||
python -m scripts.tok_train --max_chars=2000000000
|
||||
# evaluate the tokenizer (report compression ratio etc.)
|
||||
python -m scripts.tok_eval
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Base model (pretraining)
|
||||
|
||||
# Download the eval_bundle from s3 to evaluate CORE metric during training (~162MB)
|
||||
EVAL_BUNDLE_URL=https://karpathy-public.s3.us-west-2.amazonaws.com/eval_bundle.zip
|
||||
if [ ! -d "$NANOCHAT_BASE_DIR/eval_bundle" ]; then
|
||||
curl -L -o eval_bundle.zip $EVAL_BUNDLE_URL
|
||||
unzip -q eval_bundle.zip
|
||||
rm eval_bundle.zip
|
||||
mv eval_bundle $NANOCHAT_BASE_DIR
|
||||
fi
|
||||
|
||||
# The d20 model is 561M parameters.
|
||||
# Chinchilla says #tokens = 20X #params, so we need 561e6 * 20 = 11.2B tokens.
|
||||
# Assume our tokenizer is 4.8 chars/token, this is 11.2B * 4.8 ~= 54B chars.
|
||||
# At 250M chars/shard, this is 54B / 250M ~= 216 shards needed for pretraining.
|
||||
# Round up to 240 for safety. At ~100MB/shard, this downloads ~24GB of data to disk.
|
||||
# (The total number of shards available in the entire dataset is 1822.)
|
||||
echo "Waiting for dataset download to complete..."
|
||||
wait $DATASET_DOWNLOAD_PID
|
||||
|
||||
# pretrain the d20 model
|
||||
torchrun --standalone --nproc_per_node=4 -m scripts.base_train -- --depth=20 --run=$WANDB_RUN --total_batch_size=262144
|
||||
# evaluate the model on a larger chunk of train/val data and draw some samples
|
||||
torchrun --standalone --nproc_per_node=4 -m scripts.base_loss
|
||||
# evaluate the model on CORE tasks
|
||||
torchrun --standalone --nproc_per_node=4 -m scripts.base_eval
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Midtraining (teach the model conversation special tokens, tool use, multiple choice)
|
||||
|
||||
# run midtraining and eval the model
|
||||
torchrun --standalone --nproc_per_node=4 -m scripts.mid_train -- --run=$WANDB_RUN
|
||||
torchrun --standalone --nproc_per_node=4 -m scripts.chat_eval -- -i mid
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Supervised Finetuning (domain adaptation to each sequence all by itself per row)
|
||||
|
||||
# train sft and re-eval right away (should see a small bump)
|
||||
torchrun --standalone --nproc_per_node=4 -m scripts.chat_sft -- --run=$WANDB_RUN
|
||||
torchrun --standalone --nproc_per_node=4 -m scripts.chat_eval -- -i sft
|
||||
|
||||
# chat with the model over CLI! Leave out the -p to chat interactively
|
||||
# python -m scripts.chat_cli -p "Why is the sky blue?"
|
||||
|
||||
# even better, chat with your model over a pretty WebUI ChatGPT style
|
||||
# python -m scripts.chat_web
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Reinforcement Learning. Optional, and currently only on GSM8K
|
||||
# (optional)
|
||||
|
||||
# run reinforcement learning
|
||||
# torchrun --standalone --nproc_per_node=4 -m scripts.chat_rl -- --run=$WANDB_RUN
|
||||
# eval the RL model only on GSM8K
|
||||
# torchrun --standalone --nproc_per_node=4 -m scripts.chat_eval -- -i rl -a GSM8K
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Generate the full report by putting together all the sections
|
||||
# report.md is the output and will be copied to current directory for convenience
|
||||
python -m nanochat.report generated
|
||||
Loading…
Reference in New Issue
Block a user