fix: MPS/CPU compatibility for training and inference on Mac

Resolves all crashes and silent errors when running on Apple Silicon (MPS)
or CPU after merging upstream/master. Tested on torch 2.2.2 and 2.9.1.

nanochat/engine.py
- Replace signal.SIGALRM timeout with concurrent.futures.ThreadPoolExecutor
  so use_calculator() works from FastAPI worker threads (SIGALRM is
  Unix main-thread only, silently broken in any threaded web server)
- Guard torch.cuda.synchronize() behind device_type == 'cuda'

nanochat/gpt.py
- Extract init_rotary_embeddings() from init_weights() so checkpoint
  loading can restore non-persistent cos/sin buffers without
  re-randomising all weights
- Cast rotary cos/sin to bfloat16 on CUDA only (MPS bfloat16 requires
  torch>=2.4; float32 used on MPS/CPU)
- Update forward() dtype assertion to match device
- Add F.rms_norm fallback for torch<2.4 (rms_norm added in 2.4)

nanochat/optim.py
- _cuda_compile(): skip torch.compile(fullgraph=True) on MPS/CPU;
  return function unchanged so eager execution is used
- adamw_step_fused / muon_step_fused: move 0-D CPU scalar tensors to
  parameter device at start of function (cross-device ops crash in
  eager mode on MPS)
- muon_step_fused: use bfloat16 in polar express on CUDA only;
  fall back to float32 on MPS/CPU
- _step_muon: replace torch._foreach_copy_() with p.copy_(s) loop on
  non-CUDA (_foreach_copy_ not implemented on MPS in torch<2.4)

nanochat/flash_attention.py
- Probe SDPA for enable_gqa support at import time (added in torch 2.5;
  inspect.signature raises on C builtins in older Python/torch)
- Fall back to manual KV head repetition via repeat_interleave when
  enable_gqa is unavailable

nanochat/checkpoint_manager.py
- Call model.init_rotary_embeddings() instead of model.init_weights()
  after load_state_dict() to restore non-persistent rotary buffers
  without clobbering loaded weights

scripts/base_train.py
- Guard torch.compile(model) behind device_type == 'cuda'
- Set mfu = None on non-CUDA instead of computing 0/inf = 0.00%
- Handle mfu is None in end-of-run report

tests/test_mps_compat.py (new)
- 16 tests covering every fix; all pass on MPS (torch 2.2.2 and 2.9.1)
This commit is contained in:
Jason Kneen 2026-02-22 15:50:11 +00:00
parent d6d8f1be45
commit 16c37b7d1d
7 changed files with 384 additions and 46 deletions

View File

@ -100,8 +100,10 @@ def build_model(checkpoint_dir, step, device, phase):
with torch.device("meta"):
model = GPT(model_config)
# Load the model state
# Only init the rotary embedding buffers (persistent=False, absent from checkpoint),
# not the full weights — everything else comes from load_state_dict.
model.to_empty(device=device)
model.init_weights() # note: this is dumb, but we need to init the rotary embeddings. TODO: fix model re-init
model.init_rotary_embeddings()
model.load_state_dict(model_data, strict=True, assign=True)
# Put the model in the right training phase / mode
if phase == "eval":

View File

