diff --git a/scripts/chat_web.py b/scripts/chat_web.py index 4b67b62..3b2d3b6 100644 --- a/scripts/chat_web.py +++ b/scripts/chat_web.py @@ -16,10 +16,41 @@ python -m scripts.chat_web --num-gpus 4 To chat, open the URL printed in the console. (If on cloud box, make sure to use public IP) Endpoints: - GET / - Chat UI - POST /chat/completions - Chat API (streaming only) - GET /health - Health check with worker pool status - GET /stats - Worker pool statistics and GPU utilization + GET / - Chat UI + POST /chat/completions - Chat API (streaming only, custom format) + POST /v1/chat/completions - OpenAI-compatible chat completions (streaming and non-streaming) + GET /v1/models - List available models (OpenAI-compatible) + GET /health - Health check with worker pool status + GET /stats - Worker pool statistics and GPU utilization + +OpenAI API Compatibility: + - Supports both streaming and non-streaming responses + - Compatible with OpenAI Python SDK and other OpenAI-compatible clients + - Implements standard OpenAI request/response format + +Example usage with OpenAI SDK: + ```python + from openai import OpenAI + + client = OpenAI( + api_key="not-needed", + base_url="http://localhost:8000/v1" + ) + + response = client.chat.completions.create( + model="nanochat", + messages=[ + {"role": "user", "content": "Hello!"} + ], + temperature=0.8, + max_tokens=512, + stream=True + ) + + for chunk in response: + if chunk.choices[0].delta.content: + print(chunk.choices[0].delta.content, end="", flush=True) + ``` Abuse Prevention: - Maximum 500 messages per request @@ -37,12 +68,13 @@ import torch import asyncio import logging import random +import time from contextlib import asynccontextmanager from fastapi import FastAPI, HTTPException from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import StreamingResponse, HTMLResponse, FileResponse from pydantic import BaseModel -from typing import List, Optional, AsyncGenerator +from typing import List, Optional, AsyncGenerator, Literal from dataclasses import dataclass from contextlib import nullcontext from nanochat.common import compute_init, autodetect_device_type @@ -156,6 +188,50 @@ class ChatRequest(BaseModel): temperature: Optional[float] = None max_tokens: Optional[int] = None top_k: Optional[int] = None + model: Optional[str] = "nanochat" # For OpenAI compatibility + stream: Optional[bool] = None # For OpenAI compatibility + top_p: Optional[float] = None # Ignored, for compatibility + +# OpenAI API Models +class ChatCompletionResponseChoice(BaseModel): + index: int + message: ChatMessage + finish_reason: str + +class ChatCompletionResponseStreamChoice(BaseModel): + index: int + delta: dict + finish_reason: Optional[str] = None + +class Usage(BaseModel): + prompt_tokens: int + completion_tokens: int + total_tokens: int + +class ChatCompletionResponse(BaseModel): + id: str + object: Literal["chat.completion"] = "chat.completion" + created: int + model: str + choices: List[ChatCompletionResponseChoice] + usage: Usage + +class ChatCompletionStreamResponse(BaseModel): + id: str + object: Literal["chat.completion.chunk"] = "chat.completion.chunk" + created: int + model: str + choices: List[ChatCompletionResponseStreamChoice] + +class Model(BaseModel): + id: str + object: Literal["model"] = "model" + created: int + owned_by: str + +class ModelList(BaseModel): + object: Literal["list"] = "list" + data: List[Model] def validate_chat_request(request: ChatRequest): """Validate chat request to prevent abuse.""" @@ -189,11 +265,12 @@ def validate_chat_request(request: ChatRequest): ) # Validate role values + valid_roles = ["user", "assistant", "system"] for i, message in enumerate(request.messages): - if message.role not in ["user", "assistant"]: + if message.role not in valid_roles: raise HTTPException( status_code=400, - detail=f"Message {i} has invalid role. Must be 'user', 'assistant', or 'system'" + detail=f"Message {i} has invalid role '{message.role}'. Must be one of: {', '.join(valid_roles)}" ) # Validate temperature @@ -265,8 +342,8 @@ async def generate_stream( temperature=None, max_new_tokens=None, top_k=None -) -> AsyncGenerator[str, None]: - """Generate assistant response with streaming.""" +) -> AsyncGenerator[tuple[str, int], None]: + """Generate assistant response with streaming. Returns (text, token_count) tuples.""" temperature = temperature if temperature is not None else args.temperature max_new_tokens = max_new_tokens if max_new_tokens is not None else args.max_tokens top_k = top_k if top_k is not None else args.top_k @@ -278,6 +355,7 @@ async def generate_stream( accumulated_tokens = [] # Track the last complete UTF-8 string (without replacement characters) last_clean_text = "" + completion_tokens = 0 with worker.autocast_ctx: for token_column, token_masks in worker.engine.generate( @@ -296,6 +374,7 @@ async def generate_stream( # Append the token to sequence accumulated_tokens.append(token) + completion_tokens += 1 # Decode all accumulated tokens to get proper UTF-8 handling # Note that decode is a quite efficient operation, basically table lookup and string concat current_text = worker.tokenizer.decode(accumulated_tokens) @@ -305,11 +384,9 @@ async def generate_stream( # Extract only the new text since last clean decode new_text = current_text[len(last_clean_text):] if new_text: # Only yield if there's new content - yield f"data: {json.dumps({'token': new_text, 'gpu': worker.gpu_id}, ensure_ascii=False)}\n\n" + yield (new_text, completion_tokens) last_clean_text = current_text - yield f"data: {json.dumps({'done': True})}\n\n" - @app.post("/chat/completions") async def chat_completions(request: ChatRequest): """Chat completion endpoint (streaming only) - uses worker pool for multi-GPU.""" @@ -352,7 +429,7 @@ async def chat_completions(request: ChatRequest): response_tokens = [] async def stream_and_release(): try: - async for chunk in generate_stream( + async for text_chunk, token_count in generate_stream( worker, conversation_tokens, temperature=request.temperature, @@ -360,10 +437,11 @@ async def chat_completions(request: ChatRequest): top_k=request.top_k ): # Accumulate response for logging - chunk_data = json.loads(chunk.replace("data: ", "").strip()) - if "token" in chunk_data: - response_tokens.append(chunk_data["token"]) - yield chunk + response_tokens.append(text_chunk) + # Format as custom JSON for web UI + yield f"data: {json.dumps({'token': text_chunk, 'gpu': worker.gpu_id}, ensure_ascii=False)}\n\n" + # Send done message + yield f"data: {json.dumps({'done': True})}\n\n" finally: # Log the assistant response to console full_response = "".join(response_tokens) @@ -381,6 +459,156 @@ async def chat_completions(request: ChatRequest): await worker_pool.release_worker(worker) raise e +@app.post("/v1/chat/completions") +async def openai_chat_completions(request: ChatRequest): + """OpenAI-compatible chat completion endpoint (supports streaming and non-streaming).""" + + validate_chat_request(request) + + # Log incoming request + logger.info(f"OpenAI API Request: model={request.model}, stream={request.stream}, messages={len(request.messages)}") + + worker_pool = app.state.worker_pool + worker = await worker_pool.acquire_worker() + + request_id = f"chatcmpl-{random.randint(1000000, 9999999)}" + created = int(time.time()) + + try: + # Build conversation tokens + bos = worker.tokenizer.get_bos_token_id() + user_start = worker.tokenizer.encode_special("<|user_start|>") + user_end = worker.tokenizer.encode_special("<|user_end|>") + assistant_start = worker.tokenizer.encode_special("<|assistant_start|>") + assistant_end = worker.tokenizer.encode_special("<|assistant_end|>") + + conversation_tokens = [bos] + for message in request.messages: + if message.role == "user": + conversation_tokens.append(user_start) + conversation_tokens.extend(worker.tokenizer.encode(message.content)) + conversation_tokens.append(user_end) + elif message.role == "assistant": + conversation_tokens.append(assistant_start) + conversation_tokens.extend(worker.tokenizer.encode(message.content)) + conversation_tokens.append(assistant_end) + # Note: system messages are ignored as nanochat doesn't have system role support + + conversation_tokens.append(assistant_start) + prompt_tokens = len(conversation_tokens) + + if request.stream: + # Streaming response + async def stream_response(): + try: + completion_text = "" + completion_tokens = 0 + + async for text_chunk, token_count in generate_stream( + worker, + conversation_tokens, + temperature=request.temperature, + max_new_tokens=request.max_tokens, + top_k=request.top_k + ): + completion_text += text_chunk + completion_tokens = token_count + + chunk = ChatCompletionStreamResponse( + id=request_id, + created=created, + model=request.model, + choices=[ + ChatCompletionResponseStreamChoice( + index=0, + delta={"content": text_chunk}, + finish_reason=None + ) + ] + ) + yield f"data: {json.dumps(chunk.model_dump(), ensure_ascii=False)}\n\n" + + # Final chunk with finish_reason + final_chunk = ChatCompletionStreamResponse( + id=request_id, + created=created, + model=request.model, + choices=[ + ChatCompletionResponseStreamChoice( + index=0, + delta={}, + finish_reason="stop" + ) + ] + ) + yield f"data: {json.dumps(final_chunk.model_dump(), ensure_ascii=False)}\n\n" + yield "data: [DONE]\n\n" + + logger.info(f"OpenAI API Response (GPU {worker.gpu_id}): {completion_tokens} tokens") + finally: + await worker_pool.release_worker(worker) + + return StreamingResponse( + stream_response(), + media_type="text/event-stream" + ) + else: + # Non-streaming response + try: + completion_text = "" + completion_tokens = 0 + + async for text_chunk, token_count in generate_stream( + worker, + conversation_tokens, + temperature=request.temperature, + max_new_tokens=request.max_tokens, + top_k=request.top_k + ): + completion_text += text_chunk + completion_tokens = token_count + + response = ChatCompletionResponse( + id=request_id, + created=created, + model=request.model, + choices=[ + ChatCompletionResponseChoice( + index=0, + message=ChatMessage(role="assistant", content=completion_text), + finish_reason="stop" + ) + ], + usage=Usage( + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=prompt_tokens + completion_tokens + ) + ) + + logger.info(f"OpenAI API Response (GPU {worker.gpu_id}): {completion_tokens} tokens") + return response + finally: + await worker_pool.release_worker(worker) + + except Exception as e: + await worker_pool.release_worker(worker) + logger.error(f"Error: {str(e)}") + raise HTTPException(status_code=500, detail=str(e)) + +@app.get("/v1/models") +async def list_models() -> ModelList: + """List available models (OpenAI-compatible endpoint).""" + return ModelList( + data=[ + Model( + id="nanochat", + created=int(time.time()), + owned_by="nanochat" + ) + ] + ) + @app.get("/health") async def health(): """Health check endpoint.""" @@ -410,6 +638,6 @@ async def stats(): if __name__ == "__main__": import uvicorn - print(f"Starting NanoChat Web Server") + print("Starting NanoChat Web Server") print(f"Temperature: {args.temperature}, Top-k: {args.top_k}, Max tokens: {args.max_tokens}") - uvicorn.run(app, host=args.host, port=args.port) + uvicorn.run(app, host=args.host, port=args.port, timeout_keep_alive=75)