Improve Mac/MPS compatibility and device handling

Added dev/runmac_overnight.sh for optimized Mac training. Updated device-specific logic throughout dataloader, GPT, Muon optimizer, and training scripts to avoid CUDA-only features on MPS/CPU (e.g., torch.compile, pin_memory, non_blocking, bfloat16). Relaxed torch version constraints in pyproject.toml and removed Linux/CUDA-specific PyTorch config for better macOS support.
This commit is contained in:
Jason Kneen 2025-10-22 01:55:38 +01:00
parent 50bea28ef9
commit 3e184d343e
9 changed files with 1792 additions and 1278 deletions

0
dev/runcpu.sh Normal file → Executable file
View File

126
dev/runmac_overnight.sh Executable file
View File

@ -0,0 +1,126 @@
#!/bin/bash
# Optimized overnight training for Mac (MPS/Apple Silicon)
# Expected runtime: 8-12 hours
# Expected result: Much better chatbot with coherent responses
set -e # Exit on error
echo "=================================="
echo "nanochat Mac Overnight Training"
echo "=================================="
echo "Started: $(date)"
echo ""
# Activate virtual environment
source .venv/bin/activate
# Configuration
DEPTH=6 # Bigger model (6 layers vs 4)
BASE_ITERATIONS=500 # More base training
MID_ITERATIONS=150 # More midtraining
SFT_ITERATIONS=150 # More SFT
DATA_SHARDS=50 # More training data
echo "Configuration:"
echo " Model depth: $DEPTH (36.7M → 82M params)"
echo " Base iterations: $BASE_ITERATIONS"
echo " Mid iterations: $MID_ITERATIONS"
echo " SFT iterations: $SFT_ITERATIONS"
echo " Data shards: $DATA_SHARDS"
echo ""
# Clean up old run
echo "Cleaning up previous training..."
rm -f report.md
python -m scripts.report --reset
# Download training data
echo ""
echo "Step 1/6: Downloading training data ($DATA_SHARDS shards)..."
python -m nanochat.dataset -n $DATA_SHARDS
# Download identity conversations
echo ""
echo "Step 2/6: Downloading identity conversations..."
if [ ! -f ~/.cache/nanochat/identity_conversations.jsonl ]; then
curl -L -o ~/.cache/nanochat/identity_conversations.jsonl \
https://karpathy-public.s3.us-west-2.amazonaws.com/identity_conversations.jsonl
else
echo " Already downloaded, skipping."
fi
# Build tokenizer
echo ""
echo "Step 3/6: Training tokenizer..."
python -m nanochat.tokenizer
# Base model training
echo ""
echo "Step 4/6: Training base model ($BASE_ITERATIONS iterations)..."
echo " This will take ~2-4 hours..."
python -m scripts.base_train \
--depth=$DEPTH \
--max_seq_len=1024 \
--device_batch_size=1 \
--total_batch_size=1024 \
--num_iterations=$BASE_ITERATIONS \
--eval_every=100 \
--eval_tokens=8192 \
--core_metric_every=250 \
--core_metric_max_per_task=20 \
--sample_every=100
# Evaluate base model
echo ""
echo "Evaluating base model..."
python -m scripts.base_loss
python -m scripts.base_eval
# Midtraining
echo ""
echo "Step 5/6: Midtraining ($MID_ITERATIONS iterations)..."
echo " This will take ~2-3 hours..."
python -m scripts.mid_train \
--num_iterations=$MID_ITERATIONS \
--device_batch_size=1 \
--max_seq_len=1024 \
--total_batch_size=1024 \
--eval_every=50
# SFT training
echo ""
echo "Step 6/6: Chat fine-tuning (SFT) ($SFT_ITERATIONS iterations)..."
echo " This will take ~2-3 hours..."
python -m scripts.chat_sft \
--num_iterations=$SFT_ITERATIONS \
--device_batch_size=1 \
--target_examples_per_step=8 \
--eval_steps=10
# Final evaluation
echo ""
echo "Running final evaluations..."
python -m scripts.chat_eval -i sft || echo "Chat eval had issues, skipping..."
# Generate report
echo ""
echo "Generating final report..."
python -m scripts.report
# Copy report to current directory
cp ~/.cache/nanochat/report/report.md ./report_overnight.md
echo ""
echo "=================================="
echo "Training Complete!"
echo "=================================="
echo "Finished: $(date)"
echo ""
echo "Your chatbot is ready! Chat with it:"
echo " python -m scripts.chat_cli -i sft"
echo ""
echo "Or start the web UI:"
echo " python -m scripts.chat_web -i sft"
echo ""
echo "Report saved to: report_overnight.md"
echo "=================================="

