#!/usr/bin/env python3 """ Unified web chat server - serves both UI and API from a single FastAPI instance. Uses data parallelism to distribute requests across multiple GPUs. Each GPU loads a full copy of the model, and incoming requests are distributed to available workers. Launch examples: - single available GPU (default) python -m scripts.chat_web - 4 GPUs python -m scripts.chat_web --num-gpus 4 To chat, open the URL printed in the console. (If on cloud box, make sure to use public IP) Endpoints: GET / - Chat UI POST /chat/completions - Chat API (streaming only) GET /health - Health check with worker pool status GET /stats - Worker pool statistics and GPU utilization """ import argparse import json import os import torch import asyncio from contextlib import asynccontextmanager from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import StreamingResponse, HTMLResponse, FileResponse from pydantic import BaseModel from typing import List, Optional, AsyncGenerator from dataclasses import dataclass from nanochat.common import compute_init from nanochat.checkpoint_manager import load_model from nanochat.engine import Engine parser = argparse.ArgumentParser(description='NanoChat Web Server') parser.add_argument('-n', '--num-gpus', type=int, default=1, help='Number of GPUs to use (default: 1)') parser.add_argument('-i', '--source', type=str, default="sft", help="Source of the model: sft|mid|rl") parser.add_argument('-t', '--temperature', type=float, default=0.8, help='Default temperature for generation') parser.add_argument('-k', '--top-k', type=int, default=50, help='Default top-k sampling parameter') parser.add_argument('-m', '--max-tokens', type=int, default=512, help='Default max tokens for generation') parser.add_argument('-g', '--model-tag', type=str, default=None, help='Model tag to load') parser.add_argument('-s', '--step', type=int, default=None, help='Step to load') parser.add_argument('-p', '--port', type=int, default=8000, help='Port to run the server on') parser.add_argument('--host', type=str, default='0.0.0.0', help='Host to bind the server to') args = parser.parse_args() ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init() @dataclass class Worker: """A worker with a model loaded on a specific GPU.""" gpu_id: int device: torch.device engine: Engine tokenizer: object autocast_ctx: torch.amp.autocast class WorkerPool: """Pool of workers, each with a model replica on a different GPU.""" def __init__(self, num_gpus: Optional[int] = None): self.num_gpus = num_gpus if num_gpus is not None else torch.cuda.device_count() self.workers: List[Worker] = [] self.available_workers: asyncio.Queue = asyncio.Queue() async def initialize(self, source: str, model_tag: Optional[str] = None, step: Optional[int] = None): """Load model on each GPU.""" print(f"Initializing worker pool with {self.num_gpus} GPUs...") for gpu_id in range(self.num_gpus): device = torch.device(f"cuda:{gpu_id}") print(f"Loading model on GPU {gpu_id}...") model, tokenizer, _ = load_model(source, device, phase="eval", model_tag=model_tag, step=step) engine = Engine(model, tokenizer) autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16) worker = Worker( gpu_id=gpu_id, device=device, engine=engine, tokenizer=tokenizer, autocast_ctx=autocast_ctx ) self.workers.append(worker) await self.available_workers.put(worker) print(f"All {self.num_gpus} workers initialized!") async def acquire_worker(self) -> Worker: """Get an available worker from the pool.""" return await self.available_workers.get() async def release_worker(self, worker: Worker): """Return a worker to the pool.""" await self.available_workers.put(worker) class ChatMessage(BaseModel): role: str content: str class ChatRequest(BaseModel): messages: List[ChatMessage] temperature: Optional[float] = None max_tokens: Optional[int] = None top_k: Optional[int] = None @asynccontextmanager async def lifespan(app: FastAPI): """Load models on all GPUs on startup.""" print("Loading nanochat models across GPUs...") app.state.worker_pool = WorkerPool(num_gpus=args.num_gpus) await app.state.worker_pool.initialize(args.source, model_tag=args.model_tag, step=args.step) print(f"Server ready at http://localhost:{args.port}") yield app = FastAPI(lifespan=lifespan) app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) @app.get("/") async def root(): """Serve the chat UI.""" ui_html_path = os.path.join("nanochat", "ui.html") with open(ui_html_path, "r") as f: html_content = f.read() # Replace the API_URL to use the same origin html_content = html_content.replace( "const API_URL = `http://${window.location.hostname}:8000`;", "const API_URL = '';" ) return HTMLResponse(content=html_content) @app.get("/logo.svg") async def logo(): """Serve the NanoChat logo for favicon and header.""" logo_path = os.path.join("nanochat", "logo.svg") return FileResponse(logo_path, media_type="image/svg+xml") async def generate_stream( worker: Worker, tokens, temperature=None, max_new_tokens=None, top_k=None ) -> AsyncGenerator[str, None]: """Generate assistant response with streaming.""" temperature = temperature if temperature is not None else args.temperature max_new_tokens = max_new_tokens if max_new_tokens is not None else args.max_tokens top_k = top_k if top_k is not None else args.top_k assistant_end = worker.tokenizer.encode_special("<|assistant_end|>") bos = worker.tokenizer.get_bos_token_id() with worker.autocast_ctx: for token_column, token_masks in worker.engine.generate( tokens, num_samples=1, max_tokens=max_new_tokens, temperature=temperature, top_k=top_k ): token = token_column[0] if token == assistant_end or token == bos: break token_text = worker.tokenizer.decode([token]) yield f"data: {json.dumps({'token': token_text, 'gpu': worker.gpu_id})}\n\n" yield f"data: {json.dumps({'done': True})}\n\n" @app.post("/chat/completions") async def chat_completions(request: ChatRequest): """Chat completion endpoint (streaming only) - uses worker pool for multi-GPU.""" worker_pool = app.state.worker_pool # Acquire a worker from the pool (will wait if all are busy) worker = await worker_pool.acquire_worker() try: # Build conversation tokens bos = worker.tokenizer.get_bos_token_id() user_start = worker.tokenizer.encode_special("<|user_start|>") user_end = worker.tokenizer.encode_special("<|user_end|>") assistant_start = worker.tokenizer.encode_special("<|assistant_start|>") assistant_end = worker.tokenizer.encode_special("<|assistant_end|>") conversation_tokens = [bos] for message in request.messages: if message.role == "user": conversation_tokens.append(user_start) conversation_tokens.extend(worker.tokenizer.encode(message.content)) conversation_tokens.append(user_end) elif message.role == "assistant": conversation_tokens.append(assistant_start) conversation_tokens.extend(worker.tokenizer.encode(message.content)) conversation_tokens.append(assistant_end) conversation_tokens.append(assistant_start) # Streaming response with worker release after completion async def stream_and_release(): try: async for chunk in generate_stream( worker, conversation_tokens, temperature=request.temperature, max_new_tokens=request.max_tokens, top_k=request.top_k ): yield chunk finally: # Release worker back to pool after streaming is done await worker_pool.release_worker(worker) return StreamingResponse( stream_and_release(), media_type="text/event-stream" ) except Exception as e: # Make sure to release worker even on error await worker_pool.release_worker(worker) raise e @app.get("/health") async def health(): """Health check endpoint.""" worker_pool = getattr(app.state, 'worker_pool', None) return { "status": "ok", "ready": worker_pool is not None and len(worker_pool.workers) > 0, "num_gpus": worker_pool.num_gpus if worker_pool else 0, "available_workers": worker_pool.available_workers.qsize() if worker_pool else 0 } @app.get("/stats") async def stats(): """Get worker pool statistics.""" worker_pool = app.state.worker_pool return { "total_workers": len(worker_pool.workers), "available_workers": worker_pool.available_workers.qsize(), "busy_workers": len(worker_pool.workers) - worker_pool.available_workers.qsize(), "workers": [ { "gpu_id": w.gpu_id, "device": str(w.device) } for w in worker_pool.workers ] } if __name__ == "__main__": import uvicorn print(f"Starting NanoChat Web Server") print(f"Temperature: {args.temperature}, Top-k: {args.top_k}, Max tokens: {args.max_tokens}") uvicorn.run(app, host=args.host, port=args.port)