From 01fb290f539743992c6c41e99c67f5e4ff79ba2e Mon Sep 17 00:00:00 2001 From: Andrej Karpathy Date: Wed, 15 Oct 2025 19:12:19 +0000 Subject: [PATCH] allow multiple GPUs to do inference in a data parallel way --- nanochat/ui.html | 1 - scripts/chat_web.py | 217 +++++++++++++++++++++++++++++--------------- 2 files changed, 145 insertions(+), 73 deletions(-) diff --git a/nanochat/ui.html b/nanochat/ui.html index 39e608f..264a654 100644 --- a/nanochat/ui.html +++ b/nanochat/ui.html @@ -327,7 +327,6 @@ }, body: JSON.stringify({ messages: messages, - stream: true, temperature: 0.8, max_tokens: 512 }), diff --git a/scripts/chat_web.py b/scripts/chat_web.py index 1a4cfe2..2643417 100644 --- a/scripts/chat_web.py +++ b/scripts/chat_web.py @@ -1,26 +1,46 @@ #!/usr/bin/env python3 """ Unified web chat server - serves both UI and API from a single FastAPI instance. -Run with: python web_chat.py -Then open http://localhost:8000 in your browser. + +Uses data parallelism to distribute requests across multiple GPUs. Each GPU loads +a full copy of the model, and incoming requests are distributed to available workers. + +Launch examples: + +- single available GPU (default) +python -m scripts.chat_web + +- 4 GPUs +python -m scripts.chat_web --num-gpus 4 + +To chat, open the URL printed in the console. (If on cloud box, make sure to use public IP) + +Endpoints: + GET / - Chat UI + POST /chat/completions - Chat API (streaming only) + GET /health - Health check with worker pool status + GET /stats - Worker pool statistics and GPU utilization """ import argparse import json import os import torch +import asyncio from contextlib import asynccontextmanager from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import StreamingResponse, HTMLResponse, FileResponse from pydantic import BaseModel from typing import List, Optional, AsyncGenerator +from dataclasses import dataclass from nanochat.common import compute_init from nanochat.checkpoint_manager import load_model from nanochat.engine import Engine parser = argparse.ArgumentParser(description='NanoChat Web Server') +parser.add_argument('-n', '--num-gpus', type=int, default=1, help='Number of GPUs to use (default: 1)') parser.add_argument('-i', '--source', type=str, default="sft", help="Source of the model: sft|mid|rl") parser.add_argument('-t', '--temperature', type=float, default=0.8, help='Default temperature for generation') parser.add_argument('-k', '--top-k', type=int, default=50, help='Default top-k sampling parameter') @@ -32,7 +52,55 @@ parser.add_argument('--host', type=str, default='0.0.0.0', help='Host to bind th args = parser.parse_args() ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init() -autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16) + +@dataclass +class Worker: + """A worker with a model loaded on a specific GPU.""" + gpu_id: int + device: torch.device + engine: Engine + tokenizer: object + autocast_ctx: torch.amp.autocast + +class WorkerPool: + """Pool of workers, each with a model replica on a different GPU.""" + + def __init__(self, num_gpus: Optional[int] = None): + self.num_gpus = num_gpus if num_gpus is not None else torch.cuda.device_count() + self.workers: List[Worker] = [] + self.available_workers: asyncio.Queue = asyncio.Queue() + + async def initialize(self, source: str, model_tag: Optional[str] = None, step: Optional[int] = None): + """Load model on each GPU.""" + print(f"Initializing worker pool with {self.num_gpus} GPUs...") + + for gpu_id in range(self.num_gpus): + device = torch.device(f"cuda:{gpu_id}") + print(f"Loading model on GPU {gpu_id}...") + + model, tokenizer, _ = load_model(source, device, phase="eval", model_tag=model_tag, step=step) + engine = Engine(model, tokenizer) + autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16) + + worker = Worker( + gpu_id=gpu_id, + device=device, + engine=engine, + tokenizer=tokenizer, + autocast_ctx=autocast_ctx + ) + self.workers.append(worker) + await self.available_workers.put(worker) + + print(f"All {self.num_gpus} workers initialized!") + + async def acquire_worker(self) -> Worker: + """Get an available worker from the pool.""" + return await self.available_workers.get() + + async def release_worker(self, worker: Worker): + """Return a worker to the pool.""" + await self.available_workers.put(worker) class ChatMessage(BaseModel): role: str @@ -43,14 +111,13 @@ class ChatRequest(BaseModel): temperature: Optional[float] = None max_tokens: Optional[int] = None top_k: Optional[int] = None - stream: Optional[bool] = True @asynccontextmanager async def lifespan(app: FastAPI): - """Load model on startup.""" - print("Loading nanochat model...") - app.state.model, app.state.tokenizer, _ = load_model(args.source, device, phase="eval", model_tag=args.model_tag, step=args.step) - app.state.engine = Engine(app.state.model, app.state.tokenizer) + """Load models on all GPUs on startup.""" + print("Loading nanochat models across GPUs...") + app.state.worker_pool = WorkerPool(num_gpus=args.num_gpus) + await app.state.worker_pool.initialize(args.source, model_tag=args.model_tag, step=args.step) print(f"Server ready at http://localhost:{args.port}") yield @@ -85,8 +152,7 @@ async def logo(): return FileResponse(logo_path, media_type="image/svg+xml") async def generate_stream( - engine, - tokenizer, + worker: Worker, tokens, temperature=None, max_new_tokens=None, @@ -97,11 +163,11 @@ async def generate_stream( max_new_tokens = max_new_tokens if max_new_tokens is not None else args.max_tokens top_k = top_k if top_k is not None else args.top_k - assistant_end = tokenizer.encode_special("<|assistant_end|>") - bos = tokenizer.get_bos_token_id() + assistant_end = worker.tokenizer.encode_special("<|assistant_end|>") + bos = worker.tokenizer.get_bos_token_id() - with autocast_ctx: - for token_column, token_masks in engine.generate( + with worker.autocast_ctx: + for token_column, token_masks in worker.engine.generate( tokens, num_samples=1, max_tokens=max_new_tokens, @@ -113,82 +179,89 @@ async def generate_stream( if token == assistant_end or token == bos: break - token_text = tokenizer.decode([token]) - yield f"data: {json.dumps({'token': token_text})}\n\n" + token_text = worker.tokenizer.decode([token]) + yield f"data: {json.dumps({'token': token_text, 'gpu': worker.gpu_id})}\n\n" yield f"data: {json.dumps({'done': True})}\n\n" @app.post("/chat/completions") async def chat_completions(request: ChatRequest): - """Chat completion endpoint with streaming.""" - engine = app.state.engine - tokenizer = app.state.tokenizer + """Chat completion endpoint (streaming only) - uses worker pool for multi-GPU.""" + worker_pool = app.state.worker_pool - # Build conversation tokens - bos = tokenizer.get_bos_token_id() - user_start = tokenizer.encode_special("<|user_start|>") - user_end = tokenizer.encode_special("<|user_end|>") - assistant_start = tokenizer.encode_special("<|assistant_start|>") - assistant_end = tokenizer.encode_special("<|assistant_end|>") + # Acquire a worker from the pool (will wait if all are busy) + worker = await worker_pool.acquire_worker() - conversation_tokens = [bos] - for message in request.messages: - if message.role == "user": - conversation_tokens.append(user_start) - conversation_tokens.extend(tokenizer.encode(message.content)) - conversation_tokens.append(user_end) - elif message.role == "assistant": - conversation_tokens.append(assistant_start) - conversation_tokens.extend(tokenizer.encode(message.content)) - conversation_tokens.append(assistant_end) + try: + # Build conversation tokens + bos = worker.tokenizer.get_bos_token_id() + user_start = worker.tokenizer.encode_special("<|user_start|>") + user_end = worker.tokenizer.encode_special("<|user_end|>") + assistant_start = worker.tokenizer.encode_special("<|assistant_start|>") + assistant_end = worker.tokenizer.encode_special("<|assistant_end|>") - conversation_tokens.append(assistant_start) + conversation_tokens = [bos] + for message in request.messages: + if message.role == "user": + conversation_tokens.append(user_start) + conversation_tokens.extend(worker.tokenizer.encode(message.content)) + conversation_tokens.append(user_end) + elif message.role == "assistant": + conversation_tokens.append(assistant_start) + conversation_tokens.extend(worker.tokenizer.encode(message.content)) + conversation_tokens.append(assistant_end) + + conversation_tokens.append(assistant_start) + + # Streaming response with worker release after completion + async def stream_and_release(): + try: + async for chunk in generate_stream( + worker, + conversation_tokens, + temperature=request.temperature, + max_new_tokens=request.max_tokens, + top_k=request.top_k + ): + yield chunk + finally: + # Release worker back to pool after streaming is done + await worker_pool.release_worker(worker) - if request.stream: return StreamingResponse( - generate_stream( - engine, - tokenizer, - conversation_tokens, - temperature=request.temperature, - max_new_tokens=request.max_tokens, - top_k=request.top_k - ), + stream_and_release(), media_type="text/event-stream" ) - else: - # Non-streaming response - temperature = request.temperature if request.temperature is not None else args.temperature - max_tokens = request.max_tokens if request.max_tokens is not None else args.max_tokens - top_k = request.top_k if request.top_k is not None else args.top_k - - with autocast_ctx: - result_tokens, masks = engine.generate_batch( - conversation_tokens, - num_samples=1, - max_tokens=max_tokens, - temperature=temperature, - top_k=top_k - )[0] - - response_tokens = result_tokens[len(conversation_tokens):] - response_text = tokenizer.decode(response_tokens) - return { - "choices": [{ - "message": { - "role": "assistant", - "content": response_text - }, - "finish_reason": "stop" - }] - } + except Exception as e: + # Make sure to release worker even on error + await worker_pool.release_worker(worker) + raise e @app.get("/health") async def health(): """Health check endpoint.""" + worker_pool = getattr(app.state, 'worker_pool', None) return { "status": "ok", - "ready": hasattr(app.state, 'model') and app.state.model is not None + "ready": worker_pool is not None and len(worker_pool.workers) > 0, + "num_gpus": worker_pool.num_gpus if worker_pool else 0, + "available_workers": worker_pool.available_workers.qsize() if worker_pool else 0 + } + +@app.get("/stats") +async def stats(): + """Get worker pool statistics.""" + worker_pool = app.state.worker_pool + return { + "total_workers": len(worker_pool.workers), + "available_workers": worker_pool.available_workers.qsize(), + "busy_workers": len(worker_pool.workers) - worker_pool.available_workers.qsize(), + "workers": [ + { + "gpu_id": w.gpu_id, + "device": str(w.device) + } for w in worker_pool.workers + ] } if __name__ == "__main__":