improved error handling for openai sdk

This commit is contained in:
Evgeny Sorokin 2025-10-16 14:29:39 +02:00
parent 28e6b9b9c2
commit 04862cbfea
2 changed files with 340 additions and 27 deletions

View File

@ -28,6 +28,34 @@ python -m scripts.chat_web
And then visit the URL shown. Make sure to access it correctly, e.g. on Lambda use the public IP of the node you're on, followed by the port, so for example [http://209.20.xxx.xxx:8000/](http://209.20.xxx.xxx:8000/), etc. Then talk to your LLM as you'd normally talk to ChatGPT! Get it to write stories or poems. Ask it to tell you who you are to see a hallucination. Ask it why the sky is blue. Or why it's green. The speedrun is a 4e19 FLOPs capability model so it's a bit like talking to a kindergartener :).
### CPU Inference
If you want to run inference on CPU (e.g., on your laptop or a machine without GPU), use the CPU web server:
```bash
python -m scripts.chat_web_cpu --model-dir /tmp/nanochat
```
This script automatically converts the model to float32 and runs inference on CPU. You can then access the web UI at `http://localhost:8000` or use it via the OpenAI-compatible API.
CPU web server (`chat_web_cpu.py`) is compatible with the OpenAI API specification. This means you can use any OpenAI SDK, tool, or framework with your NanoChat models:
```python
from openai import OpenAI
client = OpenAI(
api_key="not_set"
base_url="http://localhost:8000/v1",
)
response = client.chat.completions.create(
model="nanochat",
messages=[{"role": "user", "content": "Hello!"}]
)
print(response.choices[0].message.content)
```
---
<img width="2672" height="1520" alt="image" src="https://github.com/user-attachments/assets/ed39ddf8-2370-437a-bedc-0f39781e76b5" />

View File

