mirror of
https://github.com/karpathy/nanochat.git
synced 2026-06-19 12:39:10 +00:00
Merge 9a9b12b1be into 8b4849d548
This commit is contained in:
commit
021341e20e
|
|
@ -16,10 +16,41 @@ python -m scripts.chat_web --num-gpus 4
|
||||||
To chat, open the URL printed in the console. (If on cloud box, make sure to use public IP)
|
To chat, open the URL printed in the console. (If on cloud box, make sure to use public IP)
|
||||||
|
|
||||||
Endpoints:
|
Endpoints:
|
||||||
GET / - Chat UI
|
GET / - Chat UI
|
||||||
POST /chat/completions - Chat API (streaming only)
|
POST /chat/completions - Chat API (streaming only, custom format)
|
||||||
GET /health - Health check with worker pool status
|
POST /v1/chat/completions - OpenAI-compatible chat completions (streaming and non-streaming)
|
||||||
GET /stats - Worker pool statistics and GPU utilization
|
GET /v1/models - List available models (OpenAI-compatible)
|
||||||
|
GET /health - Health check with worker pool status
|
||||||
|
GET /stats - Worker pool statistics and GPU utilization
|
||||||
|
|
||||||
|
OpenAI API Compatibility:
|
||||||
|
- Supports both streaming and non-streaming responses
|
||||||
|
- Compatible with OpenAI Python SDK and other OpenAI-compatible clients
|
||||||
|
- Implements standard OpenAI request/response format
|
||||||
|
|
||||||
|
Example usage with OpenAI SDK:
|
||||||
|
```python
|
||||||
|
from openai import OpenAI
|
||||||
|
|
||||||
|
client = OpenAI(
|
||||||
|
api_key="not-needed",
|
||||||
|
base_url="http://localhost:8000/v1"
|
||||||
|
)
|
||||||
|
|
||||||
|
response = client.chat.completions.create(
|
||||||
|
model="nanochat",
|
||||||
|
messages=[
|
||||||
|
{"role": "user", "content": "Hello!"}
|
||||||
|
],
|
||||||
|
temperature=0.8,
|
||||||
|
max_tokens=512,
|
||||||
|
stream=True
|
||||||
|
)
|
||||||
|
|
||||||
|
for chunk in response:
|
||||||
|
if chunk.choices[0].delta.content:
|
||||||
|
print(chunk.choices[0].delta.content, end="", flush=True)
|
||||||
|
```
|
||||||
|
|
||||||
Abuse Prevention:
|
Abuse Prevention:
|
||||||
- Maximum 500 messages per request
|
- Maximum 500 messages per request
|
||||||
|
|
@ -37,12 +68,13 @@ import torch
|
||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
import random
|
import random
|
||||||
|
import time
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
from fastapi import FastAPI, HTTPException
|
from fastapi import FastAPI, HTTPException
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
from fastapi.responses import StreamingResponse, HTMLResponse, FileResponse
|
from fastapi.responses import StreamingResponse, HTMLResponse, FileResponse
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel, Field
|
||||||
from typing import List, Optional, AsyncGenerator
|
from typing import List, Optional, AsyncGenerator, Literal
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from contextlib import nullcontext
|
from contextlib import nullcontext
|
||||||
from nanochat.common import compute_init, autodetect_device_type
|
from nanochat.common import compute_init, autodetect_device_type
|
||||||
|
|
@ -156,6 +188,50 @@ class ChatRequest(BaseModel):
|
||||||
temperature: Optional[float] = None
|
temperature: Optional[float] = None
|
||||||
max_tokens: Optional[int] = None
|
max_tokens: Optional[int] = None
|
||||||
top_k: Optional[int] = None
|
top_k: Optional[int] = None
|
||||||
|
model: str = Field(default="nanochat") # For OpenAI compatibility
|
||||||
|
stream: Optional[bool] = False # For OpenAI compatibility
|
||||||
|
top_p: Optional[float] = None # Ignored, for compatibility
|
||||||
|
|
||||||
|
# OpenAI API Models
|
||||||
|
class ChatCompletionResponseChoice(BaseModel):
|
||||||
|
index: int
|
||||||
|
message: ChatMessage
|
||||||
|
finish_reason: str
|
||||||
|
|
||||||
|
class ChatCompletionResponseStreamChoice(BaseModel):
|
||||||
|
index: int
|
||||||
|
delta: dict
|
||||||
|
finish_reason: Optional[str] = None
|
||||||
|
|
||||||
|
class Usage(BaseModel):
|
||||||
|
prompt_tokens: int
|
||||||
|
completion_tokens: int
|
||||||
|
total_tokens: int
|
||||||
|
|
||||||
|
class ChatCompletionResponse(BaseModel):
|
||||||
|
id: str
|
||||||
|
object: Literal["chat.completion"] = "chat.completion"
|
||||||
|
created: int
|
||||||
|
model: str
|
||||||
|
choices: List[ChatCompletionResponseChoice]
|
||||||
|
usage: Usage
|
||||||
|
|
||||||
|
class ChatCompletionStreamResponse(BaseModel):
|
||||||
|
id: str
|
||||||
|
object: Literal["chat.completion.chunk"] = "chat.completion.chunk"
|
||||||
|
created: int
|
||||||
|
model: str
|
||||||
|
choices: List[ChatCompletionResponseStreamChoice]
|
||||||
|
|
||||||
|
class Model(BaseModel):
|
||||||
|
id: str
|
||||||
|
object: Literal["model"] = "model"
|
||||||
|
created: int
|
||||||
|
owned_by: str
|
||||||
|
|
||||||
|
class ModelList(BaseModel):
|
||||||
|
object: Literal["list"] = "list"
|
||||||
|
data: List[Model]
|
||||||
|
|
||||||
def validate_chat_request(request: ChatRequest):
|
def validate_chat_request(request: ChatRequest):
|
||||||
"""Validate chat request to prevent abuse."""
|
"""Validate chat request to prevent abuse."""
|
||||||
|
|
@ -189,11 +265,12 @@ def validate_chat_request(request: ChatRequest):
|
||||||
)
|
)
|
||||||
|
|
||||||
# Validate role values
|
# Validate role values
|
||||||
|
valid_roles = ["user", "assistant", "system"]
|
||||||
for i, message in enumerate(request.messages):
|
for i, message in enumerate(request.messages):
|
||||||
if message.role not in ["user", "assistant"]:
|
if message.role not in valid_roles:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=400,
|
status_code=400,
|
||||||
detail=f"Message {i} has invalid role. Must be 'user', 'assistant', or 'system'"
|
detail=f"Message {i} has invalid role '{message.role}'. Must be one of: {', '.join(valid_roles)}"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Validate temperature
|
# Validate temperature
|
||||||
|
|
@ -265,8 +342,8 @@ async def generate_stream(
|
||||||
temperature=None,
|
temperature=None,
|
||||||
max_new_tokens=None,
|
max_new_tokens=None,
|
||||||
top_k=None
|
top_k=None
|
||||||
) -> AsyncGenerator[str, None]:
|
) -> AsyncGenerator[tuple[str, int], None]:
|
||||||
"""Generate assistant response with streaming."""
|
"""Generate assistant response with streaming. Returns (text, token_count) tuples."""
|
||||||
temperature = temperature if temperature is not None else args.temperature
|
temperature = temperature if temperature is not None else args.temperature
|
||||||
max_new_tokens = max_new_tokens if max_new_tokens is not None else args.max_tokens
|
max_new_tokens = max_new_tokens if max_new_tokens is not None else args.max_tokens
|
||||||
top_k = top_k if top_k is not None else args.top_k
|
top_k = top_k if top_k is not None else args.top_k
|
||||||
|
|
@ -278,6 +355,7 @@ async def generate_stream(
|
||||||
accumulated_tokens = []
|
accumulated_tokens = []
|
||||||
# Track the last complete UTF-8 string (without replacement characters)
|
# Track the last complete UTF-8 string (without replacement characters)
|
||||||
last_clean_text = ""
|
last_clean_text = ""
|
||||||
|
completion_tokens = 0
|
||||||
|
|
||||||
with worker.autocast_ctx:
|
with worker.autocast_ctx:
|
||||||
for token_column, token_masks in worker.engine.generate(
|
for token_column, token_masks in worker.engine.generate(
|
||||||
|
|
@ -296,6 +374,7 @@ async def generate_stream(
|
||||||
|
|
||||||
# Append the token to sequence
|
# Append the token to sequence
|
||||||
accumulated_tokens.append(token)
|
accumulated_tokens.append(token)
|
||||||
|
completion_tokens += 1
|
||||||
# Decode all accumulated tokens to get proper UTF-8 handling
|
# Decode all accumulated tokens to get proper UTF-8 handling
|
||||||
# Note that decode is a quite efficient operation, basically table lookup and string concat
|
# Note that decode is a quite efficient operation, basically table lookup and string concat
|
||||||
current_text = worker.tokenizer.decode(accumulated_tokens)
|
current_text = worker.tokenizer.decode(accumulated_tokens)
|
||||||
|
|
@ -305,11 +384,9 @@ async def generate_stream(
|
||||||
# Extract only the new text since last clean decode
|
# Extract only the new text since last clean decode
|
||||||
new_text = current_text[len(last_clean_text):]
|
new_text = current_text[len(last_clean_text):]
|
||||||
if new_text: # Only yield if there's new content
|
if new_text: # Only yield if there's new content
|
||||||
yield f"data: {json.dumps({'token': new_text, 'gpu': worker.gpu_id}, ensure_ascii=False)}\n\n"
|
yield (new_text, completion_tokens)
|
||||||
last_clean_text = current_text
|
last_clean_text = current_text
|
||||||
|
|
||||||
yield f"data: {json.dumps({'done': True})}\n\n"
|
|
||||||
|
|
||||||
@app.post("/chat/completions")
|
@app.post("/chat/completions")
|
||||||
async def chat_completions(request: ChatRequest):
|
async def chat_completions(request: ChatRequest):
|
||||||
"""Chat completion endpoint (streaming only) - uses worker pool for multi-GPU."""
|
"""Chat completion endpoint (streaming only) - uses worker pool for multi-GPU."""
|
||||||
|
|
@ -352,7 +429,7 @@ async def chat_completions(request: ChatRequest):
|
||||||
response_tokens = []
|
response_tokens = []
|
||||||
async def stream_and_release():
|
async def stream_and_release():
|
||||||
try:
|
try:
|
||||||
async for chunk in generate_stream(
|
async for text_chunk, token_count in generate_stream(
|
||||||
worker,
|
worker,
|
||||||
conversation_tokens,
|
conversation_tokens,
|
||||||
temperature=request.temperature,
|
temperature=request.temperature,
|
||||||
|
|
@ -360,10 +437,11 @@ async def chat_completions(request: ChatRequest):
|
||||||
top_k=request.top_k
|
top_k=request.top_k
|
||||||
):
|
):
|
||||||
# Accumulate response for logging
|
# Accumulate response for logging
|
||||||
chunk_data = json.loads(chunk.replace("data: ", "").strip())
|
response_tokens.append(text_chunk)
|
||||||
if "token" in chunk_data:
|
# Format as custom JSON for web UI
|
||||||
response_tokens.append(chunk_data["token"])
|
yield f"data: {json.dumps({'token': text_chunk, 'gpu': worker.gpu_id}, ensure_ascii=False)}\n\n"
|
||||||
yield chunk
|
# Send done message
|
||||||
|
yield f"data: {json.dumps({'done': True})}\n\n"
|
||||||
finally:
|
finally:
|
||||||
# Log the assistant response to console
|
# Log the assistant response to console
|
||||||
full_response = "".join(response_tokens)
|
full_response = "".join(response_tokens)
|
||||||
|
|
@ -381,6 +459,160 @@ async def chat_completions(request: ChatRequest):
|
||||||
await worker_pool.release_worker(worker)
|
await worker_pool.release_worker(worker)
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
|
@app.post("/v1/chat/completions")
|
||||||
|
async def openai_chat_completions(request: ChatRequest):
|
||||||
|
"""OpenAI-compatible chat completion endpoint (supports streaming and non-streaming).
|
||||||
|
|
||||||
|
Note: System role messages are accepted but ignored as nanochat doesn't support system prompts.
|
||||||
|
"""
|
||||||
|
|
||||||
|
validate_chat_request(request)
|
||||||
|
|
||||||
|
# Log incoming request
|
||||||
|
logger.info(f"OpenAI API Request: model={request.model}, stream={request.stream}, messages={len(request.messages)}")
|
||||||
|
|
||||||
|
worker_pool = app.state.worker_pool
|
||||||
|
worker = await worker_pool.acquire_worker()
|
||||||
|
|
||||||
|
request_id = f"chatcmpl-{random.randint(1000000, 9999999)}"
|
||||||
|
created = int(time.time())
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Build conversation tokens
|
||||||
|
bos = worker.tokenizer.get_bos_token_id()
|
||||||
|
user_start = worker.tokenizer.encode_special("<|user_start|>")
|
||||||
|
user_end = worker.tokenizer.encode_special("<|user_end|>")
|
||||||
|
assistant_start = worker.tokenizer.encode_special("<|assistant_start|>")
|
||||||
|
assistant_end = worker.tokenizer.encode_special("<|assistant_end|>")
|
||||||
|
|
||||||
|
conversation_tokens = [bos]
|
||||||
|
for message in request.messages:
|
||||||
|
if message.role == "user":
|
||||||
|
conversation_tokens.append(user_start)
|
||||||
|
conversation_tokens.extend(worker.tokenizer.encode(message.content))
|
||||||
|
conversation_tokens.append(user_end)
|
||||||
|
elif message.role == "assistant":
|
||||||
|
conversation_tokens.append(assistant_start)
|
||||||
|
conversation_tokens.extend(worker.tokenizer.encode(message.content))
|
||||||
|
conversation_tokens.append(assistant_end)
|
||||||
|
elif message.role == "system":
|
||||||
|
# Note: system messages are ignored as nanochat doesn't have system role support
|
||||||
|
logger.warning(f"System role message ignored (not supported by this model): {message.content[:100]}...")
|
||||||
|
|
||||||
|
conversation_tokens.append(assistant_start)
|
||||||
|
prompt_tokens = len(conversation_tokens)
|
||||||
|
|
||||||
|
if request.stream:
|
||||||
|
# Streaming response
|
||||||
|
async def stream_response():
|
||||||
|
try:
|
||||||
|
completion_text = ""
|
||||||
|
completion_tokens = 0
|
||||||
|
|
||||||
|
async for text_chunk, token_count in generate_stream(
|
||||||
|
worker,
|
||||||
|
conversation_tokens,
|
||||||
|
temperature=request.temperature,
|
||||||
|
max_new_tokens=request.max_tokens,
|
||||||
|
top_k=request.top_k
|
||||||
|
):
|
||||||
|
completion_text += text_chunk
|
||||||
|
completion_tokens = token_count
|
||||||
|
|
||||||
|
chunk = ChatCompletionStreamResponse(
|
||||||
|
id=request_id,
|
||||||
|
created=created,
|
||||||
|
model=request.model,
|
||||||
|
choices=[
|
||||||
|
ChatCompletionResponseStreamChoice(
|
||||||
|
index=0,
|
||||||
|
delta={"content": text_chunk},
|
||||||
|
finish_reason=None
|
||||||
|
)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
yield f"data: {json.dumps(chunk.model_dump(), ensure_ascii=False)}\n\n"
|
||||||
|
|
||||||
|
# Final chunk with finish_reason
|
||||||
|
final_chunk = ChatCompletionStreamResponse(
|
||||||
|
id=request_id,
|
||||||
|
created=created,
|
||||||
|
model=request.model,
|
||||||
|
choices=[
|
||||||
|
ChatCompletionResponseStreamChoice(
|
||||||
|
index=0,
|
||||||
|
delta={},
|
||||||
|
finish_reason="stop"
|
||||||
|
)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
yield f"data: {json.dumps(final_chunk.model_dump(), ensure_ascii=False)}\n\n"
|
||||||
|
yield "data: [DONE]\n\n"
|
||||||
|
|
||||||
|
logger.info(f"OpenAI API Response (GPU {worker.gpu_id}): {completion_tokens} tokens")
|
||||||
|
finally:
|
||||||
|
await worker_pool.release_worker(worker)
|
||||||
|
|
||||||
|
return StreamingResponse(
|
||||||
|
stream_response(),
|
||||||
|
media_type="text/event-stream"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Non-streaming response
|
||||||
|
try:
|
||||||
|
completion_text = ""
|
||||||
|
completion_tokens = 0
|
||||||
|
|
||||||
|
async for text_chunk, token_count in generate_stream(
|
||||||
|
worker,
|
||||||
|
conversation_tokens,
|
||||||
|
temperature=request.temperature,
|
||||||
|
max_new_tokens=request.max_tokens,
|
||||||
|
top_k=request.top_k
|
||||||
|
):
|
||||||
|
completion_text += text_chunk
|
||||||
|
completion_tokens = token_count
|
||||||
|
|
||||||
|
response = ChatCompletionResponse(
|
||||||
|
id=request_id,
|
||||||
|
created=created,
|
||||||
|
model=request.model,
|
||||||
|
choices=[
|
||||||
|
ChatCompletionResponseChoice(
|
||||||
|
index=0,
|
||||||
|
message=ChatMessage(role="assistant", content=completion_text),
|
||||||
|
finish_reason="stop"
|
||||||
|
)
|
||||||
|
],
|
||||||
|
usage=Usage(
|
||||||
|
prompt_tokens=prompt_tokens,
|
||||||
|
completion_tokens=completion_tokens,
|
||||||
|
total_tokens=prompt_tokens + completion_tokens
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"OpenAI API Response (GPU {worker.gpu_id}): {completion_tokens} tokens")
|
||||||
|
return response
|
||||||
|
finally:
|
||||||
|
await worker_pool.release_worker(worker)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error: {str(e)}")
|
||||||
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
|
@app.get("/v1/models")
|
||||||
|
async def list_models() -> ModelList:
|
||||||
|
"""List available models (OpenAI-compatible endpoint)."""
|
||||||
|
return ModelList(
|
||||||
|
data=[
|
||||||
|
Model(
|
||||||
|
id="nanochat",
|
||||||
|
created=int(time.time()),
|
||||||
|
owned_by="nanochat"
|
||||||
|
)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
@app.get("/health")
|
@app.get("/health")
|
||||||
async def health():
|
async def health():
|
||||||
"""Health check endpoint."""
|
"""Health check endpoint."""
|
||||||
|
|
@ -410,6 +642,6 @@ async def stats():
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
import uvicorn
|
import uvicorn
|
||||||
print(f"Starting NanoChat Web Server")
|
print("Starting NanoChat Web Server")
|
||||||
print(f"Temperature: {args.temperature}, Top-k: {args.top_k}, Max tokens: {args.max_tokens}")
|
print(f"Temperature: {args.temperature}, Top-k: {args.top_k}, Max tokens: {args.max_tokens}")
|
||||||
uvicorn.run(app, host=args.host, port=args.port)
|
uvicorn.run(app, host=args.host, port=args.port)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user