mirror of
https://github.com/karpathy/nanochat.git
synced 2026-06-19 12:39:10 +00:00
Add MPS device detection and memory monitoring
Add is_mps_device() and should_use_torch_compile() to nanochat/common.py Disable torch.compile on macOS MPS devices (prevents indefinite hanging) Add conditional torch.compile in base_train.py and chat_sft.py Add memory monitoring with 32GB inference / 96GB training limits Reference: Task-20, Task-18, Task-19, Task-28, Task-39
This commit is contained in:
parent
1ec0a34779
commit
143dc98c76
|
|
@ -4,6 +4,8 @@ Common utilities for nanochat.
|
|||
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
import platform
|
||||
import logging
|
||||
import urllib.request
|
||||
import torch
|
||||
|
|
@ -150,6 +152,39 @@ def autodetect_device_type():
|
|||
print0(f"Autodetected device type: {device_type}")
|
||||
return device_type
|
||||
|
||||
def is_mps_device(device):
|
||||
"""Check if device is MPS (Apple Metal Performance Shaders)."""
|
||||
if isinstance(device, str):
|
||||
return device == "mps"
|
||||
return hasattr(device, 'type') and device.type == "mps"
|
||||
|
||||
def should_use_torch_compile(device):
|
||||
"""
|
||||
Determine if torch.compile should be used based on device type and platform.
|
||||
torch.compile hangs indefinitely on MPS devices (macOS).
|
||||
Reference: https://github.com/karpathy/nanochat/pull/319
|
||||
"""
|
||||
# Check if running on macOS with MPS device
|
||||
is_macos = platform.system() == "Darwin"
|
||||
is_mps = is_mps_device(device)
|
||||
|
||||
if is_macos and is_mps:
|
||||
logger.warning("=" * 80)
|
||||
logger.warning("WARNING: torch.compile is disabled on macOS with MPS (Apple Metal)")
|
||||
logger.warning("Platform: macOS (Darwin)")
|
||||
logger.warning("Device: MPS (Metal Performance Shaders)")
|
||||
logger.warning("Reason: torch.compile hangs indefinitely on MPS devices")
|
||||
logger.warning("Reference: https://github.com/karpathy/nanochat/pull/319")
|
||||
logger.warning("Using eager mode instead (no performance impact on evaluation)")
|
||||
logger.warning("=" * 80)
|
||||
return False
|
||||
elif is_mps and not is_macos:
|
||||
# MPS on non-macOS platform (shouldn't happen, but be defensive)
|
||||
logger.warning("WARNING: MPS device detected on non-macOS platform - disabling torch.compile")
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def compute_init(device_type="cuda"): # cuda|cpu|mps
|
||||
"""Basic initialization that we keep doing over and over, so make common."""
|
||||
|
||||
|
|
|
|||
|
|
@ -26,7 +26,7 @@ import torch
|
|||
|
||||
from nanochat.gpt import GPT, GPTConfig
|
||||
from nanochat.dataloader import tokenizing_distributed_data_loader_bos_bestfit, tokenizing_distributed_data_loader_with_state_bos_bestfit
|
||||
from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, print_banner, get_base_dir, autodetect_device_type, get_peak_flops
|
||||
from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, print_banner, get_base_dir, autodetect_device_type, get_peak_flops, should_use_torch_compile
|
||||
from nanochat.tokenizer import get_tokenizer, get_token_bytes
|
||||
from nanochat.checkpoint_manager import save_checkpoint, load_checkpoint
|
||||
from nanochat.loss_eval import evaluate_bpb
|
||||
|
|
@ -234,7 +234,11 @@ def disable_fp8(model):
|
|||
# Compile the model
|
||||
|
||||
orig_model = model # original, uncompiled model, for saving raw model state_dict and for inference/evaluation (because the shapes may change shape)
|
||||
model = torch.compile(model, dynamic=False) # the inputs to model will never change shape so dynamic=False is safe
|
||||
if should_use_torch_compile(device):
|
||||
model = torch.compile(model, dynamic=False) # the inputs to model will never change shape so dynamic=False is safe
|
||||
else:
|
||||
# Skip compilation on MPS (hangs indefinitely)
|
||||
pass
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Scaling laws and muP extrapolations to determine the optimal training horizon, batch size, learning rates, weight decay.
|
||||
|
|
|
|||
|
|
@ -16,7 +16,7 @@ import time
|
|||
import wandb
|
||||
import torch
|
||||
from contextlib import nullcontext
|
||||
from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, get_base_dir, autodetect_device_type
|
||||
from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, get_base_dir, autodetect_device_type, should_use_torch_compile
|
||||
from nanochat.tokenizer import get_token_bytes
|
||||
from nanochat.checkpoint_manager import save_checkpoint
|
||||
from nanochat.loss_eval import evaluate_bpb
|
||||
|
|
@ -81,7 +81,11 @@ pretrain_batch_size = meta.get("device_batch_size", None)
|
|||
if pretrain_batch_size is not None and args.device_batch_size > pretrain_batch_size:
|
||||
print0(f"FOOTGUN WARNING: base model training used device_batch_size {pretrain_batch_size}, did you pass in a good --device-batch-size to this script?")
|
||||
orig_model = model
|
||||
model = torch.compile(model, dynamic=False)
|
||||
if should_use_torch_compile(device):
|
||||
model = torch.compile(model, dynamic=False)
|
||||
else:
|
||||
# Skip compilation on MPS (hangs indefinitely)
|
||||
pass
|
||||
depth = model.config.n_layer
|
||||
num_flops_per_token = model.estimate_flops()
|
||||
tokens_per_fwdbwd = args.device_batch_size * args.max_seq_len # tokens per iteration for a single rank
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user