This commit is contained in:
Jason Cox 2026-01-16 01:14:19 -08:00 committed by GitHub
commit 48a4d691fb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -16,10 +16,41 @@ python -m scripts.chat_web --num-gpus 4
To chat, open the URL printed in the console. (If on cloud box, make sure to use public IP)
Endpoints:
GET / - Chat UI
POST /chat/completions - Chat API (streaming only)
GET /health - Health check with worker pool status
GET /stats - Worker pool statistics and GPU utilization
GET / - Chat UI
POST /chat/completions - Chat API (streaming only, custom format)
POST /v1/chat/completions - OpenAI-compatible chat completions (streaming and non-streaming)
GET /v1/models - List available models (OpenAI-compatible)
GET /health - Health check with worker pool status
GET /stats - Worker pool statistics and GPU utilization
OpenAI API Compatibility:
- Supports both streaming and non-streaming responses
- Compatible with OpenAI Python SDK and other OpenAI-compatible clients
- Implements standard OpenAI request/response format
Example usage with OpenAI SDK:
```python
from openai import OpenAI
client = OpenAI(
api_key="not-needed",
base_url="http://localhost:8000/v1"
)
response = client.chat.completions.create(
model="nanochat",
messages=[
{"role": "user", "content": "Hello!"}
],
temperature=0.8,
max_tokens=512,
stream=True
)
for chunk in response:
if chunk.choices[0].delta.content:
print(chunk.choices[0].delta.content, end="", flush=True)
```
Abuse Prevention:
- Maximum 500 messages per request
@ -37,12 +68,13 @@ import torch
import asyncio
import logging
import random
import time
from contextlib import asynccontextmanager
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse, HTMLResponse, FileResponse
from pydantic import BaseModel
from typing import List, Optional, AsyncGenerator
from pydantic import BaseModel, Field
from typing import List, Optional, AsyncGenerator, Literal
from dataclasses import dataclass
from contextlib import nullcontext
from nanochat.common import compute_init, autodetect_device_type
@ -156,6 +188,50 @@ class ChatRequest(BaseModel):
temperature: Optional[float] = None
max_tokens: Optional[int] = None
top_k: Optional[int] = None
model: str = Field(default="nanochat") # For OpenAI compatibility
stream: Optional[bool] = False # For OpenAI compatibility
top_p: Optional[float] = None # Ignored, for compatibility
# OpenAI API Models
class ChatCompletionResponseChoice(BaseModel):
index: int
message: ChatMessage
finish_reason: str
class ChatCompletionResponseStreamChoice(BaseModel):
index: int
delta: dict
finish_reason: Optional[str] = None
class Usage(BaseModel):
prompt_tokens: int
completion_tokens: int
total_tokens: int
class ChatCompletionResponse(BaseModel):
id: str
object: Literal["chat.completion"] = "chat.completion"
created: int
model: str
choices: List[ChatCompletionResponseChoice]
usage: Usage
class ChatCompletionStreamResponse(BaseModel):
id: str
object: Literal["chat.completion.chunk"] = "chat.completion.chunk"
created: int
model: str
choices: List[ChatCompletionResponseStreamChoice]
class Model(BaseModel):
id: str
object: Literal["model"] = "model"
created: int
owned_by: str
class ModelList(BaseModel):
object: Literal["list"] = "list"
data: List[Model]
def validate_chat_request(request: ChatRequest):
"""Validate chat request to prevent abuse."""
@ -189,11 +265,12 @@ def validate_chat_request(request: ChatRequest):
)
# Validate role values
valid_roles = ["user", "assistant", "system"]
for i, message in enumerate(request.messages):
if message.role not in ["user", "assistant"]:
if message.role not in valid_roles:
raise HTTPException(
status_code=400,
detail=f"Message {i} has invalid role. Must be 'user', 'assistant', or 'system'"
detail=f"Message {i} has invalid role '{message.role}'. Must be one of: {', '.join(valid_roles)}"
)
# Validate temperature
@ -265,8 +342,8 @@ async def generate_stream(
temperature=None,
max_new_tokens=None,
top_k=None
) -> AsyncGenerator[str, None]:
"""Generate assistant response with streaming."""
) -> AsyncGenerator[tuple[str, int], None]:
"""Generate assistant response with streaming. Returns (text, token_count) tuples."""
temperature = temperature if temperature is not None else args.temperature
max_new_tokens = max_new_tokens if max_new_tokens is not None else args.max_tokens
top_k = top_k if top_k is not None else args.top_k
@ -278,6 +355,7 @@ async def generate_stream(
accumulated_tokens = []
# Track the last complete UTF-8 string (without replacement characters)
last_clean_text = ""
completion_tokens = 0
with worker.autocast_ctx:
for token_column, token_masks in worker.engine.generate(
@ -296,6 +374,7 @@ async def generate_stream(
# Append the token to sequence
accumulated_tokens.append(token)
completion_tokens += 1
# Decode all accumulated tokens to get proper UTF-8 handling
# Note that decode is a quite efficient operation, basically table lookup and string concat
current_text = worker.tokenizer.decode(accumulated_tokens)
@ -305,11 +384,9 @@ async def generate_stream(
# Extract only the new text since last clean decode
new_text = current_text[len(last_clean_text):]
if new_text: # Only yield if there's new content
yield f"data: {json.dumps({'token': new_text, 'gpu': worker.gpu_id}, ensure_ascii=False)}\n\n"
yield (new_text, completion_tokens)
last_clean_text = current_text
yield f"data: {json.dumps({'done': True})}\n\n"
@app.post("/chat/completions")
async def chat_completions(request: ChatRequest):
"""Chat completion endpoint (streaming only) - uses worker pool for multi-GPU."""
@ -352,7 +429,7 @@ async def chat_completions(request: ChatRequest):
response_tokens = []
async def stream_and_release():
try:
async for chunk in generate_stream(
async for text_chunk, token_count in generate_stream(
worker,
conversation_tokens,
temperature=request.temperature,
@ -360,10 +437,11 @@ async def chat_completions(request: ChatRequest):
top_k=request.top_k
):
# Accumulate response for logging
chunk_data = json.loads(chunk.replace("data: ", "").strip())
if "token" in chunk_data:
response_tokens.append(chunk_data["token"])
yield chunk
response_tokens.append(text_chunk)
# Format as custom JSON for web UI
yield f"data: {json.dumps({'token': text_chunk, 'gpu': worker.gpu_id}, ensure_ascii=False)}\n\n"
# Send done message
yield f"data: {json.dumps({'done': True})}\n\n"
finally:
# Log the assistant response to console
full_response = "".join(response_tokens)
@ -381,6 +459,160 @@ async def chat_completions(request: ChatRequest):
await worker_pool.release_worker(worker)
raise e
@app.post("/v1/chat/completions")
async def openai_chat_completions(request: ChatRequest):
"""OpenAI-compatible chat completion endpoint (supports streaming and non-streaming).
Note: System role messages are accepted but ignored as nanochat doesn't support system prompts.
"""
validate_chat_request(request)
# Log incoming request
logger.info(f"OpenAI API Request: model={request.model}, stream={request.stream}, messages={len(request.messages)}")
worker_pool = app.state.worker_pool
worker = await worker_pool.acquire_worker()
request_id = f"chatcmpl-{random.randint(1000000, 9999999)}"
created = int(time.time())
try:
# Build conversation tokens
bos = worker.tokenizer.get_bos_token_id()
user_start = worker.tokenizer.encode_special("<|user_start|>")
user_end = worker.tokenizer.encode_special("<|user_end|>")
assistant_start = worker.tokenizer.encode_special("<|assistant_start|>")
assistant_end = worker.tokenizer.encode_special("<|assistant_end|>")
conversation_tokens = [bos]
for message in request.messages:
if message.role == "user":
conversation_tokens.append(user_start)
conversation_tokens.extend(worker.tokenizer.encode(message.content))
conversation_tokens.append(user_end)
elif message.role == "assistant":
conversation_tokens.append(assistant_start)
conversation_tokens.extend(worker.tokenizer.encode(message.content))
conversation_tokens.append(assistant_end)
elif message.role == "system":
# Note: system messages are ignored as nanochat doesn't have system role support
logger.warning(f"System role message ignored (not supported by this model): {message.content[:100]}...")
conversation_tokens.append(assistant_start)
prompt_tokens = len(conversation_tokens)
if request.stream:
# Streaming response
async def stream_response():
try:
completion_text = ""
completion_tokens = 0
async for text_chunk, token_count in generate_stream(
worker,
conversation_tokens,
temperature=request.temperature,
max_new_tokens=request.max_tokens,
top_k=request.top_k
):
completion_text += text_chunk
completion_tokens = token_count
chunk = ChatCompletionStreamResponse(
id=request_id,
created=created,
model=request.model,
choices=[
ChatCompletionResponseStreamChoice(
index=0,
delta={"content": text_chunk},
finish_reason=None
)
]
)
yield f"data: {json.dumps(chunk.model_dump(), ensure_ascii=False)}\n\n"
# Final chunk with finish_reason
final_chunk = ChatCompletionStreamResponse(
id=request_id,
created=created,
model=request.model,
choices=[
ChatCompletionResponseStreamChoice(
index=0,
delta={},
finish_reason="stop"
)
]
)
yield f"data: {json.dumps(final_chunk.model_dump(), ensure_ascii=False)}\n\n"
yield "data: [DONE]\n\n"
logger.info(f"OpenAI API Response (GPU {worker.gpu_id}): {completion_tokens} tokens")
finally:
await worker_pool.release_worker(worker)
return StreamingResponse(
stream_response(),
media_type="text/event-stream"
)
else:
# Non-streaming response
try:
completion_text = ""
completion_tokens = 0
async for text_chunk, token_count in generate_stream(
worker,
conversation_tokens,
temperature=request.temperature,
max_new_tokens=request.max_tokens,
top_k=request.top_k
):
completion_text += text_chunk
completion_tokens = token_count
response = ChatCompletionResponse(
id=request_id,
created=created,
model=request.model,
choices=[
ChatCompletionResponseChoice(
index=0,
message=ChatMessage(role="assistant", content=completion_text),
finish_reason="stop"
)
],
usage=Usage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens
)
)
logger.info(f"OpenAI API Response (GPU {worker.gpu_id}): {completion_tokens} tokens")
return response
finally:
await worker_pool.release_worker(worker)
except Exception as e:
logger.error(f"Error: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
@app.get("/v1/models")
async def list_models() -> ModelList:
"""List available models (OpenAI-compatible endpoint)."""
return ModelList(
data=[
Model(
id="nanochat",
created=int(time.time()),
owned_by="nanochat"
)
]
)
@app.get("/health")
async def health():
"""Health check endpoint."""
@ -410,6 +642,6 @@ async def stats():
if __name__ == "__main__":
import uvicorn
print(f"Starting NanoChat Web Server")
print("Starting NanoChat Web Server")
print(f"Temperature: {args.temperature}, Top-k: {args.top_k}, Max tokens: {args.max_tokens}")
uvicorn.run(app, host=args.host, port=args.port)