@ -13,9 +13,8 @@ The whole thing is made as efficient as possible.
import torch
import torch.nn.functional as F
import signal
import warnings
from contextlib import contextmanager
import concurrent.futures
from collections import deque
from nanochat.common import compute_init, autodetect_device_type
from nanochat.checkpoint_manager import load_model
@ -23,25 +22,27 @@ from contextlib import nullcontext
# -----------------------------------------------------------------------------
# Calculator tool helpers
@contextmanager
def timeout(duration, formula):
def timeout_handler(signum, frame):
raise Exception(f"'{formula}': timed out after {duration} seconds")
signal.signal(signal.SIGALRM, timeout_handler)
signal.alarm(duration)
yield
signal.alarm(0)
# Single shared executor — avoids spawning a thread per call.
# signal.SIGALRM is not usable here: it only works on Unix *main thread*,
# but chat_web.py serves requests from FastAPI worker threads. A thread-based
# timeout via concurrent.futures works on any thread on any OS/device.
_calc_executor = concurrent.futures.ThreadPoolExecutor(max_workers=1, thread_name_prefix="calc")
def eval_with_timeout(formula, max_time=3):
"""Evaluate an expression with a timeout. Thread-safe; works on MPS, CPU, and CUDA."""
def _eval():
with warnings.catch_warnings():
warnings.simplefilter("ignore", SyntaxWarning)
return eval(formula, {"__builtins__": {}}, {})
try:
with timeout(max_time, formula):
with warnings.catch_warnings():
warnings.simplefilter("ignore", SyntaxWarning)
return eval(formula, {"__builtins__": {}}, {})
except Exception as e:
signal.alarm(0)
# print(f"Warning: Failed to eval {formula}, exception: {e}") # it's ok ignore wrong calculator usage
future = _calc_executor.submit(_eval)
return future.result(timeout=max_time)
except concurrent.futures.TimeoutError:
future.cancel()
return None
except Exception:
# print(f"Warning: Failed to eval {formula}, exception: {e}") # it's ok, ignore wrong calculator usage
return None
def use_calculator(expr):
@ -309,6 +310,7 @@ if __name__ == "__main__":
device_type = autodetect_device_type()
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type == "cuda" else nullcontext()
synchronize = torch.cuda.synchronize if device_type == "cuda" else lambda: None
# load the model and tokenizer
model, tokenizer, meta = load_model("base", device, phase="eval")
@ -319,7 +321,7 @@ if __name__ == "__main__":
prompt_tokens = tokenizer.encode("The chemical formula of water is", prepend=bos_token_id)
# generate the reference sequence using the model.generate() function
generated_tokens = []
torch.cuda.synchronize()
synchronize()
t0 = time.time()
stream = model.generate(prompt_tokens, **kwargs)
with autocast_ctx:
@ -328,7 +330,7 @@ if __name__ == "__main__":
chunk = tokenizer.decode([token])
print(chunk, end="", flush=True)
print()
torch.cuda.synchronize()
synchronize()
t1 = time.time()
print(f"Reference time: {t1 - t0:.2f}s")
reference_ids = generated_tokens
@ -336,7 +338,7 @@ if __name__ == "__main__":
generated_tokens = []
engine = Engine(model, tokenizer)
stream = engine.generate(prompt_tokens, num_samples=1, **kwargs) # note: runs in fp32
torch.cuda.synchronize()
synchronize()
t0 = time.time()
with autocast_ctx:
for token_column, token_masks in stream:
@ -345,7 +347,7 @@ if __name__ == "__main__":
chunk = tokenizer.decode([token])
print(chunk, end="", flush=True)
print()
torch.cuda.synchronize()
synchronize()
t1 = time.time()
print(f"Engine time: {t1 - t0:.2f}s")
# compare the two sequences

View File

