This commit is contained in:
haltingstate 2026-02-10 14:41:19 -05:00 committed by GitHub
commit df77f21819
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 113 additions and 4 deletions

View File

@ -4,6 +4,8 @@ Common utilities for nanochat.
import os
import re
import sys
import platform
import logging
import urllib.request
import torch
@ -150,6 +152,39 @@ def autodetect_device_type():
print0(f"Autodetected device type: {device_type}")
return device_type
def is_mps_device(device):
"""Check if device is MPS (Apple Metal Performance Shaders)."""
if isinstance(device, str):
return device == "mps"
return hasattr(device, 'type') and device.type == "mps"
def should_use_torch_compile(device):
"""
Determine if torch.compile should be used based on device type and platform.
torch.compile hangs indefinitely on MPS devices (macOS).
Reference: https://github.com/karpathy/nanochat/pull/319
"""
# Check if running on macOS with MPS device
is_macos = platform.system() == "Darwin"
is_mps = is_mps_device(device)
if is_macos and is_mps:
logger.warning("=" * 80)
logger.warning("WARNING: torch.compile is disabled on macOS with MPS (Apple Metal)")
logger.warning("Platform: macOS (Darwin)")
logger.warning("Device: MPS (Metal Performance Shaders)")
logger.warning("Reason: torch.compile hangs indefinitely on MPS devices")
logger.warning("Reference: https://github.com/karpathy/nanochat/pull/319")
logger.warning("Using eager mode instead (no performance impact on evaluation)")
logger.warning("=" * 80)
return False
elif is_mps and not is_macos:
# MPS on non-macOS platform (shouldn't happen, but be defensive)
logger.warning("WARNING: MPS device detected on non-macOS platform - disabling torch.compile")
return False
return True
def compute_init(device_type="cuda"): # cuda|cpu|mps
"""Basic initialization that we keep doing over and over, so make common."""

View File

@ -11,6 +11,8 @@ from jinja2 import Template
import torch
import torch.distributed as dist
from nanochat import eval_config
# -----------------------------------------------------------------------------
# Prompt rendering utilities
@ -146,6 +148,8 @@ def forward_model(model, input_ids):
"""
Take BxT tensor of token ids, return BxT tensor of losses and argmax predictions.
The last column of losses is set to nan because we don't have autoregressive targets there.
MEMORY FIX: Explicitly cleanup intermediate tensors to prevent GPU memory accumulation.
"""
batch_size, seq_len = input_ids.size()
outputs = model(input_ids)
@ -161,6 +165,9 @@ def forward_model(model, input_ids):
losses[:, -1] = float('nan')
# Get the argmax predictions at each position
predictions = outputs.argmax(dim=-1)
# MEMORY FIX: Explicitly free large intermediate tensors
del outputs # outputs is largest tensor (B×T×V, ~GB for large models)
del target_ids # target_ids is B×T
return losses, predictions
@ -238,6 +245,9 @@ def evaluate_example(idx, model, tokenizer, data, device, task_meta):
else:
raise ValueError(f"Unsupported task type: {task_type}")
# MEMORY FIX: Explicitly free tensors after extracting scalar result
del losses, predictions, input_ids
return is_correct
@ -245,18 +255,43 @@ def evaluate_task(model, tokenizer, data, device, task_meta):
"""
This function is responsible for evaluating one task across many examples.
It also handles dispatch to all processes if the script is run with torchrun.
MEMORY FIX: Added periodic cache cleanup to prevent memory accumulation.
"""
import gc # For explicit garbage collection
rank = dist.get_rank() if dist.is_initialized() else 0
world_size = dist.get_world_size() if dist.is_initialized() else 1
correct = torch.zeros(len(data), dtype=torch.float32, device=device)
# stride the examples to each rank
for idx in range(rank, len(data), world_size):
is_correct = evaluate_example(idx, model, tokenizer, data, device, task_meta)
correct[idx] = float(is_correct)
# MEMORY FIX: Periodic cache cleanup
# This releases cached GPU memory and triggers Python GC
# Prevents progressive slowdown from memory fragmentation
# Interval configurable via eval_config.CACHE_CLEANUP_INTERVAL (default: 256)
if eval_config.ENABLE_PERIODIC_CLEANUP and idx % eval_config.CACHE_CLEANUP_INTERVAL == 0 and idx > 0:
# Release PyTorch cached memory back to GPU
if torch.cuda.is_available() and device.type == 'cuda':
torch.cuda.empty_cache()
# Force Python garbage collection
gc.collect()
# sync results across all the processes if running distributed
if world_size > 1:
dist.barrier()
dist.all_reduce(correct, op=dist.ReduceOp.SUM)
# compute the mean
mean_correct = correct.mean().item()
# MEMORY FIX: Final cleanup after task completes
del correct
if eval_config.ENABLE_FINAL_CLEANUP:
if torch.cuda.is_available() and device.type == 'cuda':
torch.cuda.empty_cache()
return mean_correct

