add cross-platform inference (cuda training -> mps inference)

This commit is contained in:
Shivam Ojha 2025-10-18 18:36:57 -05:00 committed by GitHub
parent 423bf31dbe
commit 9767451f7e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 19 additions and 8 deletions

View File

@ -42,12 +42,14 @@ def save_checkpoint(checkpoint_dir, step, model_data, optimizer_data, meta_data)
def load_checkpoint(checkpoint_dir, step, device, load_optimizer=False):
# Load the model state
model_path = os.path.join(checkpoint_dir, f"model_{step:06d}.pt")
model_data = torch.load(model_path, map_location=device)
# Load to CPU first to avoid CUDA device validation issues when loading on MPS/CPU
map_location = 'cpu' if str(device).startswith('mps') else device
model_data = torch.load(model_path, map_location=map_location)
# Load the optimizer state if requested
optimizer_data = None
if load_optimizer:
optimizer_path = os.path.join(checkpoint_dir, f"optim_{step:06d}.pt")
optimizer_data = torch.load(optimizer_path, map_location=device)
optimizer_data = torch.load(optimizer_path, map_location=map_location)
# Load the metadata
meta_path = os.path.join(checkpoint_dir, f"meta_{step:06d}.json")
with open(meta_path, "r") as f:
@ -67,6 +69,9 @@ def build_model(checkpoint_dir, step, device, phase):
model_data, optimizer_data, meta_data = load_checkpoint(checkpoint_dir, step, device, load_optimizer=False)
# Hack: fix torch compile issue, which prepends all keys with _orig_mod.
model_data = {k.lstrip("_orig_mod."): v for k, v in model_data.items()}
# Move model_data tensors to target device if needed (for MPS/CPU loading CUDA checkpoints)
if any(v.device.type != device.type for v in model_data.values() if isinstance(v, torch.Tensor)):
model_data = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in model_data.items()}
model_config_kwargs = meta_data["model_config"]
log0(f"Building model with config: {model_config_kwargs}")
model_config = GPTConfig(**model_config_kwargs)

View File

@ -37,7 +37,7 @@ import torch
import asyncio
import logging
import random
from contextlib import asynccontextmanager
from contextlib import asynccontextmanager, nullcontext
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse, HTMLResponse, FileResponse
@ -95,7 +95,8 @@ class Worker:
class WorkerPool:
"""Pool of workers, each with a model replica on a different GPU."""
def __init__(self, num_gpus: Optional[int] = None):
def __init__(self, num_gpus: Optional[int] = None, device_type: str = "cuda"):
self.device_type = device_type
self.num_gpus = num_gpus if num_gpus is not None else torch.cuda.device_count()
self.workers: List[Worker] = []
self.available_workers: asyncio.Queue = asyncio.Queue()
@ -105,13 +106,18 @@ class WorkerPool:
print(f"Initializing worker pool with {self.num_gpus} GPUs...")
for gpu_id in range(self.num_gpus):
device = torch.device(f"cuda:{gpu_id}")
if self.device_type == "cuda":
device = torch.device(f"cuda:{gpu_id}")
else:
device = torch.device(self.device_type) # mps or cpu
print(f"Loading model on GPU {gpu_id}...")
model, tokenizer, _ = load_model(source, device, phase="eval", model_tag=model_tag, step=step)
engine = Engine(model, tokenizer)
autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16)
if self.device_type == "cuda":
autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16)
else:
autocast_ctx = nullcontext() # default precision for MPS/CPU
worker = Worker(
gpu_id=gpu_id,
device=device,
@ -209,7 +215,7 @@ def validate_chat_request(request: ChatRequest):
async def lifespan(app: FastAPI):
"""Load models on all GPUs on startup."""
print("Loading nanochat models across GPUs...")
app.state.worker_pool = WorkerPool(num_gpus=args.num_gpus)
app.state.worker_pool = WorkerPool(num_gpus=args.num_gpus, device_type=device_type)
await app.state.worker_pool.initialize(args.source, model_tag=args.model_tag, step=args.step)
print(f"Server ready at http://localhost:{args.port}")
yield