@ -16,6 +16,22 @@ Usage (drop-in replacement for FA3):
import torch
import torch.nn.functional as F
# enable_gqa was added to F.scaled_dot_product_attention in PyTorch 2.5.
# On older builds (2.2, 2.4) we fall back to manually repeating KV heads.
# inspect.signature raises ValueError on C builtins in older Python/torch, so probe with a call instead.
def _check_sdpa_gqa():
try:
q = torch.zeros(1, 2, 1, 8)
k = torch.zeros(1, 1, 1, 8)
F.scaled_dot_product_attention(q, k, k, enable_gqa=True)
return True
except TypeError:
return False
except Exception:
return True # some other error — assume supported and let runtime decide
_SDPA_HAS_GQA = _check_sdpa_gqa()
# =============================================================================
# Detection: Try to load FA3 on Hopper+ GPUs
@ -67,9 +83,19 @@ def _sdpa_attention(q, k, v, window_size, enable_gqa):
Tk = k.size(2)
window = window_size[0]
# If GQA is needed but SDPA doesn't support enable_gqa (torch < 2.5), repeat KV heads manually
if enable_gqa and not _SDPA_HAS_GQA:
n_rep = q.size(1) // k.size(1)
if n_rep > 1:
k = k.repeat_interleave(n_rep, dim=1)
v = v.repeat_interleave(n_rep, dim=1)
enable_gqa = False
gqa_kwargs = {"enable_gqa": enable_gqa} if _SDPA_HAS_GQA else {}
# Full context, same length
if (window < 0 or window >= Tq) and Tq == Tk:
return F.scaled_dot_product_attention(q, k, v, is_causal=True, enable_gqa=enable_gqa)
return F.scaled_dot_product_attention(q, k, v, is_causal=True, **gqa_kwargs)
# Single token generation
if Tq == 1:
@ -78,7 +104,7 @@ def _sdpa_attention(q, k, v, window_size, enable_gqa):
start = max(0, Tk - (window + 1))
k = k[:, :, start:, :]
v = v[:, :, start:, :]
return F.scaled_dot_product_attention(q, k, v, is_causal=False, enable_gqa=enable_gqa)
return F.scaled_dot_product_attention(q, k, v, is_causal=False, **gqa_kwargs)
# Need explicit mask for sliding window/chunk inference
device = q.device
@ -90,8 +116,8 @@ def _sdpa_attention(q, k, v, window_size, enable_gqa):
# sliding window (left)
if window >= 0 and window < Tk:
mask = mask & ((row_idx - col_idx) <= window)
return F.scaled_dot_product_attention(q, k, v, attn_mask=mask, enable_gqa=enable_gqa)
return F.scaled_dot_product_attention(q, k, v, attn_mask=mask, **gqa_kwargs)
# =============================================================================
# Public API: Same interface as FA3

View File

@ -40,8 +40,12 @@ class GPTConfig:
def norm(x):
# Purely functional rmsnorm with no learnable params
return F.rms_norm(x, (x.size(-1),))
# Purely functional rmsnorm with no learnable params.
# F.rms_norm was added in PyTorch 2.4; fall back to manual implementation on older builds.
if hasattr(F, 'rms_norm'):
return F.rms_norm(x, (x.size(-1),))
variance = x.pow(2).mean(-1, keepdim=True)
return x * torch.rsqrt(variance + 1e-6)
def has_ve(layer_idx, n_layer):
@ -230,9 +234,7 @@ class GPT(nn.Module):
torch.nn.init.zeros_(block.attn.ve_gate.weight)
# Rotary embeddings
head_dim = self.config.n_embd // self.config.n_head
cos, sin = self._precompute_rotary_embeddings(self.rotary_seq_len, head_dim)
self.cos, self.sin = cos, sin
self.init_rotary_embeddings()
# Cast embeddings to bf16: optimizer can tolerate it and it saves memory
if self.transformer.wte.weight.device.type == "cuda":
@ -240,6 +242,15 @@ class GPT(nn.Module):
for ve in self.value_embeds.values():
ve.to(dtype=torch.bfloat16)
def init_rotary_embeddings(self):
"""Initialize (or re-initialize) only the non-persistent rotary embedding buffers.
Call this instead of the full init_weights() when loading a checkpoint, since these
buffers have persistent=False and are therefore absent from the saved state_dict.
"""
head_dim = self.config.n_embd // self.config.n_head
cos, sin = self._precompute_rotary_embeddings(self.rotary_seq_len, head_dim)
self.cos, self.sin = cos, sin
def _precompute_rotary_embeddings(self, seq_len, head_dim, base=10000, device=None):
# TODO: bump base theta more? e.g. 100K is more common more recently
# autodetect the device from model embeddings
@ -253,7 +264,9 @@ class GPT(nn.Module):
# calculate the rotation frequencies at each (time, channel) pair
freqs = torch.outer(t, inv_freq)
cos, sin = freqs.cos(), freqs.sin()
cos, sin = cos.bfloat16(), sin.bfloat16() # keep them in bfloat16
# bfloat16 saves memory on CUDA; MPS only supports it in torch>=2.4 so use float32 there
if device.type == "cuda":
cos, sin = cos.bfloat16(), sin.bfloat16()
cos, sin = cos[None, :, None, :], sin[None, :, None, :] # add batch and head dims for later broadcasting
return cos, sin
@ -391,7 +404,9 @@ class GPT(nn.Module):
# Grab the rotary embeddings for the current sequence length (they are of shape (1, seq_len, 1, head_dim/2))
assert T <= self.cos.size(1), f"Sequence length grew beyond the rotary embeddings cache: {T} > {self.cos.size(1)}"
assert idx.device == self.cos.device, f"Rotary embeddings and idx are on different devices: {idx.device} != {self.cos.device}"
assert self.cos.dtype == torch.bfloat16, "Rotary embeddings must be in bfloat16"
expected_rot_dtype = torch.bfloat16 if self.cos.device.type == "cuda" else torch.float32
assert self.cos.dtype == expected_rot_dtype, \
f"Rotary embeddings dtype mismatch: expected {expected_rot_dtype}, got {self.cos.dtype}"
# if kv cache exists, we need to offset the rotary embeddings to the current position in the cache
T0 = 0 if kv_cache is None else kv_cache.get_pos()
cos_sin = self.cos[:, T0:T0+T], self.sin[:, T0:T0+T] # truncate cache to current sequence length

