diff --git a/nanochat/checkpoint_manager.py b/nanochat/checkpoint_manager.py index f400d47..d0912d6 100644 --- a/nanochat/checkpoint_manager.py +++ b/nanochat/checkpoint_manager.py @@ -42,12 +42,14 @@ def save_checkpoint(checkpoint_dir, step, model_data, optimizer_data, meta_data) def load_checkpoint(checkpoint_dir, step, device, load_optimizer=False): # Load the model state model_path = os.path.join(checkpoint_dir, f"model_{step:06d}.pt") - model_data = torch.load(model_path, map_location=device) + # Load to CPU first to avoid CUDA device validation issues when loading on MPS/CPU + map_location = 'cpu' if str(device).startswith('mps') else device + model_data = torch.load(model_path, map_location=map_location) # Load the optimizer state if requested optimizer_data = None if load_optimizer: optimizer_path = os.path.join(checkpoint_dir, f"optim_{step:06d}.pt") - optimizer_data = torch.load(optimizer_path, map_location=device) + optimizer_data = torch.load(optimizer_path, map_location=map_location) # Load the metadata meta_path = os.path.join(checkpoint_dir, f"meta_{step:06d}.json") with open(meta_path, "r") as f: @@ -67,6 +69,9 @@ def build_model(checkpoint_dir, step, device, phase): model_data, optimizer_data, meta_data = load_checkpoint(checkpoint_dir, step, device, load_optimizer=False) # Hack: fix torch compile issue, which prepends all keys with _orig_mod. model_data = {k.lstrip("_orig_mod."): v for k, v in model_data.items()} + # Move model_data tensors to target device if needed (for MPS/CPU loading CUDA checkpoints) + if any(v.device.type != device.type for v in model_data.values() if isinstance(v, torch.Tensor)): + model_data = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in model_data.items()} model_config_kwargs = meta_data["model_config"] log0(f"Building model with config: {model_config_kwargs}") model_config = GPTConfig(**model_config_kwargs) diff --git a/scripts/chat_web.py b/scripts/chat_web.py index 57afc9a..02dee09 100644 --- a/scripts/chat_web.py +++ b/scripts/chat_web.py @@ -37,7 +37,7 @@ import torch import asyncio import logging import random -from contextlib import asynccontextmanager +from contextlib import asynccontextmanager, nullcontext from fastapi import FastAPI, HTTPException from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import StreamingResponse, HTMLResponse, FileResponse @@ -95,7 +95,8 @@ class Worker: class WorkerPool: """Pool of workers, each with a model replica on a different GPU.""" - def __init__(self, num_gpus: Optional[int] = None): + def __init__(self, num_gpus: Optional[int] = None, device_type: str = "cuda"): + self.device_type = device_type self.num_gpus = num_gpus if num_gpus is not None else torch.cuda.device_count() self.workers: List[Worker] = [] self.available_workers: asyncio.Queue = asyncio.Queue() @@ -105,13 +106,18 @@ class WorkerPool: print(f"Initializing worker pool with {self.num_gpus} GPUs...") for gpu_id in range(self.num_gpus): - device = torch.device(f"cuda:{gpu_id}") + if self.device_type == "cuda": + device = torch.device(f"cuda:{gpu_id}") + else: + device = torch.device(self.device_type) # mps or cpu print(f"Loading model on GPU {gpu_id}...") model, tokenizer, _ = load_model(source, device, phase="eval", model_tag=model_tag, step=step) engine = Engine(model, tokenizer) - autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16) - + if self.device_type == "cuda": + autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16) + else: + autocast_ctx = nullcontext() # default precision for MPS/CPU worker = Worker( gpu_id=gpu_id, device=device, @@ -209,7 +215,7 @@ def validate_chat_request(request: ChatRequest): async def lifespan(app: FastAPI): """Load models on all GPUs on startup.""" print("Loading nanochat models across GPUs...") - app.state.worker_pool = WorkerPool(num_gpus=args.num_gpus) + app.state.worker_pool = WorkerPool(num_gpus=args.num_gpus, device_type=device_type) await app.state.worker_pool.initialize(args.source, model_tag=args.model_tag, step=args.step) print(f"Server ready at http://localhost:{args.port}") yield