@ -11,13 +11,16 @@ import os
import glob
import pickle
import math
import time
import uuid
import torch
from contextlib import asynccontextmanager
from fastapi import FastAPI
from fastapi import FastAPI, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse, HTMLResponse, FileResponse
from pydantic import BaseModel
from typing import List, Optional, AsyncGenerator
from fastapi.responses import StreamingResponse, HTMLResponse, FileResponse, JSONResponse
from fastapi.exceptions import RequestValidationError
from pydantic import BaseModel, Field
from typing import List, Optional, AsyncGenerator, Literal, Union, Dict, Any
from dataclasses import dataclass
import torch.nn as nn
@ -279,16 +282,63 @@ args = parser.parse_args()
device = torch.device("cpu")
# OpenAI-compatible request/response models
class ChatMessage(BaseModel):
role: str
content: str
role: Literal["system", "user", "assistant"]
content: str # Only text content supported
name: Optional[str] = None
class ChatRequest(BaseModel):
class ChatCompletionRequest(BaseModel):
model: str = Field(default="nanochat", description="Model to use for completion")
messages: List[ChatMessage]
temperature: Optional[float] = None
max_tokens: Optional[int] = None
top_k: Optional[int] = None
stream: Optional[bool] = True
# Supported parameters
temperature: Optional[float] = Field(default=None, ge=0, le=2)
max_tokens: Optional[int] = Field(default=None, ge=1)
top_k: Optional[int] = Field(default=None, ge=1, description="Top-k sampling (NanoChat-specific)")
stream: Optional[bool] = False
# Accepted but not supported (will be rejected if provided)
top_p: Optional[float] = Field(default=None, ge=0, le=1)
n: Optional[int] = Field(default=None, ge=1)
stop: Optional[Union[str, List[str]]] = None
presence_penalty: Optional[float] = Field(default=None, ge=-2, le=2)
frequency_penalty: Optional[float] = Field(default=None, ge=-2, le=2)
logit_bias: Optional[Dict[str, float]] = None
user: Optional[str] = None
# Not supported features
tools: Optional[List[Dict[str, Any]]] = None
tool_choice: Optional[Union[str, Dict[str, Any]]] = None
functions: Optional[List[Dict[str, Any]]] = None
function_call: Optional[Union[str, Dict[str, Any]]] = None
class ChatCompletionResponseChoice(BaseModel):
index: int
message: ChatMessage
finish_reason: Optional[Literal["stop", "length", "content_filter"]] = None
class ChatCompletionResponseStreamChoice(BaseModel):
index: int
delta: Dict[str, Any]
finish_reason: Optional[Literal["stop", "length", "content_filter"]] = None
class UsageInfo(BaseModel):
prompt_tokens: int
completion_tokens: int
total_tokens: int
class ChatCompletionResponse(BaseModel):
id: str
object: Literal["chat.completion"] = "chat.completion"
created: int
model: str
choices: List[ChatCompletionResponseChoice]
usage: UsageInfo
class ChatCompletionStreamResponse(BaseModel):
id: str
object: Literal["chat.completion.chunk"] = "chat.completion.chunk"
created: int
model: str
choices: List[ChatCompletionResponseStreamChoice]
@asynccontextmanager
async def lifespan(app: FastAPI):
@ -353,6 +403,55 @@ async def lifespan(app: FastAPI):
app = FastAPI(lifespan=lifespan)
# Custom exception handler for OpenAI-compatible error responses
class OpenAIError(Exception):
"""Custom exception that returns OpenAI-compatible error format."""
def __init__(self, message: str, error_type: str = "invalid_request_error", param: str = None, code: str = None):
self.message = message
self.error_type = error_type
self.param = param
self.code = code
super().__init__(message)
@app.exception_handler(OpenAIError)
async def openai_error_handler(request: Request, exc: OpenAIError):
"""Return errors in OpenAI API format."""
return JSONResponse(
status_code=400,
content={
"error": {
"message": exc.message,
"type": exc.error_type,
"param": exc.param,
"code": exc.code
}
}
)
@app.exception_handler(RequestValidationError)
async def validation_exception_handler(request: Request, exc: RequestValidationError):
"""Handle Pydantic validation errors in OpenAI format."""
errors = exc.errors()
if errors:
first_error = errors[0]
param = ".".join(str(x) for x in first_error.get("loc", []))
message = first_error.get("msg", "Invalid request")
else:
param = None
message = "Invalid request"
return JSONResponse(
status_code=400,
content={
"error": {
"message": message,
"type": "invalid_request_error",
"param": param,
"code": None
}
}
)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
@ -385,39 +484,162 @@ async def generate_stream(
model,
tokenizer,
tokens,
completion_id: str,
model_name: str,
created: int,
temperature=None,
max_new_tokens=None,
top_k=None
) -> AsyncGenerator[str, None]:
"""Generate assistant response with streaming."""
"""Generate assistant response with OpenAI-compatible streaming.
Supported parameters: temperature, max_new_tokens, top_k
"""
temperature = temperature if temperature is not None else args.temperature
# Greedy decoding when temperature <= 0
if temperature is not None and temperature <= 0:
temperature = 1e-6
max_new_tokens = max_new_tokens if max_new_tokens is not None else args.max_tokens
# Enforce max 1000 cap
if max_new_tokens is None:
max_new_tokens = 256
max_new_tokens = max(1, min(1000, int(max_new_tokens)))
top_k = top_k if top_k is not None else args.top_k
if top_k is None:
top_k = 50
vocab_size = getattr(app.state.model.config, 'vocab_size', 50257)
top_k = max(1, min(int(top_k), int(vocab_size)))
assistant_end = tokenizer.encode_special("<|assistant_end|>")
bos = tokenizer.get_bos_token_id()
# Send initial chunk with role
chunk = ChatCompletionStreamResponse(
id=completion_id,
created=created,
model=model_name,
choices=[ChatCompletionResponseStreamChoice(
index=0,
delta={"role": "assistant", "content": ""},
finish_reason=None
)]
)
yield f"data: {chunk.model_dump_json()}\n\n"
finish_reason = "length"
for token in generate_tokens(model, tokens, max_new_tokens, temperature, top_k, device):
if token == assistant_end or token == bos:
finish_reason = "stop"
break
token_text = tokenizer.decode([token])
yield f"data: {json.dumps({'token': token_text})}\n\n"
# Send content chunk
chunk = ChatCompletionStreamResponse(
id=completion_id,
created=created,
model=model_name,
choices=[ChatCompletionResponseStreamChoice(
index=0,
delta={"content": token_text},
finish_reason=None
)]
)
yield f"data: {chunk.model_dump_json()}\n\n"
yield f"data: {json.dumps({'done': True})}\n\n"
# Send final chunk with finish_reason
chunk = ChatCompletionStreamResponse(
id=completion_id,
created=created,
model=model_name,
choices=[ChatCompletionResponseStreamChoice(
index=0,
delta={},
finish_reason=finish_reason
)]
)
yield f"data: {chunk.model_dump_json()}\n\n"
# OpenAI sends [DONE] at the end
yield "data: [DONE]\n\n"
@app.post("/chat/completions")
async def chat_completions(request: ChatRequest):
"""Chat completion endpoint with streaming."""
@app.post("/v1/chat/completions")
async def chat_completions(request: ChatCompletionRequest):
"""
OpenAI-compatible chat completion endpoint.
Supported parameters:
- messages: Array of message objects (text only)
- temperature: Sampling temperature (0-2)
- max_tokens: Maximum tokens to generate
- top_k: Top-k sampling (NanoChat-specific)
- stream: Enable streaming responses
Not supported (rejected with clear errors):
- top_p, n, stop, presence_penalty, frequency_penalty, logit_bias, user
- tools, functions (function calling not supported)
- Multi-modal content (only text messages supported)
"""
model = app.state.model
tokenizer = app.state.tokenizer
# Validate unsupported features
if request.tools or request.tool_choice or request.functions or request.function_call:
raise OpenAIError(
message="Function calling and tools are not supported by this model. Only text completion is available.",
error_type="invalid_request_error",
code="unsupported_feature"
)
# Reject any unsupported standard params if provided
unsupported_fields = []
if request.n is not None:
unsupported_fields.append("n")
if request.top_p is not None:
unsupported_fields.append("top_p")
if request.stop is not None:
unsupported_fields.append("stop")
if request.presence_penalty is not None:
unsupported_fields.append("presence_penalty")
if request.frequency_penalty is not None:
unsupported_fields.append("frequency_penalty")
if request.logit_bias is not None:
unsupported_fields.append("logit_bias")
if request.user is not None:
unsupported_fields.append("user")
if unsupported_fields:
raise OpenAIError(
message=f"Unsupported parameters for this model: {', '.join(unsupported_fields)}. Supported only: messages, temperature, max_tokens, top_k, stream.",
error_type="invalid_request_error",
param=unsupported_fields[0],
code="unsupported_parameter"
)
# Validate messages are text-only
for i, msg in enumerate(request.messages):
if not isinstance(msg.content, str):
raise OpenAIError(
message=f"Message at index {i} contains non-text content. Only text messages are supported.",
error_type="invalid_request_error",
param=f"messages[{i}].content",
code="invalid_message_content"
)
# Generate unique completion ID and timestamp
completion_id = f"chatcmpl-{uuid.uuid4().hex[:24]}"
created = int(time.time())
model_name = request.model
# Build conversation tokens
bos = tokenizer.get_bos_token_id()
user_start = tokenizer.encode_special("<|user_start|>")
user_end = tokenizer.encode_special("<|user_end|>")
assistant_start = tokenizer.encode_special("<|assistant_start|>")
assistant_end = tokenizer.encode_special("<|assistant_end|>")
system_start = tokenizer.encode_special("<|system_start|>")
system_end = tokenizer.encode_special("<|system_end|>")
conversation_tokens = [bos]
for message in request.messages:
@ -429,15 +651,31 @@ async def chat_completions(request: ChatRequest):
conversation_tokens.append(assistant_start)
conversation_tokens.extend(tokenizer.encode(message.content))
conversation_tokens.append(assistant_end)
elif message.role == "system":
# Handle system messages if supported
if system_start != 0 and system_end != 0:
conversation_tokens.append(system_start)
conversation_tokens.extend(tokenizer.encode(message.content))
conversation_tokens.append(system_end)
else:
# Fallback: treat system message as user message
conversation_tokens.append(user_start)
conversation_tokens.extend(tokenizer.encode(message.content))
conversation_tokens.append(user_end)
conversation_tokens.append(assistant_start)
prompt_tokens = len(conversation_tokens)
# Use only supported parameters: temperature, max_tokens, top_k
if request.stream:
return StreamingResponse(
generate_stream(
model,
tokenizer,
conversation_tokens,
completion_id=completion_id,
model_name=model_name,
created=created,
temperature=request.temperature,
max_new_tokens=request.max_tokens,
top_k=request.top_k
@ -447,21 +685,68 @@ async def chat_completions(request: ChatRequest):
else:
# Non-streaming response
temperature = request.temperature if request.temperature is not None else args.temperature
# Enforce max 1000 tokens cap
max_tokens = request.max_tokens if request.max_tokens is not None else args.max_tokens
if max_tokens is None:
max_tokens = 256
max_tokens = max(1, min(1000, int(max_tokens)))
# Validate top_k: 1..vocab_size
top_k = request.top_k if request.top_k is not None else args.top_k
if top_k is None:
top_k = 50
vocab_size = getattr(app.state.model.config, 'vocab_size', 50257)
top_k = max(1, min(int(top_k), int(vocab_size)))
generated_tokens = list(generate_tokens(model, conversation_tokens, max_tokens, temperature, top_k, device))
response_text = tokenizer.decode(generated_tokens)
generated_tokens = []
finish_reason = "length"
return {
"choices": [{
"message": {
"role": "assistant",
"content": response_text
},
"finish_reason": "stop"
}]
}
for token in generate_tokens(model, conversation_tokens, max_tokens, temperature, top_k, device):
if token == assistant_end or token == bos:
finish_reason = "stop"
break
generated_tokens.append(token)
response_text = tokenizer.decode(generated_tokens)
completion_tokens = len(generated_tokens)
return ChatCompletionResponse(
id=completion_id,
created=created,
model=model_name,
choices=[ChatCompletionResponseChoice(
index=0,
message=ChatMessage(role="assistant", content=response_text),
finish_reason=finish_reason
)],
usage=UsageInfo(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens
)
)
@app.get("/v1/models")
@app.get("/models")
async def list_models():
"""
List available models (OpenAI-compatible endpoint).
Returns model information with capabilities annotation.
"""
return {
"object": "list",
"data": [
{
"id": "nanochat",
"object": "model",
"created": int(time.time()),
"owned_by": "nanochat",
"permission": [],
"root": "nanochat",
"parent": None
}
]
}
@app.get("/health")
async def health():