mirror of
https://github.com/karpathy/nanochat.git
synced 2026-02-07 11:09:55 +00:00
Merge 9a9b12b1be into 50413d2d67
This commit is contained in:
commit
48a4d691fb
|
|
@ -16,10 +16,41 @@ python -m scripts.chat_web --num-gpus 4
|
|||
To chat, open the URL printed in the console. (If on cloud box, make sure to use public IP)
|
||||
|
||||
Endpoints:
|
||||
GET / - Chat UI
|
||||
POST /chat/completions - Chat API (streaming only)
|
||||
GET /health - Health check with worker pool status
|
||||
GET /stats - Worker pool statistics and GPU utilization
|
||||
GET / - Chat UI
|
||||
POST /chat/completions - Chat API (streaming only, custom format)
|
||||
POST /v1/chat/completions - OpenAI-compatible chat completions (streaming and non-streaming)
|
||||
GET /v1/models - List available models (OpenAI-compatible)
|
||||
GET /health - Health check with worker pool status
|
||||
GET /stats - Worker pool statistics and GPU utilization
|
||||
|
||||
OpenAI API Compatibility:
|
||||
- Supports both streaming and non-streaming responses
|
||||
- Compatible with OpenAI Python SDK and other OpenAI-compatible clients
|
||||
- Implements standard OpenAI request/response format
|
||||
|
||||
Example usage with OpenAI SDK:
|
||||
```python
|
||||
from openai import OpenAI
|
||||
|
||||
client = OpenAI(
|
||||
api_key="not-needed",
|
||||
base_url="http://localhost:8000/v1"
|
||||
)
|
||||
|
||||
response = client.chat.completions.create(
|
||||
model="nanochat",
|
||||
messages=[
|
||||
{"role": "user", "content": "Hello!"}
|
||||
],
|
||||
temperature=0.8,
|
||||
max_tokens=512,
|
||||
stream=True
|
||||
)
|
||||
|
||||
for chunk in response:
|
||||
if chunk.choices[0].delta.content:
|
||||
print(chunk.choices[0].delta.content, end="", flush=True)
|
||||
```
|
||||
|
||||
Abuse Prevention:
|
||||
- Maximum 500 messages per request
|
||||
|
|
@ -37,12 +68,13 @@ import torch
|
|||
import asyncio
|
||||
import logging
|
||||
import random
|
||||
import time
|
||||
from contextlib import asynccontextmanager
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import StreamingResponse, HTMLResponse, FileResponse
|
||||
from pydantic import BaseModel
|
||||
from typing import List, Optional, AsyncGenerator
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import List, Optional, AsyncGenerator, Literal
|
||||
from dataclasses import dataclass
|
||||
from contextlib import nullcontext
|
||||
from nanochat.common import compute_init, autodetect_device_type
|
||||
|
|
@ -156,6 +188,50 @@ class ChatRequest(BaseModel):
|
|||
temperature: Optional[float] = None
|
||||
max_tokens: Optional[int] = None
|
||||
top_k: Optional[int] = None
|
||||
model: str = Field(default="nanochat") # For OpenAI compatibility
|
||||
stream: Optional[bool] = False # For OpenAI compatibility
|
||||
top_p: Optional[float] = None # Ignored, for compatibility
|
||||
|
||||
# OpenAI API Models
|
||||
class ChatCompletionResponseChoice(BaseModel):
|
||||
index: int
|
||||
message: ChatMessage
|
||||
finish_reason: str
|
||||
|
||||
class ChatCompletionResponseStreamChoice(BaseModel):
|
||||
index: int
|
||||
delta: dict
|
||||
finish_reason: Optional[str] = None
|
||||
|
||||
class Usage(BaseModel):
|
||||
prompt_tokens: int
|
||||
completion_tokens: int
|
||||
total_tokens: int
|
||||
|
||||
class ChatCompletionResponse(BaseModel):
|
||||
id: str
|
||||
object: Literal["chat.completion"] = "chat.completion"
|
||||
created: int
|
||||
model: str
|
||||
choices: List[ChatCompletionResponseChoice]
|
||||
usage: Usage
|
||||
|
||||
class ChatCompletionStreamResponse(BaseModel):
|
||||
id: str
|
||||
object: Literal["chat.completion.chunk"] = "chat.completion.chunk"
|
||||
created: int
|
||||
model: str
|
||||
choices: List[ChatCompletionResponseStreamChoice]
|
||||
|
||||
class Model(BaseModel):
|
||||
id: str
|
||||
object: Literal["model"] = "model"
|
||||
created: int
|
||||
owned_by: str
|
||||
|
||||
class ModelList(BaseModel):
|
||||
object: Literal["list"] = "list"
|
||||
data: List[Model]
|
||||
|
||||
def validate_chat_request(request: ChatRequest):
|
||||
"""Validate chat request to prevent abuse."""
|
||||
|
|
@ -189,11 +265,12 @@ def validate_chat_request(request: ChatRequest):
|
|||
)
|
||||
|
||||
# Validate role values
|
||||
valid_roles = ["user", "assistant", "system"]
|
||||
for i, message in enumerate(request.messages):
|
||||
if message.role not in ["user", "assistant"]:
|
||||
if message.role not in valid_roles:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Message {i} has invalid role. Must be 'user', 'assistant', or 'system'"
|
||||
detail=f"Message {i} has invalid role '{message.role}'. Must be one of: {', '.join(valid_roles)}"
|
||||
)
|
||||
|
||||
# Validate temperature
|
||||
|
|
@ -265,8 +342,8 @@ async def generate_stream(
|
|||
temperature=None,
|
||||
max_new_tokens=None,
|
||||
top_k=None
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""Generate assistant response with streaming."""
|
||||
) -> AsyncGenerator[tuple[str, int], None]:
|
||||
"""Generate assistant response with streaming. Returns (text, token_count) tuples."""
|
||||
temperature = temperature if temperature is not None else args.temperature
|
||||
max_new_tokens = max_new_tokens if max_new_tokens is not None else args.max_tokens
|
||||
top_k = top_k if top_k is not None else args.top_k
|
||||
|
|
@ -278,6 +355,7 @@ async def generate_stream(
|
|||
accumulated_tokens = []
|
||||
# Track the last complete UTF-8 string (without replacement characters)
|
||||
last_clean_text = ""
|
||||
completion_tokens = 0
|
||||
|
||||
with worker.autocast_ctx:
|
||||
for token_column, token_masks in worker.engine.generate(
|
||||
|
|
@ -296,6 +374,7 @@ async def generate_stream(
|
|||
|
||||
# Append the token to sequence
|
||||
accumulated_tokens.append(token)
|
||||
completion_tokens += 1
|
||||
# Decode all accumulated tokens to get proper UTF-8 handling
|
||||
# Note that decode is a quite efficient operation, basically table lookup and string concat
|
||||
current_text = worker.tokenizer.decode(accumulated_tokens)
|
||||
|
|
@ -305,11 +384,9 @@ async def generate_stream(
|
|||
# Extract only the new text since last clean decode
|
||||
new_text = current_text[len(last_clean_text):]
|
||||
if new_text: # Only yield if there's new content
|
||||
yield f"data: {json.dumps({'token': new_text, 'gpu': worker.gpu_id}, ensure_ascii=False)}\n\n"
|
||||
yield (new_text, completion_tokens)
|
||||
last_clean_text = current_text
|
||||
|
||||
yield f"data: {json.dumps({'done': True})}\n\n"
|
||||
|
||||
@app.post("/chat/completions")
|
||||
async def chat_completions(request: ChatRequest):
|
||||
"""Chat completion endpoint (streaming only) - uses worker pool for multi-GPU."""
|
||||
|
|
@ -352,7 +429,7 @@ async def chat_completions(request: ChatRequest):
|
|||
response_tokens = []
|
||||
async def stream_and_release():
|
||||
try:
|
||||
async for chunk in generate_stream(
|
||||
async for text_chunk, token_count in generate_stream(
|
||||
worker,
|
||||
conversation_tokens,
|
||||
temperature=request.temperature,
|
||||
|
|
@ -360,10 +437,11 @@ async def chat_completions(request: ChatRequest):
|
|||
top_k=request.top_k
|
||||
):
|
||||
# Accumulate response for logging
|
||||
chunk_data = json.loads(chunk.replace("data: ", "").strip())
|
||||
if "token" in chunk_data:
|
||||
response_tokens.append(chunk_data["token"])
|
||||
yield chunk
|
||||
response_tokens.append(text_chunk)
|
||||
# Format as custom JSON for web UI
|
||||
yield f"data: {json.dumps({'token': text_chunk, 'gpu': worker.gpu_id}, ensure_ascii=False)}\n\n"
|
||||
# Send done message
|
||||
yield f"data: {json.dumps({'done': True})}\n\n"
|
||||
finally:
|
||||
# Log the assistant response to console
|
||||
full_response = "".join(response_tokens)
|
||||
|
|
@ -381,6 +459,160 @@ async def chat_completions(request: ChatRequest):
|
|||
await worker_pool.release_worker(worker)
|
||||
raise e
|
||||
|
||||
@app.post("/v1/chat/completions")
|
||||
async def openai_chat_completions(request: ChatRequest):
|
||||
"""OpenAI-compatible chat completion endpoint (supports streaming and non-streaming).
|
||||
|
||||
Note: System role messages are accepted but ignored as nanochat doesn't support system prompts.
|
||||
"""
|
||||
|
||||
validate_chat_request(request)
|
||||
|
||||
# Log incoming request
|
||||
logger.info(f"OpenAI API Request: model={request.model}, stream={request.stream}, messages={len(request.messages)}")
|
||||
|
||||
worker_pool = app.state.worker_pool
|
||||
worker = await worker_pool.acquire_worker()
|
||||
|
||||
request_id = f"chatcmpl-{random.randint(1000000, 9999999)}"
|
||||
created = int(time.time())
|
||||
|
||||
try:
|
||||
# Build conversation tokens
|
||||
bos = worker.tokenizer.get_bos_token_id()
|
||||
user_start = worker.tokenizer.encode_special("<|user_start|>")
|
||||
user_end = worker.tokenizer.encode_special("<|user_end|>")
|
||||
assistant_start = worker.tokenizer.encode_special("<|assistant_start|>")
|
||||
assistant_end = worker.tokenizer.encode_special("<|assistant_end|>")
|
||||
|
||||
conversation_tokens = [bos]
|
||||
for message in request.messages:
|
||||
if message.role == "user":
|
||||
conversation_tokens.append(user_start)
|
||||
conversation_tokens.extend(worker.tokenizer.encode(message.content))
|
||||
conversation_tokens.append(user_end)
|
||||
elif message.role == "assistant":
|
||||
conversation_tokens.append(assistant_start)
|
||||
conversation_tokens.extend(worker.tokenizer.encode(message.content))
|
||||
conversation_tokens.append(assistant_end)
|
||||
elif message.role == "system":
|
||||
# Note: system messages are ignored as nanochat doesn't have system role support
|
||||
logger.warning(f"System role message ignored (not supported by this model): {message.content[:100]}...")
|
||||
|
||||
conversation_tokens.append(assistant_start)
|
||||
prompt_tokens = len(conversation_tokens)
|
||||
|
||||
if request.stream:
|
||||
# Streaming response
|
||||
async def stream_response():
|
||||
try:
|
||||
completion_text = ""
|
||||
completion_tokens = 0
|
||||
|
||||
async for text_chunk, token_count in generate_stream(
|
||||
worker,
|
||||
conversation_tokens,
|
||||
temperature=request.temperature,
|
||||
max_new_tokens=request.max_tokens,
|
||||
top_k=request.top_k
|
||||
):
|
||||
completion_text += text_chunk
|
||||
completion_tokens = token_count
|
||||
|
||||
chunk = ChatCompletionStreamResponse(
|
||||
id=request_id,
|
||||
created=created,
|
||||
model=request.model,
|
||||
choices=[
|
||||
ChatCompletionResponseStreamChoice(
|
||||
index=0,
|
||||
delta={"content": text_chunk},
|
||||
finish_reason=None
|
||||
)
|
||||
]
|
||||
)
|
||||
yield f"data: {json.dumps(chunk.model_dump(), ensure_ascii=False)}\n\n"
|
||||
|
||||
# Final chunk with finish_reason
|
||||
final_chunk = ChatCompletionStreamResponse(
|
||||
id=request_id,
|
||||
created=created,
|
||||
model=request.model,
|
||||
choices=[
|
||||
ChatCompletionResponseStreamChoice(
|
||||
index=0,
|
||||
delta={},
|
||||
finish_reason="stop"
|
||||
)
|
||||
]
|
||||
)
|
||||
yield f"data: {json.dumps(final_chunk.model_dump(), ensure_ascii=False)}\n\n"
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
logger.info(f"OpenAI API Response (GPU {worker.gpu_id}): {completion_tokens} tokens")
|
||||
finally:
|
||||
await worker_pool.release_worker(worker)
|
||||
|
||||
return StreamingResponse(
|
||||
stream_response(),
|
||||
media_type="text/event-stream"
|
||||
)
|
||||
else:
|
||||
# Non-streaming response
|
||||
try:
|
||||
completion_text = ""
|
||||
completion_tokens = 0
|
||||
|
||||
async for text_chunk, token_count in generate_stream(
|
||||
worker,
|
||||
conversation_tokens,
|
||||
temperature=request.temperature,
|
||||
max_new_tokens=request.max_tokens,
|
||||
top_k=request.top_k
|
||||
):
|
||||
completion_text += text_chunk
|
||||
completion_tokens = token_count
|
||||
|
||||
response = ChatCompletionResponse(
|
||||
id=request_id,
|
||||
created=created,
|
||||
model=request.model,
|
||||
choices=[
|
||||
ChatCompletionResponseChoice(
|
||||
index=0,
|
||||
message=ChatMessage(role="assistant", content=completion_text),
|
||||
finish_reason="stop"
|
||||
)
|
||||
],
|
||||
usage=Usage(
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=prompt_tokens + completion_tokens
|
||||
)
|
||||
)
|
||||
|
||||
logger.info(f"OpenAI API Response (GPU {worker.gpu_id}): {completion_tokens} tokens")
|
||||
return response
|
||||
finally:
|
||||
await worker_pool.release_worker(worker)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@app.get("/v1/models")
|
||||
async def list_models() -> ModelList:
|
||||
"""List available models (OpenAI-compatible endpoint)."""
|
||||
return ModelList(
|
||||
data=[
|
||||
Model(
|
||||
id="nanochat",
|
||||
created=int(time.time()),
|
||||
owned_by="nanochat"
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
@app.get("/health")
|
||||
async def health():
|
||||
"""Health check endpoint."""
|
||||
|
|
@ -410,6 +642,6 @@ async def stats():
|
|||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
print(f"Starting NanoChat Web Server")
|
||||
print("Starting NanoChat Web Server")
|
||||
print(f"Temperature: {args.temperature}, Top-k: {args.top_k}, Max tokens: {args.max_tokens}")
|
||||
uvicorn.run(app, host=args.host, port=args.port)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user