mirror of
https://github.com/karpathy/nanochat.git
synced 2026-04-06 07:35:32 +00:00
add cross-platform inference (cuda training -> mps inference)
This commit is contained in:
parent
423bf31dbe
commit
9767451f7e
|
|
@ -42,12 +42,14 @@ def save_checkpoint(checkpoint_dir, step, model_data, optimizer_data, meta_data)
|
|||
def load_checkpoint(checkpoint_dir, step, device, load_optimizer=False):
|
||||
# Load the model state
|
||||
model_path = os.path.join(checkpoint_dir, f"model_{step:06d}.pt")
|
||||
model_data = torch.load(model_path, map_location=device)
|
||||
# Load to CPU first to avoid CUDA device validation issues when loading on MPS/CPU
|
||||
map_location = 'cpu' if str(device).startswith('mps') else device
|
||||
model_data = torch.load(model_path, map_location=map_location)
|
||||
# Load the optimizer state if requested
|
||||
optimizer_data = None
|
||||
if load_optimizer:
|
||||
optimizer_path = os.path.join(checkpoint_dir, f"optim_{step:06d}.pt")
|
||||
optimizer_data = torch.load(optimizer_path, map_location=device)
|
||||
optimizer_data = torch.load(optimizer_path, map_location=map_location)
|
||||
# Load the metadata
|
||||
meta_path = os.path.join(checkpoint_dir, f"meta_{step:06d}.json")
|
||||
with open(meta_path, "r") as f:
|
||||
|
|
@ -67,6 +69,9 @@ def build_model(checkpoint_dir, step, device, phase):
|
|||
model_data, optimizer_data, meta_data = load_checkpoint(checkpoint_dir, step, device, load_optimizer=False)
|
||||
# Hack: fix torch compile issue, which prepends all keys with _orig_mod.
|
||||
model_data = {k.lstrip("_orig_mod."): v for k, v in model_data.items()}
|
||||
# Move model_data tensors to target device if needed (for MPS/CPU loading CUDA checkpoints)
|
||||
if any(v.device.type != device.type for v in model_data.values() if isinstance(v, torch.Tensor)):
|
||||
model_data = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in model_data.items()}
|
||||
model_config_kwargs = meta_data["model_config"]
|
||||
log0(f"Building model with config: {model_config_kwargs}")
|
||||
model_config = GPTConfig(**model_config_kwargs)
|
||||
|
|
|
|||
|
|
@ -37,7 +37,7 @@ import torch
|
|||
import asyncio
|
||||
import logging
|
||||
import random
|
||||
from contextlib import asynccontextmanager
|
||||
from contextlib import asynccontextmanager, nullcontext
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import StreamingResponse, HTMLResponse, FileResponse
|
||||
|
|
@ -95,7 +95,8 @@ class Worker:
|
|||
class WorkerPool:
|
||||
"""Pool of workers, each with a model replica on a different GPU."""
|
||||
|
||||
def __init__(self, num_gpus: Optional[int] = None):
|
||||
def __init__(self, num_gpus: Optional[int] = None, device_type: str = "cuda"):
|
||||
self.device_type = device_type
|
||||
self.num_gpus = num_gpus if num_gpus is not None else torch.cuda.device_count()
|
||||
self.workers: List[Worker] = []
|
||||
self.available_workers: asyncio.Queue = asyncio.Queue()
|
||||
|
|
@ -105,13 +106,18 @@ class WorkerPool:
|
|||
print(f"Initializing worker pool with {self.num_gpus} GPUs...")
|
||||
|
||||
for gpu_id in range(self.num_gpus):
|
||||
device = torch.device(f"cuda:{gpu_id}")
|
||||
if self.device_type == "cuda":
|
||||
device = torch.device(f"cuda:{gpu_id}")
|
||||
else:
|
||||
device = torch.device(self.device_type) # mps or cpu
|
||||
print(f"Loading model on GPU {gpu_id}...")
|
||||
|
||||
model, tokenizer, _ = load_model(source, device, phase="eval", model_tag=model_tag, step=step)
|
||||
engine = Engine(model, tokenizer)
|
||||
autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16)
|
||||
|
||||
if self.device_type == "cuda":
|
||||
autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16)
|
||||
else:
|
||||
autocast_ctx = nullcontext() # default precision for MPS/CPU
|
||||
worker = Worker(
|
||||
gpu_id=gpu_id,
|
||||
device=device,
|
||||
|
|
@ -209,7 +215,7 @@ def validate_chat_request(request: ChatRequest):
|
|||
async def lifespan(app: FastAPI):
|
||||
"""Load models on all GPUs on startup."""
|
||||
print("Loading nanochat models across GPUs...")
|
||||
app.state.worker_pool = WorkerPool(num_gpus=args.num_gpus)
|
||||
app.state.worker_pool = WorkerPool(num_gpus=args.num_gpus, device_type=device_type)
|
||||
await app.state.worker_pool.initialize(args.source, model_tag=args.model_tag, step=args.step)
|
||||
print(f"Server ready at http://localhost:{args.port}")
|
||||
yield
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user