allow multiple GPUs to do inference in a data parallel way

This commit is contained in:
Andrej Karpathy 2025-10-15 19:12:19 +00:00
parent 190d9515d0
commit 01fb290f53
2 changed files with 145 additions and 73 deletions

View File

@ -327,7 +327,6 @@
},
body: JSON.stringify({
messages: messages,
stream: true,
temperature: 0.8,
max_tokens: 512
}),

View File

@ -1,26 +1,46 @@
#!/usr/bin/env python3
"""
Unified web chat server - serves both UI and API from a single FastAPI instance.
Run with: python web_chat.py
Then open http://localhost:8000 in your browser.
Uses data parallelism to distribute requests across multiple GPUs. Each GPU loads
a full copy of the model, and incoming requests are distributed to available workers.
Launch examples:
- single available GPU (default)
python -m scripts.chat_web
- 4 GPUs
python -m scripts.chat_web --num-gpus 4
To chat, open the URL printed in the console. (If on cloud box, make sure to use public IP)
Endpoints:
GET / - Chat UI
POST /chat/completions - Chat API (streaming only)
GET /health - Health check with worker pool status
GET /stats - Worker pool statistics and GPU utilization
"""
import argparse
import json
import os
import torch
import asyncio
from contextlib import asynccontextmanager
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse, HTMLResponse, FileResponse
from pydantic import BaseModel
from typing import List, Optional, AsyncGenerator
from dataclasses import dataclass
from nanochat.common import compute_init
from nanochat.checkpoint_manager import load_model
from nanochat.engine import Engine
parser = argparse.ArgumentParser(description='NanoChat Web Server')
parser.add_argument('-n', '--num-gpus', type=int, default=1, help='Number of GPUs to use (default: 1)')
parser.add_argument('-i', '--source', type=str, default="sft", help="Source of the model: sft|mid|rl")
parser.add_argument('-t', '--temperature', type=float, default=0.8, help='Default temperature for generation')
parser.add_argument('-k', '--top-k', type=int, default=50, help='Default top-k sampling parameter')
@ -32,7 +52,55 @@ parser.add_argument('--host', type=str, default='0.0.0.0', help='Host to bind th
args = parser.parse_args()
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init()
autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16)
@dataclass
class Worker:
"""A worker with a model loaded on a specific GPU."""
gpu_id: int
device: torch.device
engine: Engine
tokenizer: object
autocast_ctx: torch.amp.autocast
class WorkerPool:
"""Pool of workers, each with a model replica on a different GPU."""
def __init__(self, num_gpus: Optional[int] = None):
self.num_gpus = num_gpus if num_gpus is not None else torch.cuda.device_count()
self.workers: List[Worker] = []
self.available_workers: asyncio.Queue = asyncio.Queue()
async def initialize(self, source: str, model_tag: Optional[str] = None, step: Optional[int] = None):
"""Load model on each GPU."""
print(f"Initializing worker pool with {self.num_gpus} GPUs...")
for gpu_id in range(self.num_gpus):
device = torch.device(f"cuda:{gpu_id}")
print(f"Loading model on GPU {gpu_id}...")
model, tokenizer, _ = load_model(source, device, phase="eval", model_tag=model_tag, step=step)
engine = Engine(model, tokenizer)
autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16)
worker = Worker(
gpu_id=gpu_id,
device=device,
engine=engine,
tokenizer=tokenizer,
autocast_ctx=autocast_ctx
)
self.workers.append(worker)
await self.available_workers.put(worker)
print(f"All {self.num_gpus} workers initialized!")
async def acquire_worker(self) -> Worker:
"""Get an available worker from the pool."""
return await self.available_workers.get()
async def release_worker(self, worker: Worker):
"""Return a worker to the pool."""
await self.available_workers.put(worker)
class ChatMessage(BaseModel):
role: str
@ -43,14 +111,13 @@ class ChatRequest(BaseModel):
temperature: Optional[float] = None
max_tokens: Optional[int] = None
top_k: Optional[int] = None
stream: Optional[bool] = True
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Load model on startup."""
print("Loading nanochat model...")
app.state.model, app.state.tokenizer, _ = load_model(args.source, device, phase="eval", model_tag=args.model_tag, step=args.step)
app.state.engine = Engine(app.state.model, app.state.tokenizer)
"""Load models on all GPUs on startup."""
print("Loading nanochat models across GPUs...")
app.state.worker_pool = WorkerPool(num_gpus=args.num_gpus)
await app.state.worker_pool.initialize(args.source, model_tag=args.model_tag, step=args.step)
print(f"Server ready at http://localhost:{args.port}")
yield
@ -85,8 +152,7 @@ async def logo():
return FileResponse(logo_path, media_type="image/svg+xml")
async def generate_stream(
engine,
tokenizer,
worker: Worker,
tokens,
temperature=None,
max_new_tokens=None,
@ -97,11 +163,11 @@ async def generate_stream(
max_new_tokens = max_new_tokens if max_new_tokens is not None else args.max_tokens
top_k = top_k if top_k is not None else args.top_k
assistant_end = tokenizer.encode_special("<|assistant_end|>")
bos = tokenizer.get_bos_token_id()
assistant_end = worker.tokenizer.encode_special("<|assistant_end|>")
bos = worker.tokenizer.get_bos_token_id()
with autocast_ctx:
for token_column, token_masks in engine.generate(
with worker.autocast_ctx:
for token_column, token_masks in worker.engine.generate(
tokens,
num_samples=1,
max_tokens=max_new_tokens,
@ -113,82 +179,89 @@ async def generate_stream(
if token == assistant_end or token == bos:
break
token_text = tokenizer.decode([token])
yield f"data: {json.dumps({'token': token_text})}\n\n"
token_text = worker.tokenizer.decode([token])
yield f"data: {json.dumps({'token': token_text, 'gpu': worker.gpu_id})}\n\n"
yield f"data: {json.dumps({'done': True})}\n\n"
@app.post("/chat/completions")
async def chat_completions(request: ChatRequest):
"""Chat completion endpoint with streaming."""
engine = app.state.engine
tokenizer = app.state.tokenizer
"""Chat completion endpoint (streaming only) - uses worker pool for multi-GPU."""
worker_pool = app.state.worker_pool
# Build conversation tokens
bos = tokenizer.get_bos_token_id()
user_start = tokenizer.encode_special("<|user_start|>")
user_end = tokenizer.encode_special("<|user_end|>")
assistant_start = tokenizer.encode_special("<|assistant_start|>")
assistant_end = tokenizer.encode_special("<|assistant_end|>")
# Acquire a worker from the pool (will wait if all are busy)
worker = await worker_pool.acquire_worker()
conversation_tokens = [bos]
for message in request.messages:
if message.role == "user":
conversation_tokens.append(user_start)
conversation_tokens.extend(tokenizer.encode(message.content))
conversation_tokens.append(user_end)
elif message.role == "assistant":
conversation_tokens.append(assistant_start)
conversation_tokens.extend(tokenizer.encode(message.content))
conversation_tokens.append(assistant_end)
try:
# Build conversation tokens
bos = worker.tokenizer.get_bos_token_id()
user_start = worker.tokenizer.encode_special("<|user_start|>")
user_end = worker.tokenizer.encode_special("<|user_end|>")
assistant_start = worker.tokenizer.encode_special("<|assistant_start|>")
assistant_end = worker.tokenizer.encode_special("<|assistant_end|>")
conversation_tokens.append(assistant_start)
conversation_tokens = [bos]
for message in request.messages:
if message.role == "user":
conversation_tokens.append(user_start)
conversation_tokens.extend(worker.tokenizer.encode(message.content))
conversation_tokens.append(user_end)
elif message.role == "assistant":
conversation_tokens.append(assistant_start)
conversation_tokens.extend(worker.tokenizer.encode(message.content))
conversation_tokens.append(assistant_end)
conversation_tokens.append(assistant_start)
# Streaming response with worker release after completion
async def stream_and_release():
try:
async for chunk in generate_stream(
worker,
conversation_tokens,
temperature=request.temperature,
max_new_tokens=request.max_tokens,
top_k=request.top_k
):
yield chunk
finally:
# Release worker back to pool after streaming is done
await worker_pool.release_worker(worker)
if request.stream:
return StreamingResponse(
generate_stream(
engine,
tokenizer,
conversation_tokens,
temperature=request.temperature,
max_new_tokens=request.max_tokens,
top_k=request.top_k
),
stream_and_release(),
media_type="text/event-stream"
)
else:
# Non-streaming response
temperature = request.temperature if request.temperature is not None else args.temperature
max_tokens = request.max_tokens if request.max_tokens is not None else args.max_tokens
top_k = request.top_k if request.top_k is not None else args.top_k
with autocast_ctx:
result_tokens, masks = engine.generate_batch(
conversation_tokens,
num_samples=1,
max_tokens=max_tokens,
temperature=temperature,
top_k=top_k
)[0]
response_tokens = result_tokens[len(conversation_tokens):]
response_text = tokenizer.decode(response_tokens)
return {
"choices": [{
"message": {
"role": "assistant",
"content": response_text
},
"finish_reason": "stop"
}]
}
except Exception as e:
# Make sure to release worker even on error
await worker_pool.release_worker(worker)
raise e
@app.get("/health")
async def health():
"""Health check endpoint."""
worker_pool = getattr(app.state, 'worker_pool', None)
return {
"status": "ok",
"ready": hasattr(app.state, 'model') and app.state.model is not None
"ready": worker_pool is not None and len(worker_pool.workers) > 0,
"num_gpus": worker_pool.num_gpus if worker_pool else 0,
"available_workers": worker_pool.available_workers.qsize() if worker_pool else 0
}
@app.get("/stats")
async def stats():
"""Get worker pool statistics."""
worker_pool = app.state.worker_pool
return {
"total_workers": len(worker_pool.workers),
"available_workers": worker_pool.available_workers.qsize(),
"busy_workers": len(worker_pool.workers) - worker_pool.available_workers.qsize(),
"workers": [
{
"gpu_id": w.gpu_id,
"device": str(w.device)
} for w in worker_pool.workers
]
}
if __name__ == "__main__":