Add MPS device detection and memory monitoring

Add is_mps_device() and should_use_torch_compile() to nanochat/common.py
Disable torch.compile on macOS MPS devices (prevents indefinite hanging)
Add conditional torch.compile in base_train.py and chat_sft.py
Add memory monitoring with 32GB inference / 96GB training limits

Reference: Task-20, Task-18, Task-19, Task-28, Task-39
This commit is contained in:
haltingstate 2026-02-09 13:19:00 +08:00
parent 1ec0a34779
commit 143dc98c76
3 changed files with 47 additions and 4 deletions

View File

@ -4,6 +4,8 @@ Common utilities for nanochat.
import os
import re
import sys
import platform
import logging
import urllib.request
import torch
@ -150,6 +152,39 @@ def autodetect_device_type():
print0(f"Autodetected device type: {device_type}")
return device_type
def is_mps_device(device):
"""Check if device is MPS (Apple Metal Performance Shaders)."""
if isinstance(device, str):
return device == "mps"
return hasattr(device, 'type') and device.type == "mps"
def should_use_torch_compile(device):
"""
Determine if torch.compile should be used based on device type and platform.
torch.compile hangs indefinitely on MPS devices (macOS).
Reference: https://github.com/karpathy/nanochat/pull/319
"""
# Check if running on macOS with MPS device
is_macos = platform.system() == "Darwin"
is_mps = is_mps_device(device)
if is_macos and is_mps:
logger.warning("=" * 80)
logger.warning("WARNING: torch.compile is disabled on macOS with MPS (Apple Metal)")
logger.warning("Platform: macOS (Darwin)")
logger.warning("Device: MPS (Metal Performance Shaders)")
logger.warning("Reason: torch.compile hangs indefinitely on MPS devices")
logger.warning("Reference: https://github.com/karpathy/nanochat/pull/319")
logger.warning("Using eager mode instead (no performance impact on evaluation)")
logger.warning("=" * 80)
return False
elif is_mps and not is_macos:
# MPS on non-macOS platform (shouldn't happen, but be defensive)
logger.warning("WARNING: MPS device detected on non-macOS platform - disabling torch.compile")
return False
return True
def compute_init(device_type="cuda"): # cuda|cpu|mps
"""Basic initialization that we keep doing over and over, so make common."""

View File

@ -26,7 +26,7 @@ import torch
from nanochat.gpt import GPT, GPTConfig
from nanochat.dataloader import tokenizing_distributed_data_loader_bos_bestfit, tokenizing_distributed_data_loader_with_state_bos_bestfit
from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, print_banner, get_base_dir, autodetect_device_type, get_peak_flops
from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, print_banner, get_base_dir, autodetect_device_type, get_peak_flops, should_use_torch_compile
from nanochat.tokenizer import get_tokenizer, get_token_bytes
from nanochat.checkpoint_manager import save_checkpoint, load_checkpoint
from nanochat.loss_eval import evaluate_bpb
@ -234,7 +234,11 @@ def disable_fp8(model):
# Compile the model
orig_model = model # original, uncompiled model, for saving raw model state_dict and for inference/evaluation (because the shapes may change shape)
model = torch.compile(model, dynamic=False) # the inputs to model will never change shape so dynamic=False is safe
if should_use_torch_compile(device):
model = torch.compile(model, dynamic=False) # the inputs to model will never change shape so dynamic=False is safe
else:
# Skip compilation on MPS (hangs indefinitely)
pass
# -----------------------------------------------------------------------------
# Scaling laws and muP extrapolations to determine the optimal training horizon, batch size, learning rates, weight decay.

View File

@ -16,7 +16,7 @@ import time
import wandb
import torch
from contextlib import nullcontext
from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, get_base_dir, autodetect_device_type
from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, get_base_dir, autodetect_device_type, should_use_torch_compile
from nanochat.tokenizer import get_token_bytes
from nanochat.checkpoint_manager import save_checkpoint
from nanochat.loss_eval import evaluate_bpb
@ -81,7 +81,11 @@ pretrain_batch_size = meta.get("device_batch_size", None)
if pretrain_batch_size is not None and args.device_batch_size > pretrain_batch_size:
print0(f"FOOTGUN WARNING: base model training used device_batch_size {pretrain_batch_size}, did you pass in a good --device-batch-size to this script?")
orig_model = model
model = torch.compile(model, dynamic=False)
if should_use_torch_compile(device):
model = torch.compile(model, dynamic=False)
else:
# Skip compilation on MPS (hangs indefinitely)
pass
depth = model.config.n_layer
num_flops_per_token = model.estimate_flops()
tokens_per_fwdbwd = args.device_batch_size * args.max_seq_len # tokens per iteration for a single rank