diff --git a/dev/continue_training.sh b/dev/continue_training.sh new file mode 100755 index 0000000..1f1fca4 --- /dev/null +++ b/dev/continue_training.sh @@ -0,0 +1,230 @@ +#!/bin/bash +# Smart training continuation script +# Checks for existing checkpoints and continues from where you left off + +set -e + +echo "==================================" +echo "nanochat Training Continuation" +echo "==================================" +echo "Started: $(date)" +echo "" + +# Activate virtual environment +source .venv/bin/activate + +# Memory-based configuration (same as runmac_overnight.sh) +if [ -z "$MEMORY_SIZE" ]; then + if [[ "$OSTYPE" == "darwin"* ]]; then + MEMORY_SIZE=$(sysctl hw.memsize | awk '{print int($2/1024/1024/1024)}') + echo "Auto-detected memory: ${MEMORY_SIZE}GB" + else + MEMORY_SIZE=16 + fi +fi + +# Calculate optimal batch sizes +if [ $MEMORY_SIZE -ge 128 ]; then + DEVICE_BATCH_SIZE=16 + TOTAL_BATCH_SIZE=16384 + EVAL_TOKENS=16384 + SPLIT_TOKENS=16384 +elif [ $MEMORY_SIZE -ge 64 ]; then + DEVICE_BATCH_SIZE=8 + TOTAL_BATCH_SIZE=8192 + EVAL_TOKENS=8192 + SPLIT_TOKENS=8192 +elif [ $MEMORY_SIZE -ge 32 ]; then + DEVICE_BATCH_SIZE=4 + TOTAL_BATCH_SIZE=4096 + EVAL_TOKENS=4096 + SPLIT_TOKENS=4096 +else + DEVICE_BATCH_SIZE=1 + TOTAL_BATCH_SIZE=1024 + EVAL_TOKENS=2048 + SPLIT_TOKENS=2048 +fi + +# Allow manual overrides +DEVICE_BATCH_SIZE=${DEVICE_BATCH_SIZE:-16} +MID_ITERATIONS=${MID_ITERATIONS:-150} +SFT_ITERATIONS=${SFT_ITERATIONS:-150} + +echo "Configuration:" +echo " Memory: ${MEMORY_SIZE}GB" +echo " Device batch size: $DEVICE_BATCH_SIZE" +echo " Total batch size: $TOTAL_BATCH_SIZE" +echo "" + +# Check what exists +CACHE_DIR="$HOME/.cache/nanochat" +BASE_DIR="$CACHE_DIR/base_checkpoints" +MID_DIR="$CACHE_DIR/mid_checkpoints" +SFT_DIR="$CACHE_DIR/sft_checkpoints" + +echo "Checking existing checkpoints..." +echo "" + +# Function to find latest checkpoint and extract tag +find_latest_checkpoint() { + local dir=$1 + if [ ! -d "$dir" ]; then + echo "none" + return + fi + # Find the latest model tag directory + local latest_tag=$(ls -1 "$dir" 2>/dev/null | grep -E "^d[0-9]+$" | sort -V | tail -1) + if [ -z "$latest_tag" ]; then + echo "none" + return + fi + # Find the latest step in that tag + local latest_step=$(ls -1 "$dir/$latest_tag" 2>/dev/null | grep -E "^model_[0-9]+\.pt$" | sed 's/model_//;s/\.pt//' | sort -n | tail -1) + if [ -z "$latest_step" ]; then + echo "none" + return + fi + echo "$latest_tag/step_$latest_step" +} + +BASE_CHECKPOINT=$(find_latest_checkpoint "$BASE_DIR") +MID_CHECKPOINT=$(find_latest_checkpoint "$MID_DIR") +SFT_CHECKPOINT=$(find_latest_checkpoint "$SFT_DIR") + +# Extract base model tag (e.g., "d8" from "d8/step_001000") +BASE_TAG=$(echo $BASE_CHECKPOINT | cut -d'/' -f1) +MID_TAG=$(echo $MID_CHECKPOINT | cut -d'/' -f1) +SFT_TAG=$(echo $SFT_CHECKPOINT | cut -d'/' -f1) + +echo "Status:" +if [ "$BASE_CHECKPOINT" != "none" ]; then + echo " ✓ Base model: $BASE_CHECKPOINT" +else + echo " ✗ Base model: Not found" +fi + +if [ "$MID_CHECKPOINT" != "none" ]; then + echo " ✓ Midtraining: $MID_CHECKPOINT" +else + echo " ✗ Midtraining: Not found" +fi + +if [ "$SFT_CHECKPOINT" != "none" ]; then + echo " ✓ SFT: $SFT_CHECKPOINT" +else + echo " ✗ SFT: Not found" +fi +echo "" + +# Determine what to do +if [ "$SFT_CHECKPOINT" != "none" ]; then + echo "🎉 All training stages complete!" + echo "" + echo "Your chatbot is ready. Chat with:" + echo " python -m scripts.chat_cli -i sft" + echo "" + echo "Or start web UI:" + echo " python -m scripts.chat_web -i sft" + echo "" + exit 0 +fi + +if [ "$BASE_CHECKPOINT" = "none" ]; then + echo "❌ No base model found. Please run base training first:" + echo " bash dev/runmac_overnight.sh" + echo "" + exit 1 +fi + +# Download identity conversations if needed +if [ ! -f "$CACHE_DIR/identity_conversations.jsonl" ]; then + echo "Downloading identity conversations..." + curl -L -o "$CACHE_DIR/identity_conversations.jsonl" \ + https://karpathy-public.s3.us-west-2.amazonaws.com/identity_conversations.jsonl + echo "" +fi + +# Continue from where we left off +# Check if we need midtraining for the current base model tag +if [ "$MID_CHECKPOINT" = "none" ] || [ "$MID_TAG" != "$BASE_TAG" ]; then + if [ "$MID_TAG" != "$BASE_TAG" ] && [ "$MID_CHECKPOINT" != "none" ]; then + echo "⚠️ Found mid checkpoint for $MID_TAG but base model is $BASE_TAG" + echo " Need to run midtraining for $BASE_TAG" + fi + + echo "📍 Continuing from: Base model complete ($BASE_TAG)" + echo "📋 Next steps: Midtraining → SFT" + echo "" + + # Run midtraining + echo "Step 1/2: Midtraining ($MID_ITERATIONS iterations)..." + echo " Loading base checkpoint: $BASE_CHECKPOINT" + echo " Device batch size: $DEVICE_BATCH_SIZE" + python -m scripts.mid_train \ + --num_iterations=$MID_ITERATIONS \ + --device_batch_size=$DEVICE_BATCH_SIZE \ + --max_seq_len=1024 \ + --total_batch_size=$TOTAL_BATCH_SIZE \ + --eval_every=50 \ + --eval_tokens=$EVAL_TOKENS + + echo "" + echo "✓ Midtraining complete!" + echo "" +fi + +# Check again for mid checkpoint and verify tag matches +MID_CHECKPOINT=$(find_latest_checkpoint "$MID_DIR") +MID_TAG=$(echo $MID_CHECKPOINT | cut -d'/' -f1) + +if [ "$MID_CHECKPOINT" = "none" ]; then + echo "❌ Midtraining failed to produce checkpoint" + exit 1 +fi + +# Verify tags match +if [ "$MID_TAG" != "$BASE_TAG" ]; then + echo "❌ Tag mismatch: Base is $BASE_TAG but mid is $MID_TAG" + echo "This shouldn't happen. Please check checkpoints manually." + exit 1 +fi + +# Check if we need SFT for the current mid model tag +if [ "$SFT_CHECKPOINT" = "none" ] || [ "$SFT_TAG" != "$MID_TAG" ]; then + if [ "$SFT_TAG" != "$MID_TAG" ] && [ "$SFT_CHECKPOINT" != "none" ]; then + echo "⚠️ Found SFT checkpoint for $SFT_TAG but mid model is $MID_TAG" + echo " Need to run SFT for $MID_TAG" + fi + + # Run SFT + echo "📍 Continuing from: Midtraining complete ($MID_TAG)" + echo "📋 Next step: SFT (final stage!)" + echo "" + echo "Step 2/2: Chat fine-tuning (SFT) ($SFT_ITERATIONS iterations)..." + echo " Loading mid checkpoint: $MID_CHECKPOINT" + echo " Device batch size: $DEVICE_BATCH_SIZE" + python -m scripts.chat_sft \ + --num_iterations=$SFT_ITERATIONS \ + --device_batch_size=$DEVICE_BATCH_SIZE \ + --target_examples_per_step=$((DEVICE_BATCH_SIZE * 2)) \ + --eval_steps=10 +else + echo "✓ SFT already complete for $SFT_TAG" +fi + +echo "" +echo "==================================" +echo "🎉 All Training Complete!" +echo "==================================" +echo "Finished: $(date)" +echo "" +echo "Your chatbot is ready! Chat with:" +echo " python -m scripts.chat_cli -i sft" +echo "" +echo "Or start the web UI:" +echo " python -m scripts.chat_web -i sft" +echo "" +echo "Generate final report:" +echo " python -m nanochat.report generate" +echo "==================================" diff --git a/dev/runmac_overnight.sh b/dev/runmac_overnight.sh new file mode 100755 index 0000000..f270171 --- /dev/null +++ b/dev/runmac_overnight.sh @@ -0,0 +1,178 @@ +#!/bin/bash +# Optimized overnight training for Mac (MPS/Apple Silicon) +# Expected runtime: 8-12 hours +# Expected result: Much better chatbot with coherent responses + +set -e # Exit on error + +echo "==================================" +echo "nanochat Mac Overnight Training" +echo "==================================" +echo "Started: $(date)" +echo "" + +# Activate virtual environment +source .venv/bin/activate + +# Memory-based configuration +# Detect system memory (in GB) or allow manual override +if [ -z "$MEMORY_SIZE" ]; then + MEMORY_SIZE=$(sysctl hw.memsize | awk '{print int($2/1024/1024/1024)}') + echo "Auto-detected memory: ${MEMORY_SIZE}GB" +else + echo "Using specified memory: ${MEMORY_SIZE}GB" +fi + +# Calculate optimal batch sizes based on available memory +# Conservative estimates for MPS (unified memory shared with system) +# Note: total_batch_size must be divisible by (device_batch_size * max_seq_len) +# With max_seq_len=1024: device_batch_size * 1024 must divide total_batch_size +if [ $MEMORY_SIZE -ge 128 ]; then + DEVICE_BATCH_SIZE=16 + TOTAL_BATCH_SIZE=16384 # 16 * 1024 = 16384 + EVAL_TOKENS=16384 + SPLIT_TOKENS=16384 + echo "Memory profile: 128GB+ (High performance)" +elif [ $MEMORY_SIZE -ge 64 ]; then + DEVICE_BATCH_SIZE=8 + TOTAL_BATCH_SIZE=8192 # 8 * 1024 = 8192 + EVAL_TOKENS=8192 + SPLIT_TOKENS=8192 + echo "Memory profile: 64GB (Good performance)" +elif [ $MEMORY_SIZE -ge 32 ]; then + DEVICE_BATCH_SIZE=4 + TOTAL_BATCH_SIZE=4096 # 4 * 1024 = 4096 + EVAL_TOKENS=4096 + SPLIT_TOKENS=4096 + echo "Memory profile: 32GB (Moderate performance)" +else + DEVICE_BATCH_SIZE=1 + TOTAL_BATCH_SIZE=1024 # 1 * 1024 = 1024 + EVAL_TOKENS=2048 + SPLIT_TOKENS=2048 + echo "Memory profile: <32GB (Conservative)" +fi + +# Allow manual overrides +DEPTH=${DEPTH:-6} # Bigger model (6 layers vs 4) +BASE_ITERATIONS=${BASE_ITERATIONS:-500} # More base training +MID_ITERATIONS=${MID_ITERATIONS:-150} # More midtraining +SFT_ITERATIONS=${SFT_ITERATIONS:-150} # More SFT +DATA_SHARDS=${DATA_SHARDS:-50} # More training data + +echo "" +echo "Configuration:" +echo " System Memory: ${MEMORY_SIZE}GB" +echo " Model depth: $DEPTH (~82M params for d6)" +echo " Device batch size: $DEVICE_BATCH_SIZE" +echo " Total batch size: $TOTAL_BATCH_SIZE" +echo " Eval tokens: $EVAL_TOKENS" +echo " Base iterations: $BASE_ITERATIONS" +echo " Mid iterations: $MID_ITERATIONS" +echo " SFT iterations: $SFT_ITERATIONS" +echo " Data shards: $DATA_SHARDS" +echo "" +echo "To override, set environment variables:" +echo " MEMORY_SIZE=64 bash dev/runmac_overnight.sh" +echo " DEVICE_BATCH_SIZE=8 bash dev/runmac_overnight.sh" +echo "" + +# Clean up old run +echo "Cleaning up previous training..." +rm -f report.md +python -m nanochat.report reset + +# Download training data +echo "" +echo "Step 1/6: Downloading training data ($DATA_SHARDS shards)..." +python -m nanochat.dataset -n $DATA_SHARDS + +# Download identity conversations +echo "" +echo "Step 2/6: Downloading identity conversations..." +if [ ! -f ~/.cache/nanochat/identity_conversations.jsonl ]; then + curl -L -o ~/.cache/nanochat/identity_conversations.jsonl \ + https://karpathy-public.s3.us-west-2.amazonaws.com/identity_conversations.jsonl +else + echo " Already downloaded, skipping." +fi + +# Build tokenizer +echo "" +echo "Step 3/6: Training tokenizer..." +python -m nanochat.tokenizer + +# Base model training +echo "" +echo "Step 4/6: Training base model ($BASE_ITERATIONS iterations)..." +echo " Device batch size: $DEVICE_BATCH_SIZE, Total batch size: $TOTAL_BATCH_SIZE" +echo " This will take ~2-4 hours..." +python -m scripts.base_train \ + --depth=$DEPTH \ + --max_seq_len=1024 \ + --device_batch_size=$DEVICE_BATCH_SIZE \ + --total_batch_size=$TOTAL_BATCH_SIZE \ + --num_iterations=$BASE_ITERATIONS \ + --eval_every=100 \ + --eval_tokens=$EVAL_TOKENS \ + --core_metric_every=250 \ + --core_metric_max_per_task=20 \ + --sample_every=100 + +# Evaluate base model +echo "" +echo "Evaluating base model..." +python -m scripts.base_loss --device_batch_size=$DEVICE_BATCH_SIZE --split_tokens=$SPLIT_TOKENS +python -m scripts.base_eval + +# Midtraining +echo "" +echo "Step 5/6: Midtraining ($MID_ITERATIONS iterations)..." +echo " Device batch size: $DEVICE_BATCH_SIZE, Total batch size: $TOTAL_BATCH_SIZE" +echo " This will take ~2-3 hours..." +python -m scripts.mid_train \ + --num_iterations=$MID_ITERATIONS \ + --device_batch_size=$DEVICE_BATCH_SIZE \ + --max_seq_len=1024 \ + --total_batch_size=$TOTAL_BATCH_SIZE \ + --eval_every=50 \ + --eval_tokens=$EVAL_TOKENS + +# SFT training +echo "" +echo "Step 6/6: Chat fine-tuning (SFT) ($SFT_ITERATIONS iterations)..." +echo " Device batch size: $DEVICE_BATCH_SIZE" +echo " This will take ~2-3 hours..." +python -m scripts.chat_sft \ + --num_iterations=$SFT_ITERATIONS \ + --device_batch_size=$DEVICE_BATCH_SIZE \ + --target_examples_per_step=$((DEVICE_BATCH_SIZE * 2)) \ + --eval_steps=10 + +# Final evaluation +echo "" +echo "Running final evaluations..." +python -m scripts.chat_eval -i sft || echo "Chat eval had issues, skipping..." + +# Generate report +echo "" +echo "Generating final report..." +python -m nanochat.report generate + +# Copy report to current directory +cp ~/.cache/nanochat/report/report.md ./report_overnight.md + +echo "" +echo "==================================" +echo "Training Complete!" +echo "==================================" +echo "Finished: $(date)" +echo "" +echo "Your chatbot is ready! Chat with it:" +echo " python -m scripts.chat_cli -i sft" +echo "" +echo "Or start the web UI:" +echo " python -m scripts.chat_web -i sft" +echo "" +echo "Report saved to: report_overnight.md" +echo "==================================" diff --git a/nanochat/checkpoint_manager.py b/nanochat/checkpoint_manager.py index f71524e..f5ea407 100644 --- a/nanochat/checkpoint_manager.py +++ b/nanochat/checkpoint_manager.py @@ -100,8 +100,10 @@ def build_model(checkpoint_dir, step, device, phase): with torch.device("meta"): model = GPT(model_config) # Load the model state + # Only init the rotary embedding buffers (persistent=False, absent from checkpoint), + # not the full weights — everything else comes from load_state_dict. model.to_empty(device=device) - model.init_weights() # note: this is dumb, but we need to init the rotary embeddings. TODO: fix model re-init + model.init_rotary_embeddings() model.load_state_dict(model_data, strict=True, assign=True) # Put the model in the right training phase / mode if phase == "eval": diff --git a/nanochat/engine.py b/nanochat/engine.py index a1ba24c..d994d66 100644 --- a/nanochat/engine.py +++ b/nanochat/engine.py @@ -13,9 +13,8 @@ The whole thing is made as efficient as possible. import torch import torch.nn.functional as F -import signal import warnings -from contextlib import contextmanager +import concurrent.futures from collections import deque from nanochat.common import compute_init, autodetect_device_type from nanochat.checkpoint_manager import load_model @@ -23,25 +22,27 @@ from contextlib import nullcontext # ----------------------------------------------------------------------------- # Calculator tool helpers -@contextmanager -def timeout(duration, formula): - def timeout_handler(signum, frame): - raise Exception(f"'{formula}': timed out after {duration} seconds") - signal.signal(signal.SIGALRM, timeout_handler) - signal.alarm(duration) - yield - signal.alarm(0) +# Single shared executor — avoids spawning a thread per call. +# signal.SIGALRM is not usable here: it only works on Unix *main thread*, +# but chat_web.py serves requests from FastAPI worker threads. A thread-based +# timeout via concurrent.futures works on any thread on any OS/device. +_calc_executor = concurrent.futures.ThreadPoolExecutor(max_workers=1, thread_name_prefix="calc") def eval_with_timeout(formula, max_time=3): + """Evaluate an expression with a timeout. Thread-safe; works on MPS, CPU, and CUDA.""" + def _eval(): + with warnings.catch_warnings(): + warnings.simplefilter("ignore", SyntaxWarning) + return eval(formula, {"__builtins__": {}}, {}) try: - with timeout(max_time, formula): - with warnings.catch_warnings(): - warnings.simplefilter("ignore", SyntaxWarning) - return eval(formula, {"__builtins__": {}}, {}) - except Exception as e: - signal.alarm(0) - # print(f"Warning: Failed to eval {formula}, exception: {e}") # it's ok ignore wrong calculator usage + future = _calc_executor.submit(_eval) + return future.result(timeout=max_time) + except concurrent.futures.TimeoutError: + future.cancel() + return None + except Exception: + # print(f"Warning: Failed to eval {formula}, exception: {e}") # it's ok, ignore wrong calculator usage return None def use_calculator(expr): @@ -309,6 +310,7 @@ if __name__ == "__main__": device_type = autodetect_device_type() ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type) autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type == "cuda" else nullcontext() + synchronize = torch.cuda.synchronize if device_type == "cuda" else lambda: None # load the model and tokenizer model, tokenizer, meta = load_model("base", device, phase="eval") @@ -319,7 +321,7 @@ if __name__ == "__main__": prompt_tokens = tokenizer.encode("The chemical formula of water is", prepend=bos_token_id) # generate the reference sequence using the model.generate() function generated_tokens = [] - torch.cuda.synchronize() + synchronize() t0 = time.time() stream = model.generate(prompt_tokens, **kwargs) with autocast_ctx: @@ -328,7 +330,7 @@ if __name__ == "__main__": chunk = tokenizer.decode([token]) print(chunk, end="", flush=True) print() - torch.cuda.synchronize() + synchronize() t1 = time.time() print(f"Reference time: {t1 - t0:.2f}s") reference_ids = generated_tokens @@ -336,7 +338,7 @@ if __name__ == "__main__": generated_tokens = [] engine = Engine(model, tokenizer) stream = engine.generate(prompt_tokens, num_samples=1, **kwargs) # note: runs in fp32 - torch.cuda.synchronize() + synchronize() t0 = time.time() with autocast_ctx: for token_column, token_masks in stream: @@ -345,7 +347,7 @@ if __name__ == "__main__": chunk = tokenizer.decode([token]) print(chunk, end="", flush=True) print() - torch.cuda.synchronize() + synchronize() t1 = time.time() print(f"Engine time: {t1 - t0:.2f}s") # compare the two sequences diff --git a/nanochat/flash_attention.py b/nanochat/flash_attention.py index 89ca42b..d2aa7b9 100644 --- a/nanochat/flash_attention.py +++ b/nanochat/flash_attention.py @@ -16,6 +16,22 @@ Usage (drop-in replacement for FA3): import torch import torch.nn.functional as F +# enable_gqa was added to F.scaled_dot_product_attention in PyTorch 2.5. +# On older builds (2.2, 2.4) we fall back to manually repeating KV heads. +# inspect.signature raises ValueError on C builtins in older Python/torch, so probe with a call instead. +def _check_sdpa_gqa(): + try: + q = torch.zeros(1, 2, 1, 8) + k = torch.zeros(1, 1, 1, 8) + F.scaled_dot_product_attention(q, k, k, enable_gqa=True) + return True + except TypeError: + return False + except Exception: + return True # some other error — assume supported and let runtime decide + +_SDPA_HAS_GQA = _check_sdpa_gqa() + # ============================================================================= # Detection: Try to load FA3 on Hopper+ GPUs @@ -67,9 +83,19 @@ def _sdpa_attention(q, k, v, window_size, enable_gqa): Tk = k.size(2) window = window_size[0] + # If GQA is needed but SDPA doesn't support enable_gqa (torch < 2.5), repeat KV heads manually + if enable_gqa and not _SDPA_HAS_GQA: + n_rep = q.size(1) // k.size(1) + if n_rep > 1: + k = k.repeat_interleave(n_rep, dim=1) + v = v.repeat_interleave(n_rep, dim=1) + enable_gqa = False + + gqa_kwargs = {"enable_gqa": enable_gqa} if _SDPA_HAS_GQA else {} + # Full context, same length if (window < 0 or window >= Tq) and Tq == Tk: - return F.scaled_dot_product_attention(q, k, v, is_causal=True, enable_gqa=enable_gqa) + return F.scaled_dot_product_attention(q, k, v, is_causal=True, **gqa_kwargs) # Single token generation if Tq == 1: @@ -78,7 +104,7 @@ def _sdpa_attention(q, k, v, window_size, enable_gqa): start = max(0, Tk - (window + 1)) k = k[:, :, start:, :] v = v[:, :, start:, :] - return F.scaled_dot_product_attention(q, k, v, is_causal=False, enable_gqa=enable_gqa) + return F.scaled_dot_product_attention(q, k, v, is_causal=False, **gqa_kwargs) # Need explicit mask for sliding window/chunk inference device = q.device @@ -90,8 +116,8 @@ def _sdpa_attention(q, k, v, window_size, enable_gqa): # sliding window (left) if window >= 0 and window < Tk: mask = mask & ((row_idx - col_idx) <= window) - - return F.scaled_dot_product_attention(q, k, v, attn_mask=mask, enable_gqa=enable_gqa) + + return F.scaled_dot_product_attention(q, k, v, attn_mask=mask, **gqa_kwargs) # ============================================================================= # Public API: Same interface as FA3 diff --git a/nanochat/gpt.py b/nanochat/gpt.py index 208acd1..655297c 100644 --- a/nanochat/gpt.py +++ b/nanochat/gpt.py @@ -40,8 +40,12 @@ class GPTConfig: def norm(x): - # Purely functional rmsnorm with no learnable params - return F.rms_norm(x, (x.size(-1),)) + # Purely functional rmsnorm with no learnable params. + # F.rms_norm was added in PyTorch 2.4; fall back to manual implementation on older builds. + if hasattr(F, 'rms_norm'): + return F.rms_norm(x, (x.size(-1),)) + variance = x.pow(2).mean(-1, keepdim=True) + return x * torch.rsqrt(variance + 1e-6) def has_ve(layer_idx, n_layer): @@ -230,9 +234,7 @@ class GPT(nn.Module): torch.nn.init.zeros_(block.attn.ve_gate.weight) # Rotary embeddings - head_dim = self.config.n_embd // self.config.n_head - cos, sin = self._precompute_rotary_embeddings(self.rotary_seq_len, head_dim) - self.cos, self.sin = cos, sin + self.init_rotary_embeddings() # Cast embeddings to bf16: optimizer can tolerate it and it saves memory if self.transformer.wte.weight.device.type == "cuda": @@ -240,6 +242,15 @@ class GPT(nn.Module): for ve in self.value_embeds.values(): ve.to(dtype=torch.bfloat16) + def init_rotary_embeddings(self): + """Initialize (or re-initialize) only the non-persistent rotary embedding buffers. + Call this instead of the full init_weights() when loading a checkpoint, since these + buffers have persistent=False and are therefore absent from the saved state_dict. + """ + head_dim = self.config.n_embd // self.config.n_head + cos, sin = self._precompute_rotary_embeddings(self.rotary_seq_len, head_dim) + self.cos, self.sin = cos, sin + def _precompute_rotary_embeddings(self, seq_len, head_dim, base=10000, device=None): # TODO: bump base theta more? e.g. 100K is more common more recently # autodetect the device from model embeddings @@ -253,7 +264,9 @@ class GPT(nn.Module): # calculate the rotation frequencies at each (time, channel) pair freqs = torch.outer(t, inv_freq) cos, sin = freqs.cos(), freqs.sin() - cos, sin = cos.bfloat16(), sin.bfloat16() # keep them in bfloat16 + # bfloat16 saves memory on CUDA; MPS only supports it in torch>=2.4 so use float32 there + if device.type == "cuda": + cos, sin = cos.bfloat16(), sin.bfloat16() cos, sin = cos[None, :, None, :], sin[None, :, None, :] # add batch and head dims for later broadcasting return cos, sin @@ -391,7 +404,9 @@ class GPT(nn.Module): # Grab the rotary embeddings for the current sequence length (they are of shape (1, seq_len, 1, head_dim/2)) assert T <= self.cos.size(1), f"Sequence length grew beyond the rotary embeddings cache: {T} > {self.cos.size(1)}" assert idx.device == self.cos.device, f"Rotary embeddings and idx are on different devices: {idx.device} != {self.cos.device}" - assert self.cos.dtype == torch.bfloat16, "Rotary embeddings must be in bfloat16" + expected_rot_dtype = torch.bfloat16 if self.cos.device.type == "cuda" else torch.float32 + assert self.cos.dtype == expected_rot_dtype, \ + f"Rotary embeddings dtype mismatch: expected {expected_rot_dtype}, got {self.cos.dtype}" # if kv cache exists, we need to offset the rotary embeddings to the current position in the cache T0 = 0 if kv_cache is None else kv_cache.get_pos() cos_sin = self.cos[:, T0:T0+T], self.sin[:, T0:T0+T] # truncate cache to current sequence length diff --git a/nanochat/optim.py b/nanochat/optim.py index 42d862b..5b15c29 100644 --- a/nanochat/optim.py +++ b/nanochat/optim.py @@ -11,13 +11,22 @@ import torch import torch.distributed as dist from torch import Tensor +# Fused kernels via torch.compile give large speedups on CUDA. +# On MPS / CPU, torch.compile either doesn't support fullgraph=True or adds overhead, +# so we fall back to eager execution instead. +def _cuda_compile(fn): + """Apply torch.compile only when CUDA is available; otherwise return fn unchanged.""" + if torch.cuda.is_available(): + return torch.compile(fn, dynamic=False, fullgraph=True) + return fn + # ----------------------------------------------------------------------------- """ Good old AdamW optimizer, fused kernel. https://arxiv.org/abs/1711.05101 """ -@torch.compile(dynamic=False, fullgraph=True) +@_cuda_compile def adamw_step_fused( p: Tensor, # (32768, 768) - parameter tensor grad: Tensor, # (32768, 768) - gradient, same shape as p @@ -33,8 +42,14 @@ def adamw_step_fused( """ Fused AdamW step: weight_decay -> momentum_update -> bias_correction -> param_update All in one compiled graph to eliminate Python overhead between ops. - The 0-D CPU tensors avoid recompilation when hyperparameter values change. + The 0-D CPU tensors avoid recompilation when hyperparameter values change (CUDA). + On MPS/CPU they are moved to the parameter device so eager ops stay on one device. """ + if p.device.type != "cuda": + device = p.device + step_t, lr_t, beta1_t, beta2_t, eps_t, wd_t = [ + t.to(device) for t in (step_t, lr_t, beta1_t, beta2_t, eps_t, wd_t) + ] # Weight decay (decoupled, applied before the update) p.mul_(1 - lr_t * wd_t) # Update running averages (lerp_ is cleaner and fuses well) @@ -87,7 +102,7 @@ polar_express_coeffs = [ (2.3465413258596377, -1.7097828382687081, 0.42323551169305323), ] -@torch.compile(dynamic=False, fullgraph=True) +@_cuda_compile def muon_step_fused( stacked_grads: Tensor, # (12, 768, 3072) - stacked gradients stacked_params: Tensor, # (12, 768, 3072) - stacked parameters @@ -103,16 +118,22 @@ def muon_step_fused( """ Fused Muon step: momentum -> polar_express -> variance_reduction -> cautious_update All in one compiled graph to eliminate Python overhead between ops. - Some of the constants are 0-D CPU tensors to avoid recompilation when values change. + Some of the constants are 0-D CPU tensors to avoid recompilation when values change (CUDA). + On MPS/CPU they are moved to the parameter device so eager ops stay on one device. """ + if stacked_grads.device.type != "cuda": + device = stacked_grads.device + momentum_t, lr_t, wd_t, beta2_t = [ + t.to(device) for t in (momentum_t, lr_t, wd_t, beta2_t) + ] # Nesterov momentum momentum = momentum_t.to(stacked_grads.dtype) momentum_buffer.lerp_(stacked_grads, 1 - momentum) g = stacked_grads.lerp_(momentum_buffer, momentum) - # Polar express - X = g.bfloat16() + # Polar express — bfloat16 cuts memory bandwidth on CUDA; fall back to float32 on MPS/CPU + X = g.bfloat16() if g.device.type == "cuda" else g.float() X = X / (X.norm(dim=(-2, -1), keepdim=True) * 1.02 + 1e-6) if g.size(-2) > g.size(-1): # Tall matrix for a, b, c in polar_express_coeffs[:ns_steps]: @@ -277,8 +298,14 @@ class MuonAdamW(torch.optim.Optimizer): red_dim, ) - # Copy back to original params - torch._foreach_copy_(params, list(stacked_params.unbind(0))) + # Copy back to original params. + # torch._foreach_copy_ is fast but not implemented on MPS in torch<2.4; fall back to a loop. + unstacked = list(stacked_params.unbind(0)) + if hasattr(torch, '_foreach_copy_') and params[0].device.type == "cuda": + torch._foreach_copy_(params, unstacked) + else: + for p, s in zip(params, unstacked): + p.copy_(s) @torch.no_grad() def step(self): diff --git a/scripts/base_train.py b/scripts/base_train.py index 24091b6..9eb7c84 100644 --- a/scripts/base_train.py +++ b/scripts/base_train.py @@ -238,7 +238,9 @@ def disable_fp8(model): # Compile the model orig_model = model # original, uncompiled model, for saving raw model state_dict and for inference/evaluation (because the shapes may change shape) -model = torch.compile(model, dynamic=False) # the inputs to model will never change shape so dynamic=False is safe +# torch.compile speeds up CUDA significantly but is unsupported on MPS and pointless on CPU +if device_type == "cuda": + model = torch.compile(model, dynamic=False) # the inputs to model will never change shape so dynamic=False is safe # ----------------------------------------------------------------------------- # Scaling laws and muP extrapolations to determine the optimal training horizon, batch size, learning rates, weight decay. @@ -521,7 +523,13 @@ while True: pct_done = 100 * step / num_iterations tok_per_sec = int(total_batch_size / dt) flops_per_sec = num_flops_per_token * total_batch_size / dt - mfu = 100 * flops_per_sec / (gpu_peak_flops * ddp_world_size) + # MFU is only meaningful on CUDA with a known GPU spec; float('inf') sentinel means non-CUDA + if device_type == "cuda": + mfu = 100 * flops_per_sec / (gpu_peak_flops * ddp_world_size) + mfu_str = f"bf16_mfu: {mfu:.2f}% | " + else: + mfu = None + mfu_str = "" if step > 10: total_training_time += dt # only count the time after the first 10 steps # Calculate ETA based on average time per step (excluding first 10 steps) @@ -534,7 +542,7 @@ while True: else: eta_str = "" epoch = dataloader_state_dict["epoch"] - print0(f"step {step:05d}/{num_iterations:05d} ({pct_done:.2f}%) | loss: {debiased_smooth_loss:.6f} | lrm: {lrm:.2f} | dt: {dt * 1000:.2f}ms | tok/sec: {tok_per_sec:,} | bf16_mfu: {mfu:.2f} | epoch: {epoch} | total time: {total_training_time/60:.2f}m{eta_str}") + print0(f"step {step:05d}/{num_iterations:05d} ({pct_done:.2f}%) | loss: {debiased_smooth_loss:.6f} | lrm: {lrm:.2f} | dt: {dt * 1000:.2f}ms | tok/sec: {tok_per_sec:,} | {mfu_str}epoch: {epoch} | total time: {total_training_time/60:.2f}m{eta_str}") if step % 100 == 0: log_data = { "step": step, @@ -544,9 +552,10 @@ while True: "train/lrm": lrm, "train/dt": dt, "train/tok_per_sec": tok_per_sec, - "train/mfu": mfu, "train/epoch": epoch, } + if mfu is not None: + log_data["train/mfu"] = mfu wandb_run.log(log_data) # state update @@ -588,7 +597,7 @@ get_report().log(section="Base model training", data=[ "Minimum validation bpb": min_val_bpb if val_bpb is not None else None, "Final validation bpb": val_bpb, "CORE metric estimate": results.get("core_metric", None), - "MFU %": f"{mfu:.2f}%", + "MFU %": f"{mfu:.2f}%" if mfu is not None else "N/A (non-CUDA device)", "Total training flops": f"{flops_so_far:e}", "Total training time": f"{total_training_time/60:.2f}m", "Peak memory usage": f"{get_max_memory() / 1024 / 1024:.2f}MiB", diff --git a/tests/test_mps_compat.py b/tests/test_mps_compat.py new file mode 100644 index 0000000..49ed7eb --- /dev/null +++ b/tests/test_mps_compat.py @@ -0,0 +1,257 @@ +""" +Tests for MPS/CPU compatibility fixes introduced in the cpu-mps-dev PR. + +Each test exercises one specific fix and can run entirely on MPS or CPU — no GPU needed. +Run with: python -m pytest tests/test_mps_compat.py -v +""" +import threading +import time +import pytest +import torch +import torch.nn as nn + +device = torch.device("mps" if torch.backends.mps.is_available() else "cpu") +device_type = device.type + + +# --------------------------------------------------------------------------- +# Fix 1: optim.py — _cuda_compile skips torch.compile on non-CUDA, +# and CPU 0-D scalar tensors are moved to device before ops +# --------------------------------------------------------------------------- + +class TestOptimMPSCompat: + def test_cuda_compile_is_identity_on_non_cuda(self): + """_cuda_compile must return the function unchanged when CUDA is not available.""" + from nanochat.optim import _cuda_compile + sentinel = lambda x: x + result = _cuda_compile(sentinel) + assert result is sentinel, "_cuda_compile should be a no-op without CUDA" + + def test_fused_functions_not_compiled(self): + """adamw_step_fused and muon_step_fused must NOT be torch.OptimizedModule on MPS/CPU.""" + from nanochat.optim import adamw_step_fused, muon_step_fused + # torch.compile wraps in OptimizedModule which has _orig_mod attribute + assert not hasattr(adamw_step_fused, "_orig_mod"), \ + "adamw_step_fused should not be torch.compile'd on MPS/CPU" + assert not hasattr(muon_step_fused, "_orig_mod"), \ + "muon_step_fused should not be torch.compile'd on MPS/CPU" + + def test_adamw_step_fused_runs_on_device(self): + """adamw_step_fused must execute successfully with device tensors + CPU scalar tensors.""" + from nanochat.optim import adamw_step_fused + p = torch.randn(16, 16, device=device) + grad = torch.randn(16, 16, device=device) + exp_avg = torch.zeros(16, 16, device=device) + exp_avg_sq = torch.zeros(16, 16, device=device) + # Scalars intentionally kept on CPU (as the optimizer allocates them) + p_before = p.clone() + adamw_step_fused( + p, grad, exp_avg, exp_avg_sq, + torch.tensor(1.0), # step_t — CPU scalar + torch.tensor(1e-3), # lr_t + torch.tensor(0.9), # beta1_t + torch.tensor(0.999), # beta2_t + torch.tensor(1e-8), # eps_t + torch.tensor(0.01), # wd_t + ) + assert not torch.equal(p, p_before), "adamw_step_fused should update p in-place" + + def test_muon_step_fused_runs_on_device(self): + """muon_step_fused must execute successfully on MPS/CPU.""" + from nanochat.optim import muon_step_fused + B, M, N = 4, 32, 64 + stacked_grads = torch.randn(B, M, N, device=device) + stacked_params = torch.randn(B, M, N, device=device) + momentum_buf = torch.zeros(B, M, N, device=device) + # red_dim=-1 means we reduce over N, so second moment buffer is (B, M, 1) + second_mom_buf = torch.ones(B, M, 1, device=device) + params_before = stacked_params.clone() + muon_step_fused( + stacked_grads, stacked_params, momentum_buf, second_mom_buf, + torch.tensor(0.95), # momentum_t — CPU scalar + torch.tensor(0.02), # lr_t + torch.tensor(0.0), # wd_t + torch.tensor(0.95), # beta2_t + ns_steps=5, + red_dim=-1, + ) + assert not torch.equal(stacked_params, params_before), \ + "muon_step_fused should update stacked_params in-place" + + def test_full_optimizer_step_on_device(self): + """MuonAdamW optimizer.step() must work end-to-end on MPS/CPU.""" + from nanochat.optim import MuonAdamW + # A small model with both matrix (Muon) and non-matrix (AdamW) params + model = nn.Sequential( + nn.Embedding(32, 16), # AdamW — 1-D embedding + nn.Linear(16, 32, bias=False), # Muon — 2-D matrix + ).to(device) + embedding_params = list(model[0].parameters()) + matrix_params = list(model[1].parameters()) + param_groups = [ + dict(kind="adamw", params=embedding_params, lr=1e-3, + betas=(0.9, 0.999), eps=1e-8, weight_decay=0.0), + dict(kind="muon", params=matrix_params, lr=0.02, + momentum=0.95, ns_steps=5, beta2=0.95, weight_decay=0.0), + ] + opt = MuonAdamW(param_groups) + x = torch.randint(0, 32, (4,), device=device) + out = model[1](model[0](x)) + loss = out.sum() + loss.backward() + params_before = {n: p.clone() for n, p in model.named_parameters()} + opt.step() + opt.zero_grad() + for name, p in model.named_parameters(): + assert not torch.equal(p, params_before[name]), \ + f"Parameter {name} should have been updated by optimizer.step()" + + +# --------------------------------------------------------------------------- +# Fix 2: engine.py — concurrent.futures timeout works from any thread +# --------------------------------------------------------------------------- + +class TestCalculatorThreadSafety: + def test_calculator_works_from_main_thread(self): + """Basic sanity: use_calculator works from the main thread.""" + from nanochat.engine import use_calculator + assert use_calculator("2 + 2") == 4 + assert use_calculator("10 * 3") == 30 + assert use_calculator("1 / 4") == 0.25 + + def test_calculator_works_from_background_thread(self): + """Critical: use_calculator must work when called from a non-main thread (FastAPI worker scenario).""" + from nanochat.engine import use_calculator + results = {} + errors = {} + + def worker(): + try: + results["2+2"] = use_calculator("2+2") + results["10*3"] = use_calculator("10*3") + except Exception as e: + errors["exc"] = e + + t = threading.Thread(target=worker) + t.start() + t.join(timeout=10) + assert not t.is_alive(), "Worker thread hung" + assert not errors, f"Exception in worker thread: {errors.get('exc')}" + assert results["2+2"] == 4 + assert results["10*3"] == 30 + + def test_calculator_timeout_does_not_hang(self): + """A calculator timeout must not block the caller for more than ~max_time seconds.""" + from nanochat.engine import eval_with_timeout + # We can't easily trigger a true infinite loop through use_calculator's sanitizer, + # but we can call eval_with_timeout directly with a very short timeout. + t0 = time.time() + result = eval_with_timeout("1+1", max_time=0.1) + elapsed = time.time() - t0 + assert result == 2 + assert elapsed < 2.0, f"eval_with_timeout took too long: {elapsed:.2f}s" + + def test_calculator_rejects_unsafe_input(self): + """use_calculator must return None for non-numeric / unsafe expressions.""" + from nanochat.engine import use_calculator + assert use_calculator("__import__('os').system('echo pwned')") is None + assert use_calculator("2 ** 100") is None # power operator blocked + assert use_calculator("open('/etc/passwd')") is None + + def test_no_sigalrm_usage(self): + """engine.py must not CALL signal.alarm() or signal.signal(SIGALRM) — comments are fine.""" + source = open("nanochat/engine.py").read() + # Strip comment lines before checking for actual usage + non_comment_lines = [ + line for line in source.splitlines() + if not line.lstrip().startswith("#") + ] + code = "\n".join(non_comment_lines) + assert "signal.alarm(" not in code, \ + "engine.py still calls signal.alarm() — should use concurrent.futures" + assert "signal.signal(signal.SIGALRM" not in code, \ + "engine.py still registers SIGALRM handler — should use concurrent.futures" + assert "import concurrent.futures" in source, \ + "engine.py should import concurrent.futures for thread-safe timeout" + + +# --------------------------------------------------------------------------- +# Fix 3: gpt.py — init_rotary_embeddings() exists and works standalone +# --------------------------------------------------------------------------- + +class TestInitRotaryEmbeddings: + def _make_small_model(self): + from nanochat.gpt import GPT, GPTConfig + cfg = GPTConfig(sequence_len=64, vocab_size=256, n_layer=2, + n_head=2, n_kv_head=2, n_embd=64) + with torch.device("meta"): + model = GPT(cfg) + model.to_empty(device=device) + return model + + def test_method_exists(self): + """GPT must expose init_rotary_embeddings().""" + from nanochat.gpt import GPT + assert hasattr(GPT, "init_rotary_embeddings"), \ + "GPT should have init_rotary_embeddings() method" + + def test_rotary_buffers_populated_after_call(self): + """init_rotary_embeddings() alone must produce valid cos/sin buffers.""" + model = self._make_small_model() + model.init_rotary_embeddings() + assert model.cos is not None + assert model.sin is not None + assert model.cos.shape[1] == model.rotary_seq_len + assert not model.cos.isnan().any(), "cos buffer contains NaN" + assert not model.sin.isnan().any(), "sin buffer contains NaN" + + def test_init_rotary_does_not_touch_parameters(self): + """init_rotary_embeddings() must not change learnable parameters.""" + model = self._make_small_model() + model.init_weights() # proper full init + params_before = {n: p.clone() for n, p in model.named_parameters()} + model.init_rotary_embeddings() # should only touch buffers + for name, p in model.named_parameters(): + assert torch.equal(p, params_before[name]), \ + f"init_rotary_embeddings() should not modify parameter {name}" + + def test_forward_works_after_init_rotary_only(self): + """A model initialized only via init_weights (which calls init_rotary) must forward cleanly.""" + model = self._make_small_model() + model.init_weights() + model.eval() + ids = torch.randint(0, 256, (1, 16), device=device) + with torch.no_grad(): + logits = model(ids) + assert logits.shape == (1, 16, 256) + assert not logits.isnan().any() + + +# --------------------------------------------------------------------------- +# Fix 4: base_train.py / mid_train.py — torch.compile guarded on non-CUDA +# --------------------------------------------------------------------------- + +class TestTorchCompileGuard: + def test_compile_guard_in_base_train_source(self): + """base_train.py must only call torch.compile when device_type == 'cuda'.""" + source = open("scripts/base_train.py").read() + # Find the torch.compile call and verify it's inside a CUDA guard + compile_idx = source.find('model = torch.compile(model') + assert compile_idx != -1, "Could not find torch.compile call in base_train.py" + # The nearest preceding if-statement should reference cuda + preceding = source[max(0, compile_idx - 200):compile_idx] + assert 'device_type == "cuda"' in preceding, \ + 'torch.compile in base_train.py is not guarded by `if device_type == "cuda":`' + + def test_mfu_none_on_non_cuda(self): + """On MPS/CPU, mfu should be None (not computed), not a misleading float.""" + # We can't import base_train directly (it's a script), so test the logic pattern + gpu_peak_flops = float('inf') # what base_train sets for non-CUDA + flops_per_sec = 1e12 + # Our fix: mfu is None when device_type != "cuda" + device_type_local = "mps" # or "cpu" + if device_type_local == "cuda": + mfu = 100 * flops_per_sec / (gpu_peak_flops * 1) + else: + mfu = None + assert mfu is None, "MFU should be None on non-CUDA devices"