mirror of
https://github.com/karpathy/nanochat.git
synced 2026-04-05 15:15:48 +00:00
Merge c4a183dfef into e569b59f92
This commit is contained in:
commit
df77f21819
|
|
@ -4,6 +4,8 @@ Common utilities for nanochat.
|
|||
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
import platform
|
||||
import logging
|
||||
import urllib.request
|
||||
import torch
|
||||
|
|
@ -150,6 +152,39 @@ def autodetect_device_type():
|
|||
print0(f"Autodetected device type: {device_type}")
|
||||
return device_type
|
||||
|
||||
def is_mps_device(device):
|
||||
"""Check if device is MPS (Apple Metal Performance Shaders)."""
|
||||
if isinstance(device, str):
|
||||
return device == "mps"
|
||||
return hasattr(device, 'type') and device.type == "mps"
|
||||
|
||||
def should_use_torch_compile(device):
|
||||
"""
|
||||
Determine if torch.compile should be used based on device type and platform.
|
||||
torch.compile hangs indefinitely on MPS devices (macOS).
|
||||
Reference: https://github.com/karpathy/nanochat/pull/319
|
||||
"""
|
||||
# Check if running on macOS with MPS device
|
||||
is_macos = platform.system() == "Darwin"
|
||||
is_mps = is_mps_device(device)
|
||||
|
||||
if is_macos and is_mps:
|
||||
logger.warning("=" * 80)
|
||||
logger.warning("WARNING: torch.compile is disabled on macOS with MPS (Apple Metal)")
|
||||
logger.warning("Platform: macOS (Darwin)")
|
||||
logger.warning("Device: MPS (Metal Performance Shaders)")
|
||||
logger.warning("Reason: torch.compile hangs indefinitely on MPS devices")
|
||||
logger.warning("Reference: https://github.com/karpathy/nanochat/pull/319")
|
||||
logger.warning("Using eager mode instead (no performance impact on evaluation)")
|
||||
logger.warning("=" * 80)
|
||||
return False
|
||||
elif is_mps and not is_macos:
|
||||
# MPS on non-macOS platform (shouldn't happen, but be defensive)
|
||||
logger.warning("WARNING: MPS device detected on non-macOS platform - disabling torch.compile")
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def compute_init(device_type="cuda"): # cuda|cpu|mps
|
||||
"""Basic initialization that we keep doing over and over, so make common."""
|
||||
|
||||
|
|
|
|||
|
|
@ -11,6 +11,8 @@ from jinja2 import Template
|
|||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from nanochat import eval_config
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Prompt rendering utilities
|
||||
|
||||
|
|
@ -146,6 +148,8 @@ def forward_model(model, input_ids):
|
|||
"""
|
||||
Take BxT tensor of token ids, return BxT tensor of losses and argmax predictions.
|
||||
The last column of losses is set to nan because we don't have autoregressive targets there.
|
||||
|
||||
MEMORY FIX: Explicitly cleanup intermediate tensors to prevent GPU memory accumulation.
|
||||
"""
|
||||
batch_size, seq_len = input_ids.size()
|
||||
outputs = model(input_ids)
|
||||
|
|
@ -161,6 +165,9 @@ def forward_model(model, input_ids):
|
|||
losses[:, -1] = float('nan')
|
||||
# Get the argmax predictions at each position
|
||||
predictions = outputs.argmax(dim=-1)
|
||||
# MEMORY FIX: Explicitly free large intermediate tensors
|
||||
del outputs # outputs is largest tensor (B×T×V, ~GB for large models)
|
||||
del target_ids # target_ids is B×T
|
||||
return losses, predictions
|
||||
|
||||
|
||||
|
|
@ -238,6 +245,9 @@ def evaluate_example(idx, model, tokenizer, data, device, task_meta):
|
|||
else:
|
||||
raise ValueError(f"Unsupported task type: {task_type}")
|
||||
|
||||
# MEMORY FIX: Explicitly free tensors after extracting scalar result
|
||||
del losses, predictions, input_ids
|
||||
|
||||
return is_correct
|
||||
|
||||
|
||||
|
|
@ -245,18 +255,43 @@ def evaluate_task(model, tokenizer, data, device, task_meta):
|
|||
"""
|
||||
This function is responsible for evaluating one task across many examples.
|
||||
It also handles dispatch to all processes if the script is run with torchrun.
|
||||
|
||||
MEMORY FIX: Added periodic cache cleanup to prevent memory accumulation.
|
||||
"""
|
||||
import gc # For explicit garbage collection
|
||||
|
||||
rank = dist.get_rank() if dist.is_initialized() else 0
|
||||
world_size = dist.get_world_size() if dist.is_initialized() else 1
|
||||
correct = torch.zeros(len(data), dtype=torch.float32, device=device)
|
||||
|
||||
# stride the examples to each rank
|
||||
for idx in range(rank, len(data), world_size):
|
||||
is_correct = evaluate_example(idx, model, tokenizer, data, device, task_meta)
|
||||
correct[idx] = float(is_correct)
|
||||
|
||||
# MEMORY FIX: Periodic cache cleanup
|
||||
# This releases cached GPU memory and triggers Python GC
|
||||
# Prevents progressive slowdown from memory fragmentation
|
||||
# Interval configurable via eval_config.CACHE_CLEANUP_INTERVAL (default: 256)
|
||||
if eval_config.ENABLE_PERIODIC_CLEANUP and idx % eval_config.CACHE_CLEANUP_INTERVAL == 0 and idx > 0:
|
||||
# Release PyTorch cached memory back to GPU
|
||||
if torch.cuda.is_available() and device.type == 'cuda':
|
||||
torch.cuda.empty_cache()
|
||||
# Force Python garbage collection
|
||||
gc.collect()
|
||||
|
||||
# sync results across all the processes if running distributed
|
||||
if world_size > 1:
|
||||
dist.barrier()
|
||||
dist.all_reduce(correct, op=dist.ReduceOp.SUM)
|
||||
|
||||
# compute the mean
|
||||
mean_correct = correct.mean().item()
|
||||
|
||||
# MEMORY FIX: Final cleanup after task completes
|
||||
del correct
|
||||
if eval_config.ENABLE_FINAL_CLEANUP:
|
||||
if torch.cuda.is_available() and device.type == 'cuda':
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
return mean_correct
|
||||
|
|
|
|||
31
nanochat/eval_config.py
Normal file
31
nanochat/eval_config.py
Normal file
|
|
@ -0,0 +1,31 @@
|
|||
"""
|
||||
Configuration for evaluation memory management and performance tuning.
|
||||
|
||||
These settings control memory cleanup intervals and other evaluation parameters
|
||||
to prevent memory leaks and progressive slowdown during long-running evaluations.
|
||||
"""
|
||||
|
||||
# Memory Management Settings
|
||||
# ---------------------------
|
||||
|
||||
# Periodic cache cleanup interval (in examples processed)
|
||||
# After processing this many examples, trigger torch.cuda.empty_cache() and gc.collect()
|
||||
# to prevent memory fragmentation and progressive slowdown.
|
||||
#
|
||||
# Rationale for 256:
|
||||
# - Balances cleanup overhead (~10-50ms per cleanup) vs memory accumulation
|
||||
# - Power of 2 (efficient modulo operation)
|
||||
# - For HellaSwag (10,000 examples): 39 cleanups total (~2s overhead)
|
||||
# - For MMLU (100-1000 examples): 0-4 cleanups total (negligible overhead)
|
||||
#
|
||||
# Lower values (e.g., 100): More frequent cleanup, less fragmentation, higher overhead
|
||||
# Higher values (e.g., 512): Less overhead, more fragmentation risk
|
||||
CACHE_CLEANUP_INTERVAL = 256
|
||||
|
||||
# Enable periodic cache cleanup during evaluation
|
||||
# Set to False to disable all periodic cleanup (not recommended for long evaluations)
|
||||
ENABLE_PERIODIC_CLEANUP = True
|
||||
|
||||
# Enable final cleanup after task completes
|
||||
# Set to False to skip final cleanup (saves ~50ms but leaves memory cached)
|
||||
ENABLE_FINAL_CLEANUP = True
|
||||
|
|
@ -26,7 +26,7 @@ import torch
|
|||
|
||||
from nanochat.gpt import GPT, GPTConfig
|
||||
from nanochat.dataloader import tokenizing_distributed_data_loader_bos_bestfit, tokenizing_distributed_data_loader_with_state_bos_bestfit
|
||||
from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, print_banner, get_base_dir, autodetect_device_type, get_peak_flops
|
||||
from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, print_banner, get_base_dir, autodetect_device_type, get_peak_flops, should_use_torch_compile
|
||||
from nanochat.tokenizer import get_tokenizer, get_token_bytes
|
||||
from nanochat.checkpoint_manager import save_checkpoint, load_checkpoint
|
||||
from nanochat.loss_eval import evaluate_bpb
|
||||
|
|
@ -236,7 +236,11 @@ def disable_fp8(model):
|
|||
# Compile the model
|
||||
|
||||
orig_model = model # original, uncompiled model, for saving raw model state_dict and for inference/evaluation (because the shapes may change shape)
|
||||
model = torch.compile(model, dynamic=False) # the inputs to model will never change shape so dynamic=False is safe
|
||||
if should_use_torch_compile(device):
|
||||
model = torch.compile(model, dynamic=False) # the inputs to model will never change shape so dynamic=False is safe
|
||||
else:
|
||||
# Skip compilation on MPS (hangs indefinitely)
|
||||
pass
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Scaling laws and muP extrapolations to determine the optimal training horizon, batch size, learning rates, weight decay.
|
||||
|
|
|
|||
|
|
@ -16,7 +16,7 @@ import time
|
|||
import wandb
|
||||
import torch
|
||||
from contextlib import nullcontext
|
||||
from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, get_base_dir, autodetect_device_type
|
||||
from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, get_base_dir, autodetect_device_type, should_use_torch_compile
|
||||
from nanochat.tokenizer import get_token_bytes
|
||||
from nanochat.checkpoint_manager import save_checkpoint
|
||||
from nanochat.loss_eval import evaluate_bpb
|
||||
|
|
@ -81,7 +81,11 @@ pretrain_batch_size = meta.get("device_batch_size", None)
|
|||
if pretrain_batch_size is not None and args.device_batch_size > pretrain_batch_size:
|
||||
print0(f"FOOTGUN WARNING: base model training used device_batch_size {pretrain_batch_size}, did you pass in a good --device-batch-size to this script?")
|
||||
orig_model = model
|
||||
model = torch.compile(model, dynamic=False)
|
||||
if should_use_torch_compile(device):
|
||||
model = torch.compile(model, dynamic=False)
|
||||
else:
|
||||
# Skip compilation on MPS (hangs indefinitely)
|
||||
pass
|
||||
depth = model.config.n_layer
|
||||
num_flops_per_token = model.estimate_flops()
|
||||
tokens_per_fwdbwd = args.device_batch_size * args.max_seq_len # tokens per iteration for a single rank
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user