mirror of
https://github.com/karpathy/nanochat.git
synced 2026-05-07 16:30:11 +00:00
Merge pull request #17 from manmohan659/feat/chat-api-service
feat(chat-api): conversation orchestration + SSE streaming proxy (#6)
This commit is contained in:
commit
1e2fc09ca6
|
|
@ -1,9 +1,34 @@
|
|||
FROM python:3.12-slim
|
||||
|
||||
ENV PYTHONDONTWRITEBYTECODE=1 \
|
||||
PYTHONUNBUFFERED=1 \
|
||||
UV_SYSTEM_PYTHON=1 \
|
||||
UV_LINK_MODE=copy
|
||||
|
||||
RUN apt-get update \
|
||||
&& apt-get install -y --no-install-recommends build-essential curl \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
RUN pip install --no-cache-dir uv==0.4.30
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
COPY README.md /app/README.md
|
||||
COPY pyproject.toml README.md /app/
|
||||
|
||||
RUN uv pip install --system --no-cache \
|
||||
"fastapi>=0.117.1" \
|
||||
"uvicorn[standard]>=0.36.0" \
|
||||
"pydantic>=2.8.0" \
|
||||
"pydantic-settings>=2.4.0" \
|
||||
"sqlalchemy[asyncio]>=2.0.36" \
|
||||
"asyncpg>=0.29.0" \
|
||||
"httpx>=0.27.0" \
|
||||
"sse-starlette>=2.1.3" \
|
||||
"structlog>=24.4.0" \
|
||||
"cachetools>=5.5.0"
|
||||
|
||||
COPY src /app/src
|
||||
|
||||
EXPOSE 8002
|
||||
|
||||
CMD ["python", "-m", "http.server", "8002", "--bind", "0.0.0.0"]
|
||||
CMD ["uvicorn", "src.main:app", "--host", "0.0.0.0", "--port", "8002"]
|
||||
|
|
|
|||
|
|
@ -1,7 +1,54 @@
|
|||
# Chat API Service
|
||||
|
||||
Scaffold placeholder for Issue #6.
|
||||
Orchestration layer for samosaChaat conversations. Manages conversation state
|
||||
in PostgreSQL, authenticates every request via the auth service, and proxies
|
||||
streaming inference requests via Server-Sent Events.
|
||||
|
||||
This container is intentionally minimal on the monorepo scaffold branch. It
|
||||
keeps the compose topology stable while the dedicated chat API implementation
|
||||
is developed.
|
||||
## Endpoints
|
||||
|
||||
| Method | Path | Description |
|
||||
| --- | --- | --- |
|
||||
| GET | `/api/health` | Liveness probe (unauthenticated) |
|
||||
| GET | `/api/conversations` | List the authenticated user's conversations, grouped by date |
|
||||
| POST | `/api/conversations` | Create a new conversation |
|
||||
| GET | `/api/conversations/{id}` | Fetch a conversation + full message history |
|
||||
| PUT | `/api/conversations/{id}` | Update the conversation title |
|
||||
| DELETE | `/api/conversations/{id}` | Delete a conversation (cascade deletes messages) |
|
||||
| POST | `/api/conversations/{id}/messages` | Append a user message and stream the assistant response |
|
||||
| POST | `/api/conversations/{id}/regenerate` | Delete the last assistant message and regenerate it |
|
||||
| GET | `/api/models` | Proxy to inference `GET /models` |
|
||||
| POST | `/api/models/swap` | Proxy to inference `POST /models/swap` (admin only) |
|
||||
|
||||
All authenticated endpoints expect `Authorization: Bearer <jwt>`. The chat API
|
||||
validates the token by calling the auth service `POST /auth/validate` with the
|
||||
shared `X-Internal-API-Key` header and caches the result for 5 minutes.
|
||||
|
||||
## Environment
|
||||
|
||||
| Variable | Default | Purpose |
|
||||
| --- | --- | --- |
|
||||
| `DATABASE_URL` | `postgresql+asyncpg://localhost/samosachaat` | PostgreSQL connection string |
|
||||
| `AUTH_SERVICE_URL` | `http://auth:8001` | Base URL of the auth service |
|
||||
| `INFERENCE_SERVICE_URL` | `http://inference:8000` | Base URL of the inference service |
|
||||
| `INTERNAL_API_KEY` | — | Shared key for internal service auth |
|
||||
| `MAX_CONVERSATION_HISTORY` | `50` | Max messages included in each inference call |
|
||||
| `MAX_TOKEN_BUDGET` | `6000` | Character budget proxy for the above |
|
||||
| `FRONTEND_URL` | `http://localhost:3000` | Origin allowed by CORS |
|
||||
| `LOG_LEVEL` | `INFO` | Python log level |
|
||||
|
||||
## Running locally
|
||||
|
||||
```
|
||||
uv pip install -e ".[dev]"
|
||||
uvicorn src.main:app --reload --port 8002
|
||||
```
|
||||
|
||||
## Running tests
|
||||
|
||||
```
|
||||
cd services/chat-api
|
||||
pytest
|
||||
```
|
||||
|
||||
Tests use SQLite + aiosqlite for a throwaway database, respx to mock the auth
|
||||
service, and hand-crafted httpx mocks for the inference SSE stream.
|
||||
|
|
|
|||
38
services/chat-api/pyproject.toml
Normal file
38
services/chat-api/pyproject.toml
Normal file
|
|
@ -0,0 +1,38 @@
|
|||
[project]
|
||||
name = "samosachaat-chat-api"
|
||||
version = "0.1.0"
|
||||
description = "samosaChaat chat API orchestration service"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.12"
|
||||
dependencies = [
|
||||
"fastapi>=0.117.1",
|
||||
"uvicorn[standard]>=0.36.0",
|
||||
"pydantic>=2.8.0",
|
||||
"pydantic-settings>=2.4.0",
|
||||
"sqlalchemy[asyncio]>=2.0.36",
|
||||
"asyncpg>=0.29.0",
|
||||
"httpx>=0.27.0",
|
||||
"sse-starlette>=2.1.3",
|
||||
"structlog>=24.4.0",
|
||||
"cachetools>=5.5.0",
|
||||
]
|
||||
|
||||
[dependency-groups]
|
||||
dev = [
|
||||
"pytest>=8.0.0",
|
||||
"pytest-asyncio>=0.24.0",
|
||||
"aiosqlite>=0.20.0",
|
||||
"respx>=0.21.1",
|
||||
]
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
asyncio_mode = "auto"
|
||||
testpaths = ["src/tests"]
|
||||
python_files = ["test_*.py"]
|
||||
|
||||
[build-system]
|
||||
requires = ["hatchling"]
|
||||
build-backend = "hatchling.build"
|
||||
|
||||
[tool.hatch.build.targets.wheel]
|
||||
packages = ["src"]
|
||||
0
services/chat-api/src/__init__.py
Normal file
0
services/chat-api/src/__init__.py
Normal file
36
services/chat-api/src/config.py
Normal file
36
services/chat-api/src/config.py
Normal file
|
|
@ -0,0 +1,36 @@
|
|||
"""Runtime configuration for the chat API service."""
|
||||
from __future__ import annotations
|
||||
|
||||
from functools import lru_cache
|
||||
|
||||
from pydantic import Field
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
model_config = SettingsConfigDict(env_file=".env", extra="ignore")
|
||||
|
||||
database_url: str = Field(default="postgresql+asyncpg://localhost/samosachaat")
|
||||
|
||||
auth_service_url: str = Field(default="http://auth:8001")
|
||||
inference_service_url: str = Field(default="http://inference:8000")
|
||||
internal_api_key: str = Field(default="")
|
||||
|
||||
max_conversation_history: int = Field(default=50)
|
||||
max_token_budget: int = Field(default=6000)
|
||||
|
||||
auth_cache_ttl_seconds: int = Field(default=300)
|
||||
auth_cache_max_size: int = Field(default=1024)
|
||||
|
||||
inference_default_temperature: float = Field(default=0.8)
|
||||
inference_default_max_tokens: int = Field(default=512)
|
||||
inference_default_top_k: int = Field(default=50)
|
||||
|
||||
frontend_url: str = Field(default="http://localhost:3000")
|
||||
|
||||
log_level: str = Field(default="INFO")
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def get_settings() -> Settings:
|
||||
return Settings()
|
||||
49
services/chat-api/src/database.py
Normal file
49
services/chat-api/src/database.py
Normal file
|
|
@ -0,0 +1,49 @@
|
|||
"""Async SQLAlchemy engine and session factory."""
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import AsyncIterator
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||
from sqlalchemy.orm import DeclarativeBase
|
||||
|
||||
from .config import get_settings
|
||||
|
||||
|
||||
class Base(DeclarativeBase):
|
||||
"""Shared declarative base for all chat-api ORM models."""
|
||||
|
||||
|
||||
_engine = None
|
||||
_session_factory: async_sessionmaker[AsyncSession] | None = None
|
||||
|
||||
|
||||
def _build_engine() -> None:
|
||||
global _engine, _session_factory
|
||||
settings = get_settings()
|
||||
_engine = create_async_engine(settings.database_url, pool_pre_ping=True)
|
||||
_session_factory = async_sessionmaker(_engine, expire_on_commit=False)
|
||||
|
||||
|
||||
def get_engine():
|
||||
if _engine is None:
|
||||
_build_engine()
|
||||
return _engine
|
||||
|
||||
|
||||
def get_session_factory() -> async_sessionmaker[AsyncSession]:
|
||||
if _session_factory is None:
|
||||
_build_engine()
|
||||
assert _session_factory is not None
|
||||
return _session_factory
|
||||
|
||||
|
||||
async def get_session() -> AsyncIterator[AsyncSession]:
|
||||
factory = get_session_factory()
|
||||
async with factory() as session:
|
||||
yield session
|
||||
|
||||
|
||||
def override_session_factory(factory: async_sessionmaker[AsyncSession]) -> None:
|
||||
"""Testing hook: swap the session factory for an in-memory engine."""
|
||||
global _session_factory
|
||||
_session_factory = factory
|
||||
74
services/chat-api/src/logging_setup.py
Normal file
74
services/chat-api/src/logging_setup.py
Normal file
|
|
@ -0,0 +1,74 @@
|
|||
"""Structured JSON logging for the chat API service."""
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import sys
|
||||
import uuid
|
||||
from contextvars import ContextVar
|
||||
|
||||
import structlog
|
||||
|
||||
from .config import get_settings
|
||||
|
||||
_trace_id_ctx: ContextVar[str | None] = ContextVar("trace_id", default=None)
|
||||
_user_id_ctx: ContextVar[str | None] = ContextVar("user_id", default=None)
|
||||
|
||||
|
||||
def set_trace_id(trace_id: str | None) -> None:
|
||||
_trace_id_ctx.set(trace_id)
|
||||
|
||||
|
||||
def set_user_id(user_id: str | None) -> None:
|
||||
_user_id_ctx.set(user_id)
|
||||
|
||||
|
||||
def get_trace_id() -> str | None:
|
||||
return _trace_id_ctx.get()
|
||||
|
||||
|
||||
def get_user_id() -> str | None:
|
||||
return _user_id_ctx.get()
|
||||
|
||||
|
||||
def new_trace_id() -> str:
|
||||
return uuid.uuid4().hex
|
||||
|
||||
|
||||
def _inject_context(_logger, _method, event_dict):
|
||||
event_dict.setdefault("service", "chat-api")
|
||||
trace_id = _trace_id_ctx.get()
|
||||
if trace_id is not None:
|
||||
event_dict.setdefault("trace_id", trace_id)
|
||||
user_id = _user_id_ctx.get()
|
||||
if user_id is not None:
|
||||
event_dict.setdefault("user_id", user_id)
|
||||
return event_dict
|
||||
|
||||
|
||||
def configure_logging() -> None:
|
||||
settings = get_settings()
|
||||
level = getattr(logging, settings.log_level.upper(), logging.INFO)
|
||||
|
||||
logging.basicConfig(
|
||||
format="%(message)s",
|
||||
stream=sys.stdout,
|
||||
level=level,
|
||||
force=True,
|
||||
)
|
||||
|
||||
structlog.configure(
|
||||
processors=[
|
||||
structlog.contextvars.merge_contextvars,
|
||||
structlog.processors.add_log_level,
|
||||
structlog.processors.TimeStamper(fmt="iso", utc=True),
|
||||
_inject_context,
|
||||
structlog.processors.JSONRenderer(),
|
||||
],
|
||||
wrapper_class=structlog.make_filtering_bound_logger(level),
|
||||
logger_factory=structlog.stdlib.LoggerFactory(),
|
||||
cache_logger_on_first_use=True,
|
||||
)
|
||||
|
||||
|
||||
def get_logger(name: str | None = None):
|
||||
return structlog.get_logger(name)
|
||||
84
services/chat-api/src/main.py
Normal file
84
services/chat-api/src/main.py
Normal file
|
|
@ -0,0 +1,84 @@
|
|||
"""FastAPI entrypoint for the samosaChaat chat API service."""
|
||||
from __future__ import annotations
|
||||
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
import httpx
|
||||
from fastapi import FastAPI, Request, Response
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
|
||||
from .config import get_settings
|
||||
from .logging_setup import (
|
||||
configure_logging,
|
||||
get_logger,
|
||||
new_trace_id,
|
||||
set_trace_id,
|
||||
set_user_id,
|
||||
)
|
||||
from .routes import conversations, messages, models
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
app.state.auth_http_client = httpx.AsyncClient(
|
||||
timeout=httpx.Timeout(5.0, connect=2.0)
|
||||
)
|
||||
app.state.inference_http_client = httpx.AsyncClient(
|
||||
timeout=httpx.Timeout(60.0, connect=5.0)
|
||||
)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
await app.state.auth_http_client.aclose()
|
||||
await app.state.inference_http_client.aclose()
|
||||
|
||||
|
||||
def create_app() -> FastAPI:
|
||||
configure_logging()
|
||||
settings = get_settings()
|
||||
logger = get_logger(__name__)
|
||||
|
||||
app = FastAPI(title="samosaChaat Chat API", version="0.1.0", lifespan=lifespan)
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=[settings.frontend_url],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
@app.middleware("http")
|
||||
async def request_context(request: Request, call_next) -> Response:
|
||||
incoming = request.headers.get("x-trace-id") or request.headers.get("x-request-id")
|
||||
trace_id = incoming or new_trace_id()
|
||||
set_trace_id(trace_id)
|
||||
set_user_id(None)
|
||||
|
||||
logger.info(
|
||||
"request_start",
|
||||
method=request.method,
|
||||
path=request.url.path,
|
||||
)
|
||||
response = await call_next(request)
|
||||
response.headers["x-trace-id"] = trace_id
|
||||
logger.info(
|
||||
"request_end",
|
||||
method=request.method,
|
||||
path=request.url.path,
|
||||
status_code=response.status_code,
|
||||
)
|
||||
return response
|
||||
|
||||
app.include_router(conversations.router)
|
||||
app.include_router(messages.router)
|
||||
app.include_router(models.router)
|
||||
|
||||
@app.get("/api/health")
|
||||
async def health():
|
||||
return {"status": "ok", "ready": True, "service": "chat-api"}
|
||||
|
||||
return app
|
||||
|
||||
|
||||
app = create_app()
|
||||
0
services/chat-api/src/middleware/__init__.py
Normal file
0
services/chat-api/src/middleware/__init__.py
Normal file
154
services/chat-api/src/middleware/auth_guard.py
Normal file
154
services/chat-api/src/middleware/auth_guard.py
Normal file
|
|
@ -0,0 +1,154 @@
|
|||
"""Authentication guard that validates JWTs via the auth service.
|
||||
|
||||
Successful validations are cached in an in-memory TTL cache keyed by the raw
|
||||
JWT string so that a burst of requests from the same user does not fan out to
|
||||
the auth service on every call.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from dataclasses import dataclass
|
||||
from typing import Annotated, Any
|
||||
|
||||
import httpx
|
||||
from cachetools import TTLCache
|
||||
from fastapi import Depends, Header, HTTPException, Request, status
|
||||
|
||||
from ..config import Settings, get_settings
|
||||
from ..logging_setup import get_logger, set_user_id
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AuthenticatedUser:
|
||||
id: str
|
||||
email: str
|
||||
name: str | None
|
||||
raw: dict[str, Any]
|
||||
|
||||
@classmethod
|
||||
def from_validate_response(cls, payload: dict[str, Any]) -> "AuthenticatedUser":
|
||||
user = payload.get("user") or {}
|
||||
return cls(
|
||||
id=str(user["id"]),
|
||||
email=user.get("email", ""),
|
||||
name=user.get("name"),
|
||||
raw=user,
|
||||
)
|
||||
|
||||
|
||||
class AuthCache:
|
||||
"""Thread-safe TTL cache for validated JWTs."""
|
||||
|
||||
def __init__(self, ttl_seconds: int, max_size: int) -> None:
|
||||
self._cache: TTLCache[str, AuthenticatedUser] = TTLCache(
|
||||
maxsize=max_size, ttl=ttl_seconds
|
||||
)
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
async def get(self, token: str) -> AuthenticatedUser | None:
|
||||
async with self._lock:
|
||||
return self._cache.get(token)
|
||||
|
||||
async def set(self, token: str, user: AuthenticatedUser) -> None:
|
||||
async with self._lock:
|
||||
self._cache[token] = user
|
||||
|
||||
async def clear(self) -> None:
|
||||
async with self._lock:
|
||||
self._cache.clear()
|
||||
|
||||
|
||||
_auth_cache: AuthCache | None = None
|
||||
|
||||
|
||||
def get_auth_cache() -> AuthCache:
|
||||
global _auth_cache
|
||||
if _auth_cache is None:
|
||||
settings = get_settings()
|
||||
_auth_cache = AuthCache(
|
||||
ttl_seconds=settings.auth_cache_ttl_seconds,
|
||||
max_size=settings.auth_cache_max_size,
|
||||
)
|
||||
return _auth_cache
|
||||
|
||||
|
||||
def reset_auth_cache() -> None:
|
||||
"""Testing hook: drop the cached singleton so a fresh one is built."""
|
||||
global _auth_cache
|
||||
_auth_cache = None
|
||||
|
||||
|
||||
async def _validate_with_auth_service(
|
||||
token: str, settings: Settings, http_client: httpx.AsyncClient | None = None
|
||||
) -> AuthenticatedUser:
|
||||
owns_client = http_client is None
|
||||
client = http_client or httpx.AsyncClient(timeout=5.0)
|
||||
try:
|
||||
response = await client.post(
|
||||
f"{settings.auth_service_url.rstrip('/')}/auth/validate",
|
||||
headers={"X-Internal-API-Key": settings.internal_api_key},
|
||||
json={"token": token},
|
||||
)
|
||||
except httpx.HTTPError as exc:
|
||||
logger.error("auth_service_unreachable", error=str(exc))
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail="auth service unreachable",
|
||||
) from exc
|
||||
finally:
|
||||
if owns_client:
|
||||
await client.aclose()
|
||||
|
||||
if response.status_code == 401:
|
||||
raise HTTPException(status_code=401, detail="invalid or expired token")
|
||||
if response.status_code == 403:
|
||||
raise HTTPException(status_code=500, detail="internal api key rejected by auth")
|
||||
if response.status_code >= 400:
|
||||
logger.error(
|
||||
"auth_validate_failed",
|
||||
status_code=response.status_code,
|
||||
body=response.text[:200],
|
||||
)
|
||||
raise HTTPException(status_code=401, detail="token validation failed")
|
||||
|
||||
data = response.json()
|
||||
if not data.get("valid"):
|
||||
raise HTTPException(status_code=401, detail=data.get("reason", "invalid token"))
|
||||
|
||||
return AuthenticatedUser.from_validate_response(data)
|
||||
|
||||
|
||||
def _extract_bearer(authorization: str | None) -> str:
|
||||
if not authorization:
|
||||
raise HTTPException(status_code=401, detail="missing authorization header")
|
||||
parts = authorization.split(" ", 1)
|
||||
if len(parts) != 2 or parts[0].lower() != "bearer" or not parts[1].strip():
|
||||
raise HTTPException(status_code=401, detail="invalid authorization scheme")
|
||||
return parts[1].strip()
|
||||
|
||||
|
||||
async def require_user(
|
||||
request: Request,
|
||||
authorization: Annotated[str | None, Header()] = None,
|
||||
settings: Annotated[Settings, Depends(get_settings)] = None, # type: ignore[assignment]
|
||||
) -> AuthenticatedUser:
|
||||
"""FastAPI dependency that yields the authenticated user for the request."""
|
||||
token = _extract_bearer(authorization)
|
||||
cache = get_auth_cache()
|
||||
|
||||
cached = await cache.get(token)
|
||||
if cached is not None:
|
||||
set_user_id(cached.id)
|
||||
request.state.user = cached
|
||||
return cached
|
||||
|
||||
http_client: httpx.AsyncClient | None = getattr(
|
||||
request.app.state, "auth_http_client", None
|
||||
)
|
||||
user = await _validate_with_auth_service(token, settings, http_client=http_client)
|
||||
await cache.set(token, user)
|
||||
set_user_id(user.id)
|
||||
request.state.user = user
|
||||
return user
|
||||
0
services/chat-api/src/routes/__init__.py
Normal file
0
services/chat-api/src/routes/__init__.py
Normal file
130
services/chat-api/src/routes/conversations.py
Normal file
130
services/chat-api/src/routes/conversations.py
Normal file
|
|
@ -0,0 +1,130 @@
|
|||
"""CRUD routes for conversations, scoped to the authenticated user."""
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, status
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from ..database import get_session
|
||||
from ..logging_setup import get_logger
|
||||
from ..middleware.auth_guard import AuthenticatedUser, require_user
|
||||
from ..services import conversation_service
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/api/conversations", tags=["conversations"])
|
||||
|
||||
|
||||
class CreateConversationRequest(BaseModel):
|
||||
title: str | None = Field(default=None, max_length=500)
|
||||
model_tag: str | None = Field(default=None, max_length=100)
|
||||
|
||||
|
||||
class UpdateConversationRequest(BaseModel):
|
||||
title: str = Field(..., min_length=1, max_length=500)
|
||||
|
||||
|
||||
def _parse_uuid(raw: str) -> uuid.UUID:
|
||||
try:
|
||||
return uuid.UUID(raw)
|
||||
except (ValueError, TypeError) as exc:
|
||||
raise HTTPException(status_code=404, detail="conversation not found") from exc
|
||||
|
||||
|
||||
@router.get("")
|
||||
async def list_conversations(
|
||||
user: Annotated[AuthenticatedUser, Depends(require_user)],
|
||||
session: Annotated[AsyncSession, Depends(get_session)],
|
||||
limit: int = Query(default=50, ge=1, le=200),
|
||||
offset: int = Query(default=0, ge=0),
|
||||
):
|
||||
user_uuid = uuid.UUID(user.id)
|
||||
conversations = await conversation_service.list_conversations(
|
||||
session, user_id=user_uuid, limit=limit, offset=offset
|
||||
)
|
||||
grouped = conversation_service.group_by_date(conversations)
|
||||
return {
|
||||
"items": [c.to_dict() for c in conversations],
|
||||
"grouped": grouped,
|
||||
"limit": limit,
|
||||
"offset": offset,
|
||||
}
|
||||
|
||||
|
||||
@router.post("", status_code=status.HTTP_201_CREATED)
|
||||
async def create_conversation(
|
||||
body: CreateConversationRequest,
|
||||
user: Annotated[AuthenticatedUser, Depends(require_user)],
|
||||
session: Annotated[AsyncSession, Depends(get_session)],
|
||||
):
|
||||
user_uuid = uuid.UUID(user.id)
|
||||
convo = await conversation_service.create_conversation(
|
||||
session,
|
||||
user_id=user_uuid,
|
||||
title=body.title,
|
||||
model_tag=body.model_tag,
|
||||
)
|
||||
logger.info("conversation_created", conversation_id=str(convo.id))
|
||||
return convo.to_dict()
|
||||
|
||||
|
||||
@router.get("/{conversation_id}")
|
||||
async def get_conversation(
|
||||
conversation_id: str,
|
||||
user: Annotated[AuthenticatedUser, Depends(require_user)],
|
||||
session: Annotated[AsyncSession, Depends(get_session)],
|
||||
):
|
||||
conv_uuid = _parse_uuid(conversation_id)
|
||||
user_uuid = uuid.UUID(user.id)
|
||||
convo = await conversation_service.get_user_conversation(
|
||||
session, user_id=user_uuid, conversation_id=conv_uuid
|
||||
)
|
||||
if convo is None:
|
||||
raise HTTPException(status_code=404, detail="conversation not found")
|
||||
|
||||
messages = await conversation_service.get_conversation_messages(
|
||||
session, conversation_id=conv_uuid
|
||||
)
|
||||
payload = convo.to_dict()
|
||||
payload["messages"] = [m.to_dict() for m in messages]
|
||||
return payload
|
||||
|
||||
|
||||
@router.put("/{conversation_id}")
|
||||
async def update_conversation(
|
||||
conversation_id: str,
|
||||
body: UpdateConversationRequest,
|
||||
user: Annotated[AuthenticatedUser, Depends(require_user)],
|
||||
session: Annotated[AsyncSession, Depends(get_session)],
|
||||
):
|
||||
conv_uuid = _parse_uuid(conversation_id)
|
||||
user_uuid = uuid.UUID(user.id)
|
||||
convo = await conversation_service.update_conversation_title(
|
||||
session,
|
||||
user_id=user_uuid,
|
||||
conversation_id=conv_uuid,
|
||||
title=body.title,
|
||||
)
|
||||
if convo is None:
|
||||
raise HTTPException(status_code=404, detail="conversation not found")
|
||||
return convo.to_dict()
|
||||
|
||||
|
||||
@router.delete("/{conversation_id}", status_code=status.HTTP_204_NO_CONTENT)
|
||||
async def delete_conversation(
|
||||
conversation_id: str,
|
||||
user: Annotated[AuthenticatedUser, Depends(require_user)],
|
||||
session: Annotated[AsyncSession, Depends(get_session)],
|
||||
):
|
||||
conv_uuid = _parse_uuid(conversation_id)
|
||||
user_uuid = uuid.UUID(user.id)
|
||||
deleted = await conversation_service.delete_conversation(
|
||||
session, user_id=user_uuid, conversation_id=conv_uuid
|
||||
)
|
||||
if not deleted:
|
||||
raise HTTPException(status_code=404, detail="conversation not found")
|
||||
logger.info("conversation_deleted", conversation_id=str(conv_uuid))
|
||||
return None
|
||||
257
services/chat-api/src/routes/messages.py
Normal file
257
services/chat-api/src/routes/messages.py
Normal file
|
|
@ -0,0 +1,257 @@
|
|||
"""The main chat route: send a message and stream an assistant response."""
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from typing import Annotated, AsyncIterator
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, status
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
||||
from sse_starlette.sse import EventSourceResponse
|
||||
|
||||
from ..config import Settings, get_settings
|
||||
from ..database import get_session_factory
|
||||
from ..logging_setup import get_logger
|
||||
from ..middleware.auth_guard import AuthenticatedUser, require_user
|
||||
from ..services import conversation_service
|
||||
from ..services.inference_client import InferenceClient
|
||||
from ..services.stream_proxy import StreamResult, proxy_inference_stream
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/api/conversations", tags=["messages"])
|
||||
|
||||
|
||||
class SendMessageRequest(BaseModel):
|
||||
content: str = Field(..., min_length=1, max_length=8000)
|
||||
temperature: float | None = Field(default=None, ge=0.0, le=2.0)
|
||||
max_tokens: int | None = Field(default=None, ge=1, le=4096)
|
||||
top_k: int | None = Field(default=None, ge=0, le=200)
|
||||
|
||||
|
||||
class RegenerateRequest(BaseModel):
|
||||
temperature: float | None = Field(default=None, ge=0.0, le=2.0)
|
||||
max_tokens: int | None = Field(default=None, ge=1, le=4096)
|
||||
top_k: int | None = Field(default=None, ge=0, le=200)
|
||||
|
||||
|
||||
def _parse_uuid(raw: str) -> uuid.UUID:
|
||||
try:
|
||||
return uuid.UUID(raw)
|
||||
except (ValueError, TypeError) as exc:
|
||||
raise HTTPException(status_code=404, detail="conversation not found") from exc
|
||||
|
||||
|
||||
async def _ensure_ownership(
|
||||
session: AsyncSession,
|
||||
*,
|
||||
user_id: uuid.UUID,
|
||||
conversation_id: uuid.UUID,
|
||||
):
|
||||
convo = await conversation_service.get_user_conversation(
|
||||
session, user_id=user_id, conversation_id=conversation_id
|
||||
)
|
||||
if convo is None:
|
||||
raise HTTPException(status_code=404, detail="conversation not found")
|
||||
return convo
|
||||
|
||||
|
||||
async def _stream_and_persist(
|
||||
*,
|
||||
request: Request,
|
||||
user_id: uuid.UUID,
|
||||
conversation_id: uuid.UUID,
|
||||
history: list[dict[str, str]],
|
||||
temperature: float | None,
|
||||
max_tokens: int | None,
|
||||
top_k: int | None,
|
||||
model_tag: str,
|
||||
first_message: bool,
|
||||
first_message_preview: str | None,
|
||||
settings: Settings,
|
||||
) -> AsyncIterator[dict]:
|
||||
"""Generator that streams inference SSE events to the client and, after the
|
||||
stream closes, persists the full assistant message in a fresh DB session.
|
||||
"""
|
||||
http_client = getattr(request.app.state, "inference_http_client", None)
|
||||
inference = InferenceClient(settings=settings, http_client=http_client)
|
||||
|
||||
accumulated: dict[str, StreamResult] = {}
|
||||
|
||||
def _capture(result: StreamResult) -> None:
|
||||
accumulated["result"] = result
|
||||
|
||||
try:
|
||||
async with inference.stream_generate(
|
||||
messages=history,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
top_k=top_k,
|
||||
) as response:
|
||||
async for event in proxy_inference_stream(response, on_complete=_capture):
|
||||
yield event
|
||||
except Exception as exc: # pragma: no cover - defensive path
|
||||
logger.error(
|
||||
"inference_stream_failed",
|
||||
conversation_id=str(conversation_id),
|
||||
error=str(exc),
|
||||
)
|
||||
yield {"data": '{"error":"inference_stream_failed"}'}
|
||||
yield {"data": '{"done":true}'}
|
||||
return
|
||||
|
||||
result = accumulated.get("result")
|
||||
if result is None or not result.completed or not result.content:
|
||||
logger.warning(
|
||||
"assistant_message_not_persisted",
|
||||
conversation_id=str(conversation_id),
|
||||
completed=bool(result and result.completed),
|
||||
content_len=len(result.content) if result else 0,
|
||||
)
|
||||
return
|
||||
|
||||
factory: async_sessionmaker[AsyncSession] = get_session_factory()
|
||||
try:
|
||||
async with factory() as persist_session:
|
||||
await conversation_service.append_message(
|
||||
persist_session,
|
||||
conversation_id=conversation_id,
|
||||
role="assistant",
|
||||
content=result.content,
|
||||
token_count=result.token_count,
|
||||
model_tag=model_tag,
|
||||
inference_time_ms=result.inference_time_ms,
|
||||
)
|
||||
if first_message and first_message_preview is not None:
|
||||
await conversation_service.update_conversation_title(
|
||||
persist_session,
|
||||
user_id=user_id,
|
||||
conversation_id=conversation_id,
|
||||
title=first_message_preview,
|
||||
)
|
||||
logger.info(
|
||||
"assistant_message_persisted",
|
||||
conversation_id=str(conversation_id),
|
||||
token_count=result.token_count,
|
||||
inference_time_ms=result.inference_time_ms,
|
||||
)
|
||||
except Exception as exc: # pragma: no cover - defensive path
|
||||
logger.error(
|
||||
"assistant_message_persist_failed",
|
||||
conversation_id=str(conversation_id),
|
||||
error=str(exc),
|
||||
)
|
||||
|
||||
|
||||
@router.post("/{conversation_id}/messages")
|
||||
async def send_message(
|
||||
conversation_id: str,
|
||||
body: SendMessageRequest,
|
||||
request: Request,
|
||||
user: Annotated[AuthenticatedUser, Depends(require_user)],
|
||||
settings: Annotated[Settings, Depends(get_settings)] = None, # type: ignore[assignment]
|
||||
):
|
||||
# We own our DB sessions explicitly here so the background task that
|
||||
# persists the streamed assistant response can open its own session once
|
||||
# the request scope has already closed.
|
||||
conv_uuid = _parse_uuid(conversation_id)
|
||||
user_uuid = uuid.UUID(user.id)
|
||||
session_factory = get_session_factory()
|
||||
|
||||
async with session_factory() as db_session:
|
||||
convo = await _ensure_ownership(
|
||||
db_session, user_id=user_uuid, conversation_id=conv_uuid
|
||||
)
|
||||
model_tag = convo.model_tag or "default"
|
||||
|
||||
existing = await conversation_service.get_conversation_messages(
|
||||
db_session, conversation_id=conv_uuid, limit=1
|
||||
)
|
||||
first_message = len(existing) == 0
|
||||
first_preview = body.content[:80] if first_message else None
|
||||
|
||||
await conversation_service.append_message(
|
||||
db_session,
|
||||
conversation_id=conv_uuid,
|
||||
role="user",
|
||||
content=body.content,
|
||||
token_count=None,
|
||||
model_tag=model_tag,
|
||||
)
|
||||
history = await conversation_service.build_history_for_inference(
|
||||
db_session, conversation_id=conv_uuid
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"send_message",
|
||||
conversation_id=str(conv_uuid),
|
||||
history_len=len(history),
|
||||
model_tag=model_tag,
|
||||
)
|
||||
|
||||
generator = _stream_and_persist(
|
||||
request=request,
|
||||
user_id=user_uuid,
|
||||
conversation_id=conv_uuid,
|
||||
history=history,
|
||||
temperature=body.temperature,
|
||||
max_tokens=body.max_tokens,
|
||||
top_k=body.top_k,
|
||||
model_tag=model_tag,
|
||||
first_message=first_message,
|
||||
first_message_preview=first_preview,
|
||||
settings=settings,
|
||||
)
|
||||
return EventSourceResponse(generator, media_type="text/event-stream")
|
||||
|
||||
|
||||
@router.post("/{conversation_id}/regenerate")
|
||||
async def regenerate(
|
||||
conversation_id: str,
|
||||
body: RegenerateRequest,
|
||||
request: Request,
|
||||
user: Annotated[AuthenticatedUser, Depends(require_user)],
|
||||
settings: Annotated[Settings, Depends(get_settings)] = None, # type: ignore[assignment]
|
||||
):
|
||||
conv_uuid = _parse_uuid(conversation_id)
|
||||
user_uuid = uuid.UUID(user.id)
|
||||
session_factory = get_session_factory()
|
||||
|
||||
async with session_factory() as db_session:
|
||||
convo = await _ensure_ownership(
|
||||
db_session, user_id=user_uuid, conversation_id=conv_uuid
|
||||
)
|
||||
model_tag = convo.model_tag or "default"
|
||||
await conversation_service.delete_last_assistant_message(
|
||||
db_session, conversation_id=conv_uuid
|
||||
)
|
||||
history = await conversation_service.build_history_for_inference(
|
||||
db_session, conversation_id=conv_uuid
|
||||
)
|
||||
|
||||
if not history:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="conversation has no user messages to regenerate from",
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"regenerate_message",
|
||||
conversation_id=str(conv_uuid),
|
||||
history_len=len(history),
|
||||
)
|
||||
|
||||
generator = _stream_and_persist(
|
||||
request=request,
|
||||
user_id=user_uuid,
|
||||
conversation_id=conv_uuid,
|
||||
history=history,
|
||||
temperature=body.temperature,
|
||||
max_tokens=body.max_tokens,
|
||||
top_k=body.top_k,
|
||||
model_tag=model_tag,
|
||||
first_message=False,
|
||||
first_message_preview=None,
|
||||
settings=settings,
|
||||
)
|
||||
return EventSourceResponse(generator, media_type="text/event-stream")
|
||||
74
services/chat-api/src/routes/models.py
Normal file
74
services/chat-api/src/routes/models.py
Normal file
|
|
@ -0,0 +1,74 @@
|
|||
"""Proxy routes that forward model management calls to the inference service."""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Annotated
|
||||
|
||||
import httpx
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, status
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from ..config import Settings, get_settings
|
||||
from ..logging_setup import get_logger
|
||||
from ..middleware.auth_guard import AuthenticatedUser, require_user
|
||||
from ..services.inference_client import InferenceClient
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/api/models", tags=["models"])
|
||||
|
||||
|
||||
class SwapModelRequest(BaseModel):
|
||||
model_tag: str = Field(..., min_length=1, max_length=100)
|
||||
|
||||
|
||||
def _client_for(request: Request, settings: Settings) -> InferenceClient:
|
||||
http_client = getattr(request.app.state, "inference_http_client", None)
|
||||
return InferenceClient(settings=settings, http_client=http_client)
|
||||
|
||||
|
||||
@router.get("")
|
||||
async def list_models(
|
||||
request: Request,
|
||||
user: Annotated[AuthenticatedUser, Depends(require_user)],
|
||||
settings: Annotated[Settings, Depends(get_settings)] = None, # type: ignore[assignment]
|
||||
):
|
||||
client = _client_for(request, settings)
|
||||
try:
|
||||
return await client.list_models()
|
||||
except httpx.HTTPStatusError as exc:
|
||||
logger.error("list_models_proxy_failed", status_code=exc.response.status_code)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_502_BAD_GATEWAY,
|
||||
detail="inference service error",
|
||||
) from exc
|
||||
except httpx.HTTPError as exc:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail="inference service unreachable",
|
||||
) from exc
|
||||
|
||||
|
||||
@router.post("/swap")
|
||||
async def swap_model(
|
||||
body: SwapModelRequest,
|
||||
request: Request,
|
||||
user: Annotated[AuthenticatedUser, Depends(require_user)],
|
||||
settings: Annotated[Settings, Depends(get_settings)] = None, # type: ignore[assignment]
|
||||
):
|
||||
if not user.raw.get("is_admin"):
|
||||
raise HTTPException(status_code=403, detail="admin privilege required")
|
||||
|
||||
client = _client_for(request, settings)
|
||||
try:
|
||||
return await client.swap_model(body.model_tag)
|
||||
except httpx.HTTPStatusError as exc:
|
||||
logger.error("swap_model_proxy_failed", status_code=exc.response.status_code)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_502_BAD_GATEWAY,
|
||||
detail="inference service rejected swap",
|
||||
) from exc
|
||||
except httpx.HTTPError as exc:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail="inference service unreachable",
|
||||
) from exc
|
||||
0
services/chat-api/src/services/__init__.py
Normal file
0
services/chat-api/src/services/__init__.py
Normal file
209
services/chat-api/src/services/conversation_service.py
Normal file
209
services/chat-api/src/services/conversation_service.py
Normal file
|
|
@ -0,0 +1,209 @@
|
|||
"""Business logic for conversations and messages, scoped to a single user.
|
||||
|
||||
Every query in this module filters by `user_id` — that scoping is the only
|
||||
thing preventing one user from reading or mutating another user's data.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from collections import defaultdict
|
||||
from datetime import date
|
||||
from typing import Iterable
|
||||
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from ..config import get_settings
|
||||
from ..models import Conversation, Message
|
||||
|
||||
|
||||
async def create_conversation(
|
||||
session: AsyncSession,
|
||||
*,
|
||||
user_id: uuid.UUID,
|
||||
title: str | None = None,
|
||||
model_tag: str | None = None,
|
||||
) -> Conversation:
|
||||
conversation = Conversation(
|
||||
user_id=user_id,
|
||||
title=title or "New conversation",
|
||||
model_tag=model_tag or "default",
|
||||
)
|
||||
session.add(conversation)
|
||||
await session.commit()
|
||||
await session.refresh(conversation)
|
||||
return conversation
|
||||
|
||||
|
||||
async def list_conversations(
|
||||
session: AsyncSession,
|
||||
*,
|
||||
user_id: uuid.UUID,
|
||||
limit: int = 50,
|
||||
offset: int = 0,
|
||||
) -> list[Conversation]:
|
||||
stmt = (
|
||||
sa.select(Conversation)
|
||||
.where(Conversation.user_id == user_id)
|
||||
.order_by(Conversation.updated_at.desc())
|
||||
.limit(limit)
|
||||
.offset(offset)
|
||||
)
|
||||
result = await session.execute(stmt)
|
||||
return list(result.scalars().all())
|
||||
|
||||
|
||||
def group_by_date(conversations: Iterable[Conversation]) -> dict[str, list[dict]]:
|
||||
"""Group conversations by YYYY-MM-DD (UTC) of `updated_at`."""
|
||||
buckets: dict[str, list[dict]] = defaultdict(list)
|
||||
for convo in conversations:
|
||||
bucket_key: str
|
||||
if convo.updated_at is None:
|
||||
bucket_key = date.today().isoformat()
|
||||
else:
|
||||
bucket_key = convo.updated_at.date().isoformat()
|
||||
buckets[bucket_key].append(convo.to_dict())
|
||||
return dict(buckets)
|
||||
|
||||
|
||||
async def get_user_conversation(
|
||||
session: AsyncSession,
|
||||
*,
|
||||
user_id: uuid.UUID,
|
||||
conversation_id: uuid.UUID,
|
||||
) -> Conversation | None:
|
||||
stmt = sa.select(Conversation).where(
|
||||
Conversation.id == conversation_id,
|
||||
Conversation.user_id == user_id,
|
||||
)
|
||||
result = await session.execute(stmt)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
|
||||
async def get_conversation_messages(
|
||||
session: AsyncSession,
|
||||
*,
|
||||
conversation_id: uuid.UUID,
|
||||
limit: int | None = None,
|
||||
) -> list[Message]:
|
||||
stmt = (
|
||||
sa.select(Message)
|
||||
.where(Message.conversation_id == conversation_id)
|
||||
.order_by(Message.created_at.asc(), Message.id.asc())
|
||||
)
|
||||
if limit is not None:
|
||||
stmt = stmt.limit(limit)
|
||||
result = await session.execute(stmt)
|
||||
return list(result.scalars().all())
|
||||
|
||||
|
||||
async def update_conversation_title(
|
||||
session: AsyncSession,
|
||||
*,
|
||||
user_id: uuid.UUID,
|
||||
conversation_id: uuid.UUID,
|
||||
title: str,
|
||||
) -> Conversation | None:
|
||||
convo = await get_user_conversation(
|
||||
session, user_id=user_id, conversation_id=conversation_id
|
||||
)
|
||||
if convo is None:
|
||||
return None
|
||||
convo.title = title
|
||||
await session.commit()
|
||||
await session.refresh(convo)
|
||||
return convo
|
||||
|
||||
|
||||
async def delete_conversation(
|
||||
session: AsyncSession,
|
||||
*,
|
||||
user_id: uuid.UUID,
|
||||
conversation_id: uuid.UUID,
|
||||
) -> bool:
|
||||
convo = await get_user_conversation(
|
||||
session, user_id=user_id, conversation_id=conversation_id
|
||||
)
|
||||
if convo is None:
|
||||
return False
|
||||
await session.delete(convo)
|
||||
await session.commit()
|
||||
return True
|
||||
|
||||
|
||||
async def append_message(
|
||||
session: AsyncSession,
|
||||
*,
|
||||
conversation_id: uuid.UUID,
|
||||
role: str,
|
||||
content: str,
|
||||
token_count: int | None = None,
|
||||
model_tag: str | None = None,
|
||||
inference_time_ms: int | None = None,
|
||||
) -> Message:
|
||||
message = Message(
|
||||
conversation_id=conversation_id,
|
||||
role=role,
|
||||
content=content,
|
||||
token_count=token_count,
|
||||
model_tag=model_tag,
|
||||
inference_time_ms=inference_time_ms,
|
||||
)
|
||||
session.add(message)
|
||||
await session.commit()
|
||||
await session.refresh(message)
|
||||
return message
|
||||
|
||||
|
||||
async def build_history_for_inference(
|
||||
session: AsyncSession,
|
||||
*,
|
||||
conversation_id: uuid.UUID,
|
||||
) -> list[dict[str, str]]:
|
||||
"""Return the trailing slice of history that fits the configured budgets."""
|
||||
settings = get_settings()
|
||||
max_history = settings.max_conversation_history
|
||||
max_budget = settings.max_token_budget
|
||||
|
||||
stmt = (
|
||||
sa.select(Message)
|
||||
.where(Message.conversation_id == conversation_id)
|
||||
.order_by(Message.created_at.desc(), Message.id.desc())
|
||||
.limit(max_history)
|
||||
)
|
||||
result = await session.execute(stmt)
|
||||
newest_first = list(result.scalars().all())
|
||||
|
||||
selected: list[Message] = []
|
||||
budget = 0
|
||||
for message in newest_first:
|
||||
budget += len(message.content or "")
|
||||
if budget > max_budget and selected:
|
||||
break
|
||||
selected.append(message)
|
||||
|
||||
selected.reverse()
|
||||
return [{"role": m.role, "content": m.content} for m in selected]
|
||||
|
||||
|
||||
async def delete_last_assistant_message(
|
||||
session: AsyncSession,
|
||||
*,
|
||||
conversation_id: uuid.UUID,
|
||||
) -> bool:
|
||||
stmt = (
|
||||
sa.select(Message)
|
||||
.where(
|
||||
Message.conversation_id == conversation_id,
|
||||
Message.role == "assistant",
|
||||
)
|
||||
.order_by(Message.created_at.desc(), Message.id.desc())
|
||||
.limit(1)
|
||||
)
|
||||
result = await session.execute(stmt)
|
||||
last = result.scalar_one_or_none()
|
||||
if last is None:
|
||||
return False
|
||||
await session.delete(last)
|
||||
await session.commit()
|
||||
return True
|
||||
93
services/chat-api/src/services/inference_client.py
Normal file
93
services/chat-api/src/services/inference_client.py
Normal file
|
|
@ -0,0 +1,93 @@
|
|||
"""HTTP client wrapper for the inference service."""
|
||||
from __future__ import annotations
|
||||
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import AsyncIterator
|
||||
|
||||
import httpx
|
||||
|
||||
from ..config import Settings, get_settings
|
||||
|
||||
|
||||
class InferenceClient:
|
||||
"""Thin async wrapper around the inference service HTTP contract."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
settings: Settings | None = None,
|
||||
http_client: httpx.AsyncClient | None = None,
|
||||
) -> None:
|
||||
self._settings = settings or get_settings()
|
||||
self._client = http_client
|
||||
self._owns_client = http_client is None
|
||||
|
||||
@property
|
||||
def base_url(self) -> str:
|
||||
return self._settings.inference_service_url.rstrip("/")
|
||||
|
||||
@property
|
||||
def headers(self) -> dict[str, str]:
|
||||
return {"X-Internal-API-Key": self._settings.internal_api_key}
|
||||
|
||||
def _get_client(self) -> httpx.AsyncClient:
|
||||
if self._client is None:
|
||||
self._client = httpx.AsyncClient(timeout=httpx.Timeout(60.0, connect=5.0))
|
||||
return self._client
|
||||
|
||||
async def aclose(self) -> None:
|
||||
if self._client is not None and self._owns_client:
|
||||
await self._client.aclose()
|
||||
self._client = None
|
||||
|
||||
async def list_models(self) -> dict:
|
||||
client = self._get_client()
|
||||
resp = await client.get(f"{self.base_url}/models", headers=self.headers)
|
||||
resp.raise_for_status()
|
||||
return resp.json()
|
||||
|
||||
async def swap_model(self, model_tag: str) -> dict:
|
||||
client = self._get_client()
|
||||
resp = await client.post(
|
||||
f"{self.base_url}/models/swap",
|
||||
headers=self.headers,
|
||||
json={"model_tag": model_tag},
|
||||
)
|
||||
resp.raise_for_status()
|
||||
return resp.json()
|
||||
|
||||
@asynccontextmanager
|
||||
async def stream_generate(
|
||||
self,
|
||||
*,
|
||||
messages: list[dict[str, str]],
|
||||
temperature: float | None = None,
|
||||
max_tokens: int | None = None,
|
||||
top_k: int | None = None,
|
||||
) -> AsyncIterator[httpx.Response]:
|
||||
temperature = (
|
||||
temperature
|
||||
if temperature is not None
|
||||
else self._settings.inference_default_temperature
|
||||
)
|
||||
max_tokens = (
|
||||
max_tokens
|
||||
if max_tokens is not None
|
||||
else self._settings.inference_default_max_tokens
|
||||
)
|
||||
top_k = top_k if top_k is not None else self._settings.inference_default_top_k
|
||||
|
||||
payload = {
|
||||
"messages": messages,
|
||||
"temperature": temperature,
|
||||
"max_tokens": max_tokens,
|
||||
"top_k": top_k,
|
||||
}
|
||||
|
||||
client = self._get_client()
|
||||
async with client.stream(
|
||||
"POST",
|
||||
f"{self.base_url}/generate",
|
||||
headers=self.headers,
|
||||
json=payload,
|
||||
) as response:
|
||||
yield response
|
||||
114
services/chat-api/src/services/stream_proxy.py
Normal file
114
services/chat-api/src/services/stream_proxy.py
Normal file
|
|
@ -0,0 +1,114 @@
|
|||
"""Proxy the inference SSE stream to the client while accumulating tokens.
|
||||
|
||||
The inference service emits lines like `data: {"token": "...", "gpu": 0}` and
|
||||
terminates with `data: {"done": true}`. We forward each event unchanged to the
|
||||
client, collect assistant tokens into a buffer, and report the buffer plus
|
||||
timing info once the stream closes.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from typing import AsyncIterator, Callable
|
||||
|
||||
import httpx
|
||||
|
||||
from ..logging_setup import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class StreamResult:
|
||||
content: str
|
||||
token_count: int
|
||||
inference_time_ms: int
|
||||
completed: bool
|
||||
|
||||
|
||||
def _parse_sse_data(raw_line: str) -> dict | None:
|
||||
line = raw_line.strip()
|
||||
if not line or not line.startswith("data:"):
|
||||
return None
|
||||
body = line[len("data:"):].strip()
|
||||
if not body:
|
||||
return None
|
||||
try:
|
||||
return json.loads(body)
|
||||
except json.JSONDecodeError:
|
||||
logger.warning("stream_proxy_bad_sse_payload", payload=body[:200])
|
||||
return None
|
||||
|
||||
|
||||
async def proxy_inference_stream(
|
||||
response: httpx.Response,
|
||||
*,
|
||||
on_complete: Callable[[StreamResult], None] | None = None,
|
||||
) -> AsyncIterator[dict]:
|
||||
"""Async generator that yields SSE events as dicts with ``data`` keys.
|
||||
|
||||
Each yielded event is shaped as ``{"data": "<json-string>"}`` so callers can
|
||||
pass it straight into sse-starlette's ``EventSourceResponse``.
|
||||
"""
|
||||
started = time.perf_counter()
|
||||
buffer: list[str] = []
|
||||
token_count = 0
|
||||
completed = False
|
||||
|
||||
if response.status_code >= 400:
|
||||
body = await response.aread()
|
||||
logger.error(
|
||||
"inference_error_response",
|
||||
status_code=response.status_code,
|
||||
body=body.decode("utf-8", errors="replace")[:200],
|
||||
)
|
||||
error_payload = json.dumps(
|
||||
{
|
||||
"error": "inference service returned an error",
|
||||
"status_code": response.status_code,
|
||||
}
|
||||
)
|
||||
yield {"data": error_payload}
|
||||
yield {"data": json.dumps({"done": True})}
|
||||
if on_complete is not None:
|
||||
on_complete(
|
||||
StreamResult(
|
||||
content="",
|
||||
token_count=0,
|
||||
inference_time_ms=int((time.perf_counter() - started) * 1000),
|
||||
completed=False,
|
||||
)
|
||||
)
|
||||
return
|
||||
|
||||
try:
|
||||
async for raw_line in response.aiter_lines():
|
||||
if not raw_line:
|
||||
continue
|
||||
parsed = _parse_sse_data(raw_line)
|
||||
if parsed is None:
|
||||
continue
|
||||
|
||||
if parsed.get("done"):
|
||||
completed = True
|
||||
yield {"data": json.dumps(parsed)}
|
||||
break
|
||||
|
||||
token = parsed.get("token")
|
||||
if isinstance(token, str) and token:
|
||||
buffer.append(token)
|
||||
token_count += 1
|
||||
|
||||
yield {"data": json.dumps(parsed)}
|
||||
finally:
|
||||
elapsed_ms = int((time.perf_counter() - started) * 1000)
|
||||
if on_complete is not None:
|
||||
on_complete(
|
||||
StreamResult(
|
||||
content="".join(buffer),
|
||||
token_count=token_count,
|
||||
inference_time_ms=elapsed_ms,
|
||||
completed=completed,
|
||||
)
|
||||
)
|
||||
0
services/chat-api/src/tests/__init__.py
Normal file
0
services/chat-api/src/tests/__init__.py
Normal file
160
services/chat-api/src/tests/conftest.py
Normal file
160
services/chat-api/src/tests/conftest.py
Normal file
|
|
@ -0,0 +1,160 @@
|
|||
"""Shared pytest fixtures for the chat API test suite.
|
||||
|
||||
Tests run against a throwaway SQLite database (via aiosqlite) with a stub
|
||||
`users` table so that the foreign key from `conversations.user_id` validates.
|
||||
The auth service is mocked with respx; the inference service SSE responses
|
||||
are mocked with hand-rolled httpx MockTransports.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from typing import AsyncIterator
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
import sqlalchemy as sa
|
||||
from httpx import ASGITransport, AsyncClient
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||
|
||||
from src import database as db_module
|
||||
from src.database import Base
|
||||
from src.middleware import auth_guard
|
||||
from src.models import Conversation, Message # noqa: F401 (registers tables)
|
||||
|
||||
|
||||
# Attach a stub `users` table to the shared metadata so the FK on
|
||||
# `conversations.user_id` can resolve during SQLite-based test setup.
|
||||
if "users" not in Base.metadata.tables:
|
||||
sa.Table(
|
||||
"users",
|
||||
Base.metadata,
|
||||
sa.Column("id", sa.Uuid(as_uuid=True), primary_key=True),
|
||||
sa.Column("email", sa.String(255)),
|
||||
sa.Column("name", sa.String(255)),
|
||||
sa.Column("is_admin", sa.Integer(), default=0),
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _reset_caches(monkeypatch):
|
||||
# Settings and the auth cache are module-level singletons; wipe them
|
||||
# between tests so environment overrides take effect cleanly.
|
||||
from src.config import get_settings
|
||||
|
||||
get_settings.cache_clear()
|
||||
auth_guard.reset_auth_cache()
|
||||
_TOKEN_REGISTRY.clear()
|
||||
yield
|
||||
get_settings.cache_clear()
|
||||
auth_guard.reset_auth_cache()
|
||||
_TOKEN_REGISTRY.clear()
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _env(monkeypatch):
|
||||
monkeypatch.setenv("DATABASE_URL", "sqlite+aiosqlite:///:memory:")
|
||||
monkeypatch.setenv("AUTH_SERVICE_URL", "http://auth.test")
|
||||
monkeypatch.setenv("INFERENCE_SERVICE_URL", "http://inference.test")
|
||||
monkeypatch.setenv("INTERNAL_API_KEY", "test-internal-key")
|
||||
monkeypatch.setenv("LOG_LEVEL", "WARNING")
|
||||
from src.config import get_settings
|
||||
|
||||
get_settings.cache_clear()
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def engine(_env) -> AsyncIterator:
|
||||
engine = create_async_engine("sqlite+aiosqlite:///:memory:")
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
yield engine
|
||||
await engine.dispose()
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def session_factory(engine) -> async_sessionmaker[AsyncSession]:
|
||||
factory = async_sessionmaker(engine, expire_on_commit=False)
|
||||
db_module.override_session_factory(factory)
|
||||
return factory
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def seeded_user(session_factory) -> dict:
|
||||
user_id = str(uuid.uuid4())
|
||||
async with session_factory() as session:
|
||||
await session.execute(
|
||||
sa.text(
|
||||
"INSERT INTO users (id, email, name, is_admin) "
|
||||
"VALUES (:id, :email, :name, :is_admin)"
|
||||
),
|
||||
{
|
||||
"id": user_id,
|
||||
"email": "alice@example.com",
|
||||
"name": "Alice",
|
||||
"is_admin": 0,
|
||||
},
|
||||
)
|
||||
await session.commit()
|
||||
return {"id": user_id, "email": "alice@example.com", "name": "Alice"}
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def other_user(session_factory) -> dict:
|
||||
user_id = str(uuid.uuid4())
|
||||
async with session_factory() as session:
|
||||
await session.execute(
|
||||
sa.text(
|
||||
"INSERT INTO users (id, email, name, is_admin) "
|
||||
"VALUES (:id, :email, :name, :is_admin)"
|
||||
),
|
||||
{"id": user_id, "email": "bob@example.com", "name": "Bob", "is_admin": 0},
|
||||
)
|
||||
await session.commit()
|
||||
return {"id": user_id, "email": "bob@example.com", "name": "Bob"}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app():
|
||||
from src.main import create_app
|
||||
|
||||
return create_app()
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def client(app, session_factory) -> AsyncIterator[AsyncClient]:
|
||||
transport = ASGITransport(app=app)
|
||||
async with AsyncClient(transport=transport, base_url="http://test") as ac:
|
||||
yield ac
|
||||
|
||||
|
||||
_TOKEN_REGISTRY: dict[str, dict] = {}
|
||||
|
||||
|
||||
def _dispatch_auth_validate(request):
|
||||
import json
|
||||
|
||||
import httpx
|
||||
|
||||
payload = json.loads(request.read())
|
||||
token = payload.get("token")
|
||||
if request.headers.get("X-Internal-API-Key") != "test-internal-key":
|
||||
return httpx.Response(403, json={"detail": "invalid internal api key"})
|
||||
user = _TOKEN_REGISTRY.get(token)
|
||||
if user is None:
|
||||
return httpx.Response(401, json={"valid": False, "reason": "invalid token"})
|
||||
return httpx.Response(
|
||||
200,
|
||||
json={"valid": True, "user": user, "claims": {"sub": user["id"]}},
|
||||
)
|
||||
|
||||
|
||||
def stub_auth_validate(respx_mock, user: dict, token: str = "valid-token"):
|
||||
"""Register a respx mock that returns the given user for the given token.
|
||||
|
||||
Multiple calls within a single test accumulate token→user mappings so
|
||||
several users can authenticate in the same scenario.
|
||||
"""
|
||||
_TOKEN_REGISTRY[token] = user
|
||||
respx_mock.post("http://auth.test/auth/validate").mock(
|
||||
side_effect=_dispatch_auth_validate
|
||||
)
|
||||
110
services/chat-api/src/tests/test_conversations.py
Normal file
110
services/chat-api/src/tests/test_conversations.py
Normal file
|
|
@ -0,0 +1,110 @@
|
|||
import pytest
|
||||
import respx
|
||||
|
||||
from .conftest import stub_auth_validate
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@respx.mock
|
||||
async def test_requires_authorization_header(client):
|
||||
response = await client.get("/api/conversations")
|
||||
assert response.status_code == 401
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@respx.mock
|
||||
async def test_create_and_list_conversation(client, seeded_user):
|
||||
stub_auth_validate(respx.mock, seeded_user)
|
||||
headers = {"Authorization": "Bearer valid-token"}
|
||||
|
||||
create = await client.post(
|
||||
"/api/conversations", json={"title": "my first chat"}, headers=headers
|
||||
)
|
||||
assert create.status_code == 201
|
||||
convo = create.json()
|
||||
assert convo["title"] == "my first chat"
|
||||
assert convo["user_id"] == seeded_user["id"]
|
||||
|
||||
listed = await client.get("/api/conversations", headers=headers)
|
||||
assert listed.status_code == 200
|
||||
payload = listed.json()
|
||||
assert len(payload["items"]) == 1
|
||||
assert payload["items"][0]["id"] == convo["id"]
|
||||
assert payload["grouped"] # at least one date bucket
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@respx.mock
|
||||
async def test_update_conversation_title(client, seeded_user):
|
||||
stub_auth_validate(respx.mock, seeded_user)
|
||||
headers = {"Authorization": "Bearer valid-token"}
|
||||
|
||||
create = await client.post("/api/conversations", json={}, headers=headers)
|
||||
convo_id = create.json()["id"]
|
||||
|
||||
updated = await client.put(
|
||||
f"/api/conversations/{convo_id}",
|
||||
json={"title": "Renamed"},
|
||||
headers=headers,
|
||||
)
|
||||
assert updated.status_code == 200
|
||||
assert updated.json()["title"] == "Renamed"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@respx.mock
|
||||
async def test_delete_conversation(client, seeded_user):
|
||||
stub_auth_validate(respx.mock, seeded_user)
|
||||
headers = {"Authorization": "Bearer valid-token"}
|
||||
|
||||
create = await client.post("/api/conversations", json={}, headers=headers)
|
||||
convo_id = create.json()["id"]
|
||||
|
||||
deleted = await client.delete(f"/api/conversations/{convo_id}", headers=headers)
|
||||
assert deleted.status_code == 204
|
||||
|
||||
missing = await client.get(f"/api/conversations/{convo_id}", headers=headers)
|
||||
assert missing.status_code == 404
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@respx.mock
|
||||
async def test_user_scoping_prevents_cross_user_access(
|
||||
client, seeded_user, other_user
|
||||
):
|
||||
stub_auth_validate(respx.mock, seeded_user, token="alice-token")
|
||||
stub_auth_validate(respx.mock, other_user, token="bob-token")
|
||||
|
||||
alice_headers = {"Authorization": "Bearer alice-token"}
|
||||
bob_headers = {"Authorization": "Bearer bob-token"}
|
||||
|
||||
alice_convo = await client.post(
|
||||
"/api/conversations",
|
||||
json={"title": "alice only"},
|
||||
headers=alice_headers,
|
||||
)
|
||||
convo_id = alice_convo.json()["id"]
|
||||
|
||||
bob_view = await client.get(
|
||||
f"/api/conversations/{convo_id}", headers=bob_headers
|
||||
)
|
||||
assert bob_view.status_code == 404
|
||||
|
||||
bob_delete = await client.delete(
|
||||
f"/api/conversations/{convo_id}", headers=bob_headers
|
||||
)
|
||||
assert bob_delete.status_code == 404
|
||||
|
||||
bob_list = await client.get("/api/conversations", headers=bob_headers)
|
||||
assert bob_list.json()["items"] == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@respx.mock
|
||||
async def test_invalid_token_is_rejected(client, seeded_user):
|
||||
stub_auth_validate(respx.mock, seeded_user, token="valid-token")
|
||||
response = await client.get(
|
||||
"/api/conversations",
|
||||
headers={"Authorization": "Bearer wrong-token"},
|
||||
)
|
||||
assert response.status_code == 401
|
||||
11
services/chat-api/src/tests/test_health.py
Normal file
11
services/chat-api/src/tests/test_health.py
Normal file
|
|
@ -0,0 +1,11 @@
|
|||
import pytest
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_health_is_unauthenticated(client):
|
||||
response = await client.get("/api/health")
|
||||
assert response.status_code == 200
|
||||
body = response.json()
|
||||
assert body["status"] == "ok"
|
||||
assert body["ready"] is True
|
||||
assert body["service"] == "chat-api"
|
||||
166
services/chat-api/src/tests/test_messages.py
Normal file
166
services/chat-api/src/tests/test_messages.py
Normal file
|
|
@ -0,0 +1,166 @@
|
|||
"""Tests for the streaming message send + regenerate flows."""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import uuid
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
import respx
|
||||
|
||||
from .conftest import stub_auth_validate
|
||||
|
||||
|
||||
def _build_inference_mock(tokens: list[str]) -> httpx.MockTransport:
|
||||
"""Build an httpx mock transport that streams an SSE response."""
|
||||
sse_lines: list[bytes] = []
|
||||
for token in tokens:
|
||||
sse_lines.append(
|
||||
f"data: {json.dumps({'token': token, 'gpu': 0})}\n\n".encode("utf-8")
|
||||
)
|
||||
sse_lines.append(f"data: {json.dumps({'done': True})}\n\n".encode("utf-8"))
|
||||
|
||||
async def handler(request: httpx.Request) -> httpx.Response:
|
||||
if request.url.path != "/generate":
|
||||
return httpx.Response(404)
|
||||
|
||||
async def body():
|
||||
for chunk in sse_lines:
|
||||
yield chunk
|
||||
|
||||
return httpx.Response(
|
||||
200,
|
||||
headers={"content-type": "text/event-stream"},
|
||||
content=body(),
|
||||
)
|
||||
|
||||
return httpx.MockTransport(handler)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@respx.mock
|
||||
async def test_send_message_streams_and_persists(app, client, seeded_user):
|
||||
stub_auth_validate(respx.mock, seeded_user)
|
||||
headers = {"Authorization": "Bearer valid-token"}
|
||||
|
||||
create = await client.post("/api/conversations", json={}, headers=headers)
|
||||
convo_id = create.json()["id"]
|
||||
|
||||
app.state.inference_http_client = httpx.AsyncClient(
|
||||
transport=_build_inference_mock(["hel", "lo", " world"])
|
||||
)
|
||||
try:
|
||||
resp = await client.post(
|
||||
f"/api/conversations/{convo_id}/messages",
|
||||
json={"content": "hi there"},
|
||||
headers=headers,
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
body = resp.text
|
||||
assert '"token": "hel"' in body or '"token":"hel"' in body
|
||||
assert '"done": true' in body or '"done":true' in body
|
||||
finally:
|
||||
await app.state.inference_http_client.aclose()
|
||||
|
||||
fetched = await client.get(
|
||||
f"/api/conversations/{convo_id}", headers=headers
|
||||
)
|
||||
assert fetched.status_code == 200
|
||||
payload = fetched.json()
|
||||
messages = payload["messages"]
|
||||
|
||||
roles = [m["role"] for m in messages]
|
||||
assert roles == ["user", "assistant"]
|
||||
assert messages[0]["content"] == "hi there"
|
||||
assert messages[1]["content"] == "hello world"
|
||||
assert messages[1]["token_count"] == 3
|
||||
assert messages[1]["inference_time_ms"] >= 0
|
||||
|
||||
# First message should have auto-populated the title
|
||||
assert payload["title"] == "hi there"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@respx.mock
|
||||
async def test_send_message_rejected_on_foreign_conversation(
|
||||
app, client, seeded_user, other_user
|
||||
):
|
||||
stub_auth_validate(respx.mock, seeded_user, token="alice-token")
|
||||
stub_auth_validate(respx.mock, other_user, token="bob-token")
|
||||
|
||||
alice_headers = {"Authorization": "Bearer alice-token"}
|
||||
bob_headers = {"Authorization": "Bearer bob-token"}
|
||||
|
||||
create = await client.post("/api/conversations", json={}, headers=alice_headers)
|
||||
convo_id = create.json()["id"]
|
||||
|
||||
app.state.inference_http_client = httpx.AsyncClient(
|
||||
transport=_build_inference_mock(["x"])
|
||||
)
|
||||
try:
|
||||
resp = await client.post(
|
||||
f"/api/conversations/{convo_id}/messages",
|
||||
json={"content": "steal me"},
|
||||
headers=bob_headers,
|
||||
)
|
||||
finally:
|
||||
await app.state.inference_http_client.aclose()
|
||||
assert resp.status_code == 404
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@respx.mock
|
||||
async def test_send_message_returns_404_for_missing_conversation(
|
||||
app, client, seeded_user
|
||||
):
|
||||
stub_auth_validate(respx.mock, seeded_user)
|
||||
headers = {"Authorization": "Bearer valid-token"}
|
||||
|
||||
missing_id = str(uuid.uuid4())
|
||||
resp = await client.post(
|
||||
f"/api/conversations/{missing_id}/messages",
|
||||
json={"content": "hello"},
|
||||
headers=headers,
|
||||
)
|
||||
assert resp.status_code == 404
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@respx.mock
|
||||
async def test_regenerate_drops_last_assistant_message(app, client, seeded_user):
|
||||
stub_auth_validate(respx.mock, seeded_user)
|
||||
headers = {"Authorization": "Bearer valid-token"}
|
||||
|
||||
create = await client.post("/api/conversations", json={}, headers=headers)
|
||||
convo_id = create.json()["id"]
|
||||
|
||||
app.state.inference_http_client = httpx.AsyncClient(
|
||||
transport=_build_inference_mock(["first"])
|
||||
)
|
||||
try:
|
||||
first = await client.post(
|
||||
f"/api/conversations/{convo_id}/messages",
|
||||
json={"content": "hi"},
|
||||
headers=headers,
|
||||
)
|
||||
assert first.status_code == 200
|
||||
|
||||
app.state.inference_http_client = httpx.AsyncClient(
|
||||
transport=_build_inference_mock(["second", " reply"])
|
||||
)
|
||||
regen = await client.post(
|
||||
f"/api/conversations/{convo_id}/regenerate",
|
||||
json={},
|
||||
headers=headers,
|
||||
)
|
||||
assert regen.status_code == 200
|
||||
finally:
|
||||
await app.state.inference_http_client.aclose()
|
||||
|
||||
fetched = await client.get(
|
||||
f"/api/conversations/{convo_id}", headers=headers
|
||||
)
|
||||
messages = fetched.json()["messages"]
|
||||
assistant_messages = [m for m in messages if m["role"] == "assistant"]
|
||||
assert len(assistant_messages) == 1
|
||||
assert assistant_messages[0]["content"] == "second reply"
|
||||
95
services/chat-api/src/tests/test_models_proxy.py
Normal file
95
services/chat-api/src/tests/test_models_proxy.py
Normal file
|
|
@ -0,0 +1,95 @@
|
|||
"""Tests for /api/models proxy routes."""
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
import respx
|
||||
import sqlalchemy as sa
|
||||
|
||||
from .conftest import stub_auth_validate
|
||||
|
||||
|
||||
def _inference_mock(models_response: dict) -> httpx.MockTransport:
|
||||
async def handler(request: httpx.Request) -> httpx.Response:
|
||||
if request.url.path == "/models":
|
||||
return httpx.Response(200, json=models_response)
|
||||
if request.url.path == "/models/swap":
|
||||
return httpx.Response(200, json={"status": "ok", "current_model": "new-model"})
|
||||
return httpx.Response(404)
|
||||
|
||||
return httpx.MockTransport(handler)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@respx.mock
|
||||
async def test_list_models_proxies_to_inference(app, client, seeded_user):
|
||||
stub_auth_validate(respx.mock, seeded_user)
|
||||
headers = {"Authorization": "Bearer valid-token"}
|
||||
|
||||
app.state.inference_http_client = httpx.AsyncClient(
|
||||
transport=_inference_mock({"current_model": "m1", "models": ["m1", "m2"]})
|
||||
)
|
||||
try:
|
||||
resp = await client.get("/api/models", headers=headers)
|
||||
assert resp.status_code == 200
|
||||
assert resp.json() == {"current_model": "m1", "models": ["m1", "m2"]}
|
||||
finally:
|
||||
await app.state.inference_http_client.aclose()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@respx.mock
|
||||
async def test_swap_model_requires_admin(app, client, seeded_user):
|
||||
stub_auth_validate(respx.mock, seeded_user)
|
||||
headers = {"Authorization": "Bearer valid-token"}
|
||||
|
||||
app.state.inference_http_client = httpx.AsyncClient(transport=_inference_mock({}))
|
||||
try:
|
||||
resp = await client.post(
|
||||
"/api/models/swap",
|
||||
json={"model_tag": "new-model"},
|
||||
headers=headers,
|
||||
)
|
||||
finally:
|
||||
await app.state.inference_http_client.aclose()
|
||||
assert resp.status_code == 403
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@respx.mock
|
||||
async def test_swap_model_succeeds_for_admin(app, client, session_factory):
|
||||
admin_id = str(uuid.uuid4())
|
||||
async with session_factory() as session:
|
||||
await session.execute(
|
||||
sa.text(
|
||||
"INSERT INTO users (id, email, name, is_admin) "
|
||||
"VALUES (:id, :email, :name, :is_admin)"
|
||||
),
|
||||
{"id": admin_id, "email": "root@example.com", "name": "Root", "is_admin": 1},
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
admin_user = {
|
||||
"id": admin_id,
|
||||
"email": "root@example.com",
|
||||
"name": "Root",
|
||||
"is_admin": True,
|
||||
}
|
||||
stub_auth_validate(respx.mock, admin_user)
|
||||
headers = {"Authorization": "Bearer valid-token"}
|
||||
|
||||
app.state.inference_http_client = httpx.AsyncClient(
|
||||
transport=_inference_mock({"current_model": "new-model", "models": ["new-model"]})
|
||||
)
|
||||
try:
|
||||
resp = await client.post(
|
||||
"/api/models/swap",
|
||||
json={"model_tag": "new-model"},
|
||||
headers=headers,
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["current_model"] == "new-model"
|
||||
finally:
|
||||
await app.state.inference_http_client.aclose()
|
||||
73
services/chat-api/src/tests/test_stream_proxy.py
Normal file
73
services/chat-api/src/tests/test_stream_proxy.py
Normal file
|
|
@ -0,0 +1,73 @@
|
|||
"""Unit tests for the inference SSE stream proxy."""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
|
||||
from src.services.stream_proxy import StreamResult, proxy_inference_stream
|
||||
|
||||
|
||||
def _make_response(lines: list[str]) -> httpx.Response:
|
||||
async def body():
|
||||
for line in lines:
|
||||
yield f"{line}\n".encode("utf-8")
|
||||
|
||||
return httpx.Response(
|
||||
200,
|
||||
headers={"content-type": "text/event-stream"},
|
||||
content=body(),
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stream_proxy_accumulates_tokens_and_signals_done():
|
||||
resp = _make_response(
|
||||
[
|
||||
"data: " + json.dumps({"token": "hel", "gpu": 0}),
|
||||
"",
|
||||
"data: " + json.dumps({"token": "lo", "gpu": 0}),
|
||||
"",
|
||||
"data: " + json.dumps({"done": True}),
|
||||
"",
|
||||
]
|
||||
)
|
||||
|
||||
captured: dict[str, StreamResult] = {}
|
||||
|
||||
def on_complete(result: StreamResult) -> None:
|
||||
captured["result"] = result
|
||||
|
||||
events = []
|
||||
async for event in proxy_inference_stream(resp, on_complete=on_complete):
|
||||
events.append(event)
|
||||
|
||||
assert [json.loads(e["data"]) for e in events] == [
|
||||
{"token": "hel", "gpu": 0},
|
||||
{"token": "lo", "gpu": 0},
|
||||
{"done": True},
|
||||
]
|
||||
|
||||
result = captured["result"]
|
||||
assert result.content == "hello"
|
||||
assert result.token_count == 2
|
||||
assert result.completed is True
|
||||
assert result.inference_time_ms >= 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stream_proxy_surfaces_error_status_codes():
|
||||
resp = httpx.Response(502, content=b"upstream down")
|
||||
captured: dict[str, StreamResult] = {}
|
||||
|
||||
def on_complete(result: StreamResult) -> None:
|
||||
captured["result"] = result
|
||||
|
||||
events = []
|
||||
async for event in proxy_inference_stream(resp, on_complete=on_complete):
|
||||
events.append(event)
|
||||
|
||||
assert any("error" in e["data"] for e in events)
|
||||
assert any('"done": true' in e["data"] or '"done":true' in e["data"] for e in events)
|
||||
assert captured["result"].completed is False
|
||||
Loading…
Reference in New Issue
Block a user