diff --git a/nanochat/common.py b/nanochat/common.py index 9bcd5dd..ea609b2 100644 --- a/nanochat/common.py +++ b/nanochat/common.py @@ -4,6 +4,8 @@ Common utilities for nanochat. import os import re +import sys +import platform import logging import urllib.request import torch @@ -150,6 +152,39 @@ def autodetect_device_type(): print0(f"Autodetected device type: {device_type}") return device_type +def is_mps_device(device): + """Check if device is MPS (Apple Metal Performance Shaders).""" + if isinstance(device, str): + return device == "mps" + return hasattr(device, 'type') and device.type == "mps" + +def should_use_torch_compile(device): + """ + Determine if torch.compile should be used based on device type and platform. + torch.compile hangs indefinitely on MPS devices (macOS). + Reference: https://github.com/karpathy/nanochat/pull/319 + """ + # Check if running on macOS with MPS device + is_macos = platform.system() == "Darwin" + is_mps = is_mps_device(device) + + if is_macos and is_mps: + logger.warning("=" * 80) + logger.warning("WARNING: torch.compile is disabled on macOS with MPS (Apple Metal)") + logger.warning("Platform: macOS (Darwin)") + logger.warning("Device: MPS (Metal Performance Shaders)") + logger.warning("Reason: torch.compile hangs indefinitely on MPS devices") + logger.warning("Reference: https://github.com/karpathy/nanochat/pull/319") + logger.warning("Using eager mode instead (no performance impact on evaluation)") + logger.warning("=" * 80) + return False + elif is_mps and not is_macos: + # MPS on non-macOS platform (shouldn't happen, but be defensive) + logger.warning("WARNING: MPS device detected on non-macOS platform - disabling torch.compile") + return False + + return True + def compute_init(device_type="cuda"): # cuda|cpu|mps """Basic initialization that we keep doing over and over, so make common.""" diff --git a/nanochat/core_eval.py b/nanochat/core_eval.py index f3c9a9f..274d2ab 100644 --- a/nanochat/core_eval.py +++ b/nanochat/core_eval.py @@ -11,6 +11,8 @@ from jinja2 import Template import torch import torch.distributed as dist +from nanochat import eval_config + # ----------------------------------------------------------------------------- # Prompt rendering utilities @@ -146,6 +148,8 @@ def forward_model(model, input_ids): """ Take BxT tensor of token ids, return BxT tensor of losses and argmax predictions. The last column of losses is set to nan because we don't have autoregressive targets there. + + MEMORY FIX: Explicitly cleanup intermediate tensors to prevent GPU memory accumulation. """ batch_size, seq_len = input_ids.size() outputs = model(input_ids) @@ -161,6 +165,9 @@ def forward_model(model, input_ids): losses[:, -1] = float('nan') # Get the argmax predictions at each position predictions = outputs.argmax(dim=-1) + # MEMORY FIX: Explicitly free large intermediate tensors + del outputs # outputs is largest tensor (B×T×V, ~GB for large models) + del target_ids # target_ids is B×T return losses, predictions @@ -238,6 +245,9 @@ def evaluate_example(idx, model, tokenizer, data, device, task_meta): else: raise ValueError(f"Unsupported task type: {task_type}") + # MEMORY FIX: Explicitly free tensors after extracting scalar result + del losses, predictions, input_ids + return is_correct @@ -245,18 +255,43 @@ def evaluate_task(model, tokenizer, data, device, task_meta): """ This function is responsible for evaluating one task across many examples. It also handles dispatch to all processes if the script is run with torchrun. + + MEMORY FIX: Added periodic cache cleanup to prevent memory accumulation. """ + import gc # For explicit garbage collection + rank = dist.get_rank() if dist.is_initialized() else 0 world_size = dist.get_world_size() if dist.is_initialized() else 1 correct = torch.zeros(len(data), dtype=torch.float32, device=device) + # stride the examples to each rank for idx in range(rank, len(data), world_size): is_correct = evaluate_example(idx, model, tokenizer, data, device, task_meta) correct[idx] = float(is_correct) + + # MEMORY FIX: Periodic cache cleanup + # This releases cached GPU memory and triggers Python GC + # Prevents progressive slowdown from memory fragmentation + # Interval configurable via eval_config.CACHE_CLEANUP_INTERVAL (default: 256) + if eval_config.ENABLE_PERIODIC_CLEANUP and idx % eval_config.CACHE_CLEANUP_INTERVAL == 0 and idx > 0: + # Release PyTorch cached memory back to GPU + if torch.cuda.is_available() and device.type == 'cuda': + torch.cuda.empty_cache() + # Force Python garbage collection + gc.collect() + # sync results across all the processes if running distributed if world_size > 1: dist.barrier() dist.all_reduce(correct, op=dist.ReduceOp.SUM) + # compute the mean mean_correct = correct.mean().item() + + # MEMORY FIX: Final cleanup after task completes + del correct + if eval_config.ENABLE_FINAL_CLEANUP: + if torch.cuda.is_available() and device.type == 'cuda': + torch.cuda.empty_cache() + return mean_correct diff --git a/nanochat/eval_config.py b/nanochat/eval_config.py new file mode 100644 index 0000000..6127b1c --- /dev/null +++ b/nanochat/eval_config.py @@ -0,0 +1,31 @@ +""" +Configuration for evaluation memory management and performance tuning. + +These settings control memory cleanup intervals and other evaluation parameters +to prevent memory leaks and progressive slowdown during long-running evaluations. +""" + +# Memory Management Settings +# --------------------------- + +# Periodic cache cleanup interval (in examples processed) +# After processing this many examples, trigger torch.cuda.empty_cache() and gc.collect() +# to prevent memory fragmentation and progressive slowdown. +# +# Rationale for 256: +# - Balances cleanup overhead (~10-50ms per cleanup) vs memory accumulation +# - Power of 2 (efficient modulo operation) +# - For HellaSwag (10,000 examples): 39 cleanups total (~2s overhead) +# - For MMLU (100-1000 examples): 0-4 cleanups total (negligible overhead) +# +# Lower values (e.g., 100): More frequent cleanup, less fragmentation, higher overhead +# Higher values (e.g., 512): Less overhead, more fragmentation risk +CACHE_CLEANUP_INTERVAL = 256 + +# Enable periodic cache cleanup during evaluation +# Set to False to disable all periodic cleanup (not recommended for long evaluations) +ENABLE_PERIODIC_CLEANUP = True + +# Enable final cleanup after task completes +# Set to False to skip final cleanup (saves ~50ms but leaves memory cached) +ENABLE_FINAL_CLEANUP = True diff --git a/scripts/base_train.py b/scripts/base_train.py index ee53098..2cfcabc 100644 --- a/scripts/base_train.py +++ b/scripts/base_train.py @@ -26,7 +26,7 @@ import torch from nanochat.gpt import GPT, GPTConfig from nanochat.dataloader import tokenizing_distributed_data_loader_bos_bestfit, tokenizing_distributed_data_loader_with_state_bos_bestfit -from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, print_banner, get_base_dir, autodetect_device_type, get_peak_flops +from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, print_banner, get_base_dir, autodetect_device_type, get_peak_flops, should_use_torch_compile from nanochat.tokenizer import get_tokenizer, get_token_bytes from nanochat.checkpoint_manager import save_checkpoint, load_checkpoint from nanochat.loss_eval import evaluate_bpb @@ -236,7 +236,11 @@ def disable_fp8(model): # Compile the model orig_model = model # original, uncompiled model, for saving raw model state_dict and for inference/evaluation (because the shapes may change shape) -model = torch.compile(model, dynamic=False) # the inputs to model will never change shape so dynamic=False is safe +if should_use_torch_compile(device): + model = torch.compile(model, dynamic=False) # the inputs to model will never change shape so dynamic=False is safe +else: + # Skip compilation on MPS (hangs indefinitely) + pass # ----------------------------------------------------------------------------- # Scaling laws and muP extrapolations to determine the optimal training horizon, batch size, learning rates, weight decay. diff --git a/scripts/chat_sft.py b/scripts/chat_sft.py index 4c81f06..891ec89 100644 --- a/scripts/chat_sft.py +++ b/scripts/chat_sft.py @@ -16,7 +16,7 @@ import time import wandb import torch from contextlib import nullcontext -from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, get_base_dir, autodetect_device_type +from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, get_base_dir, autodetect_device_type, should_use_torch_compile from nanochat.tokenizer import get_token_bytes from nanochat.checkpoint_manager import save_checkpoint from nanochat.loss_eval import evaluate_bpb @@ -81,7 +81,11 @@ pretrain_batch_size = meta.get("device_batch_size", None) if pretrain_batch_size is not None and args.device_batch_size > pretrain_batch_size: print0(f"FOOTGUN WARNING: base model training used device_batch_size {pretrain_batch_size}, did you pass in a good --device-batch-size to this script?") orig_model = model -model = torch.compile(model, dynamic=False) +if should_use_torch_compile(device): + model = torch.compile(model, dynamic=False) +else: + # Skip compilation on MPS (hangs indefinitely) + pass depth = model.config.n_layer num_flops_per_token = model.estimate_flops() tokens_per_fwdbwd = args.device_batch_size * args.max_seq_len # tokens per iteration for a single rank