View File

@ -16,6 +16,9 @@ def tokenizing_distributed_data_loader(B, T, split, tokenizer_threads=4, tokeniz
bos_token = tokenizer.get_bos_token_id()
# scratch buffer holds the tokens for one iteration
token_buffer = deque() # we stream tokens on the right and pop from the left
# pin_memory and non_blocking only work on CUDA
device_type = device if isinstance(device, str) else device.type
use_cuda_optimizations = device_type == "cuda"
# infinite iterator over document batches
def document_batches():
@ -38,11 +41,11 @@ def tokenizing_distributed_data_loader(B, T, split, tokenizer_threads=4, tokeniz
batch_index += 1
# Move tokens from the deque into the scratch buffer
tokens = [token_buffer.popleft() for _ in range(needed_tokens)]
scratch = torch.tensor(tokens, dtype=torch.int64, pin_memory=True)
scratch = torch.tensor(tokens, dtype=torch.int64, pin_memory=use_cuda_optimizations)
# Create the inputs/targets as 1D tensors
inputs_cpu = scratch[:-1].to(dtype=torch.int32)
targets_cpu = scratch[1:]
# Reshape to 2D and move to GPU async
inputs = inputs_cpu.view(B, T).to(device=device, dtype=torch.int32, non_blocking=True)
targets = targets_cpu.view(B, T).to(device=device, dtype=torch.int64, non_blocking=True)
# Reshape to 2D and move to device (async on CUDA)
inputs = inputs_cpu.view(B, T).to(device=device, dtype=torch.int32, non_blocking=use_cuda_optimizations)
targets = targets_cpu.view(B, T).to(device=device, dtype=torch.int64, non_blocking=use_cuda_optimizations)
yield inputs, targets

View File

@ -35,7 +35,13 @@ class GPTConfig:
def norm(x):
# Purely functional rmsnorm with no learnable params
return F.rms_norm(x, (x.size(-1),))
# Fallback for older PyTorch versions that don't have F.rms_norm
if hasattr(F, 'rms_norm'):
return F.rms_norm(x, (x.size(-1),))
else:
# Manual RMS norm implementation
variance = x.pow(2).mean(-1, keepdim=True)
return x * torch.rsqrt(variance + 1e-5)
def apply_rotary_emb(x, cos, sin):
@ -211,7 +217,9 @@ class GPT(nn.Module):
# calculate the rotation frequencies at each (time, channel) pair
freqs = torch.outer(t, inv_freq)
cos, sin = freqs.cos(), freqs.sin()
cos, sin = cos.bfloat16(), sin.bfloat16() # keep them in bfloat16
# Use bfloat16 on CUDA, float32 on MPS/CPU (MPS doesn't support bfloat16)
if device.type == "cuda":
cos, sin = cos.bfloat16(), sin.bfloat16()
cos, sin = cos[None, :, None, :], sin[None, :, None, :] # add batch and head dims for later broadcasting
return cos, sin
@ -244,7 +252,10 @@ class GPT(nn.Module):
dict(params=embedding_params, lr=embedding_lr * dmodel_lr_scale),
]
adamw_kwargs = dict(betas=(0.8, 0.95), eps=1e-10, weight_decay=weight_decay)
AdamWFactory = DistAdamW if ddp else partial(torch.optim.AdamW, fused=True)
# fused=True only works on CUDA, not MPS or CPU
device_type = self.get_device().type
use_fused = device_type == "cuda"
AdamWFactory = DistAdamW if ddp else partial(torch.optim.AdamW, fused=use_fused)
adamw_optimizer = AdamWFactory(adam_groups, **adamw_kwargs)
# Create the Muon optimizer for the linear layers
muon_kwargs = dict(lr=matrix_lr, momentum=0.95)
@ -263,7 +274,9 @@ class GPT(nn.Module):
# Grab the rotary embeddings for the current sequence length (they are of shape (1, seq_len, 1, head_dim))
assert T <= self.cos.size(1), f"Sequence length grew beyond the rotary embeddings cache: {T} > {self.cos.size(1)}"
assert idx.device == self.cos.device, f"Rotary embeddings and idx are on different devices: {idx.device} != {self.cos.device}"
assert self.cos.dtype == torch.bfloat16, "Rotary embeddings must be in bfloat16"
# Rotary embeddings are bfloat16 on CUDA, float32 on MPS/CPU
expected_dtype = torch.bfloat16 if self.cos.device.type == "cuda" else torch.float32
assert self.cos.dtype == expected_dtype, f"Rotary embeddings dtype mismatch: expected {expected_dtype}, got {self.cos.dtype}"
# if kv cache exists, we need to offset the rotary embeddings to the current position in the cache
T0 = 0 if kv_cache is None else kv_cache.get_pos()
cos_sin = self.cos[:, T0:T0+T], self.sin[:, T0:T0+T] # truncate cache to current sequence length

View File

@ -6,7 +6,6 @@ import torch
from torch import Tensor
import torch.distributed as dist
@torch.compile
def zeropower_via_newtonschulz5(G: Tensor, steps: int) -> Tensor:
"""
Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a
@ -19,7 +18,11 @@ def zeropower_via_newtonschulz5(G: Tensor, steps: int) -> Tensor:
"""
assert G.ndim >= 2 # batched Muon implementation by @scottjmaddox, and put into practice in the record by @YouJiacheng
a, b, c = (3.4445, -4.7750, 2.0315)
X = G.bfloat16()
# Use bfloat16 on CUDA for speed, float32 on MPS/CPU (MPS doesn't support bfloat16)
if G.device.type == "cuda":
X = G.bfloat16()
else:
X = G.float()
if G.size(-2) > G.size(-1):
X = X.mT
@ -35,6 +38,10 @@ def zeropower_via_newtonschulz5(G: Tensor, steps: int) -> Tensor:
X = X.mT
return X
# Compile the function for CUDA; use eager mode for MPS/CPU
# We check at runtime in the optimizer which version to use
_zeropower_compiled = torch.compile(zeropower_via_newtonschulz5)
class Muon(torch.optim.Optimizer):
"""
Muon - MomentUm Orthogonalized by Newton-schulz
@ -65,6 +72,11 @@ class Muon(torch.optim.Optimizer):
group = dict(params=[p for p in params if p.numel() == size])
param_groups.append(group)
super().__init__(param_groups, defaults)
# Determine which zeropower function to use based on device
# Check the first parameter to determine device
first_param = next(iter(params))
self._use_compiled = first_param.device.type == "cuda"
self._zeropower_fn = _zeropower_compiled if self._use_compiled else zeropower_via_newtonschulz5
@torch.no_grad()
def step(self):
@ -79,7 +91,7 @@ class Muon(torch.optim.Optimizer):
buf: Tensor = state["momentum_buffer"]
buf.lerp_(g, 1 - group["momentum"])
g = g.lerp_(buf, group["momentum"]) if group["nesterov"] else buf
g = zeropower_via_newtonschulz5(g, steps=group["ns_steps"])
g = self._zeropower_fn(g, steps=group["ns_steps"])
p.add_(g, alpha=-group["lr"] * max(1, p.size(-2) / p.size(-1))**0.5)
@ -122,6 +134,10 @@ class DistMuon(torch.optim.Optimizer):
print(f"Muon: Grouping {len(group_params)} params of shape {shape}, device {device}, dtype {dtype}")
param_groups.append(dict(params=group_params, zero_buffer=torch.zeros_like(group_params[0])))
super().__init__(param_groups, defaults)
# Determine which zeropower function to use based on device
first_param = params[0]
self._use_compiled = first_param.device.type == "cuda"
self._zeropower_fn = _zeropower_compiled if self._use_compiled else zeropower_via_newtonschulz5
@torch.no_grad()
def step(self):
@ -173,7 +189,7 @@ class DistMuon(torch.optim.Optimizer):
buf: Tensor = state["momentum_buffer"]
buf.lerp_(g, 1.0 - group["momentum"])
g = g.lerp_(buf, group["momentum"]) if group["nesterov"] else buf
g = zeropower_via_newtonschulz5(g, steps=group["ns_steps"])
g = self._zeropower_fn(g, steps=group["ns_steps"])
scale = (max(1.0, p.size(-2) / p.size(-1)) ** 0.5)
p.add_(g, alpha=-group["lr"] * scale)
# Replicate updated parameters to all ranks

View File

@ -14,7 +14,7 @@ dependencies = [
"setuptools>=80.9.0",
"tiktoken>=0.11.0",
"tokenizers>=0.22.0",
"torch>=2.8.0",
"torch>=2.0.0,!=2.3.*,!=2.4.*,!=2.5.*,!=2.6.*,!=2.7.*,!=2.8.*,!=2.9.*",
"uvicorn>=0.36.0",
"wandb>=0.21.3",
]
@ -44,19 +44,5 @@ python_files = ["test_*.py"]
python_classes = ["Test*"]
python_functions = ["test_*"]
# target torch to cuda 12.8
[tool.uv.sources]
torch = [
{ index = "pytorch-cpu", marker = "sys_platform != 'linux'" },
{ index = "pytorch-cu128", marker = "sys_platform == 'linux'" },
]
[[tool.uv.index]]
name = "pytorch-cpu"
url = "https://download.pytorch.org/whl/cpu"
explicit = true
[[tool.uv.index]]
name = "pytorch-cu128"
url = "https://download.pytorch.org/whl/cu128"
explicit = true
# Note: PyTorch configuration removed for macOS compatibility
# On Linux with CUDA, you may need to reinstall from pytorch.org/whl/cu128

View File

@ -109,7 +109,9 @@ with torch.device("meta"):
model.to_empty(device=device)
model.init_weights()
orig_model = model # original, uncompiled model, for saving raw model state_dict
model = torch.compile(model, dynamic=False) # TODO: dynamic True/False think through
# torch.compile only works well on CUDA; skip on MPS/CPU
if device_type == "cuda":
model = torch.compile(model, dynamic=False) # TODO: dynamic True/False think through
num_params = sum(p.numel() for p in model.parameters())
print0(f"Number of parameters: {num_params:,}")
num_flops_per_token = model.estimate_flops()

View File

@ -15,6 +15,9 @@ os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
import time
import wandb
import torch
# Disable torch.compile's dynamo on MPS (not supported)
import torch._dynamo
torch._dynamo.config.suppress_errors = True
from contextlib import nullcontext
from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, get_base_dir, autodetect_device_type
from nanochat.tokenizer import get_token_bytes
@ -70,7 +73,9 @@ pretrain_batch_size = meta.get("device_batch_size", None)
if pretrain_batch_size is not None and device_batch_size > pretrain_batch_size:
print0(f"FOOTGUN WARNING: base model training used device_batch_size {pretrain_batch_size}, did you pass in a good --device_batch_size to this script?")
orig_model = model
model = torch.compile(model, dynamic=False)
# torch.compile only works well on CUDA; skip on MPS/CPU
if device_type == "cuda":
model = torch.compile(model, dynamic=False)
depth = model.config.n_layer
num_flops_per_token = model.estimate_flops()
tokens_per_fwdbwd = device_batch_size * max_seq_len # tokens per iteration for a single rank
@ -119,7 +124,9 @@ def mid_data_generator(split):
assert dataset_size > 0
needed_tokens = device_batch_size * max_seq_len + 1 # to form one training batch of inputs,targets
token_buffer = deque()
scratch = torch.empty(needed_tokens, dtype=torch.int64, pin_memory=True)
# pin_memory only works on CUDA
use_pin_memory = device_type == "cuda"
scratch = torch.empty(needed_tokens, dtype=torch.int64, pin_memory=use_pin_memory)
cursor = ddp_rank # increments by ddp_world_size each time, so each rank processes unique documents
it = 0 # iteration counter
while True:
@ -142,8 +149,9 @@ def mid_data_generator(split):
scratch[i] = token_buffer.popleft()
inputs_cpu = scratch[:-1].to(dtype=torch.int32)
targets_cpu = scratch[1:]
inputs = inputs_cpu.view(device_batch_size, max_seq_len).to(device=device, dtype=torch.int32, non_blocking=True)
targets = targets_cpu.view(device_batch_size, max_seq_len).to(device=device, dtype=torch.int64, non_blocking=True)
# non_blocking only works on CUDA
inputs = inputs_cpu.view(device_batch_size, max_seq_len).to(device=device, dtype=torch.int32, non_blocking=use_pin_memory)
targets = targets_cpu.view(device_batch_size, max_seq_len).to(device=device, dtype=torch.int64, non_blocking=use_pin_memory)
if split == "train":
if num_iterations > 0:
approx_progress = it / num_iterations # calculate progress from the max number of iterations

2848
uv.lock

File diff suppressed because it is too large Load Diff