diff --git a/.claude/settings.local.json b/.claude/settings.local.json new file mode 100644 index 0000000..f315077 --- /dev/null +++ b/.claude/settings.local.json @@ -0,0 +1,9 @@ +{ + "permissions": { + "allow": [ + "Bash(python:*)" + ], + "deny": [], + "ask": [] + } +} diff --git a/nanochat/auto_batch_size.py b/nanochat/auto_batch_size.py index 97e0c34..2dbb559 100644 --- a/nanochat/auto_batch_size.py +++ b/nanochat/auto_batch_size.py @@ -1,186 +1,339 @@ """ -Auto-discovery module for finding optimal batch sizes. +Automatic batch size discovery module for maximizing GPU utilization. -This is a minimal stub implementation to enable testing. -The full implementation should be added as part of Task 41 (Auto Batch Size Module). +This module implements an intelligent batch size search algorithm that: +1. Uses exponential search to quickly find an upper bound +2. Refines with binary search for optimal size +3. Applies safety margin to prevent edge-case OOMs +4. Supports DDP multi-GPU coordination +5. Caches results for faster subsequent runs """ import os import json +import time import hashlib import torch -import torch.distributed as dist -from typing import Optional, Callable, Dict, Any + from nanochat.common import print0, get_base_dir -def discover_batch_size( - model: torch.nn.Module, - max_seq_len: int, - device: torch.device, - safety_margin: float = 0.85, - min_batch_size: int = 1, - max_batch_size: int = 128, - ddp_rank: int = 0, - ddp_world_size: int = 1, - use_cache: bool = False, - cache_key_components: Optional[Dict[str, Any]] = None, -) -> int: +def find_optimal_device_batch_size( + model, + max_seq_len, + total_batch_size, + ddp_world_size, + data_sample_fn, + override=None, + safety_margin=0.85, + enable_cache=True, + ddp_rank=0, +): """ - Discover the optimal batch size for a model. - + Main entry point for automatic batch size discovery. + Args: - model: The model to test + model: PyTorch model to test max_seq_len: Maximum sequence length - device: Device to run on - safety_margin: Safety factor (e.g., 0.85 = use 85% of max) - min_batch_size: Minimum batch size to try - max_batch_size: Maximum batch size to try - ddp_rank: Rank in distributed setting - ddp_world_size: World size in distributed setting - use_cache: Whether to use cache - cache_key_components: Components for cache key - + total_batch_size: Total batch size across all GPUs (for gradient accumulation calculation) + ddp_world_size: Number of GPUs in DDP + data_sample_fn: Callable(batch_size) -> (inputs, targets) + override: If set, skip discovery and return this value + safety_margin: Fraction of optimal batch size to use (default 0.85) + enable_cache: Whether to use caching + ddp_rank: Current rank in DDP + Returns: - Discovered batch size + optimal_batch_size: Optimal device batch size for this GPU """ - # Only rank 0 performs discovery in DDP + # Handle manual override + if override is not None: + print0(f"Using manual batch_size override: {override}") + return override + + optimal_batch_size = None + + # Only rank 0 performs discovery if ddp_rank == 0: - print0("Running auto-discovery on rank 0") - - # Check cache first - if use_cache and cache_key_components: - cached_size = _load_from_cache(cache_key_components) - if cached_size is not None: - print0(f"Cache hit! Using batch_size={cached_size}") - discovered_size = cached_size - else: - print0("Cache miss, performing discovery") - discovered_size = _perform_discovery( - model, max_seq_len, device, safety_margin, - min_batch_size, max_batch_size + start_time = time.time() + print0(f"\n{'='*60}") + print0(f"Starting automatic batch size discovery...") + print0(f"Parameters: max_seq_len={max_seq_len}, ddp_world_size={ddp_world_size}") + print0(f"Safety margin: {safety_margin:.2%}") + print0(f"{'='*60}\n") + + # Check cache + cache_key = None + if enable_cache: + cache_key = _get_cache_key(model, max_seq_len) + cached_batch_size = _load_from_cache(cache_key) + if cached_batch_size is not None: + print0(f"✓ Cache hit! Using cached batch_size: {cached_batch_size}") + optimal_batch_size = cached_batch_size + + # Run discovery if no cache hit + if optimal_batch_size is None: + try: + # Warmup CUDA + _warmup_cuda() + + # Run the search algorithm + optimal_batch_size = _find_batch_size_internal( + model=model, + max_seq_len=max_seq_len, + data_sample_fn=data_sample_fn, + safety_margin=safety_margin, ) - if cache_key_components: - _save_to_cache(cache_key_components, discovered_size) - else: - discovered_size = _perform_discovery( - model, max_seq_len, device, safety_margin, - min_batch_size, max_batch_size - ) - - print0(f"Auto-discovery found device_batch_size={discovered_size}") + + # Save to cache + if enable_cache and cache_key is not None and optimal_batch_size is not None: + _save_to_cache(cache_key, optimal_batch_size) + + elapsed = time.time() - start_time + print0(f"\n{'='*60}") + print0(f"✓ Found optimal batch_size={optimal_batch_size} in {elapsed:.1f} seconds") + print0(f"{'='*60}\n") + + except Exception as e: + print0(f"⚠ Warning: Batch size discovery failed with error: {e}") + optimal_batch_size = None + + # Fallback to conservative defaults if discovery failed + if optimal_batch_size is None: + print0(f"⚠ Warning: Using conservative fallback batch_size=8") + optimal_batch_size = 8 else: - discovered_size = 0 # Will be broadcast from rank 0 - - # Broadcast to all ranks in DDP + optimal_batch_size = 0 # Will be broadcast from rank 0 + + # DDP: Broadcast result from rank 0 to all ranks if ddp_world_size > 1: - discovered_tensor = torch.tensor(discovered_size, dtype=torch.int32, device=device) - dist.broadcast(discovered_tensor, src=0) - discovered_size = discovered_tensor.item() - if ddp_rank != 0: - print0(f"Received batch size from rank 0: {discovered_size}") - - return discovered_size + try: + import torch.distributed as dist + tensor = torch.tensor([optimal_batch_size], dtype=torch.long, device='cuda') + dist.broadcast(tensor, src=0) + optimal_batch_size = tensor.item() + if ddp_rank != 0: + print0(f"Received batch_size from rank 0: {optimal_batch_size}") + except Exception as e: + print0(f"⚠ Warning: DDP broadcast failed: {e}") + if optimal_batch_size == 0: + optimal_batch_size = 8 + + return optimal_batch_size -def _perform_discovery( - model: torch.nn.Module, - max_seq_len: int, - device: torch.device, - safety_margin: float, - min_batch_size: int, - max_batch_size: int, -) -> int: +def _find_batch_size_internal(model, max_seq_len, data_sample_fn, safety_margin): """ - Perform the actual discovery using exponential + binary search. - - This is a stub implementation that returns a fixed value. - The real implementation should: - 1. Exponential search to find upper bound - 2. Binary search to refine - 3. Apply safety margin - """ - # Stub: return a fixed reasonable value - # Real implementation would perform exponential + binary search - batch_size = min(32, max_batch_size) - return max(int(batch_size * safety_margin), min_batch_size) + Core algorithm implementing exponential search followed by binary search. - -def _test_batch_size( - model: torch.nn.Module, - batch_size: int, - max_seq_len: int, - device: torch.device, -) -> bool: - """ - Test if a given batch size fits in memory. - Returns: - True if batch size works, False if OOM + optimal_batch_size: The largest batch size that fits in memory (with safety margin) + """ + # Phase 1: Exponential search to find upper bound + print0("Phase 1: Exponential search to find upper bound...") + batch_size = 1 + last_successful = None + + while True: + print0(f" Testing batch_size={batch_size}...", end=" ") + success = _test_batch_size( + model=model, + batch_size=batch_size, + max_seq_len=max_seq_len, + data_sample_fn=data_sample_fn, + ) + + if success: + print0("✓ Success") + last_successful = batch_size + batch_size *= 2 + else: + print0("✗ OOM") + break + + # If even batch_size=1 failed, return None + if last_successful is None: + print0("✗ Even batch_size=1 caused OOM!") + return None + + # Phase 2: Binary search refinement + print0(f"\nPhase 2: Binary search refinement between {last_successful} and {batch_size}...") + lower = last_successful + upper = batch_size + + while upper - lower > 1: + mid = (lower + upper) // 2 + print0(f" Testing batch_size={mid}...", end=" ") + success = _test_batch_size( + model=model, + batch_size=mid, + max_seq_len=max_seq_len, + data_sample_fn=data_sample_fn, + ) + + if success: + print0("✓ Success") + lower = mid + else: + print0("✗ OOM") + upper = mid + + # Phase 3: Apply safety margin + optimal_batch_size = int(lower * safety_margin) + print0(f"\nApplying safety margin: {lower} × {safety_margin:.2%} = {optimal_batch_size}") + + return optimal_batch_size + + +def _test_batch_size(model, batch_size, max_seq_len, data_sample_fn): + """ + Test if a specific batch size fits in memory by simulating training loop. + + Returns: + bool: True if batch size fits, False if OOM """ try: - # Create dummy inputs - inputs = torch.randint(0, 50000, (batch_size, max_seq_len), device=device, dtype=torch.int32) - targets = torch.randint(0, 50000, (batch_size, max_seq_len), device=device, dtype=torch.int64) - - # Forward + backward pass + # Clear CUDA cache before test + torch.cuda.empty_cache() + + # Set model to training mode model.train() + + # Zero gradients + model.zero_grad(set_to_none=True) + + # Generate test batch + inputs, targets = data_sample_fn(batch_size) + + # Forward pass with bfloat16 autocast with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16): loss = model(inputs, targets) + + # Backward pass loss.backward() - model.zero_grad(set_to_none=True) - - # Clean up + + # Synchronize CUDA to ensure all operations complete + torch.cuda.synchronize() + + # Clear cache after test del inputs, targets, loss torch.cuda.empty_cache() - + return True + except torch.cuda.OutOfMemoryError: + # Clear cache and return False on OOM torch.cuda.empty_cache() return False except Exception as e: - print0(f"Error testing batch size {batch_size}: {e}") + # Handle other exceptions + print0(f"\n⚠ Warning: Test failed with unexpected error: {e}") torch.cuda.empty_cache() return False -def _get_cache_key(components: Dict[str, Any]) -> str: - """Generate cache key from components.""" - key_str = json.dumps(components, sort_keys=True) - return hashlib.md5(key_str.encode()).hexdigest() +def _warmup_cuda(): + """Warmup CUDA by allocating and freeing a small tensor.""" + try: + x = torch.zeros(1, device='cuda') + del x + torch.cuda.synchronize() + torch.cuda.empty_cache() + except Exception as e: + print0(f"⚠ Warning: CUDA warmup failed: {e}") -def _load_from_cache(components: Dict[str, Any]) -> Optional[int]: - """Load batch size from cache if available.""" +def _get_cache_key(model, max_seq_len): + """ + Generate cache key from model config hash, GPU model, and max_seq_len. + + Returns: + str: Hash string to use as cache key + """ + try: + # Get model config attributes + config = model.config if hasattr(model, 'config') else None + if config is None: + # Try to get from original model (in case of compiled model) + config = model._orig_mod.config if hasattr(model, '_orig_mod') else None + + if config is None: + return None + + # Build config string + config_parts = [ + f"vocab_size={config.vocab_size}", + f"n_layer={config.n_layer}", + f"n_embd={config.n_embd}", + f"n_head={config.n_head}", + f"n_kv_head={config.n_kv_head}", + ] + config_str = "|".join(config_parts) + + # Get GPU model name + gpu_name = torch.cuda.get_device_name(0) + + # Combine all components + key_str = f"{config_str}|gpu={gpu_name}|seq_len={max_seq_len}" + + # Hash to create a short key + cache_key = hashlib.md5(key_str.encode()).hexdigest() + + return cache_key + + except Exception as e: + print0(f"⚠ Warning: Failed to generate cache key: {e}") + return None + + +def _load_from_cache(cache_key): + """ + Load cached batch size from JSON file. + + Returns: + int or None: Cached batch size, or None if not found + """ + if cache_key is None: + return None + try: base_dir = get_base_dir() cache_dir = os.path.join(base_dir, "auto_batch_cache") - cache_key = _get_cache_key(components) cache_file = os.path.join(cache_dir, f"{cache_key}.json") - - if os.path.exists(cache_file): - with open(cache_file, 'r') as f: - data = json.load(f) + + if not os.path.exists(cache_file): + return None + + with open(cache_file, 'r') as f: + data = json.load(f) return data.get('batch_size') + except Exception as e: - print0(f"Cache load error: {e}") - return None + print0(f"⚠ Warning: Failed to load from cache: {e}") + return None -def _save_to_cache(components: Dict[str, Any], batch_size: int) -> None: - """Save batch size to cache.""" +def _save_to_cache(cache_key, batch_size): + """Save batch size to JSON cache file.""" + if cache_key is None or batch_size is None: + return + try: base_dir = get_base_dir() cache_dir = os.path.join(base_dir, "auto_batch_cache") os.makedirs(cache_dir, exist_ok=True) - - cache_key = _get_cache_key(components) + cache_file = os.path.join(cache_dir, f"{cache_key}.json") - + + data = { + 'batch_size': batch_size, + 'timestamp': time.time(), + } + with open(cache_file, 'w') as f: - json.dump({ - 'batch_size': batch_size, - 'components': components, - }, f, indent=2) + json.dump(data, f, indent=2) + + print0(f"✓ Saved batch_size={batch_size} to cache") + except Exception as e: - print0(f"Cache save error: {e}") + print0(f"⚠ Warning: Failed to save to cache: {e}") diff --git a/nanochat/engine.py b/nanochat/engine.py index fec90cf..620bdcc 100644 --- a/nanochat/engine.py +++ b/nanochat/engine.py @@ -219,9 +219,7 @@ class Engine: # Get sampled tokens - either from prefill or from forward pass if first_iteration: - # Use the tokens we already sampled from prefill - sampled_tokens = [sampled_tokens[0]] * num_samples # Broadcast first token to all rows - # TODO: we should sample a token for each row instead of broadcasting + # sampled_tokens already contains num_samples independently sampled tokens first_iteration = False else: # Forward the model and get the next token for each row diff --git a/nanochat/gpt.py b/nanochat/gpt.py index 79c6085..1757d3f 100644 --- a/nanochat/gpt.py +++ b/nanochat/gpt.py @@ -22,6 +22,7 @@ import torch.nn.functional as F from nanochat.common import get_dist_info, print0 from nanochat.muon import Muon, DistMuon from nanochat.adamw import DistAdamW +from nanochat.engine import KVCache @dataclass class GPTConfig: @@ -304,10 +305,44 @@ class GPT(nn.Module): if temperature > 0: rng = torch.Generator(device=device) rng.manual_seed(seed) - ids = torch.tensor([tokens], dtype=torch.long, device=device) # add batch dim - for _ in range(max_tokens): - logits = self.forward(ids) # (B, T, vocab_size) - logits = logits[:, -1, :] # (B, vocab_size) + + # Initialize KV cache + kv_length_hint = len(tokens) + max_tokens + kv_cache = KVCache( + batch_size=1, + num_heads=self.config.n_kv_head, + seq_len=kv_length_hint, + head_dim=self.config.n_embd // self.config.n_head, + num_layers=self.config.n_layer + ) + + # Prefill phase: process the prompt + ids = torch.tensor([tokens], dtype=torch.long, device=device) + logits = self.forward(ids, kv_cache=kv_cache) + logits = logits[:, -1, :] + + # Sample first token + if top_k is not None: + v, _ = torch.topk(logits, min(top_k, logits.size(-1))) + logits[logits < v[:, [-1]]] = -float('Inf') + if temperature > 0: + logits = logits / temperature + probs = F.softmax(logits, dim=-1) + next_ids = torch.multinomial(probs, num_samples=1, generator=rng) + else: + next_ids = torch.argmax(logits, dim=-1, keepdim=True) + + # Yield first token + token = next_ids.item() + yield token + + # Generation loop: process one token at a time with KV cache + for _ in range(max_tokens - 1): + # Forward pass with only the new token + logits = self.forward(next_ids, kv_cache=kv_cache) + logits = logits[:, -1, :] + + # Sample next token if top_k is not None: v, _ = torch.topk(logits, min(top_k, logits.size(-1))) logits[logits < v[:, [-1]]] = -float('Inf') @@ -317,6 +352,6 @@ class GPT(nn.Module): next_ids = torch.multinomial(probs, num_samples=1, generator=rng) else: next_ids = torch.argmax(logits, dim=-1, keepdim=True) - ids = torch.cat((ids, next_ids), dim=1) + token = next_ids.item() yield token diff --git a/scripts/chat_sft.py b/scripts/chat_sft.py index f46fe2f..fc74cca 100644 --- a/scripts/chat_sft.py +++ b/scripts/chat_sft.py @@ -40,6 +40,7 @@ device_batch_size = None # per-device batch size (set to not OOM), None = auto-d auto_batch_size = True # whether to auto-discover optimal batch size batch_size_margin = 0.85 # safety margin for auto-discovered batch size batch_size_cache = True # whether to cache auto-discovered batch size +max_seq_len = 2048 # maximum sequence length for fixed padding (enables torch.compile) # optimization num_epochs = 1 max_iterations = -1 # override number of iterations (-1 = use num_epochs * num_iterations) @@ -104,8 +105,8 @@ elif device_batch_size is None: print0(f"Auto-discovery disabled, using default device_batch_size={device_batch_size}") orig_model = model # original, uncompiled model -# model = torch.compile(model, dynamic=True) # doesn't work super well because of variable lengths of inputs -engine = Engine(model, tokenizer) # will be used for inline model evaluation only +model = torch.compile(model, dynamic=False) # enabled with fixed-length padding +engine = Engine(orig_model, tokenizer) # use uncompiled model for engine (variable-length eval) # ----------------------------------------------------------------------------- # Task data mixture we'll train on @@ -126,7 +127,7 @@ def sft_data_generator(dataset, batch_size): # prepares a list of tokenized conversations into a batch and yields def collate_and_yield(batch): nrows = len(batch) - ncols = max(len(ids) for ids, mask in batch) - 1 # seq of n creates inputs/targets of n-1 + ncols = max_seq_len - 1 # fixed length for torch.compile (seq of n creates inputs/targets of n-1) inputs = torch.full((nrows, ncols), pad_token_id, dtype=torch.long) targets = torch.full((nrows, ncols), -1, dtype=torch.long) # -1 is ignore index for i, (ids, mask) in enumerate(batch): @@ -199,13 +200,13 @@ for step in range(num_iterations): # evaluate the validation loss if last_step or step % eval_every == 0: - model.eval() + orig_model.eval() val_iter = iter(build_val_loader()) losses = [] for _ in range(eval_steps): val_inputs, val_targets = next(val_iter) with torch.no_grad(), autocast_ctx: - loss = model(val_inputs, val_targets) + loss = orig_model(val_inputs, val_targets) losses.append(loss) val_loss = torch.stack(losses).mean() # average over eval_steps if ddp: @@ -216,22 +217,24 @@ for step in range(num_iterations): "step": step, "val_loss": val_loss, }) + orig_model.train() model.train() # evlauate accuracy of the multiple choice tasks (which are quick to run) if last_step or (step > 0 and step % eval_metrics_every == 0): - model.eval() + orig_model.eval() metrics = {} with torch.no_grad(), autocast_ctx: # note that because these are inside no_grad, we can usually afford to at least ~2X the batch size - metrics["mmlu_acc"] = run_chat_eval("MMLU", model, tokenizer, engine, batch_size=device_batch_size*2, max_problems=1024) - metrics["arc_easy_acc"] = run_chat_eval("ARC-Easy", model, tokenizer, engine, batch_size=device_batch_size*2, max_problems=1024) + metrics["mmlu_acc"] = run_chat_eval("MMLU", orig_model, tokenizer, engine, batch_size=device_batch_size*2, max_problems=1024) + metrics["arc_easy_acc"] = run_chat_eval("ARC-Easy", orig_model, tokenizer, engine, batch_size=device_batch_size*2, max_problems=1024) metrics_str = ', '.join(f'{k}: {v:.6f}' for k, v in metrics.items()) print0(f"Step {step:05d} | {metrics_str}") wandb_run.log({ "step": step, **metrics, }) + orig_model.train() model.train() if last_step: @@ -276,14 +279,14 @@ for step in range(num_iterations): # Save the model at the end of the run if master_process: base_dir = get_base_dir() - depth = model.config.n_layer + depth = orig_model.config.n_layer model_tag = f"d{depth}" # base the model tag on the depth of the base model checkpoint_dir = os.path.join(base_dir, "chatsft_checkpoints", model_tag) - model_config_kwargs = model.config.__dict__ # slightly naughty, abusing the simplicity of GPTConfig, TODO nicer + model_config_kwargs = orig_model.config.__dict__ # slightly naughty, abusing the simplicity of GPTConfig, TODO nicer save_checkpoint( checkpoint_dir, step, - model.state_dict(), + orig_model.state_dict(), None, # note: we don't bother to save the optimizer state { "step": step, diff --git a/scripts/check_checkpoint.py b/scripts/check_checkpoint.py new file mode 100644 index 0000000..e6f175a --- /dev/null +++ b/scripts/check_checkpoint.py @@ -0,0 +1,57 @@ +"""Quick script to inspect checkpoint structure.""" +import torch +import sys + +if len(sys.argv) < 2: + print("Usage: python check_checkpoint.py ") + sys.exit(1) + +checkpoint_path = sys.argv[1] +print(f"Loading checkpoint: {checkpoint_path}") + +checkpoint = torch.load(checkpoint_path, map_location='cpu') + +print("\n📦 Checkpoint keys:") +for key in checkpoint.keys(): + value = checkpoint[key] + if isinstance(value, dict): + print(f" - {key}: dict with {len(value)} items") + elif isinstance(value, torch.Tensor): + print(f" - {key}: Tensor {value.shape}") + else: + print(f" - {key}: {type(value).__name__}") + +# Check for model weights +if 'model' in checkpoint: + print("\n✅ Model weights found") + model_dict = checkpoint['model'] + print(f" Number of parameters: {len(model_dict)}") + print(f" Sample keys: {list(model_dict.keys())[:5]}") +else: + print("\n⚠️ No 'model' key found!") + print(" Available keys:", list(checkpoint.keys())) + +# Check for config +if 'config' in checkpoint: + print("\n✅ Config found") + config = checkpoint['config'] + print(f" Type: {type(config)}") + if hasattr(config, '__dict__'): + print(f" Attributes: {list(vars(config).keys())[:10]}") +else: + print("\n⚠️ No 'config' key found!") + +print("\n" + "="*60) +print("RECOMMENDATION:") +if 'model' not in checkpoint or 'config' not in checkpoint: + print("Your checkpoint is missing required keys.") + print("Please check how the model was saved during training.") + print("\nExpected checkpoint structure:") + print(" checkpoint = {") + print(" 'model': model.state_dict(),") + print(" 'config': model.config,") + print(" 'optimizer': optimizer.state_dict(), # optional") + print(" 'step': current_step, # optional") + print(" }") +else: + print("✅ Checkpoint looks good!") diff --git a/scripts/quick_check.py b/scripts/quick_check.py new file mode 100644 index 0000000..6fe958b --- /dev/null +++ b/scripts/quick_check.py @@ -0,0 +1,66 @@ +"""Quick checkpoint structure check.""" +import torch +import sys + +checkpoint_path = "/raid/diana/nanochat_cache/chatsft_checkpoints/d20/model_000650.pt" +print(f"Loading: {checkpoint_path}") + +try: + checkpoint = torch.load(checkpoint_path, map_location='cpu') + + print("\n" + "="*60) + print("CHECKPOINT STRUCTURE") + print("="*60) + + print(f"\nTop-level keys: {list(checkpoint.keys())}\n") + + for key in checkpoint.keys(): + value = checkpoint[key] + if isinstance(value, dict): + print(f"'{key}': dict with {len(value)} items") + # Show a few sub-keys if it's a dict + sub_keys = list(value.keys())[:3] + print(f" Sample keys: {sub_keys}") + elif isinstance(value, torch.Tensor): + print(f"'{key}': Tensor {value.shape}, dtype={value.dtype}") + else: + print(f"'{key}': {type(value).__name__} = {value}") + + print("\n" + "="*60) + print("DIAGNOSIS") + print("="*60) + + # Check what we need + has_model = 'model' in checkpoint + has_config = 'config' in checkpoint + has_state_dict = 'state_dict' in checkpoint + has_model_state_dict = 'model_state_dict' in checkpoint + + print(f"\n✓ Has 'model' key: {has_model}") + print(f"✓ Has 'config' key: {has_config}") + print(f"✓ Has 'state_dict' key: {has_state_dict}") + print(f"✓ Has 'model_state_dict' key: {has_model_state_dict}") + + # Try to infer the structure + print("\n" + "="*60) + print("RECOMMENDATION") + print("="*60) + + if has_model and has_config: + print("\n✅ Checkpoint has expected structure!") + print(" No changes needed to benchmark_optimizations.py") + elif has_state_dict: + print("\n⚠️ Checkpoint uses 'state_dict' instead of 'model'") + print(" Need to update benchmark to use checkpoint['state_dict']") + elif has_model_state_dict: + print("\n⚠️ Checkpoint uses 'model_state_dict' instead of 'model'") + print(" Need to update benchmark to use checkpoint['model_state_dict']") + else: + print("\n❌ Checkpoint has unexpected structure!") + print(" Available keys:", list(checkpoint.keys())) + print(" You may need to check how the model was saved during training") + +except Exception as e: + print(f"\n❌ Error loading checkpoint: {e}") + import traceback + traceback.print_exc() diff --git a/speedrun_4gpu.sh b/speedrun_4gpu.sh new file mode 100644 index 0000000..25775d8 --- /dev/null +++ b/speedrun_4gpu.sh @@ -0,0 +1,133 @@ +#!/bin/bash + +# This script is the "Best ChatGPT clone that $100 can buy", +# It is designed to run in ~4 hours on 8XH100 node at $3/GPU/hour. + +# 1) Example launch (simplest): +# bash speedrun.sh +# 2) Example launch in a screen session (because the run takes ~4 hours): +# screen -L -Logfile speedrun.log -S speedrun bash speedrun.sh +# 3) Example launch with wandb logging, but see below for setting up wandb first: +# WANDB_RUN=speedrun screen -L -Logfile speedrun.log -S speedrun bash speedrun.sh + +# Default intermediate artifacts directory is in ~/.cache/nanochat +export OMP_NUM_THREADS=1 +export NANOCHAT_BASE_DIR="/raid/diana/nanochat_cache" +mkdir -p $NANOCHAT_BASE_DIR + +# ----------------------------------------------------------------------------- +# Python venv setup with uv + +# install uv (if not already installed) +command -v uv &> /dev/null || curl -LsSf https://astral.sh/uv/install.sh | sh +# create a .venv local virtual environment (if it doesn't exist) +[ -d ".venv" ] || uv venv +# install the repo dependencies +uv sync +# activate venv so that `python` uses the project's venv instead of system python +source .venv/bin/activate + +# ----------------------------------------------------------------------------- +# wandb setup +# If you wish to use wandb for logging (it's nice!, recommended). +# 1) Make sure to first log in to wandb, e.g. run: +# `wandb login` +# 2) Set the WANDB_RUN environment variable when running this script, e.g.: +# `WANDB_RUN=d26 bash speedrun.sh` +if [ -z "$WANDB_RUN" ]; then + # by default use "dummy" : it's handled as a special case, skips logging to wandb + WANDB_RUN=dummy +fi + +# ----------------------------------------------------------------------------- +# During the course of the run, we will be writing markdown reports to the report/ +# directory in the base dir. This command clears it out and writes a header section +# with a bunch of system info and a timestamp that marks the start of the run. +python -m nanochat.report reset + +# ----------------------------------------------------------------------------- +# Tokenizer + +# Install Rust / Cargo +curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y +source "$HOME/.cargo/env" + +# Build the rustbpe Tokenizer +uv run maturin develop --release --manifest-path rustbpe/Cargo.toml + +# Download the first ~2B characters of pretraining dataset +# look at dev/repackage_data_reference.py for details on how this data was prepared +# each data shard is ~250M chars +# so we download 2e9 / 250e6 = 8 data shards at this point +# each shard is ~100MB of text (compressed), so this is about ~800MB of data on disk +python -m nanochat.dataset -n 8 +# Immediately also kick off downloading more shards in the background while tokenizer trains +# See comment below for why 240 is the right number here +python -m nanochat.dataset -n 240 & +DATASET_DOWNLOAD_PID=$! +# train the tokenizer with vocab size 2**16 = 65536 on ~2B characters of data +python -m scripts.tok_train --max_chars=2000000000 +# evaluate the tokenizer (report compression ratio etc.) +python -m scripts.tok_eval + +# ----------------------------------------------------------------------------- +# Base model (pretraining) + +# Download the eval_bundle from s3 to evaluate CORE metric during training (~162MB) +EVAL_BUNDLE_URL=https://karpathy-public.s3.us-west-2.amazonaws.com/eval_bundle.zip +if [ ! -d "$NANOCHAT_BASE_DIR/eval_bundle" ]; then + curl -L -o eval_bundle.zip $EVAL_BUNDLE_URL + unzip -q eval_bundle.zip + rm eval_bundle.zip + mv eval_bundle $NANOCHAT_BASE_DIR +fi + +# The d20 model is 561M parameters. +# Chinchilla says #tokens = 20X #params, so we need 561e6 * 20 = 11.2B tokens. +# Assume our tokenizer is 4.8 chars/token, this is 11.2B * 4.8 ~= 54B chars. +# At 250M chars/shard, this is 54B / 250M ~= 216 shards needed for pretraining. +# Round up to 240 for safety. At ~100MB/shard, this downloads ~24GB of data to disk. +# (The total number of shards available in the entire dataset is 1822.) +echo "Waiting for dataset download to complete..." +wait $DATASET_DOWNLOAD_PID + +# pretrain the d20 model +torchrun --standalone --nproc_per_node=4 -m scripts.base_train -- --depth=20 --run=$WANDB_RUN --total_batch_size=262144 +# evaluate the model on a larger chunk of train/val data and draw some samples +torchrun --standalone --nproc_per_node=4 -m scripts.base_loss +# evaluate the model on CORE tasks +torchrun --standalone --nproc_per_node=4 -m scripts.base_eval + +# ----------------------------------------------------------------------------- +# Midtraining (teach the model conversation special tokens, tool use, multiple choice) + +# run midtraining and eval the model +torchrun --standalone --nproc_per_node=4 -m scripts.mid_train -- --run=$WANDB_RUN +torchrun --standalone --nproc_per_node=4 -m scripts.chat_eval -- -i mid + +# ----------------------------------------------------------------------------- +# Supervised Finetuning (domain adaptation to each sequence all by itself per row) + +# train sft and re-eval right away (should see a small bump) +torchrun --standalone --nproc_per_node=4 -m scripts.chat_sft -- --run=$WANDB_RUN +torchrun --standalone --nproc_per_node=4 -m scripts.chat_eval -- -i sft + +# chat with the model over CLI! Leave out the -p to chat interactively +# python -m scripts.chat_cli -p "Why is the sky blue?" + +# even better, chat with your model over a pretty WebUI ChatGPT style +# python -m scripts.chat_web + +# ----------------------------------------------------------------------------- +# Reinforcement Learning. Optional, and currently only on GSM8K +# (optional) + +# run reinforcement learning +# torchrun --standalone --nproc_per_node=4 -m scripts.chat_rl -- --run=$WANDB_RUN +# eval the RL model only on GSM8K +# torchrun --standalone --nproc_per_node=4 -m scripts.chat_eval -- -i rl -a GSM8K + +# ----------------------------------------------------------------------------- +# Generate the full report by putting together all the sections +# report.md is the output and will be copied to current directory for convenience +python -m nanochat.report generated \ No newline at end of file