nanochat/scripts/chat_web.py
santhoshravindran7 ac5927e158 security: add rate limiting, CORS fix, stats key guard, log redaction, macOS memory limits
H-2 (High) — scripts/chat_web.py
  Fix CORS misconfiguration: remove allow_credentials=True (incompatible with
  wildcard origin) and restrict allow_methods/allow_headers to the minimum
  required set (GET, POST / Content-Type, X-Stats-Key).

M-5 (Medium) — scripts/chat_web.py
  Add sliding-window rate limiter on /chat/completions keyed by client IP.
  Implemented without additional dependencies using asyncio + defaultdict.
  Configurable via NANOCHAT_RATE_LIMIT and NANOCHAT_RATE_WINDOW env vars
  (defaults: 10 requests per 60 seconds).

M-1 (Medium) — scripts/chat_web.py
  Protect /health and /stats with an optional API key dependency.
  When NANOCHAT_STATS_KEY env var is set, both endpoints require the value
  in the X-Stats-Key header. Uses secrets.compare_digest to prevent timing
  attacks. No-op when env var is unset (backwards compatible).

M-4 (Medium) — scripts/chat_web.py
  Redact full conversation content from server logs.
  User message bodies are no longer logged at INFO level; only message count
  and a 120-char preview at DEBUG level. Assistant response logs now record
  character count only, not content.

L-2 (Low) — nanochat/execution.py
  Enforce memory limits on macOS in the code execution sandbox.
  Previously the entire resource limit block was skipped on Darwin with a
  comment 'seem to fail'. RLIMIT_AS is indeed unsupported on macOS, but
  RLIMIT_DATA is. Linux now uses both RLIMIT_AS and RLIMIT_DATA; macOS uses
  RLIMIT_DATA. Both paths are guarded by a None check.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
2026-03-08 23:19:32 -07:00

449 lines
18 KiB
Python
Raw Blame History