31
nanochat/eval_config.py Normal file
View File

@ -0,0 +1,31 @@
"""
Configuration for evaluation memory management and performance tuning.
These settings control memory cleanup intervals and other evaluation parameters
to prevent memory leaks and progressive slowdown during long-running evaluations.
"""
# Memory Management Settings
# ---------------------------
# Periodic cache cleanup interval (in examples processed)
# After processing this many examples, trigger torch.cuda.empty_cache() and gc.collect()
# to prevent memory fragmentation and progressive slowdown.
#
# Rationale for 256:
# - Balances cleanup overhead (~10-50ms per cleanup) vs memory accumulation
# - Power of 2 (efficient modulo operation)
# - For HellaSwag (10,000 examples): 39 cleanups total (~2s overhead)
# - For MMLU (100-1000 examples): 0-4 cleanups total (negligible overhead)
#
# Lower values (e.g., 100): More frequent cleanup, less fragmentation, higher overhead
# Higher values (e.g., 512): Less overhead, more fragmentation risk
CACHE_CLEANUP_INTERVAL = 256
# Enable periodic cache cleanup during evaluation
# Set to False to disable all periodic cleanup (not recommended for long evaluations)
ENABLE_PERIODIC_CLEANUP = True
# Enable final cleanup after task completes
# Set to False to skip final cleanup (saves ~50ms but leaves memory cached)
ENABLE_FINAL_CLEANUP = True

View File

@ -26,7 +26,7 @@ import torch
from nanochat.gpt import GPT, GPTConfig
from nanochat.dataloader import tokenizing_distributed_data_loader_bos_bestfit, tokenizing_distributed_data_loader_with_state_bos_bestfit
from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, print_banner, get_base_dir, autodetect_device_type, get_peak_flops
from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, print_banner, get_base_dir, autodetect_device_type, get_peak_flops, should_use_torch_compile
from nanochat.tokenizer import get_tokenizer, get_token_bytes
from nanochat.checkpoint_manager import save_checkpoint, load_checkpoint
from nanochat.loss_eval import evaluate_bpb
@ -236,7 +236,11 @@ def disable_fp8(model):
# Compile the model
orig_model = model # original, uncompiled model, for saving raw model state_dict and for inference/evaluation (because the shapes may change shape)
model = torch.compile(model, dynamic=False) # the inputs to model will never change shape so dynamic=False is safe
if should_use_torch_compile(device):
model = torch.compile(model, dynamic=False) # the inputs to model will never change shape so dynamic=False is safe
else:
# Skip compilation on MPS (hangs indefinitely)
pass
# -----------------------------------------------------------------------------
# Scaling laws and muP extrapolations to determine the optimal training horizon, batch size, learning rates, weight decay.

View File

@ -16,7 +16,7 @@ import time
import wandb
import torch
from contextlib import nullcontext
from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, get_base_dir, autodetect_device_type
from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, get_base_dir, autodetect_device_type, should_use_torch_compile
from nanochat.tokenizer import get_token_bytes
from nanochat.checkpoint_manager import save_checkpoint
from nanochat.loss_eval import evaluate_bpb
@ -81,7 +81,11 @@ pretrain_batch_size = meta.get("device_batch_size", None)
if pretrain_batch_size is not None and args.device_batch_size > pretrain_batch_size:
print0(f"FOOTGUN WARNING: base model training used device_batch_size {pretrain_batch_size}, did you pass in a good --device-batch-size to this script?")
orig_model = model
model = torch.compile(model, dynamic=False)
if should_use_torch_compile(device):
model = torch.compile(model, dynamic=False)
else:
# Skip compilation on MPS (hangs indefinitely)
pass
depth = model.config.n_layer
num_flops_per_token = model.estimate_flops()
tokens_per_fwdbwd = args.device_batch_size * args.max_seq_len # tokens per iteration for a single rank