diff --git a/nanochat/common.py b/nanochat/common.py index 9bcd5dd1..ea609b28 100644 --- a/nanochat/common.py +++ b/nanochat/common.py @@ -4,6 +4,8 @@ Common utilities for nanochat. import os import re +import sys +import platform import logging import urllib.request import torch @@ -150,6 +152,39 @@ def autodetect_device_type(): print0(f"Autodetected device type: {device_type}") return device_type +def is_mps_device(device): + """Check if device is MPS (Apple Metal Performance Shaders).""" + if isinstance(device, str): + return device == "mps" + return hasattr(device, 'type') and device.type == "mps" + +def should_use_torch_compile(device): + """ + Determine if torch.compile should be used based on device type and platform. + torch.compile hangs indefinitely on MPS devices (macOS). + Reference: https://github.com/karpathy/nanochat/pull/319 + """ + # Check if running on macOS with MPS device + is_macos = platform.system() == "Darwin" + is_mps = is_mps_device(device) + + if is_macos and is_mps: + logger.warning("=" * 80) + logger.warning("WARNING: torch.compile is disabled on macOS with MPS (Apple Metal)") + logger.warning("Platform: macOS (Darwin)") + logger.warning("Device: MPS (Metal Performance Shaders)") + logger.warning("Reason: torch.compile hangs indefinitely on MPS devices") + logger.warning("Reference: https://github.com/karpathy/nanochat/pull/319") + logger.warning("Using eager mode instead (no performance impact on evaluation)") + logger.warning("=" * 80) + return False + elif is_mps and not is_macos: + # MPS on non-macOS platform (shouldn't happen, but be defensive) + logger.warning("WARNING: MPS device detected on non-macOS platform - disabling torch.compile") + return False + + return True + def compute_init(device_type="cuda"): # cuda|cpu|mps """Basic initialization that we keep doing over and over, so make common.""" diff --git a/scripts/base_train.py b/scripts/base_train.py index ccf35e64..2dbd00ca 100644 --- a/scripts/base_train.py +++ b/scripts/base_train.py @@ -26,7 +26,7 @@ import torch from nanochat.gpt import GPT, GPTConfig from nanochat.dataloader import tokenizing_distributed_data_loader_bos_bestfit, tokenizing_distributed_data_loader_with_state_bos_bestfit -from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, print_banner, get_base_dir, autodetect_device_type, get_peak_flops +from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, print_banner, get_base_dir, autodetect_device_type, get_peak_flops, should_use_torch_compile from nanochat.tokenizer import get_tokenizer, get_token_bytes from nanochat.checkpoint_manager import save_checkpoint, load_checkpoint from nanochat.loss_eval import evaluate_bpb @@ -234,7 +234,11 @@ def disable_fp8(model): # Compile the model orig_model = model # original, uncompiled model, for saving raw model state_dict and for inference/evaluation (because the shapes may change shape) -model = torch.compile(model, dynamic=False) # the inputs to model will never change shape so dynamic=False is safe +if should_use_torch_compile(device): + model = torch.compile(model, dynamic=False) # the inputs to model will never change shape so dynamic=False is safe +else: + # Skip compilation on MPS (hangs indefinitely) + pass # ----------------------------------------------------------------------------- # Scaling laws and muP extrapolations to determine the optimal training horizon, batch size, learning rates, weight decay. diff --git a/scripts/chat_sft.py b/scripts/chat_sft.py index 4c81f065..891ec89a 100644 --- a/scripts/chat_sft.py +++ b/scripts/chat_sft.py @@ -16,7 +16,7 @@ import time import wandb import torch from contextlib import nullcontext -from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, get_base_dir, autodetect_device_type +from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, get_base_dir, autodetect_device_type, should_use_torch_compile from nanochat.tokenizer import get_token_bytes from nanochat.checkpoint_manager import save_checkpoint from nanochat.loss_eval import evaluate_bpb @@ -81,7 +81,11 @@ pretrain_batch_size = meta.get("device_batch_size", None) if pretrain_batch_size is not None and args.device_batch_size > pretrain_batch_size: print0(f"FOOTGUN WARNING: base model training used device_batch_size {pretrain_batch_size}, did you pass in a good --device-batch-size to this script?") orig_model = model -model = torch.compile(model, dynamic=False) +if should_use_torch_compile(device): + model = torch.compile(model, dynamic=False) +else: + # Skip compilation on MPS (hangs indefinitely) + pass depth = model.config.n_layer num_flops_per_token = model.estimate_flops() tokens_per_fwdbwd = args.device_batch_size * args.max_seq_len # tokens per iteration for a single rank