#!/usr/bin/env python3
"""
Unified web chat server - serves both UI and API from a single FastAPI instance.
Uses data parallelism to distribute requests across multiple GPUs. Each GPU loads
a full copy of the model, and incoming requests are distributed to available workers.
Launch examples:
- single available GPU (default)
python -m scripts.chat_web
- 4 GPUs
python -m scripts.chat_web --num-gpus 4
To chat, open the URL printed in the console. (If on cloud box, make sure to use public IP)
Endpoints:
GET / - Chat UI
POST /chat/completions - Chat API (streaming only)
GET /health - Health check with worker pool status
GET /stats - Worker pool statistics and GPU utilization
Abuse Prevention:
- Maximum 500 messages per request
- Maximum 8000 characters per message
- Maximum 32000 characters total conversation length
- Temperature clamped to 0.0-2.0
- Top-k clamped to 0-200 (0 disables top-k filtering, using full vocabulary)
- Max tokens clamped to 1-4096
- Sliding-window rate limiting per client IP (default: 10 req/60s)
Security Environment Variables:
NANOCHAT_RATE_LIMIT Max requests per window per IP (default: 10)
NANOCHAT_RATE_WINDOW Rate limit window in seconds (default: 60)
NANOCHAT_STATS_KEY If set, /health and /stats require this value in X-Stats-Key header
"""
import argparse
import json
import os
import time
import secrets
import torch
import asyncio
import logging
import random
from collections import defaultdict
from contextlib import asynccontextmanager
from fastapi import FastAPI, HTTPException, Request, Depends
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse, HTMLResponse, FileResponse
from fastapi.security import APIKeyHeader
from pydantic import BaseModel
from typing import List, Optional, AsyncGenerator
from dataclasses import dataclass
from nanochat.common import compute_init, autodetect_device_type
from nanochat.checkpoint_manager import load_model
from nanochat.engine import Engine
# Abuse prevention limits
MAX_MESSAGES_PER_REQUEST = 500
MAX_MESSAGE_LENGTH = 8000
MAX_TOTAL_CONVERSATION_LENGTH = 32000
MIN_TEMPERATURE = 0.0
MAX_TEMPERATURE = 2.0
MIN_TOP_K = 0 # 0 disables top-k filtering, using full vocabulary
MAX_TOP_K = 200
MIN_MAX_TOKENS = 1
MAX_MAX_TOKENS = 4096
# Rate limiting: sliding-window per IP, no extra dependencies required
RATE_LIMIT_REQUESTS = int(os.environ.get("NANOCHAT_RATE_LIMIT", "10")) # max requests
RATE_LIMIT_WINDOW = int(os.environ.get("NANOCHAT_RATE_WINDOW", "60")) # per N seconds
_rate_limit_store: dict[str, list[float]] = defaultdict(list)
async def check_rate_limit(request: Request) -> None:
"""Sliding-window rate limiter keyed on client IP. Raises HTTP 429 when exceeded."""
client_ip = request.client.host if request.client else "unknown"
now = time.monotonic()
window_start = now - RATE_LIMIT_WINDOW
timestamps = _rate_limit_store[client_ip]
# Purge timestamps outside the current window
_rate_limit_store[client_ip] = [t for t in timestamps if t > window_start]
if len(_rate_limit_store[client_ip]) >= RATE_LIMIT_REQUESTS:
raise HTTPException(
status_code=429,
detail=f"Rate limit exceeded: max {RATE_LIMIT_REQUESTS} requests per {RATE_LIMIT_WINDOW}s."
)
_rate_limit_store[client_ip].append(now)
# Optional API key guard for /health and /stats (set NANOCHAT_STATS_KEY env var to enable)
_STATS_KEY = os.environ.get("NANOCHAT_STATS_KEY", "")
_stats_key_header = APIKeyHeader(name="X-Stats-Key", auto_error=False)
async def check_stats_key(key: Optional[str] = Depends(_stats_key_header)) -> None:
"""If NANOCHAT_STATS_KEY is set, require it on /health and /stats requests."""
if _STATS_KEY and not secrets.compare_digest(key or "", _STATS_KEY):
raise HTTPException(status_code=403, detail="Forbidden: invalid or missing X-Stats-Key header.")
parser = argparse.ArgumentParser(description='NanoChat Web Server')
parser.add_argument('-n', '--num-gpus', type=int, default=1, help='Number of GPUs to use (default: 1)')
parser.add_argument('-i', '--source', type=str, default="sft", help="Source of the model: sft|rl")
parser.add_argument('-t', '--temperature', type=float, default=0.8, help='Default temperature for generation')
parser.add_argument('-k', '--top-k', type=int, default=50, help='Default top-k sampling parameter')
parser.add_argument('-m', '--max-tokens', type=int, default=512, help='Default max tokens for generation')
parser.add_argument('-g', '--model-tag', type=str, default=None, help='Model tag to load')
parser.add_argument('-s', '--step', type=int, default=None, help='Step to load')
parser.add_argument('-p', '--port', type=int, default=8000, help='Port to run the server on')
parser.add_argument('--device-type', type=str, default='', choices=['cuda', 'cpu', 'mps'], help='Device type for evaluation: cuda|cpu|mps. empty => autodetect')
parser.add_argument('--host', type=str, default='0.0.0.0', help='Host to bind the server to')
args = parser.parse_args()
# Configure logging for conversation traffic
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(message)s',
datefmt='%Y-%m-%d %H:%M:%S'
)
logger = logging.getLogger(__name__)
device_type = autodetect_device_type() if args.device_type == "" else args.device_type
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
@dataclass
class Worker:
"""A worker with a model loaded on a specific GPU."""
gpu_id: int
device: torch.device
engine: Engine
tokenizer: object
class WorkerPool:
"""Pool of workers, each with a model replica on a different GPU."""
def __init__(self, num_gpus: Optional[int] = None):
if num_gpus is None:
if device_type == "cuda":
num_gpus = torch.cuda.device_count()
else:
num_gpus = 1 # e.g. cpu|mps
self.num_gpus = num_gpus
self.workers: List[Worker] = []
self.available_workers: asyncio.Queue = asyncio.Queue()
async def initialize(self, source: str, model_tag: Optional[str] = None, step: Optional[int] = None):
"""Load model on each GPU."""
print(f"Initializing worker pool with {self.num_gpus} GPUs...")
if self.num_gpus > 1:
assert device_type == "cuda", "Only CUDA supports multiple workers/GPUs. cpu|mps does not."
for gpu_id in range(self.num_gpus):
if device_type == "cuda":
device = torch.device(f"cuda:{gpu_id}")
print(f"Loading model on GPU {gpu_id}...")
else:
device = torch.device(device_type) # e.g. cpu|mps
print(f"Loading model on {device_type}...")
model, tokenizer, _ = load_model(source, device, phase="eval", model_tag=model_tag, step=step)
engine = Engine(model, tokenizer)
worker = Worker(
gpu_id=gpu_id,
device=device,
engine=engine,
tokenizer=tokenizer,
)
self.workers.append(worker)
await self.available_workers.put(worker)
print(f"All {self.num_gpus} workers initialized!")
async def acquire_worker(self) -> Worker:
"""Get an available worker from the pool."""
return await self.available_workers.get()
async def release_worker(self, worker: Worker):
"""Return a worker to the pool."""
await self.available_workers.put(worker)
class ChatMessage(BaseModel):
role: str
content: str
class ChatRequest(BaseModel):
messages: List[ChatMessage]
temperature: Optional[float] = None
max_tokens: Optional[int] = None
top_k: Optional[int] = None
def validate_chat_request(request: ChatRequest):
"""Validate chat request to prevent abuse."""
# Check number of messages
if len(request.messages) == 0:
raise HTTPException(status_code=400, detail="At least one message is required")
if len(request.messages) > MAX_MESSAGES_PER_REQUEST:
raise HTTPException(
status_code=400,
detail=f"Too many messages. Maximum {MAX_MESSAGES_PER_REQUEST} messages allowed per request"
)
# Check individual message lengths and total conversation length
total_length = 0
for i, message in enumerate(request.messages):
if not message.content:
raise HTTPException(status_code=400, detail=f"Message {i} has empty content")
msg_length = len(message.content)
if msg_length > MAX_MESSAGE_LENGTH:
raise HTTPException(
status_code=400,
detail=f"Message {i} is too long. Maximum {MAX_MESSAGE_LENGTH} characters allowed per message"
)
total_length += msg_length
if total_length > MAX_TOTAL_CONVERSATION_LENGTH:
raise HTTPException(
status_code=400,
detail=f"Total conversation is too long. Maximum {MAX_TOTAL_CONVERSATION_LENGTH} characters allowed"
)
# Validate role values
for i, message in enumerate(request.messages):
if message.role not in ["user", "assistant"]:
raise HTTPException(
status_code=400,
detail=f"Message {i} has invalid role '{message.role}'. Must be 'user' or 'assistant'."
)
# Validate temperature
if request.temperature is not None:
if not (MIN_TEMPERATURE <= request.temperature <= MAX_TEMPERATURE):
raise HTTPException(
status_code=400,
detail=f"Temperature must be between {MIN_TEMPERATURE} and {MAX_TEMPERATURE}"
)
# Validate top_k
if request.top_k is not None:
if not (MIN_TOP_K <= request.top_k <= MAX_TOP_K):
raise HTTPException(
status_code=400,
detail=f"top_k must be between {MIN_TOP_K} and {MAX_TOP_K}"
)
# Validate max_tokens
if request.max_tokens is not None:
if not (MIN_MAX_TOKENS <= request.max_tokens <= MAX_MAX_TOKENS):
raise HTTPException(
status_code=400,
detail=f"max_tokens must be between {MIN_MAX_TOKENS} and {MAX_MAX_TOKENS}"
)
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Load models on all GPUs on startup."""
print("Loading nanochat models across GPUs...")
app.state.worker_pool = WorkerPool(num_gpus=args.num_gpus)
await app.state.worker_pool.initialize(args.source, model_tag=args.model_tag, step=args.step)
print(f"Server ready at http://localhost:{args.port}")
yield
app = FastAPI(lifespan=lifespan)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=False, # credentials require explicit origin allowlist, not wildcard
allow_methods=["GET", "POST"],
allow_headers=["Content-Type", "X-Stats-Key"],
)
@app.get("/")
async def root():
"""Serve the chat UI."""
ui_html_path = os.path.join("nanochat", "ui.html")
with open(ui_html_path, "r", encoding="utf-8") as f:
html_content = f.read()
# Replace the API_URL to use the same origin
html_content = html_content.replace(
"const API_URL = `http://${window.location.hostname}:8000`;",
"const API_URL = '';"
)
return HTMLResponse(content=html_content)
@app.get("/logo.svg")
async def logo():
"""Serve the NanoChat logo for favicon and header."""
logo_path = os.path.join("nanochat", "logo.svg")
return FileResponse(logo_path, media_type="image/svg+xml")
async def generate_stream(
worker: Worker,
tokens,
temperature=None,
max_new_tokens=None,
top_k=None
) -> AsyncGenerator[str, None]:
"""Generate assistant response with streaming."""
temperature = temperature if temperature is not None else args.temperature
max_new_tokens = max_new_tokens if max_new_tokens is not None else args.max_tokens
top_k = top_k if top_k is not None else args.top_k
assistant_end = worker.tokenizer.encode_special("<|assistant_end|>")
bos = worker.tokenizer.get_bos_token_id()
# Accumulate tokens to properly handle multi-byte UTF-8 characters (like emojis)
accumulated_tokens = []
# Track the last complete UTF-8 string (without replacement characters)
last_clean_text = ""
for token_column, token_masks in worker.engine.generate(
tokens,
num_samples=1,
max_tokens=max_new_tokens,
temperature=temperature,
top_k=top_k,
seed=random.randint(0, 2**31 - 1)
):
token = token_column[0]
# Stopping criteria
if token == assistant_end or token == bos:
break
# Append the token to sequence
accumulated_tokens.append(token)
# Decode all accumulated tokens to get proper UTF-8 handling
# Note that decode is a quite efficient operation, basically table lookup and string concat
current_text = worker.tokenizer.decode(accumulated_tokens)
# Only emit text if it doesn't end with a replacement character
# This ensures we don't emit incomplete UTF-8 sequences
if not current_text.endswith('<EFBFBD>'):
# Extract only the new text since last clean decode
new_text = current_text[len(last_clean_text):]
if new_text: # Only yield if there's new content
yield f"data: {json.dumps({'token': new_text, 'gpu': worker.gpu_id}, ensure_ascii=False)}\n\n"
last_clean_text = current_text
yield f"data: {json.dumps({'done': True})}\n\n"
@app.post("/chat/completions")
async def chat_completions(request: Request, chat_request: ChatRequest, _: None = Depends(check_rate_limit)):
"""Chat completion endpoint (streaming only) - uses worker pool for multi-GPU."""
# Basic validation to prevent abuse
validate_chat_request(chat_request)
# Log incoming conversation to console (truncated for privacy)
logger.info("="*20)
for i, message in enumerate(chat_request.messages):
preview = message.content[:120] + ("" if len(message.content) > 120 else "")
logger.debug(f"[{message.role.upper()}]: {preview}")
logger.info(f"Received request with {len(chat_request.messages)} message(s)")
logger.info("-"*20)
# Acquire a worker from the pool (will wait if all are busy)
worker_pool = app.state.worker_pool
worker = await worker_pool.acquire_worker()
try:
# Build conversation tokens
bos = worker.tokenizer.get_bos_token_id()
user_start = worker.tokenizer.encode_special("<|user_start|>")
user_end = worker.tokenizer.encode_special("<|user_end|>")
assistant_start = worker.tokenizer.encode_special("<|assistant_start|>")
assistant_end = worker.tokenizer.encode_special("<|assistant_end|>")
conversation_tokens = [bos]
for message in chat_request.messages:
if message.role == "user":
conversation_tokens.append(user_start)
conversation_tokens.extend(worker.tokenizer.encode(message.content))
conversation_tokens.append(user_end)
elif message.role == "assistant":
conversation_tokens.append(assistant_start)
conversation_tokens.extend(worker.tokenizer.encode(message.content))
conversation_tokens.append(assistant_end)
conversation_tokens.append(assistant_start)
# Streaming response with worker release after completion
response_tokens = []
async def stream_and_release():
try:
async for chunk in generate_stream(
worker,
conversation_tokens,
temperature=chat_request.temperature,
max_new_tokens=chat_request.max_tokens,
top_k=chat_request.top_k
):
# Accumulate response for logging
chunk_data = json.loads(chunk.replace("data: ", "").strip())
if "token" in chunk_data:
response_tokens.append(chunk_data["token"])
yield chunk
finally:
# Log response length only (not content) to avoid storing user data in logs
full_response = "".join(response_tokens)
logger.info(f"[ASSISTANT] (GPU {worker.gpu_id}): {len(full_response)} chars generated")
logger.info("="*20)
# Release worker back to pool after streaming is done
await worker_pool.release_worker(worker)
return StreamingResponse(
stream_and_release(),
media_type="text/event-stream"
)
except Exception as e:
# Make sure to release worker even on error
await worker_pool.release_worker(worker)
raise e
@app.get("/health")
async def health(_: None = Depends(check_stats_key)):
"""Health check endpoint. Protected by X-Stats-Key header when NANOCHAT_STATS_KEY is set."""
worker_pool = getattr(app.state, 'worker_pool', None)
return {
"status": "ok",
"ready": worker_pool is not None and len(worker_pool.workers) > 0,
"num_gpus": worker_pool.num_gpus if worker_pool else 0,
"available_workers": worker_pool.available_workers.qsize() if worker_pool else 0
}
@app.get("/stats")
async def stats(_: None = Depends(check_stats_key)):
"""Get worker pool statistics. Protected by X-Stats-Key header when NANOCHAT_STATS_KEY is set."""
worker_pool = app.state.worker_pool
return {
"total_workers": len(worker_pool.workers),
"available_workers": worker_pool.available_workers.qsize(),
"busy_workers": len(worker_pool.workers) - worker_pool.available_workers.qsize(),
"workers": [
{
"gpu_id": w.gpu_id,
"device": str(w.device)
} for w in worker_pool.workers
]
}
if __name__ == "__main__":
import uvicorn
print(f"Starting NanoChat Web Server")
print(f"Temperature: {args.temperature}, Top-k: {args.top_k}, Max tokens: {args.max_tokens}")
uvicorn.run(app, host=args.host, port=args.port)