mirror of
https://github.com/karpathy/nanochat.git
synced 2025-12-06 04:12:13 +00:00
Improve Mac/MPS compatibility and device handling
Added dev/runmac_overnight.sh for optimized Mac training. Updated device-specific logic throughout dataloader, GPT, Muon optimizer, and training scripts to avoid CUDA-only features on MPS/CPU (e.g., torch.compile, pin_memory, non_blocking, bfloat16). Relaxed torch version constraints in pyproject.toml and removed Linux/CUDA-specific PyTorch config for better macOS support.
This commit is contained in:
parent
50bea28ef9
commit
3e184d343e
0
dev/runcpu.sh
Normal file → Executable file
0
dev/runcpu.sh
Normal file → Executable file
126
dev/runmac_overnight.sh
Executable file
126
dev/runmac_overnight.sh
Executable file
|
|
@ -0,0 +1,126 @@
|
|||
#!/bin/bash
|
||||
# Optimized overnight training for Mac (MPS/Apple Silicon)
|
||||
# Expected runtime: 8-12 hours
|
||||
# Expected result: Much better chatbot with coherent responses
|
||||
|
||||
set -e # Exit on error
|
||||
|
||||
echo "=================================="
|
||||
echo "nanochat Mac Overnight Training"
|
||||
echo "=================================="
|
||||
echo "Started: $(date)"
|
||||
echo ""
|
||||
|
||||
# Activate virtual environment
|
||||
source .venv/bin/activate
|
||||
|
||||
# Configuration
|
||||
DEPTH=6 # Bigger model (6 layers vs 4)
|
||||
BASE_ITERATIONS=500 # More base training
|
||||
MID_ITERATIONS=150 # More midtraining
|
||||
SFT_ITERATIONS=150 # More SFT
|
||||
DATA_SHARDS=50 # More training data
|
||||
|
||||
echo "Configuration:"
|
||||
echo " Model depth: $DEPTH (36.7M → 82M params)"
|
||||
echo " Base iterations: $BASE_ITERATIONS"
|
||||
echo " Mid iterations: $MID_ITERATIONS"
|
||||
echo " SFT iterations: $SFT_ITERATIONS"
|
||||
echo " Data shards: $DATA_SHARDS"
|
||||
echo ""
|
||||
|
||||
# Clean up old run
|
||||
echo "Cleaning up previous training..."
|
||||
rm -f report.md
|
||||
python -m scripts.report --reset
|
||||
|
||||
# Download training data
|
||||
echo ""
|
||||
echo "Step 1/6: Downloading training data ($DATA_SHARDS shards)..."
|
||||
python -m nanochat.dataset -n $DATA_SHARDS
|
||||
|
||||
# Download identity conversations
|
||||
echo ""
|
||||
echo "Step 2/6: Downloading identity conversations..."
|
||||
if [ ! -f ~/.cache/nanochat/identity_conversations.jsonl ]; then
|
||||
curl -L -o ~/.cache/nanochat/identity_conversations.jsonl \
|
||||
https://karpathy-public.s3.us-west-2.amazonaws.com/identity_conversations.jsonl
|
||||
else
|
||||
echo " Already downloaded, skipping."
|
||||
fi
|
||||
|
||||
# Build tokenizer
|
||||
echo ""
|
||||
echo "Step 3/6: Training tokenizer..."
|
||||
python -m nanochat.tokenizer
|
||||
|
||||
# Base model training
|
||||
echo ""
|
||||
echo "Step 4/6: Training base model ($BASE_ITERATIONS iterations)..."
|
||||
echo " This will take ~2-4 hours..."
|
||||
python -m scripts.base_train \
|
||||
--depth=$DEPTH \
|
||||
--max_seq_len=1024 \
|
||||
--device_batch_size=1 \
|
||||
--total_batch_size=1024 \
|
||||
--num_iterations=$BASE_ITERATIONS \
|
||||
--eval_every=100 \
|
||||
--eval_tokens=8192 \
|
||||
--core_metric_every=250 \
|
||||
--core_metric_max_per_task=20 \
|
||||
--sample_every=100
|
||||
|
||||
# Evaluate base model
|
||||
echo ""
|
||||
echo "Evaluating base model..."
|
||||
python -m scripts.base_loss
|
||||
python -m scripts.base_eval
|
||||
|
||||
# Midtraining
|
||||
echo ""
|
||||
echo "Step 5/6: Midtraining ($MID_ITERATIONS iterations)..."
|
||||
echo " This will take ~2-3 hours..."
|
||||
python -m scripts.mid_train \
|
||||
--num_iterations=$MID_ITERATIONS \
|
||||
--device_batch_size=1 \
|
||||
--max_seq_len=1024 \
|
||||
--total_batch_size=1024 \
|
||||
--eval_every=50
|
||||
|
||||
# SFT training
|
||||
echo ""
|
||||
echo "Step 6/6: Chat fine-tuning (SFT) ($SFT_ITERATIONS iterations)..."
|
||||
echo " This will take ~2-3 hours..."
|
||||
python -m scripts.chat_sft \
|
||||
--num_iterations=$SFT_ITERATIONS \
|
||||
--device_batch_size=1 \
|
||||
--target_examples_per_step=8 \
|
||||
--eval_steps=10
|
||||
|
||||
# Final evaluation
|
||||
echo ""
|
||||
echo "Running final evaluations..."
|
||||
python -m scripts.chat_eval -i sft || echo "Chat eval had issues, skipping..."
|
||||
|
||||
# Generate report
|
||||
echo ""
|
||||
echo "Generating final report..."
|
||||
python -m scripts.report
|
||||
|
||||
# Copy report to current directory
|
||||
cp ~/.cache/nanochat/report/report.md ./report_overnight.md
|
||||
|
||||
echo ""
|
||||
echo "=================================="
|
||||
echo "Training Complete!"
|
||||
echo "=================================="
|
||||
echo "Finished: $(date)"
|
||||
echo ""
|
||||
echo "Your chatbot is ready! Chat with it:"
|
||||
echo " python -m scripts.chat_cli -i sft"
|
||||
echo ""
|
||||
echo "Or start the web UI:"
|
||||
echo " python -m scripts.chat_web -i sft"
|
||||
echo ""
|
||||
echo "Report saved to: report_overnight.md"
|
||||
echo "=================================="
|
||||
|
|
@ -16,6 +16,9 @@ def tokenizing_distributed_data_loader(B, T, split, tokenizer_threads=4, tokeniz
|
|||
bos_token = tokenizer.get_bos_token_id()
|
||||
# scratch buffer holds the tokens for one iteration
|
||||
token_buffer = deque() # we stream tokens on the right and pop from the left
|
||||
# pin_memory and non_blocking only work on CUDA
|
||||
device_type = device if isinstance(device, str) else device.type
|
||||
use_cuda_optimizations = device_type == "cuda"
|
||||
|
||||
# infinite iterator over document batches
|
||||
def document_batches():
|
||||
|
|
@ -38,11 +41,11 @@ def tokenizing_distributed_data_loader(B, T, split, tokenizer_threads=4, tokeniz
|
|||
batch_index += 1
|
||||
# Move tokens from the deque into the scratch buffer
|
||||
tokens = [token_buffer.popleft() for _ in range(needed_tokens)]
|
||||
scratch = torch.tensor(tokens, dtype=torch.int64, pin_memory=True)
|
||||
scratch = torch.tensor(tokens, dtype=torch.int64, pin_memory=use_cuda_optimizations)
|
||||
# Create the inputs/targets as 1D tensors
|
||||
inputs_cpu = scratch[:-1].to(dtype=torch.int32)
|
||||
targets_cpu = scratch[1:]
|
||||
# Reshape to 2D and move to GPU async
|
||||
inputs = inputs_cpu.view(B, T).to(device=device, dtype=torch.int32, non_blocking=True)
|
||||
targets = targets_cpu.view(B, T).to(device=device, dtype=torch.int64, non_blocking=True)
|
||||
# Reshape to 2D and move to device (async on CUDA)
|
||||
inputs = inputs_cpu.view(B, T).to(device=device, dtype=torch.int32, non_blocking=use_cuda_optimizations)
|
||||
targets = targets_cpu.view(B, T).to(device=device, dtype=torch.int64, non_blocking=use_cuda_optimizations)
|
||||
yield inputs, targets
|
||||
|
|
|
|||
|
|
@ -35,7 +35,13 @@ class GPTConfig:
|
|||
|
||||
def norm(x):
|
||||
# Purely functional rmsnorm with no learnable params
|
||||
# Fallback for older PyTorch versions that don't have F.rms_norm
|
||||
if hasattr(F, 'rms_norm'):
|
||||
return F.rms_norm(x, (x.size(-1),))
|
||||
else:
|
||||
# Manual RMS norm implementation
|
||||
variance = x.pow(2).mean(-1, keepdim=True)
|
||||
return x * torch.rsqrt(variance + 1e-5)
|
||||
|
||||
|
||||
def apply_rotary_emb(x, cos, sin):
|
||||
|
|
@ -211,7 +217,9 @@ class GPT(nn.Module):
|
|||
# calculate the rotation frequencies at each (time, channel) pair
|
||||
freqs = torch.outer(t, inv_freq)
|
||||
cos, sin = freqs.cos(), freqs.sin()
|
||||
cos, sin = cos.bfloat16(), sin.bfloat16() # keep them in bfloat16
|
||||
# Use bfloat16 on CUDA, float32 on MPS/CPU (MPS doesn't support bfloat16)
|
||||
if device.type == "cuda":
|
||||
cos, sin = cos.bfloat16(), sin.bfloat16()
|
||||
cos, sin = cos[None, :, None, :], sin[None, :, None, :] # add batch and head dims for later broadcasting
|
||||
return cos, sin
|
||||
|
||||
|
|
@ -244,7 +252,10 @@ class GPT(nn.Module):
|
|||
dict(params=embedding_params, lr=embedding_lr * dmodel_lr_scale),
|
||||
]
|
||||
adamw_kwargs = dict(betas=(0.8, 0.95), eps=1e-10, weight_decay=weight_decay)
|
||||
AdamWFactory = DistAdamW if ddp else partial(torch.optim.AdamW, fused=True)
|
||||
# fused=True only works on CUDA, not MPS or CPU
|
||||
device_type = self.get_device().type
|
||||
use_fused = device_type == "cuda"
|
||||
AdamWFactory = DistAdamW if ddp else partial(torch.optim.AdamW, fused=use_fused)
|
||||
adamw_optimizer = AdamWFactory(adam_groups, **adamw_kwargs)
|
||||
# Create the Muon optimizer for the linear layers
|
||||
muon_kwargs = dict(lr=matrix_lr, momentum=0.95)
|
||||
|
|
@ -263,7 +274,9 @@ class GPT(nn.Module):
|
|||
# Grab the rotary embeddings for the current sequence length (they are of shape (1, seq_len, 1, head_dim))
|
||||
assert T <= self.cos.size(1), f"Sequence length grew beyond the rotary embeddings cache: {T} > {self.cos.size(1)}"
|
||||
assert idx.device == self.cos.device, f"Rotary embeddings and idx are on different devices: {idx.device} != {self.cos.device}"
|
||||
assert self.cos.dtype == torch.bfloat16, "Rotary embeddings must be in bfloat16"
|
||||
# Rotary embeddings are bfloat16 on CUDA, float32 on MPS/CPU
|
||||
expected_dtype = torch.bfloat16 if self.cos.device.type == "cuda" else torch.float32
|
||||
assert self.cos.dtype == expected_dtype, f"Rotary embeddings dtype mismatch: expected {expected_dtype}, got {self.cos.dtype}"
|
||||
# if kv cache exists, we need to offset the rotary embeddings to the current position in the cache
|
||||
T0 = 0 if kv_cache is None else kv_cache.get_pos()
|
||||
cos_sin = self.cos[:, T0:T0+T], self.sin[:, T0:T0+T] # truncate cache to current sequence length
|
||||
|
|
|
|||
|
|
@ -6,7 +6,6 @@ import torch
|
|||
from torch import Tensor
|
||||
import torch.distributed as dist
|
||||
|
||||
@torch.compile
|
||||
def zeropower_via_newtonschulz5(G: Tensor, steps: int) -> Tensor:
|
||||
"""
|
||||
Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a
|
||||
|
|
@ -19,7 +18,11 @@ def zeropower_via_newtonschulz5(G: Tensor, steps: int) -> Tensor:
|
|||
"""
|
||||
assert G.ndim >= 2 # batched Muon implementation by @scottjmaddox, and put into practice in the record by @YouJiacheng
|
||||
a, b, c = (3.4445, -4.7750, 2.0315)
|
||||
# Use bfloat16 on CUDA for speed, float32 on MPS/CPU (MPS doesn't support bfloat16)
|
||||
if G.device.type == "cuda":
|
||||
X = G.bfloat16()
|
||||
else:
|
||||
X = G.float()
|
||||
if G.size(-2) > G.size(-1):
|
||||
X = X.mT
|
||||
|
||||
|
|
@ -35,6 +38,10 @@ def zeropower_via_newtonschulz5(G: Tensor, steps: int) -> Tensor:
|
|||
X = X.mT
|
||||
return X
|
||||
|
||||
# Compile the function for CUDA; use eager mode for MPS/CPU
|
||||
# We check at runtime in the optimizer which version to use
|
||||
_zeropower_compiled = torch.compile(zeropower_via_newtonschulz5)
|
||||
|
||||
class Muon(torch.optim.Optimizer):
|
||||
"""
|
||||
Muon - MomentUm Orthogonalized by Newton-schulz
|
||||
|
|
@ -65,6 +72,11 @@ class Muon(torch.optim.Optimizer):
|
|||
group = dict(params=[p for p in params if p.numel() == size])
|
||||
param_groups.append(group)
|
||||
super().__init__(param_groups, defaults)
|
||||
# Determine which zeropower function to use based on device
|
||||
# Check the first parameter to determine device
|
||||
first_param = next(iter(params))
|
||||
self._use_compiled = first_param.device.type == "cuda"
|
||||
self._zeropower_fn = _zeropower_compiled if self._use_compiled else zeropower_via_newtonschulz5
|
||||
|
||||
@torch.no_grad()
|
||||
def step(self):
|
||||
|
|
@ -79,7 +91,7 @@ class Muon(torch.optim.Optimizer):
|
|||
buf: Tensor = state["momentum_buffer"]
|
||||
buf.lerp_(g, 1 - group["momentum"])
|
||||
g = g.lerp_(buf, group["momentum"]) if group["nesterov"] else buf
|
||||
g = zeropower_via_newtonschulz5(g, steps=group["ns_steps"])
|
||||
g = self._zeropower_fn(g, steps=group["ns_steps"])
|
||||
p.add_(g, alpha=-group["lr"] * max(1, p.size(-2) / p.size(-1))**0.5)
|
||||
|
||||
|
||||
|
|
@ -122,6 +134,10 @@ class DistMuon(torch.optim.Optimizer):
|
|||
print(f"Muon: Grouping {len(group_params)} params of shape {shape}, device {device}, dtype {dtype}")
|
||||
param_groups.append(dict(params=group_params, zero_buffer=torch.zeros_like(group_params[0])))
|
||||
super().__init__(param_groups, defaults)
|
||||
# Determine which zeropower function to use based on device
|
||||
first_param = params[0]
|
||||
self._use_compiled = first_param.device.type == "cuda"
|
||||
self._zeropower_fn = _zeropower_compiled if self._use_compiled else zeropower_via_newtonschulz5
|
||||
|
||||
@torch.no_grad()
|
||||
def step(self):
|
||||
|
|
@ -173,7 +189,7 @@ class DistMuon(torch.optim.Optimizer):
|
|||
buf: Tensor = state["momentum_buffer"]
|
||||
buf.lerp_(g, 1.0 - group["momentum"])
|
||||
g = g.lerp_(buf, group["momentum"]) if group["nesterov"] else buf
|
||||
g = zeropower_via_newtonschulz5(g, steps=group["ns_steps"])
|
||||
g = self._zeropower_fn(g, steps=group["ns_steps"])
|
||||
scale = (max(1.0, p.size(-2) / p.size(-1)) ** 0.5)
|
||||
p.add_(g, alpha=-group["lr"] * scale)
|
||||
# Replicate updated parameters to all ranks
|
||||
|
|
|
|||
|
|
@ -14,7 +14,7 @@ dependencies = [
|
|||
"setuptools>=80.9.0",
|
||||
"tiktoken>=0.11.0",
|
||||
"tokenizers>=0.22.0",
|
||||
"torch>=2.8.0",
|
||||
"torch>=2.0.0,!=2.3.*,!=2.4.*,!=2.5.*,!=2.6.*,!=2.7.*,!=2.8.*,!=2.9.*",
|
||||
"uvicorn>=0.36.0",
|
||||
"wandb>=0.21.3",
|
||||
]
|
||||
|
|
@ -44,19 +44,5 @@ python_files = ["test_*.py"]
|
|||
python_classes = ["Test*"]
|
||||
python_functions = ["test_*"]
|
||||
|
||||
# target torch to cuda 12.8
|
||||
[tool.uv.sources]
|
||||
torch = [
|
||||
{ index = "pytorch-cpu", marker = "sys_platform != 'linux'" },
|
||||
{ index = "pytorch-cu128", marker = "sys_platform == 'linux'" },
|
||||
]
|
||||
|
||||
[[tool.uv.index]]
|
||||
name = "pytorch-cpu"
|
||||
url = "https://download.pytorch.org/whl/cpu"
|
||||
explicit = true
|
||||
|
||||
[[tool.uv.index]]
|
||||
name = "pytorch-cu128"
|
||||
url = "https://download.pytorch.org/whl/cu128"
|
||||
explicit = true
|
||||
# Note: PyTorch configuration removed for macOS compatibility
|
||||
# On Linux with CUDA, you may need to reinstall from pytorch.org/whl/cu128
|
||||
|
|
@ -109,7 +109,9 @@ with torch.device("meta"):
|
|||
model.to_empty(device=device)
|
||||
model.init_weights()
|
||||
orig_model = model # original, uncompiled model, for saving raw model state_dict
|
||||
model = torch.compile(model, dynamic=False) # TODO: dynamic True/False think through
|
||||
# torch.compile only works well on CUDA; skip on MPS/CPU
|
||||
if device_type == "cuda":
|
||||
model = torch.compile(model, dynamic=False) # TODO: dynamic True/False think through
|
||||
num_params = sum(p.numel() for p in model.parameters())
|
||||
print0(f"Number of parameters: {num_params:,}")
|
||||
num_flops_per_token = model.estimate_flops()
|
||||
|
|
|
|||
|
|
@ -15,6 +15,9 @@ os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
|
|||
import time
|
||||
import wandb
|
||||
import torch
|
||||
# Disable torch.compile's dynamo on MPS (not supported)
|
||||
import torch._dynamo
|
||||
torch._dynamo.config.suppress_errors = True
|
||||
from contextlib import nullcontext
|
||||
from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, get_base_dir, autodetect_device_type
|
||||
from nanochat.tokenizer import get_token_bytes
|
||||
|
|
@ -70,7 +73,9 @@ pretrain_batch_size = meta.get("device_batch_size", None)
|
|||
if pretrain_batch_size is not None and device_batch_size > pretrain_batch_size:
|
||||
print0(f"FOOTGUN WARNING: base model training used device_batch_size {pretrain_batch_size}, did you pass in a good --device_batch_size to this script?")
|
||||
orig_model = model
|
||||
model = torch.compile(model, dynamic=False)
|
||||
# torch.compile only works well on CUDA; skip on MPS/CPU
|
||||
if device_type == "cuda":
|
||||
model = torch.compile(model, dynamic=False)
|
||||
depth = model.config.n_layer
|
||||
num_flops_per_token = model.estimate_flops()
|
||||
tokens_per_fwdbwd = device_batch_size * max_seq_len # tokens per iteration for a single rank
|
||||
|
|
@ -119,7 +124,9 @@ def mid_data_generator(split):
|
|||
assert dataset_size > 0
|
||||
needed_tokens = device_batch_size * max_seq_len + 1 # to form one training batch of inputs,targets
|
||||
token_buffer = deque()
|
||||
scratch = torch.empty(needed_tokens, dtype=torch.int64, pin_memory=True)
|
||||
# pin_memory only works on CUDA
|
||||
use_pin_memory = device_type == "cuda"
|
||||
scratch = torch.empty(needed_tokens, dtype=torch.int64, pin_memory=use_pin_memory)
|
||||
cursor = ddp_rank # increments by ddp_world_size each time, so each rank processes unique documents
|
||||
it = 0 # iteration counter
|
||||
while True:
|
||||
|
|
@ -142,8 +149,9 @@ def mid_data_generator(split):
|
|||
scratch[i] = token_buffer.popleft()
|
||||
inputs_cpu = scratch[:-1].to(dtype=torch.int32)
|
||||
targets_cpu = scratch[1:]
|
||||
inputs = inputs_cpu.view(device_batch_size, max_seq_len).to(device=device, dtype=torch.int32, non_blocking=True)
|
||||
targets = targets_cpu.view(device_batch_size, max_seq_len).to(device=device, dtype=torch.int64, non_blocking=True)
|
||||
# non_blocking only works on CUDA
|
||||
inputs = inputs_cpu.view(device_batch_size, max_seq_len).to(device=device, dtype=torch.int32, non_blocking=use_pin_memory)
|
||||
targets = targets_cpu.view(device_batch_size, max_seq_len).to(device=device, dtype=torch.int64, non_blocking=use_pin_memory)
|
||||
if split == "train":
|
||||
if num_iterations > 0:
|
||||
approx_progress = it / num_iterations # calculate progress from the max number of iterations
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user