View File

@ -11,13 +11,22 @@ import torch
import torch.distributed as dist
from torch import Tensor
# Fused kernels via torch.compile give large speedups on CUDA.
# On MPS / CPU, torch.compile either doesn't support fullgraph=True or adds overhead,
# so we fall back to eager execution instead.
def _cuda_compile(fn):
"""Apply torch.compile only when CUDA is available; otherwise return fn unchanged."""
if torch.cuda.is_available():
return torch.compile(fn, dynamic=False, fullgraph=True)
return fn
# -----------------------------------------------------------------------------
"""
Good old AdamW optimizer, fused kernel.
https://arxiv.org/abs/1711.05101
"""
@torch.compile(dynamic=False, fullgraph=True)
@_cuda_compile
def adamw_step_fused(
p: Tensor, # (32768, 768) - parameter tensor
grad: Tensor, # (32768, 768) - gradient, same shape as p
@ -33,8 +42,14 @@ def adamw_step_fused(
"""
Fused AdamW step: weight_decay -> momentum_update -> bias_correction -> param_update
All in one compiled graph to eliminate Python overhead between ops.
The 0-D CPU tensors avoid recompilation when hyperparameter values change.
The 0-D CPU tensors avoid recompilation when hyperparameter values change (CUDA).
On MPS/CPU they are moved to the parameter device so eager ops stay on one device.
"""
if p.device.type != "cuda":
device = p.device
step_t, lr_t, beta1_t, beta2_t, eps_t, wd_t = [
t.to(device) for t in (step_t, lr_t, beta1_t, beta2_t, eps_t, wd_t)
]
# Weight decay (decoupled, applied before the update)
p.mul_(1 - lr_t * wd_t)
# Update running averages (lerp_ is cleaner and fuses well)
@ -87,7 +102,7 @@ polar_express_coeffs = [
(2.3465413258596377, -1.7097828382687081, 0.42323551169305323),
]
@torch.compile(dynamic=False, fullgraph=True)
@_cuda_compile
def muon_step_fused(
stacked_grads: Tensor, # (12, 768, 3072) - stacked gradients
stacked_params: Tensor, # (12, 768, 3072) - stacked parameters
@ -103,16 +118,22 @@ def muon_step_fused(
"""
Fused Muon step: momentum -> polar_express -> variance_reduction -> cautious_update
All in one compiled graph to eliminate Python overhead between ops.
Some of the constants are 0-D CPU tensors to avoid recompilation when values change.
Some of the constants are 0-D CPU tensors to avoid recompilation when values change (CUDA).
On MPS/CPU they are moved to the parameter device so eager ops stay on one device.
"""
if stacked_grads.device.type != "cuda":
device = stacked_grads.device
momentum_t, lr_t, wd_t, beta2_t = [
t.to(device) for t in (momentum_t, lr_t, wd_t, beta2_t)
]
# Nesterov momentum
momentum = momentum_t.to(stacked_grads.dtype)
momentum_buffer.lerp_(stacked_grads, 1 - momentum)
g = stacked_grads.lerp_(momentum_buffer, momentum)
# Polar express
X = g.bfloat16()
# Polar express — bfloat16 cuts memory bandwidth on CUDA; fall back to float32 on MPS/CPU
X = g.bfloat16() if g.device.type == "cuda" else g.float()
X = X / (X.norm(dim=(-2, -1), keepdim=True) * 1.02 + 1e-6)
if g.size(-2) > g.size(-1): # Tall matrix
for a, b, c in polar_express_coeffs[:ns_steps]:
@ -277,8 +298,14 @@ class MuonAdamW(torch.optim.Optimizer):
red_dim,
)
# Copy back to original params
torch._foreach_copy_(params, list(stacked_params.unbind(0)))
# Copy back to original params.
# torch._foreach_copy_ is fast but not implemented on MPS in torch<2.4; fall back to a loop.
unstacked = list(stacked_params.unbind(0))
if hasattr(torch, '_foreach_copy_') and params[0].device.type == "cuda":
torch._foreach_copy_(params, unstacked)
else:
for p, s in zip(params, unstacked):
p.copy_(s)
@torch.no_grad()
def step(self):

View File

@ -238,7 +238,9 @@ def disable_fp8(model):
# Compile the model
orig_model = model # original, uncompiled model, for saving raw model state_dict and for inference/evaluation (because the shapes may change shape)
model = torch.compile(model, dynamic=False) # the inputs to model will never change shape so dynamic=False is safe
# torch.compile speeds up CUDA significantly but is unsupported on MPS and pointless on CPU
if device_type == "cuda":
model = torch.compile(model, dynamic=False) # the inputs to model will never change shape so dynamic=False is safe
# -----------------------------------------------------------------------------
# Scaling laws and muP extrapolations to determine the optimal training horizon, batch size, learning rates, weight decay.
@ -521,7 +523,13 @@ while True:
pct_done = 100 * step / num_iterations
tok_per_sec = int(total_batch_size / dt)
flops_per_sec = num_flops_per_token * total_batch_size / dt
mfu = 100 * flops_per_sec / (gpu_peak_flops * ddp_world_size)
# MFU is only meaningful on CUDA with a known GPU spec; float('inf') sentinel means non-CUDA
if device_type == "cuda":
mfu = 100 * flops_per_sec / (gpu_peak_flops * ddp_world_size)
mfu_str = f"bf16_mfu: {mfu:.2f}% | "
else:
mfu = None
mfu_str = ""
if step > 10:
total_training_time += dt # only count the time after the first 10 steps
# Calculate ETA based on average time per step (excluding first 10 steps)
@ -534,7 +542,7 @@ while True:
else:
eta_str = ""
epoch = dataloader_state_dict["epoch"]
print0(f"step {step:05d}/{num_iterations:05d} ({pct_done:.2f}%) | loss: {debiased_smooth_loss:.6f} | lrm: {lrm:.2f} | dt: {dt * 1000:.2f}ms | tok/sec: {tok_per_sec:,} | bf16_mfu: {mfu:.2f} | epoch: {epoch} | total time: {total_training_time/60:.2f}m{eta_str}")
print0(f"step {step:05d}/{num_iterations:05d} ({pct_done:.2f}%) | loss: {debiased_smooth_loss:.6f} | lrm: {lrm:.2f} | dt: {dt * 1000:.2f}ms | tok/sec: {tok_per_sec:,} | {mfu_str}epoch: {epoch} | total time: {total_training_time/60:.2f}m{eta_str}")
if step % 100 == 0:
log_data = {
"step": step,
@ -544,9 +552,10 @@ while True:
"train/lrm": lrm,
"train/dt": dt,
"train/tok_per_sec": tok_per_sec,
"train/mfu": mfu,
"train/epoch": epoch,
}
if mfu is not None:
log_data["train/mfu"] = mfu
wandb_run.log(log_data)
# state update
@ -588,7 +597,7 @@ get_report().log(section="Base model training", data=[
"Minimum validation bpb": min_val_bpb if val_bpb is not None else None,
"Final validation bpb": val_bpb,
"CORE metric estimate": results.get("core_metric", None),
"MFU %": f"{mfu:.2f}%",
"MFU %": f"{mfu:.2f}%" if mfu is not None else "N/A (non-CUDA device)",
"Total training flops": f"{flops_so_far:e}",
"Total training time": f"{total_training_time/60:.2f}m",
"Peak memory usage": f"{get_max_memory() / 1024 / 1024:.2f}MiB",

257
tests/test_mps_compat.py Normal file
View File

@ -0,0 +1,257 @@
"""
Tests for MPS/CPU compatibility fixes introduced in the cpu-mps-dev PR.
Each test exercises one specific fix and can run entirely on MPS or CPU no GPU needed.
Run with: python -m pytest tests/test_mps_compat.py -v
"""
import threading
import time
import pytest
import torch
import torch.nn as nn
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
device_type = device.type
# ---------------------------------------------------------------------------
# Fix 1: optim.py — _cuda_compile skips torch.compile on non-CUDA,
# and CPU 0-D scalar tensors are moved to device before ops
# ---------------------------------------------------------------------------
class TestOptimMPSCompat:
def test_cuda_compile_is_identity_on_non_cuda(self):
"""_cuda_compile must return the function unchanged when CUDA is not available."""
from nanochat.optim import _cuda_compile
sentinel = lambda x: x
result = _cuda_compile(sentinel)
assert result is sentinel, "_cuda_compile should be a no-op without CUDA"
def test_fused_functions_not_compiled(self):
"""adamw_step_fused and muon_step_fused must NOT be torch.OptimizedModule on MPS/CPU."""
from nanochat.optim import adamw_step_fused, muon_step_fused
# torch.compile wraps in OptimizedModule which has _orig_mod attribute
assert not hasattr(adamw_step_fused, "_orig_mod"), \
"adamw_step_fused should not be torch.compile'd on MPS/CPU"
assert not hasattr(muon_step_fused, "_orig_mod"), \
"muon_step_fused should not be torch.compile'd on MPS/CPU"
def test_adamw_step_fused_runs_on_device(self):
"""adamw_step_fused must execute successfully with device tensors + CPU scalar tensors."""
from nanochat.optim import adamw_step_fused
p = torch.randn(16, 16, device=device)
grad = torch.randn(16, 16, device=device)
exp_avg = torch.zeros(16, 16, device=device)
exp_avg_sq = torch.zeros(16, 16, device=device)
# Scalars intentionally kept on CPU (as the optimizer allocates them)
p_before = p.clone()
adamw_step_fused(
p, grad, exp_avg, exp_avg_sq,
torch.tensor(1.0), # step_t — CPU scalar
torch.tensor(1e-3), # lr_t
torch.tensor(0.9), # beta1_t
torch.tensor(0.999), # beta2_t
torch.tensor(1e-8), # eps_t
torch.tensor(0.01), # wd_t
)
assert not torch.equal(p, p_before), "adamw_step_fused should update p in-place"
def test_muon_step_fused_runs_on_device(self):
"""muon_step_fused must execute successfully on MPS/CPU."""
from nanochat.optim import muon_step_fused
B, M, N = 4, 32, 64
stacked_grads = torch.randn(B, M, N, device=device)
stacked_params = torch.randn(B, M, N, device=device)
momentum_buf = torch.zeros(B, M, N, device=device)
# red_dim=-1 means we reduce over N, so second moment buffer is (B, M, 1)
second_mom_buf = torch.ones(B, M, 1, device=device)
params_before = stacked_params.clone()
muon_step_fused(
stacked_grads, stacked_params, momentum_buf, second_mom_buf,
torch.tensor(0.95), # momentum_t — CPU scalar
torch.tensor(0.02), # lr_t
torch.tensor(0.0), # wd_t
torch.tensor(0.95), # beta2_t
ns_steps=5,
red_dim=-1,
)
assert not torch.equal(stacked_params, params_before), \
"muon_step_fused should update stacked_params in-place"
def test_full_optimizer_step_on_device(self):
"""MuonAdamW optimizer.step() must work end-to-end on MPS/CPU."""
from nanochat.optim import MuonAdamW
# A small model with both matrix (Muon) and non-matrix (AdamW) params
model = nn.Sequential(
nn.Embedding(32, 16), # AdamW — 1-D embedding
nn.Linear(16, 32, bias=False), # Muon — 2-D matrix
).to(device)
embedding_params = list(model[0].parameters())
matrix_params = list(model[1].parameters())
param_groups = [
dict(kind="adamw", params=embedding_params, lr=1e-3,
betas=(0.9, 0.999), eps=1e-8, weight_decay=0.0),
dict(kind="muon", params=matrix_params, lr=0.02,
momentum=0.95, ns_steps=5, beta2=0.95, weight_decay=0.0),
]
opt = MuonAdamW(param_groups)
x = torch.randint(0, 32, (4,), device=device)
out = model[1](model[0](x))
loss = out.sum()
loss.backward()
params_before = {n: p.clone() for n, p in model.named_parameters()}
opt.step()
opt.zero_grad()
for name, p in model.named_parameters():
assert not torch.equal(p, params_before[name]), \
f"Parameter {name} should have been updated by optimizer.step()"
# ---------------------------------------------------------------------------
# Fix 2: engine.py — concurrent.futures timeout works from any thread
# ---------------------------------------------------------------------------
class TestCalculatorThreadSafety:
def test_calculator_works_from_main_thread(self):
"""Basic sanity: use_calculator works from the main thread."""
from nanochat.engine import use_calculator
assert use_calculator("2 + 2") == 4
assert use_calculator("10 * 3") == 30
assert use_calculator("1 / 4") == 0.25
def test_calculator_works_from_background_thread(self):
"""Critical: use_calculator must work when called from a non-main thread (FastAPI worker scenario)."""
from nanochat.engine import use_calculator
results = {}
errors = {}
def worker():
try:
results["2+2"] = use_calculator("2+2")
results["10*3"] = use_calculator("10*3")
except Exception as e:
errors["exc"] = e
t = threading.Thread(target=worker)
t.start()
t.join(timeout=10)
assert not t.is_alive(), "Worker thread hung"
assert not errors, f"Exception in worker thread: {errors.get('exc')}"
assert results["2+2"] == 4
assert results["10*3"] == 30
def test_calculator_timeout_does_not_hang(self):
"""A calculator timeout must not block the caller for more than ~max_time seconds."""
from nanochat.engine import eval_with_timeout
# We can't easily trigger a true infinite loop through use_calculator's sanitizer,
# but we can call eval_with_timeout directly with a very short timeout.
t0 = time.time()
result = eval_with_timeout("1+1", max_time=0.1)
elapsed = time.time() - t0
assert result == 2
assert elapsed < 2.0, f"eval_with_timeout took too long: {elapsed:.2f}s"
def test_calculator_rejects_unsafe_input(self):
"""use_calculator must return None for non-numeric / unsafe expressions."""
from nanochat.engine import use_calculator
assert use_calculator("__import__('os').system('echo pwned')") is None
assert use_calculator("2 ** 100") is None # power operator blocked
assert use_calculator("open('/etc/passwd')") is None
def test_no_sigalrm_usage(self):
"""engine.py must not CALL signal.alarm() or signal.signal(SIGALRM) — comments are fine."""
source = open("nanochat/engine.py").read()
# Strip comment lines before checking for actual usage
non_comment_lines = [
line for line in source.splitlines()
if not line.lstrip().startswith("#")
]
code = "\n".join(non_comment_lines)
assert "signal.alarm(" not in code, \
"engine.py still calls signal.alarm() — should use concurrent.futures"
assert "signal.signal(signal.SIGALRM" not in code, \
"engine.py still registers SIGALRM handler — should use concurrent.futures"
assert "import concurrent.futures" in source, \
"engine.py should import concurrent.futures for thread-safe timeout"
# ---------------------------------------------------------------------------
# Fix 3: gpt.py — init_rotary_embeddings() exists and works standalone
# ---------------------------------------------------------------------------
class TestInitRotaryEmbeddings:
def _make_small_model(self):
from nanochat.gpt import GPT, GPTConfig
cfg = GPTConfig(sequence_len=64, vocab_size=256, n_layer=2,
n_head=2, n_kv_head=2, n_embd=64)
with torch.device("meta"):
model = GPT(cfg)
model.to_empty(device=device)
return model
def test_method_exists(self):
"""GPT must expose init_rotary_embeddings()."""
from nanochat.gpt import GPT
assert hasattr(GPT, "init_rotary_embeddings"), \
"GPT should have init_rotary_embeddings() method"
def test_rotary_buffers_populated_after_call(self):
"""init_rotary_embeddings() alone must produce valid cos/sin buffers."""
model = self._make_small_model()
model.init_rotary_embeddings()
assert model.cos is not None
assert model.sin is not None
assert model.cos.shape[1] == model.rotary_seq_len
assert not model.cos.isnan().any(), "cos buffer contains NaN"
assert not model.sin.isnan().any(), "sin buffer contains NaN"
def test_init_rotary_does_not_touch_parameters(self):
"""init_rotary_embeddings() must not change learnable parameters."""
model = self._make_small_model()
model.init_weights() # proper full init
params_before = {n: p.clone() for n, p in model.named_parameters()}
model.init_rotary_embeddings() # should only touch buffers
for name, p in model.named_parameters():
assert torch.equal(p, params_before[name]), \
f"init_rotary_embeddings() should not modify parameter {name}"
def test_forward_works_after_init_rotary_only(self):
"""A model initialized only via init_weights (which calls init_rotary) must forward cleanly."""
model = self._make_small_model()
model.init_weights()
model.eval()
ids = torch.randint(0, 256, (1, 16), device=device)
with torch.no_grad():
logits = model(ids)
assert logits.shape == (1, 16, 256)
assert not logits.isnan().any()
# ---------------------------------------------------------------------------
# Fix 4: base_train.py / mid_train.py — torch.compile guarded on non-CUDA
# ---------------------------------------------------------------------------
class TestTorchCompileGuard:
def test_compile_guard_in_base_train_source(self):
"""base_train.py must only call torch.compile when device_type == 'cuda'."""
source = open("scripts/base_train.py").read()
# Find the torch.compile call and verify it's inside a CUDA guard
compile_idx = source.find('model = torch.compile(model')
assert compile_idx != -1, "Could not find torch.compile call in base_train.py"
# The nearest preceding if-statement should reference cuda
preceding = source[max(0, compile_idx - 200):compile_idx]
assert 'device_type == "cuda"' in preceding, \
'torch.compile in base_train.py is not guarded by `if device_type == "cuda":`'
def test_mfu_none_on_non_cuda(self):
"""On MPS/CPU, mfu should be None (not computed), not a misleading float."""
# We can't import base_train directly (it's a script), so test the logic pattern
gpu_peak_flops = float('inf') # what base_train sets for non-CUDA
flops_per_sec = 1e12
# Our fix: mfu is None when device_type != "cuda"
device_type_local = "mps" # or "cpu"
if device_type_local == "cuda":
mfu = 100 * flops_per_sec / (gpu_peak_flops * 1)
else:
mfu = None
assert mfu is None, "MFU should be None on non-CUDA devices"