nanochat/scripts/chat_web.py
jasonacox 9a9b12b1be Address Review
- Refactor ChatRequest model to use Pydantic Field for default values
- Update documentation for system prompt exclusion.
- Remove duplicate await in exception
2025-12-11 00:06:49 -08:00

648 lines
24 KiB
Python
Raw Blame History

#!/usr/bin/env python3
"""
Unified web chat server - serves both UI and API from a single FastAPI instance.
Uses data parallelism to distribute requests across multiple GPUs. Each GPU loads
a full copy of the model, and incoming requests are distributed to available workers.
Launch examples:
- single available GPU (default)
python -m scripts.chat_web
- 4 GPUs
python -m scripts.chat_web --num-gpus 4
To chat, open the URL printed in the console. (If on cloud box, make sure to use public IP)
Endpoints:
GET / - Chat UI
POST /chat/completions - Chat API (streaming only, custom format)
POST /v1/chat/completions - OpenAI-compatible chat completions (streaming and non-streaming)
GET /v1/models - List available models (OpenAI-compatible)
GET /health - Health check with worker pool status
GET /stats - Worker pool statistics and GPU utilization
OpenAI API Compatibility:
- Supports both streaming and non-streaming responses
- Compatible with OpenAI Python SDK and other OpenAI-compatible clients
- Implements standard OpenAI request/response format
Example usage with OpenAI SDK:
```python
from openai import OpenAI
client = OpenAI(
api_key="not-needed",
base_url="http://localhost:8000/v1"
)
response = client.chat.completions.create(
model="nanochat",
messages=[
{"role": "user", "content": "Hello!"}
],
temperature=0.8,
max_tokens=512,
stream=True
)
for chunk in response:
if chunk.choices[0].delta.content:
print(chunk.choices[0].delta.content, end="", flush=True)
```
Abuse Prevention:
- Maximum 500 messages per request
- Maximum 8000 characters per message
- Maximum 32000 characters total conversation length
- Temperature clamped to 0.0-2.0
- Top-k clamped to 1-200
- Max tokens clamped to 1-4096
"""
import argparse
import json
import os
import torch
import asyncio
import logging
import random
import time
from contextlib import asynccontextmanager
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse, HTMLResponse, FileResponse
from pydantic import BaseModel, Field
from typing import List, Optional, AsyncGenerator, Literal
from dataclasses import dataclass
from contextlib import nullcontext
from nanochat.common import compute_init, autodetect_device_type
from nanochat.checkpoint_manager import load_model
from nanochat.engine import Engine
# Abuse prevention limits
MAX_MESSAGES_PER_REQUEST = 500
MAX_MESSAGE_LENGTH = 8000
MAX_TOTAL_CONVERSATION_LENGTH = 32000
MIN_TEMPERATURE = 0.0
MAX_TEMPERATURE = 2.0
MIN_TOP_K = 1
MAX_TOP_K = 200
MIN_MAX_TOKENS = 1
MAX_MAX_TOKENS = 4096
parser = argparse.ArgumentParser(description='NanoChat Web Server')
parser.add_argument('-n', '--num-gpus', type=int, default=1, help='Number of GPUs to use (default: 1)')
parser.add_argument('-i', '--source', type=str, default="sft", help="Source of the model: sft|mid|rl")
parser.add_argument('-t', '--temperature', type=float, default=0.8, help='Default temperature for generation')
parser.add_argument('-k', '--top-k', type=int, default=50, help='Default top-k sampling parameter')
parser.add_argument('-m', '--max-tokens', type=int, default=512, help='Default max tokens for generation')
parser.add_argument('-g', '--model-tag', type=str, default=None, help='Model tag to load')
parser.add_argument('-s', '--step', type=int, default=None, help='Step to load')
parser.add_argument('-p', '--port', type=int, default=8000, help='Port to run the server on')
parser.add_argument('-d', '--dtype', type=str, default='bfloat16', choices=['float32', 'bfloat16'])
parser.add_argument('--device-type', type=str, default='', choices=['cuda', 'cpu', 'mps'], help='Device type for evaluation: cuda|cpu|mps. empty => autodetect')
parser.add_argument('--host', type=str, default='0.0.0.0', help='Host to bind the server to')
args = parser.parse_args()
# Configure logging for conversation traffic
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(message)s',
datefmt='%Y-%m-%d %H:%M:%S'
)
logger = logging.getLogger(__name__)
device_type = autodetect_device_type() if args.device_type == "" else args.device_type
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
ptdtype = torch.float32 if args.dtype == 'float32' else torch.bfloat16
@dataclass
class Worker:
"""A worker with a model loaded on a specific GPU."""
gpu_id: int
device: torch.device
engine: Engine
tokenizer: object
autocast_ctx: torch.amp.autocast
class WorkerPool:
"""Pool of workers, each with a model replica on a different GPU."""
def __init__(self, num_gpus: Optional[int] = None):
if num_gpus is None:
if device_type == "cuda":
num_gpus = torch.cuda.device_count()
else:
num_gpus = 1 # e.g. cpu|mps
self.num_gpus = num_gpus
self.workers: List[Worker] = []
self.available_workers: asyncio.Queue = asyncio.Queue()
async def initialize(self, source: str, model_tag: Optional[str] = None, step: Optional[int] = None):
"""Load model on each GPU."""
print(f"Initializing worker pool with {self.num_gpus} GPUs...")
if self.num_gpus > 1:
assert device_type == "cuda", "Only CUDA supports multiple workers/GPUs. cpu|mps does not."
for gpu_id in range(self.num_gpus):
if device_type == "cuda":
device = torch.device(f"cuda:{gpu_id}")
print(f"Loading model on GPU {gpu_id}...")
else:
device = torch.device(device_type) # e.g. cpu|mps
print(f"Loading model on {device_type}...")
model, tokenizer, _ = load_model(source, device, phase="eval", model_tag=model_tag, step=step)
engine = Engine(model, tokenizer)
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype) if device_type == "cuda" else nullcontext()
worker = Worker(
gpu_id=gpu_id,
device=device,
engine=engine,
tokenizer=tokenizer,
autocast_ctx=autocast_ctx
)
self.workers.append(worker)
await self.available_workers.put(worker)
print(f"All {self.num_gpus} workers initialized!")
async def acquire_worker(self) -> Worker:
"""Get an available worker from the pool."""
return await self.available_workers.get()
async def release_worker(self, worker: Worker):
"""Return a worker to the pool."""
await self.available_workers.put(worker)
class ChatMessage(BaseModel):
role: str
content: str
class ChatRequest(BaseModel):
messages: List[ChatMessage]
temperature: Optional[float] = None
max_tokens: Optional[int] = None
top_k: Optional[int] = None
model: str = Field(default="nanochat") # For OpenAI compatibility
stream: Optional[bool] = False # For OpenAI compatibility
top_p: Optional[float] = None # Ignored, for compatibility
# OpenAI API Models
class ChatCompletionResponseChoice(BaseModel):
index: int
message: ChatMessage
finish_reason: str
class ChatCompletionResponseStreamChoice(BaseModel):
index: int
delta: dict
finish_reason: Optional[str] = None
class Usage(BaseModel):
prompt_tokens: int
completion_tokens: int
total_tokens: int
class ChatCompletionResponse(BaseModel):
id: str
object: Literal["chat.completion"] = "chat.completion"
created: int
model: str
choices: List[ChatCompletionResponseChoice]
usage: Usage
class ChatCompletionStreamResponse(BaseModel):
id: str
object: Literal["chat.completion.chunk"] = "chat.completion.chunk"
created: int
model: str
choices: List[ChatCompletionResponseStreamChoice]
class Model(BaseModel):
id: str
object: Literal["model"] = "model"
created: int
owned_by: str
class ModelList(BaseModel):
object: Literal["list"] = "list"
data: List[Model]
def validate_chat_request(request: ChatRequest):
"""Validate chat request to prevent abuse."""
# Check number of messages
if len(request.messages) == 0:
raise HTTPException(status_code=400, detail="At least one message is required")
if len(request.messages) > MAX_MESSAGES_PER_REQUEST:
raise HTTPException(
status_code=400,
detail=f"Too many messages. Maximum {MAX_MESSAGES_PER_REQUEST} messages allowed per request"
)
# Check individual message lengths and total conversation length
total_length = 0
for i, message in enumerate(request.messages):
if not message.content:
raise HTTPException(status_code=400, detail=f"Message {i} has empty content")
msg_length = len(message.content)
if msg_length > MAX_MESSAGE_LENGTH:
raise HTTPException(
status_code=400,
detail=f"Message {i} is too long. Maximum {MAX_MESSAGE_LENGTH} characters allowed per message"
)
total_length += msg_length
if total_length > MAX_TOTAL_CONVERSATION_LENGTH:
raise HTTPException(
status_code=400,
detail=f"Total conversation is too long. Maximum {MAX_TOTAL_CONVERSATION_LENGTH} characters allowed"
)
# Validate role values
valid_roles = ["user", "assistant", "system"]
for i, message in enumerate(request.messages):
if message.role not in valid_roles:
raise HTTPException(
status_code=400,
detail=f"Message {i} has invalid role '{message.role}'. Must be one of: {', '.join(valid_roles)}"
)
# Validate temperature
if request.temperature is not None:
if not (MIN_TEMPERATURE <= request.temperature <= MAX_TEMPERATURE):
raise HTTPException(
status_code=400,
detail=f"Temperature must be between {MIN_TEMPERATURE} and {MAX_TEMPERATURE}"
)
# Validate top_k
if request.top_k is not None:
if not (MIN_TOP_K <= request.top_k <= MAX_TOP_K):
raise HTTPException(
status_code=400,
detail=f"top_k must be between {MIN_TOP_K} and {MAX_TOP_K}"
)
# Validate max_tokens
if request.max_tokens is not None:
if not (MIN_MAX_TOKENS <= request.max_tokens <= MAX_MAX_TOKENS):
raise HTTPException(
status_code=400,
detail=f"max_tokens must be between {MIN_MAX_TOKENS} and {MAX_MAX_TOKENS}"
)
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Load models on all GPUs on startup."""
print("Loading nanochat models across GPUs...")
app.state.worker_pool = WorkerPool(num_gpus=args.num_gpus)
await app.state.worker_pool.initialize(args.source, model_tag=args.model_tag, step=args.step)
print(f"Server ready at http://localhost:{args.port}")
yield
app = FastAPI(lifespan=lifespan)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
@app.get("/")
async def root():
"""Serve the chat UI."""
ui_html_path = os.path.join("nanochat", "ui.html")
with open(ui_html_path, "r", encoding="utf-8") as f:
html_content = f.read()
# Replace the API_URL to use the same origin
html_content = html_content.replace(
"const API_URL = `http://${window.location.hostname}:8000`;",
"const API_URL = '';"
)
return HTMLResponse(content=html_content)
@app.get("/logo.svg")
async def logo():
"""Serve the NanoChat logo for favicon and header."""
logo_path = os.path.join("nanochat", "logo.svg")
return FileResponse(logo_path, media_type="image/svg+xml")
async def generate_stream(
worker: Worker,
tokens,
temperature=None,
max_new_tokens=None,
top_k=None
) -> AsyncGenerator[tuple[str, int], None]:
"""Generate assistant response with streaming. Returns (text, token_count) tuples."""
temperature = temperature if temperature is not None else args.temperature
max_new_tokens = max_new_tokens if max_new_tokens is not None else args.max_tokens
top_k = top_k if top_k is not None else args.top_k
assistant_end = worker.tokenizer.encode_special("<|assistant_end|>")
bos = worker.tokenizer.get_bos_token_id()
# Accumulate tokens to properly handle multi-byte UTF-8 characters (like emojis)
accumulated_tokens = []
# Track the last complete UTF-8 string (without replacement characters)
last_clean_text = ""
completion_tokens = 0
with worker.autocast_ctx:
for token_column, token_masks in worker.engine.generate(
tokens,
num_samples=1,
max_tokens=max_new_tokens,
temperature=temperature,
top_k=top_k,
seed=random.randint(0, 2**31 - 1)
):
token = token_column[0]
# Stopping criteria
if token == assistant_end or token == bos:
break
# Append the token to sequence
accumulated_tokens.append(token)
completion_tokens += 1
# Decode all accumulated tokens to get proper UTF-8 handling
# Note that decode is a quite efficient operation, basically table lookup and string concat
current_text = worker.tokenizer.decode(accumulated_tokens)
# Only emit text if it doesn't end with a replacement character
# This ensures we don't emit incomplete UTF-8 sequences
if not current_text.endswith('<EFBFBD>'):
# Extract only the new text since last clean decode
new_text = current_text[len(last_clean_text):]
if new_text: # Only yield if there's new content
yield (new_text, completion_tokens)
last_clean_text = current_text
@app.post("/chat/completions")
async def chat_completions(request: ChatRequest):
"""Chat completion endpoint (streaming only) - uses worker pool for multi-GPU."""
# Basic validation to prevent abuse
validate_chat_request(request)
# Log incoming conversation to console
logger.info("="*20)
for i, message in enumerate(request.messages):
logger.info(f"[{message.role.upper()}]: {message.content}")
logger.info("-"*20)
# Acquire a worker from the pool (will wait if all are busy)
worker_pool = app.state.worker_pool
worker = await worker_pool.acquire_worker()
try:
# Build conversation tokens
bos = worker.tokenizer.get_bos_token_id()
user_start = worker.tokenizer.encode_special("<|user_start|>")
user_end = worker.tokenizer.encode_special("<|user_end|>")
assistant_start = worker.tokenizer.encode_special("<|assistant_start|>")
assistant_end = worker.tokenizer.encode_special("<|assistant_end|>")
conversation_tokens = [bos]
for message in request.messages:
if message.role == "user":
conversation_tokens.append(user_start)
conversation_tokens.extend(worker.tokenizer.encode(message.content))
conversation_tokens.append(user_end)
elif message.role == "assistant":
conversation_tokens.append(assistant_start)
conversation_tokens.extend(worker.tokenizer.encode(message.content))
conversation_tokens.append(assistant_end)
conversation_tokens.append(assistant_start)
# Streaming response with worker release after completion
response_tokens = []
async def stream_and_release():
try:
async for text_chunk, token_count in generate_stream(
worker,
conversation_tokens,
temperature=request.temperature,
max_new_tokens=request.max_tokens,
top_k=request.top_k
):
# Accumulate response for logging
response_tokens.append(text_chunk)
# Format as custom JSON for web UI
yield f"data: {json.dumps({'token': text_chunk, 'gpu': worker.gpu_id}, ensure_ascii=False)}\n\n"
# Send done message
yield f"data: {json.dumps({'done': True})}\n\n"
finally:
# Log the assistant response to console
full_response = "".join(response_tokens)
logger.info(f"[ASSISTANT] (GPU {worker.gpu_id}): {full_response}")
logger.info("="*20)
# Release worker back to pool after streaming is done
await worker_pool.release_worker(worker)
return StreamingResponse(
stream_and_release(),
media_type="text/event-stream"
)
except Exception as e:
# Make sure to release worker even on error
await worker_pool.release_worker(worker)
raise e
@app.post("/v1/chat/completions")
async def openai_chat_completions(request: ChatRequest):
"""OpenAI-compatible chat completion endpoint (supports streaming and non-streaming).
Note: System role messages are accepted but ignored as nanochat doesn't support system prompts.
"""
validate_chat_request(request)
# Log incoming request
logger.info(f"OpenAI API Request: model={request.model}, stream={request.stream}, messages={len(request.messages)}")
worker_pool = app.state.worker_pool
worker = await worker_pool.acquire_worker()
request_id = f"chatcmpl-{random.randint(1000000, 9999999)}"
created = int(time.time())
try:
# Build conversation tokens
bos = worker.tokenizer.get_bos_token_id()
user_start = worker.tokenizer.encode_special("<|user_start|>")
user_end = worker.tokenizer.encode_special("<|user_end|>")
assistant_start = worker.tokenizer.encode_special("<|assistant_start|>")
assistant_end = worker.tokenizer.encode_special("<|assistant_end|>")
conversation_tokens = [bos]
for message in request.messages:
if message.role == "user":
conversation_tokens.append(user_start)
conversation_tokens.extend(worker.tokenizer.encode(message.content))
conversation_tokens.append(user_end)
elif message.role == "assistant":
conversation_tokens.append(assistant_start)
conversation_tokens.extend(worker.tokenizer.encode(message.content))
conversation_tokens.append(assistant_end)
elif message.role == "system":
# Note: system messages are ignored as nanochat doesn't have system role support
logger.warning(f"System role message ignored (not supported by this model): {message.content[:100]}...")
conversation_tokens.append(assistant_start)
prompt_tokens = len(conversation_tokens)
if request.stream:
# Streaming response
async def stream_response():
try:
completion_text = ""
completion_tokens = 0
async for text_chunk, token_count in generate_stream(
worker,
conversation_tokens,
temperature=request.temperature,
max_new_tokens=request.max_tokens,
top_k=request.top_k
):
completion_text += text_chunk
completion_tokens = token_count
chunk = ChatCompletionStreamResponse(
id=request_id,
created=created,
model=request.model,
choices=[
ChatCompletionResponseStreamChoice(
index=0,
delta={"content": text_chunk},
finish_reason=None
)
]
)
yield f"data: {json.dumps(chunk.model_dump(), ensure_ascii=False)}\n\n"
# Final chunk with finish_reason
final_chunk = ChatCompletionStreamResponse(
id=request_id,
created=created,
model=request.model,
choices=[
ChatCompletionResponseStreamChoice(
index=0,
delta={},
finish_reason="stop"
)
]
)
yield f"data: {json.dumps(final_chunk.model_dump(), ensure_ascii=False)}\n\n"
yield "data: [DONE]\n\n"
logger.info(f"OpenAI API Response (GPU {worker.gpu_id}): {completion_tokens} tokens")
finally:
await worker_pool.release_worker(worker)
return StreamingResponse(
stream_response(),
media_type="text/event-stream"
)
else:
# Non-streaming response
try:
completion_text = ""
completion_tokens = 0
async for text_chunk, token_count in generate_stream(
worker,
conversation_tokens,
temperature=request.temperature,
max_new_tokens=request.max_tokens,
top_k=request.top_k
):
completion_text += text_chunk
completion_tokens = token_count
response = ChatCompletionResponse(
id=request_id,
created=created,
model=request.model,
choices=[
ChatCompletionResponseChoice(
index=0,
message=ChatMessage(role="assistant", content=completion_text),
finish_reason="stop"
)
],
usage=Usage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens
)
)
logger.info(f"OpenAI API Response (GPU {worker.gpu_id}): {completion_tokens} tokens")
return response
finally:
await worker_pool.release_worker(worker)
except Exception as e:
logger.error(f"Error: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
@app.get("/v1/models")
async def list_models() -> ModelList:
"""List available models (OpenAI-compatible endpoint)."""
return ModelList(
data=[
Model(
id="nanochat",
created=int(time.time()),
owned_by="nanochat"
)
]
)
@app.get("/health")
async def health():
"""Health check endpoint."""
worker_pool = getattr(app.state, 'worker_pool', None)
return {
"status": "ok",
"ready": worker_pool is not None and len(worker_pool.workers) > 0,
"num_gpus": worker_pool.num_gpus if worker_pool else 0,
"available_workers": worker_pool.available_workers.qsize() if worker_pool else 0
}
@app.get("/stats")
async def stats():
"""Get worker pool statistics."""
worker_pool = app.state.worker_pool
return {
"total_workers": len(worker_pool.workers),
"available_workers": worker_pool.available_workers.qsize(),
"busy_workers": len(worker_pool.workers) - worker_pool.available_workers.qsize(),
"workers": [
{
"gpu_id": w.gpu_id,
"device": str(w.device)
} for w in worker_pool.workers
]
}
if __name__ == "__main__":
import uvicorn
print("Starting NanoChat Web Server")
print(f"Temperature: {args.temperature}, Top-k: {args.top_k}, Max tokens: {args.max_tokens}")
uvicorn.run(app, host=args.host, port=args.port)