Merge pull request #17 from manmohan659/feat/chat-api-service

feat(chat-api): conversation orchestration + SSE streaming proxy (#6)
This commit is contained in:
Manmohan 2026-04-16 14:57:10 -04:00 committed by GitHub
commit 1e2fc09ca6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
25 changed files with 2005 additions and 6 deletions

View File

@ -1,9 +1,34 @@
FROM python:3.12-slim
ENV PYTHONDONTWRITEBYTECODE=1 \
PYTHONUNBUFFERED=1 \
UV_SYSTEM_PYTHON=1 \
UV_LINK_MODE=copy
RUN apt-get update \
&& apt-get install -y --no-install-recommends build-essential curl \
&& rm -rf /var/lib/apt/lists/*
RUN pip install --no-cache-dir uv==0.4.30
WORKDIR /app
COPY README.md /app/README.md
COPY pyproject.toml README.md /app/
RUN uv pip install --system --no-cache \
"fastapi>=0.117.1" \
"uvicorn[standard]>=0.36.0" \
"pydantic>=2.8.0" \
"pydantic-settings>=2.4.0" \
"sqlalchemy[asyncio]>=2.0.36" \
"asyncpg>=0.29.0" \
"httpx>=0.27.0" \
"sse-starlette>=2.1.3" \
"structlog>=24.4.0" \
"cachetools>=5.5.0"
COPY src /app/src
EXPOSE 8002
CMD ["python", "-m", "http.server", "8002", "--bind", "0.0.0.0"]
CMD ["uvicorn", "src.main:app", "--host", "0.0.0.0", "--port", "8002"]

View File

@ -1,7 +1,54 @@
# Chat API Service
Scaffold placeholder for Issue #6.
Orchestration layer for samosaChaat conversations. Manages conversation state
in PostgreSQL, authenticates every request via the auth service, and proxies
streaming inference requests via Server-Sent Events.
This container is intentionally minimal on the monorepo scaffold branch. It
keeps the compose topology stable while the dedicated chat API implementation
is developed.
## Endpoints
| Method | Path | Description |
| --- | --- | --- |
| GET | `/api/health` | Liveness probe (unauthenticated) |
| GET | `/api/conversations` | List the authenticated user's conversations, grouped by date |
| POST | `/api/conversations` | Create a new conversation |
| GET | `/api/conversations/{id}` | Fetch a conversation + full message history |
| PUT | `/api/conversations/{id}` | Update the conversation title |
| DELETE | `/api/conversations/{id}` | Delete a conversation (cascade deletes messages) |
| POST | `/api/conversations/{id}/messages` | Append a user message and stream the assistant response |
| POST | `/api/conversations/{id}/regenerate` | Delete the last assistant message and regenerate it |
| GET | `/api/models` | Proxy to inference `GET /models` |
| POST | `/api/models/swap` | Proxy to inference `POST /models/swap` (admin only) |
All authenticated endpoints expect `Authorization: Bearer <jwt>`. The chat API
validates the token by calling the auth service `POST /auth/validate` with the
shared `X-Internal-API-Key` header and caches the result for 5 minutes.
## Environment
| Variable | Default | Purpose |
| --- | --- | --- |
| `DATABASE_URL` | `postgresql+asyncpg://localhost/samosachaat` | PostgreSQL connection string |
| `AUTH_SERVICE_URL` | `http://auth:8001` | Base URL of the auth service |
| `INFERENCE_SERVICE_URL` | `http://inference:8000` | Base URL of the inference service |
| `INTERNAL_API_KEY` | — | Shared key for internal service auth |
| `MAX_CONVERSATION_HISTORY` | `50` | Max messages included in each inference call |
| `MAX_TOKEN_BUDGET` | `6000` | Character budget proxy for the above |
| `FRONTEND_URL` | `http://localhost:3000` | Origin allowed by CORS |
| `LOG_LEVEL` | `INFO` | Python log level |
## Running locally
```
uv pip install -e ".[dev]"
uvicorn src.main:app --reload --port 8002
```
## Running tests
```
cd services/chat-api
pytest
```
Tests use SQLite + aiosqlite for a throwaway database, respx to mock the auth
service, and hand-crafted httpx mocks for the inference SSE stream.

View File

@ -0,0 +1,38 @@
[project]
name = "samosachaat-chat-api"
version = "0.1.0"
description = "samosaChaat chat API orchestration service"
readme = "README.md"
requires-python = ">=3.12"
dependencies = [
"fastapi>=0.117.1",
"uvicorn[standard]>=0.36.0",
"pydantic>=2.8.0",
"pydantic-settings>=2.4.0",
"sqlalchemy[asyncio]>=2.0.36",
"asyncpg>=0.29.0",
"httpx>=0.27.0",
"sse-starlette>=2.1.3",
"structlog>=24.4.0",
"cachetools>=5.5.0",
]
[dependency-groups]
dev = [
"pytest>=8.0.0",
"pytest-asyncio>=0.24.0",
"aiosqlite>=0.20.0",
"respx>=0.21.1",
]
[tool.pytest.ini_options]
asyncio_mode = "auto"
testpaths = ["src/tests"]
python_files = ["test_*.py"]
[build-system]
requires = ["hatchling"]
build-backend = "hatchling.build"
[tool.hatch.build.targets.wheel]
packages = ["src"]

View File

View File

@ -0,0 +1,36 @@
"""Runtime configuration for the chat API service."""
from __future__ import annotations
from functools import lru_cache
from pydantic import Field
from pydantic_settings import BaseSettings, SettingsConfigDict
class Settings(BaseSettings):
model_config = SettingsConfigDict(env_file=".env", extra="ignore")
database_url: str = Field(default="postgresql+asyncpg://localhost/samosachaat")
auth_service_url: str = Field(default="http://auth:8001")
inference_service_url: str = Field(default="http://inference:8000")
internal_api_key: str = Field(default="")
max_conversation_history: int = Field(default=50)
max_token_budget: int = Field(default=6000)
auth_cache_ttl_seconds: int = Field(default=300)
auth_cache_max_size: int = Field(default=1024)
inference_default_temperature: float = Field(default=0.8)
inference_default_max_tokens: int = Field(default=512)
inference_default_top_k: int = Field(default=50)
frontend_url: str = Field(default="http://localhost:3000")
log_level: str = Field(default="INFO")
@lru_cache(maxsize=1)
def get_settings() -> Settings:
return Settings()

View File

@ -0,0 +1,49 @@
"""Async SQLAlchemy engine and session factory."""
from __future__ import annotations
from collections.abc import AsyncIterator
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from sqlalchemy.orm import DeclarativeBase
from .config import get_settings
class Base(DeclarativeBase):
"""Shared declarative base for all chat-api ORM models."""
_engine = None
_session_factory: async_sessionmaker[AsyncSession] | None = None
def _build_engine() -> None:
global _engine, _session_factory
settings = get_settings()
_engine = create_async_engine(settings.database_url, pool_pre_ping=True)
_session_factory = async_sessionmaker(_engine, expire_on_commit=False)
def get_engine():
if _engine is None:
_build_engine()
return _engine
def get_session_factory() -> async_sessionmaker[AsyncSession]:
if _session_factory is None:
_build_engine()
assert _session_factory is not None
return _session_factory
async def get_session() -> AsyncIterator[AsyncSession]:
factory = get_session_factory()
async with factory() as session:
yield session
def override_session_factory(factory: async_sessionmaker[AsyncSession]) -> None:
"""Testing hook: swap the session factory for an in-memory engine."""
global _session_factory
_session_factory = factory

View File

@ -0,0 +1,74 @@
"""Structured JSON logging for the chat API service."""
from __future__ import annotations
import logging
import sys
import uuid
from contextvars import ContextVar
import structlog
from .config import get_settings
_trace_id_ctx: ContextVar[str | None] = ContextVar("trace_id", default=None)
_user_id_ctx: ContextVar[str | None] = ContextVar("user_id", default=None)
def set_trace_id(trace_id: str | None) -> None:
_trace_id_ctx.set(trace_id)
def set_user_id(user_id: str | None) -> None:
_user_id_ctx.set(user_id)
def get_trace_id() -> str | None:
return _trace_id_ctx.get()
def get_user_id() -> str | None:
return _user_id_ctx.get()
def new_trace_id() -> str:
return uuid.uuid4().hex
def _inject_context(_logger, _method, event_dict):
event_dict.setdefault("service", "chat-api")
trace_id = _trace_id_ctx.get()
if trace_id is not None:
event_dict.setdefault("trace_id", trace_id)
user_id = _user_id_ctx.get()
if user_id is not None:
event_dict.setdefault("user_id", user_id)
return event_dict
def configure_logging() -> None:
settings = get_settings()
level = getattr(logging, settings.log_level.upper(), logging.INFO)
logging.basicConfig(
format="%(message)s",
stream=sys.stdout,
level=level,
force=True,
)
structlog.configure(
processors=[
structlog.contextvars.merge_contextvars,
structlog.processors.add_log_level,
structlog.processors.TimeStamper(fmt="iso", utc=True),
_inject_context,
structlog.processors.JSONRenderer(),
],
wrapper_class=structlog.make_filtering_bound_logger(level),
logger_factory=structlog.stdlib.LoggerFactory(),
cache_logger_on_first_use=True,
)
def get_logger(name: str | None = None):
return structlog.get_logger(name)

View File

@ -0,0 +1,84 @@
"""FastAPI entrypoint for the samosaChaat chat API service."""
from __future__ import annotations
from contextlib import asynccontextmanager
import httpx
from fastapi import FastAPI, Request, Response
from fastapi.middleware.cors import CORSMiddleware
from .config import get_settings
from .logging_setup import (
configure_logging,
get_logger,
new_trace_id,
set_trace_id,
set_user_id,
)
from .routes import conversations, messages, models
@asynccontextmanager
async def lifespan(app: FastAPI):
app.state.auth_http_client = httpx.AsyncClient(
timeout=httpx.Timeout(5.0, connect=2.0)
)
app.state.inference_http_client = httpx.AsyncClient(
timeout=httpx.Timeout(60.0, connect=5.0)
)
try:
yield
finally:
await app.state.auth_http_client.aclose()
await app.state.inference_http_client.aclose()
def create_app() -> FastAPI:
configure_logging()
settings = get_settings()
logger = get_logger(__name__)
app = FastAPI(title="samosaChaat Chat API", version="0.1.0", lifespan=lifespan)
app.add_middleware(
CORSMiddleware,
allow_origins=[settings.frontend_url],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
@app.middleware("http")
async def request_context(request: Request, call_next) -> Response:
incoming = request.headers.get("x-trace-id") or request.headers.get("x-request-id")
trace_id = incoming or new_trace_id()
set_trace_id(trace_id)
set_user_id(None)
logger.info(
"request_start",
method=request.method,
path=request.url.path,
)
response = await call_next(request)
response.headers["x-trace-id"] = trace_id
logger.info(
"request_end",
method=request.method,
path=request.url.path,
status_code=response.status_code,
)
return response
app.include_router(conversations.router)
app.include_router(messages.router)
app.include_router(models.router)
@app.get("/api/health")
async def health():
return {"status": "ok", "ready": True, "service": "chat-api"}
return app
app = create_app()

View File

@ -0,0 +1,154 @@
"""Authentication guard that validates JWTs via the auth service.
Successful validations are cached in an in-memory TTL cache keyed by the raw
JWT string so that a burst of requests from the same user does not fan out to
the auth service on every call.
"""
from __future__ import annotations
import asyncio
from dataclasses import dataclass
from typing import Annotated, Any
import httpx
from cachetools import TTLCache
from fastapi import Depends, Header, HTTPException, Request, status
from ..config import Settings, get_settings
from ..logging_setup import get_logger, set_user_id
logger = get_logger(__name__)
@dataclass
class AuthenticatedUser:
id: str
email: str
name: str | None
raw: dict[str, Any]
@classmethod
def from_validate_response(cls, payload: dict[str, Any]) -> "AuthenticatedUser":
user = payload.get("user") or {}
return cls(
id=str(user["id"]),
email=user.get("email", ""),
name=user.get("name"),
raw=user,
)
class AuthCache:
"""Thread-safe TTL cache for validated JWTs."""
def __init__(self, ttl_seconds: int, max_size: int) -> None:
self._cache: TTLCache[str, AuthenticatedUser] = TTLCache(
maxsize=max_size, ttl=ttl_seconds
)
self._lock = asyncio.Lock()
async def get(self, token: str) -> AuthenticatedUser | None:
async with self._lock:
return self._cache.get(token)
async def set(self, token: str, user: AuthenticatedUser) -> None:
async with self._lock:
self._cache[token] = user
async def clear(self) -> None:
async with self._lock:
self._cache.clear()
_auth_cache: AuthCache | None = None
def get_auth_cache() -> AuthCache:
global _auth_cache
if _auth_cache is None:
settings = get_settings()
_auth_cache = AuthCache(
ttl_seconds=settings.auth_cache_ttl_seconds,
max_size=settings.auth_cache_max_size,
)
return _auth_cache
def reset_auth_cache() -> None:
"""Testing hook: drop the cached singleton so a fresh one is built."""
global _auth_cache
_auth_cache = None
async def _validate_with_auth_service(
token: str, settings: Settings, http_client: httpx.AsyncClient | None = None
) -> AuthenticatedUser:
owns_client = http_client is None
client = http_client or httpx.AsyncClient(timeout=5.0)
try:
response = await client.post(
f"{settings.auth_service_url.rstrip('/')}/auth/validate",
headers={"X-Internal-API-Key": settings.internal_api_key},
json={"token": token},
)
except httpx.HTTPError as exc:
logger.error("auth_service_unreachable", error=str(exc))
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail="auth service unreachable",
) from exc
finally:
if owns_client:
await client.aclose()
if response.status_code == 401:
raise HTTPException(status_code=401, detail="invalid or expired token")
if response.status_code == 403:
raise HTTPException(status_code=500, detail="internal api key rejected by auth")
if response.status_code >= 400:
logger.error(
"auth_validate_failed",
status_code=response.status_code,
body=response.text[:200],
)
raise HTTPException(status_code=401, detail="token validation failed")
data = response.json()
if not data.get("valid"):
raise HTTPException(status_code=401, detail=data.get("reason", "invalid token"))
return AuthenticatedUser.from_validate_response(data)
def _extract_bearer(authorization: str | None) -> str:
if not authorization:
raise HTTPException(status_code=401, detail="missing authorization header")
parts = authorization.split(" ", 1)
if len(parts) != 2 or parts[0].lower() != "bearer" or not parts[1].strip():
raise HTTPException(status_code=401, detail="invalid authorization scheme")
return parts[1].strip()
async def require_user(
request: Request,
authorization: Annotated[str | None, Header()] = None,
settings: Annotated[Settings, Depends(get_settings)] = None, # type: ignore[assignment]
) -> AuthenticatedUser:
"""FastAPI dependency that yields the authenticated user for the request."""
token = _extract_bearer(authorization)
cache = get_auth_cache()
cached = await cache.get(token)
if cached is not None:
set_user_id(cached.id)
request.state.user = cached
return cached
http_client: httpx.AsyncClient | None = getattr(
request.app.state, "auth_http_client", None
)
user = await _validate_with_auth_service(token, settings, http_client=http_client)
await cache.set(token, user)
set_user_id(user.id)
request.state.user = user
return user

View File

View File

@ -0,0 +1,130 @@
"""CRUD routes for conversations, scoped to the authenticated user."""
from __future__ import annotations
import uuid
from typing import Annotated
from fastapi import APIRouter, Depends, HTTPException, Query, status
from pydantic import BaseModel, Field
from sqlalchemy.ext.asyncio import AsyncSession
from ..database import get_session
from ..logging_setup import get_logger
from ..middleware.auth_guard import AuthenticatedUser, require_user
from ..services import conversation_service
logger = get_logger(__name__)
router = APIRouter(prefix="/api/conversations", tags=["conversations"])
class CreateConversationRequest(BaseModel):
title: str | None = Field(default=None, max_length=500)
model_tag: str | None = Field(default=None, max_length=100)
class UpdateConversationRequest(BaseModel):
title: str = Field(..., min_length=1, max_length=500)
def _parse_uuid(raw: str) -> uuid.UUID:
try:
return uuid.UUID(raw)
except (ValueError, TypeError) as exc:
raise HTTPException(status_code=404, detail="conversation not found") from exc
@router.get("")
async def list_conversations(
user: Annotated[AuthenticatedUser, Depends(require_user)],
session: Annotated[AsyncSession, Depends(get_session)],
limit: int = Query(default=50, ge=1, le=200),
offset: int = Query(default=0, ge=0),
):
user_uuid = uuid.UUID(user.id)
conversations = await conversation_service.list_conversations(
session, user_id=user_uuid, limit=limit, offset=offset
)
grouped = conversation_service.group_by_date(conversations)
return {
"items": [c.to_dict() for c in conversations],
"grouped": grouped,
"limit": limit,
"offset": offset,
}
@router.post("", status_code=status.HTTP_201_CREATED)
async def create_conversation(
body: CreateConversationRequest,
user: Annotated[AuthenticatedUser, Depends(require_user)],
session: Annotated[AsyncSession, Depends(get_session)],
):
user_uuid = uuid.UUID(user.id)
convo = await conversation_service.create_conversation(
session,
user_id=user_uuid,
title=body.title,
model_tag=body.model_tag,
)
logger.info("conversation_created", conversation_id=str(convo.id))
return convo.to_dict()
@router.get("/{conversation_id}")
async def get_conversation(
conversation_id: str,
user: Annotated[AuthenticatedUser, Depends(require_user)],
session: Annotated[AsyncSession, Depends(get_session)],
):
conv_uuid = _parse_uuid(conversation_id)
user_uuid = uuid.UUID(user.id)
convo = await conversation_service.get_user_conversation(
session, user_id=user_uuid, conversation_id=conv_uuid
)
if convo is None:
raise HTTPException(status_code=404, detail="conversation not found")
messages = await conversation_service.get_conversation_messages(
session, conversation_id=conv_uuid
)
payload = convo.to_dict()
payload["messages"] = [m.to_dict() for m in messages]
return payload
@router.put("/{conversation_id}")
async def update_conversation(
conversation_id: str,
body: UpdateConversationRequest,
user: Annotated[AuthenticatedUser, Depends(require_user)],
session: Annotated[AsyncSession, Depends(get_session)],
):
conv_uuid = _parse_uuid(conversation_id)
user_uuid = uuid.UUID(user.id)
convo = await conversation_service.update_conversation_title(
session,
user_id=user_uuid,
conversation_id=conv_uuid,
title=body.title,
)
if convo is None:
raise HTTPException(status_code=404, detail="conversation not found")
return convo.to_dict()
@router.delete("/{conversation_id}", status_code=status.HTTP_204_NO_CONTENT)
async def delete_conversation(
conversation_id: str,
user: Annotated[AuthenticatedUser, Depends(require_user)],
session: Annotated[AsyncSession, Depends(get_session)],
):
conv_uuid = _parse_uuid(conversation_id)
user_uuid = uuid.UUID(user.id)
deleted = await conversation_service.delete_conversation(
session, user_id=user_uuid, conversation_id=conv_uuid
)
if not deleted:
raise HTTPException(status_code=404, detail="conversation not found")
logger.info("conversation_deleted", conversation_id=str(conv_uuid))
return None

View File

@ -0,0 +1,257 @@
"""The main chat route: send a message and stream an assistant response."""
from __future__ import annotations
import uuid
from typing import Annotated, AsyncIterator
from fastapi import APIRouter, Depends, HTTPException, Request, status
from pydantic import BaseModel, Field
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
from sse_starlette.sse import EventSourceResponse
from ..config import Settings, get_settings
from ..database import get_session_factory
from ..logging_setup import get_logger
from ..middleware.auth_guard import AuthenticatedUser, require_user
from ..services import conversation_service
from ..services.inference_client import InferenceClient
from ..services.stream_proxy import StreamResult, proxy_inference_stream
logger = get_logger(__name__)
router = APIRouter(prefix="/api/conversations", tags=["messages"])
class SendMessageRequest(BaseModel):
content: str = Field(..., min_length=1, max_length=8000)
temperature: float | None = Field(default=None, ge=0.0, le=2.0)
max_tokens: int | None = Field(default=None, ge=1, le=4096)
top_k: int | None = Field(default=None, ge=0, le=200)
class RegenerateRequest(BaseModel):
temperature: float | None = Field(default=None, ge=0.0, le=2.0)
max_tokens: int | None = Field(default=None, ge=1, le=4096)
top_k: int | None = Field(default=None, ge=0, le=200)
def _parse_uuid(raw: str) -> uuid.UUID:
try:
return uuid.UUID(raw)
except (ValueError, TypeError) as exc:
raise HTTPException(status_code=404, detail="conversation not found") from exc
async def _ensure_ownership(
session: AsyncSession,
*,
user_id: uuid.UUID,
conversation_id: uuid.UUID,
):
convo = await conversation_service.get_user_conversation(
session, user_id=user_id, conversation_id=conversation_id
)
if convo is None:
raise HTTPException(status_code=404, detail="conversation not found")
return convo
async def _stream_and_persist(
*,
request: Request,
user_id: uuid.UUID,
conversation_id: uuid.UUID,
history: list[dict[str, str]],
temperature: float | None,
max_tokens: int | None,
top_k: int | None,
model_tag: str,
first_message: bool,
first_message_preview: str | None,
settings: Settings,
) -> AsyncIterator[dict]:
"""Generator that streams inference SSE events to the client and, after the
stream closes, persists the full assistant message in a fresh DB session.
"""
http_client = getattr(request.app.state, "inference_http_client", None)
inference = InferenceClient(settings=settings, http_client=http_client)
accumulated: dict[str, StreamResult] = {}
def _capture(result: StreamResult) -> None:
accumulated["result"] = result
try:
async with inference.stream_generate(
messages=history,
temperature=temperature,
max_tokens=max_tokens,
top_k=top_k,
) as response:
async for event in proxy_inference_stream(response, on_complete=_capture):
yield event
except Exception as exc: # pragma: no cover - defensive path
logger.error(
"inference_stream_failed",
conversation_id=str(conversation_id),
error=str(exc),
)
yield {"data": '{"error":"inference_stream_failed"}'}
yield {"data": '{"done":true}'}
return
result = accumulated.get("result")
if result is None or not result.completed or not result.content:
logger.warning(
"assistant_message_not_persisted",
conversation_id=str(conversation_id),
completed=bool(result and result.completed),
content_len=len(result.content) if result else 0,
)
return
factory: async_sessionmaker[AsyncSession] = get_session_factory()
try:
async with factory() as persist_session:
await conversation_service.append_message(
persist_session,
conversation_id=conversation_id,
role="assistant",
content=result.content,
token_count=result.token_count,
model_tag=model_tag,
inference_time_ms=result.inference_time_ms,
)
if first_message and first_message_preview is not None:
await conversation_service.update_conversation_title(
persist_session,
user_id=user_id,
conversation_id=conversation_id,
title=first_message_preview,
)
logger.info(
"assistant_message_persisted",
conversation_id=str(conversation_id),
token_count=result.token_count,
inference_time_ms=result.inference_time_ms,
)
except Exception as exc: # pragma: no cover - defensive path
logger.error(
"assistant_message_persist_failed",
conversation_id=str(conversation_id),
error=str(exc),
)
@router.post("/{conversation_id}/messages")
async def send_message(
conversation_id: str,
body: SendMessageRequest,
request: Request,
user: Annotated[AuthenticatedUser, Depends(require_user)],
settings: Annotated[Settings, Depends(get_settings)] = None, # type: ignore[assignment]
):
# We own our DB sessions explicitly here so the background task that
# persists the streamed assistant response can open its own session once
# the request scope has already closed.
conv_uuid = _parse_uuid(conversation_id)
user_uuid = uuid.UUID(user.id)
session_factory = get_session_factory()
async with session_factory() as db_session:
convo = await _ensure_ownership(
db_session, user_id=user_uuid, conversation_id=conv_uuid
)
model_tag = convo.model_tag or "default"
existing = await conversation_service.get_conversation_messages(
db_session, conversation_id=conv_uuid, limit=1
)
first_message = len(existing) == 0
first_preview = body.content[:80] if first_message else None
await conversation_service.append_message(
db_session,
conversation_id=conv_uuid,
role="user",
content=body.content,
token_count=None,
model_tag=model_tag,
)
history = await conversation_service.build_history_for_inference(
db_session, conversation_id=conv_uuid
)
logger.info(
"send_message",
conversation_id=str(conv_uuid),
history_len=len(history),
model_tag=model_tag,
)
generator = _stream_and_persist(
request=request,
user_id=user_uuid,
conversation_id=conv_uuid,
history=history,
temperature=body.temperature,
max_tokens=body.max_tokens,
top_k=body.top_k,
model_tag=model_tag,
first_message=first_message,
first_message_preview=first_preview,
settings=settings,
)
return EventSourceResponse(generator, media_type="text/event-stream")
@router.post("/{conversation_id}/regenerate")
async def regenerate(
conversation_id: str,
body: RegenerateRequest,
request: Request,
user: Annotated[AuthenticatedUser, Depends(require_user)],
settings: Annotated[Settings, Depends(get_settings)] = None, # type: ignore[assignment]
):
conv_uuid = _parse_uuid(conversation_id)
user_uuid = uuid.UUID(user.id)
session_factory = get_session_factory()
async with session_factory() as db_session:
convo = await _ensure_ownership(
db_session, user_id=user_uuid, conversation_id=conv_uuid
)
model_tag = convo.model_tag or "default"
await conversation_service.delete_last_assistant_message(
db_session, conversation_id=conv_uuid
)
history = await conversation_service.build_history_for_inference(
db_session, conversation_id=conv_uuid
)
if not history:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="conversation has no user messages to regenerate from",
)
logger.info(
"regenerate_message",
conversation_id=str(conv_uuid),
history_len=len(history),
)
generator = _stream_and_persist(
request=request,
user_id=user_uuid,
conversation_id=conv_uuid,
history=history,
temperature=body.temperature,
max_tokens=body.max_tokens,
top_k=body.top_k,
model_tag=model_tag,
first_message=False,
first_message_preview=None,
settings=settings,
)
return EventSourceResponse(generator, media_type="text/event-stream")

View File

@ -0,0 +1,74 @@
"""Proxy routes that forward model management calls to the inference service."""
from __future__ import annotations
from typing import Annotated
import httpx
from fastapi import APIRouter, Depends, HTTPException, Request, status
from pydantic import BaseModel, Field
from ..config import Settings, get_settings
from ..logging_setup import get_logger
from ..middleware.auth_guard import AuthenticatedUser, require_user
from ..services.inference_client import InferenceClient
logger = get_logger(__name__)
router = APIRouter(prefix="/api/models", tags=["models"])
class SwapModelRequest(BaseModel):
model_tag: str = Field(..., min_length=1, max_length=100)
def _client_for(request: Request, settings: Settings) -> InferenceClient:
http_client = getattr(request.app.state, "inference_http_client", None)
return InferenceClient(settings=settings, http_client=http_client)
@router.get("")
async def list_models(
request: Request,
user: Annotated[AuthenticatedUser, Depends(require_user)],
settings: Annotated[Settings, Depends(get_settings)] = None, # type: ignore[assignment]
):
client = _client_for(request, settings)
try:
return await client.list_models()
except httpx.HTTPStatusError as exc:
logger.error("list_models_proxy_failed", status_code=exc.response.status_code)
raise HTTPException(
status_code=status.HTTP_502_BAD_GATEWAY,
detail="inference service error",
) from exc
except httpx.HTTPError as exc:
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail="inference service unreachable",
) from exc
@router.post("/swap")
async def swap_model(
body: SwapModelRequest,
request: Request,
user: Annotated[AuthenticatedUser, Depends(require_user)],
settings: Annotated[Settings, Depends(get_settings)] = None, # type: ignore[assignment]
):
if not user.raw.get("is_admin"):
raise HTTPException(status_code=403, detail="admin privilege required")
client = _client_for(request, settings)
try:
return await client.swap_model(body.model_tag)
except httpx.HTTPStatusError as exc:
logger.error("swap_model_proxy_failed", status_code=exc.response.status_code)
raise HTTPException(
status_code=status.HTTP_502_BAD_GATEWAY,
detail="inference service rejected swap",
) from exc
except httpx.HTTPError as exc:
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail="inference service unreachable",
) from exc

View File

@ -0,0 +1,209 @@
"""Business logic for conversations and messages, scoped to a single user.
Every query in this module filters by `user_id` that scoping is the only
thing preventing one user from reading or mutating another user's data.
"""
from __future__ import annotations
import uuid
from collections import defaultdict
from datetime import date
from typing import Iterable
import sqlalchemy as sa
from sqlalchemy.ext.asyncio import AsyncSession
from ..config import get_settings
from ..models import Conversation, Message
async def create_conversation(
session: AsyncSession,
*,
user_id: uuid.UUID,
title: str | None = None,
model_tag: str | None = None,
) -> Conversation:
conversation = Conversation(
user_id=user_id,
title=title or "New conversation",
model_tag=model_tag or "default",
)
session.add(conversation)
await session.commit()
await session.refresh(conversation)
return conversation
async def list_conversations(
session: AsyncSession,
*,
user_id: uuid.UUID,
limit: int = 50,
offset: int = 0,
) -> list[Conversation]:
stmt = (
sa.select(Conversation)
.where(Conversation.user_id == user_id)
.order_by(Conversation.updated_at.desc())
.limit(limit)
.offset(offset)
)
result = await session.execute(stmt)
return list(result.scalars().all())
def group_by_date(conversations: Iterable[Conversation]) -> dict[str, list[dict]]:
"""Group conversations by YYYY-MM-DD (UTC) of `updated_at`."""
buckets: dict[str, list[dict]] = defaultdict(list)
for convo in conversations:
bucket_key: str
if convo.updated_at is None:
bucket_key = date.today().isoformat()
else:
bucket_key = convo.updated_at.date().isoformat()
buckets[bucket_key].append(convo.to_dict())
return dict(buckets)
async def get_user_conversation(
session: AsyncSession,
*,
user_id: uuid.UUID,
conversation_id: uuid.UUID,
) -> Conversation | None:
stmt = sa.select(Conversation).where(
Conversation.id == conversation_id,
Conversation.user_id == user_id,
)
result = await session.execute(stmt)
return result.scalar_one_or_none()
async def get_conversation_messages(
session: AsyncSession,
*,
conversation_id: uuid.UUID,
limit: int | None = None,
) -> list[Message]:
stmt = (
sa.select(Message)
.where(Message.conversation_id == conversation_id)
.order_by(Message.created_at.asc(), Message.id.asc())
)
if limit is not None:
stmt = stmt.limit(limit)
result = await session.execute(stmt)
return list(result.scalars().all())
async def update_conversation_title(
session: AsyncSession,
*,
user_id: uuid.UUID,
conversation_id: uuid.UUID,
title: str,
) -> Conversation | None:
convo = await get_user_conversation(
session, user_id=user_id, conversation_id=conversation_id
)
if convo is None:
return None
convo.title = title
await session.commit()
await session.refresh(convo)
return convo
async def delete_conversation(
session: AsyncSession,
*,
user_id: uuid.UUID,
conversation_id: uuid.UUID,
) -> bool:
convo = await get_user_conversation(
session, user_id=user_id, conversation_id=conversation_id
)
if convo is None:
return False
await session.delete(convo)
await session.commit()
return True
async def append_message(
session: AsyncSession,
*,
conversation_id: uuid.UUID,
role: str,
content: str,
token_count: int | None = None,
model_tag: str | None = None,
inference_time_ms: int | None = None,
) -> Message:
message = Message(
conversation_id=conversation_id,
role=role,
content=content,
token_count=token_count,
model_tag=model_tag,
inference_time_ms=inference_time_ms,
)
session.add(message)
await session.commit()
await session.refresh(message)
return message
async def build_history_for_inference(
session: AsyncSession,
*,
conversation_id: uuid.UUID,
) -> list[dict[str, str]]:
"""Return the trailing slice of history that fits the configured budgets."""
settings = get_settings()
max_history = settings.max_conversation_history
max_budget = settings.max_token_budget
stmt = (
sa.select(Message)
.where(Message.conversation_id == conversation_id)
.order_by(Message.created_at.desc(), Message.id.desc())
.limit(max_history)
)
result = await session.execute(stmt)
newest_first = list(result.scalars().all())
selected: list[Message] = []
budget = 0
for message in newest_first:
budget += len(message.content or "")
if budget > max_budget and selected:
break
selected.append(message)
selected.reverse()
return [{"role": m.role, "content": m.content} for m in selected]
async def delete_last_assistant_message(
session: AsyncSession,
*,
conversation_id: uuid.UUID,
) -> bool:
stmt = (
sa.select(Message)
.where(
Message.conversation_id == conversation_id,
Message.role == "assistant",
)
.order_by(Message.created_at.desc(), Message.id.desc())
.limit(1)
)
result = await session.execute(stmt)
last = result.scalar_one_or_none()
if last is None:
return False
await session.delete(last)
await session.commit()
return True

View File

@ -0,0 +1,93 @@
"""HTTP client wrapper for the inference service."""
from __future__ import annotations
from contextlib import asynccontextmanager
from typing import AsyncIterator
import httpx
from ..config import Settings, get_settings
class InferenceClient:
"""Thin async wrapper around the inference service HTTP contract."""
def __init__(
self,
settings: Settings | None = None,
http_client: httpx.AsyncClient | None = None,
) -> None:
self._settings = settings or get_settings()
self._client = http_client
self._owns_client = http_client is None
@property
def base_url(self) -> str:
return self._settings.inference_service_url.rstrip("/")
@property
def headers(self) -> dict[str, str]:
return {"X-Internal-API-Key": self._settings.internal_api_key}
def _get_client(self) -> httpx.AsyncClient:
if self._client is None:
self._client = httpx.AsyncClient(timeout=httpx.Timeout(60.0, connect=5.0))
return self._client
async def aclose(self) -> None:
if self._client is not None and self._owns_client:
await self._client.aclose()
self._client = None
async def list_models(self) -> dict:
client = self._get_client()
resp = await client.get(f"{self.base_url}/models", headers=self.headers)
resp.raise_for_status()
return resp.json()
async def swap_model(self, model_tag: str) -> dict:
client = self._get_client()
resp = await client.post(
f"{self.base_url}/models/swap",
headers=self.headers,
json={"model_tag": model_tag},
)
resp.raise_for_status()
return resp.json()
@asynccontextmanager
async def stream_generate(
self,
*,
messages: list[dict[str, str]],
temperature: float | None = None,
max_tokens: int | None = None,
top_k: int | None = None,
) -> AsyncIterator[httpx.Response]:
temperature = (
temperature
if temperature is not None
else self._settings.inference_default_temperature
)
max_tokens = (
max_tokens
if max_tokens is not None
else self._settings.inference_default_max_tokens
)
top_k = top_k if top_k is not None else self._settings.inference_default_top_k
payload = {
"messages": messages,
"temperature": temperature,
"max_tokens": max_tokens,
"top_k": top_k,
}
client = self._get_client()
async with client.stream(
"POST",
f"{self.base_url}/generate",
headers=self.headers,
json=payload,
) as response:
yield response

View File

@ -0,0 +1,114 @@
"""Proxy the inference SSE stream to the client while accumulating tokens.
The inference service emits lines like `data: {"token": "...", "gpu": 0}` and
terminates with `data: {"done": true}`. We forward each event unchanged to the
client, collect assistant tokens into a buffer, and report the buffer plus
timing info once the stream closes.
"""
from __future__ import annotations
import json
import time
from dataclasses import dataclass
from typing import AsyncIterator, Callable
import httpx
from ..logging_setup import get_logger
logger = get_logger(__name__)
@dataclass
class StreamResult:
content: str
token_count: int
inference_time_ms: int
completed: bool
def _parse_sse_data(raw_line: str) -> dict | None:
line = raw_line.strip()
if not line or not line.startswith("data:"):
return None
body = line[len("data:"):].strip()
if not body:
return None
try:
return json.loads(body)
except json.JSONDecodeError:
logger.warning("stream_proxy_bad_sse_payload", payload=body[:200])
return None
async def proxy_inference_stream(
response: httpx.Response,
*,
on_complete: Callable[[StreamResult], None] | None = None,
) -> AsyncIterator[dict]:
"""Async generator that yields SSE events as dicts with ``data`` keys.
Each yielded event is shaped as ``{"data": "<json-string>"}`` so callers can
pass it straight into sse-starlette's ``EventSourceResponse``.
"""
started = time.perf_counter()
buffer: list[str] = []
token_count = 0
completed = False
if response.status_code >= 400:
body = await response.aread()
logger.error(
"inference_error_response",
status_code=response.status_code,
body=body.decode("utf-8", errors="replace")[:200],
)
error_payload = json.dumps(
{
"error": "inference service returned an error",
"status_code": response.status_code,
}
)
yield {"data": error_payload}
yield {"data": json.dumps({"done": True})}
if on_complete is not None:
on_complete(
StreamResult(
content="",
token_count=0,
inference_time_ms=int((time.perf_counter() - started) * 1000),
completed=False,
)
)
return
try:
async for raw_line in response.aiter_lines():
if not raw_line:
continue
parsed = _parse_sse_data(raw_line)
if parsed is None:
continue
if parsed.get("done"):
completed = True
yield {"data": json.dumps(parsed)}
break
token = parsed.get("token")
if isinstance(token, str) and token:
buffer.append(token)
token_count += 1
yield {"data": json.dumps(parsed)}
finally:
elapsed_ms = int((time.perf_counter() - started) * 1000)
if on_complete is not None:
on_complete(
StreamResult(
content="".join(buffer),
token_count=token_count,
inference_time_ms=elapsed_ms,
completed=completed,
)
)

View File

View File

@ -0,0 +1,160 @@
"""Shared pytest fixtures for the chat API test suite.
Tests run against a throwaway SQLite database (via aiosqlite) with a stub
`users` table so that the foreign key from `conversations.user_id` validates.
The auth service is mocked with respx; the inference service SSE responses
are mocked with hand-rolled httpx MockTransports.
"""
from __future__ import annotations
import uuid
from typing import AsyncIterator
import pytest
import pytest_asyncio
import sqlalchemy as sa
from httpx import ASGITransport, AsyncClient
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from src import database as db_module
from src.database import Base
from src.middleware import auth_guard
from src.models import Conversation, Message # noqa: F401 (registers tables)
# Attach a stub `users` table to the shared metadata so the FK on
# `conversations.user_id` can resolve during SQLite-based test setup.
if "users" not in Base.metadata.tables:
sa.Table(
"users",
Base.metadata,
sa.Column("id", sa.Uuid(as_uuid=True), primary_key=True),
sa.Column("email", sa.String(255)),
sa.Column("name", sa.String(255)),
sa.Column("is_admin", sa.Integer(), default=0),
)
@pytest.fixture(autouse=True)
def _reset_caches(monkeypatch):
# Settings and the auth cache are module-level singletons; wipe them
# between tests so environment overrides take effect cleanly.
from src.config import get_settings
get_settings.cache_clear()
auth_guard.reset_auth_cache()
_TOKEN_REGISTRY.clear()
yield
get_settings.cache_clear()
auth_guard.reset_auth_cache()
_TOKEN_REGISTRY.clear()
@pytest.fixture(autouse=True)
def _env(monkeypatch):
monkeypatch.setenv("DATABASE_URL", "sqlite+aiosqlite:///:memory:")
monkeypatch.setenv("AUTH_SERVICE_URL", "http://auth.test")
monkeypatch.setenv("INFERENCE_SERVICE_URL", "http://inference.test")
monkeypatch.setenv("INTERNAL_API_KEY", "test-internal-key")
monkeypatch.setenv("LOG_LEVEL", "WARNING")
from src.config import get_settings
get_settings.cache_clear()
@pytest_asyncio.fixture
async def engine(_env) -> AsyncIterator:
engine = create_async_engine("sqlite+aiosqlite:///:memory:")
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
yield engine
await engine.dispose()
@pytest_asyncio.fixture
async def session_factory(engine) -> async_sessionmaker[AsyncSession]:
factory = async_sessionmaker(engine, expire_on_commit=False)
db_module.override_session_factory(factory)
return factory
@pytest_asyncio.fixture
async def seeded_user(session_factory) -> dict:
user_id = str(uuid.uuid4())
async with session_factory() as session:
await session.execute(
sa.text(
"INSERT INTO users (id, email, name, is_admin) "
"VALUES (:id, :email, :name, :is_admin)"
),
{
"id": user_id,
"email": "alice@example.com",
"name": "Alice",
"is_admin": 0,
},
)
await session.commit()
return {"id": user_id, "email": "alice@example.com", "name": "Alice"}
@pytest_asyncio.fixture
async def other_user(session_factory) -> dict:
user_id = str(uuid.uuid4())
async with session_factory() as session:
await session.execute(
sa.text(
"INSERT INTO users (id, email, name, is_admin) "
"VALUES (:id, :email, :name, :is_admin)"
),
{"id": user_id, "email": "bob@example.com", "name": "Bob", "is_admin": 0},
)
await session.commit()
return {"id": user_id, "email": "bob@example.com", "name": "Bob"}
@pytest.fixture
def app():
from src.main import create_app
return create_app()
@pytest_asyncio.fixture
async def client(app, session_factory) -> AsyncIterator[AsyncClient]:
transport = ASGITransport(app=app)
async with AsyncClient(transport=transport, base_url="http://test") as ac:
yield ac
_TOKEN_REGISTRY: dict[str, dict] = {}
def _dispatch_auth_validate(request):
import json
import httpx
payload = json.loads(request.read())
token = payload.get("token")
if request.headers.get("X-Internal-API-Key") != "test-internal-key":
return httpx.Response(403, json={"detail": "invalid internal api key"})
user = _TOKEN_REGISTRY.get(token)
if user is None:
return httpx.Response(401, json={"valid": False, "reason": "invalid token"})
return httpx.Response(
200,
json={"valid": True, "user": user, "claims": {"sub": user["id"]}},
)
def stub_auth_validate(respx_mock, user: dict, token: str = "valid-token"):
"""Register a respx mock that returns the given user for the given token.
Multiple calls within a single test accumulate tokenuser mappings so
several users can authenticate in the same scenario.
"""
_TOKEN_REGISTRY[token] = user
respx_mock.post("http://auth.test/auth/validate").mock(
side_effect=_dispatch_auth_validate
)

View File

@ -0,0 +1,110 @@
import pytest
import respx
from .conftest import stub_auth_validate
@pytest.mark.asyncio
@respx.mock
async def test_requires_authorization_header(client):
response = await client.get("/api/conversations")
assert response.status_code == 401
@pytest.mark.asyncio
@respx.mock
async def test_create_and_list_conversation(client, seeded_user):
stub_auth_validate(respx.mock, seeded_user)
headers = {"Authorization": "Bearer valid-token"}
create = await client.post(
"/api/conversations", json={"title": "my first chat"}, headers=headers
)
assert create.status_code == 201
convo = create.json()
assert convo["title"] == "my first chat"
assert convo["user_id"] == seeded_user["id"]
listed = await client.get("/api/conversations", headers=headers)
assert listed.status_code == 200
payload = listed.json()
assert len(payload["items"]) == 1
assert payload["items"][0]["id"] == convo["id"]
assert payload["grouped"] # at least one date bucket
@pytest.mark.asyncio
@respx.mock
async def test_update_conversation_title(client, seeded_user):
stub_auth_validate(respx.mock, seeded_user)
headers = {"Authorization": "Bearer valid-token"}
create = await client.post("/api/conversations", json={}, headers=headers)
convo_id = create.json()["id"]
updated = await client.put(
f"/api/conversations/{convo_id}",
json={"title": "Renamed"},
headers=headers,
)
assert updated.status_code == 200
assert updated.json()["title"] == "Renamed"
@pytest.mark.asyncio
@respx.mock
async def test_delete_conversation(client, seeded_user):
stub_auth_validate(respx.mock, seeded_user)
headers = {"Authorization": "Bearer valid-token"}
create = await client.post("/api/conversations", json={}, headers=headers)
convo_id = create.json()["id"]
deleted = await client.delete(f"/api/conversations/{convo_id}", headers=headers)
assert deleted.status_code == 204
missing = await client.get(f"/api/conversations/{convo_id}", headers=headers)
assert missing.status_code == 404
@pytest.mark.asyncio
@respx.mock
async def test_user_scoping_prevents_cross_user_access(
client, seeded_user, other_user
):
stub_auth_validate(respx.mock, seeded_user, token="alice-token")
stub_auth_validate(respx.mock, other_user, token="bob-token")
alice_headers = {"Authorization": "Bearer alice-token"}
bob_headers = {"Authorization": "Bearer bob-token"}
alice_convo = await client.post(
"/api/conversations",
json={"title": "alice only"},
headers=alice_headers,
)
convo_id = alice_convo.json()["id"]
bob_view = await client.get(
f"/api/conversations/{convo_id}", headers=bob_headers
)
assert bob_view.status_code == 404
bob_delete = await client.delete(
f"/api/conversations/{convo_id}", headers=bob_headers
)
assert bob_delete.status_code == 404
bob_list = await client.get("/api/conversations", headers=bob_headers)
assert bob_list.json()["items"] == []
@pytest.mark.asyncio
@respx.mock
async def test_invalid_token_is_rejected(client, seeded_user):
stub_auth_validate(respx.mock, seeded_user, token="valid-token")
response = await client.get(
"/api/conversations",
headers={"Authorization": "Bearer wrong-token"},
)
assert response.status_code == 401

View File

@ -0,0 +1,11 @@
import pytest
@pytest.mark.asyncio
async def test_health_is_unauthenticated(client):
response = await client.get("/api/health")
assert response.status_code == 200
body = response.json()
assert body["status"] == "ok"
assert body["ready"] is True
assert body["service"] == "chat-api"

View File

@ -0,0 +1,166 @@
"""Tests for the streaming message send + regenerate flows."""
from __future__ import annotations
import json
import uuid
import httpx
import pytest
import respx
from .conftest import stub_auth_validate
def _build_inference_mock(tokens: list[str]) -> httpx.MockTransport:
"""Build an httpx mock transport that streams an SSE response."""
sse_lines: list[bytes] = []
for token in tokens:
sse_lines.append(
f"data: {json.dumps({'token': token, 'gpu': 0})}\n\n".encode("utf-8")
)
sse_lines.append(f"data: {json.dumps({'done': True})}\n\n".encode("utf-8"))
async def handler(request: httpx.Request) -> httpx.Response:
if request.url.path != "/generate":
return httpx.Response(404)
async def body():
for chunk in sse_lines:
yield chunk
return httpx.Response(
200,
headers={"content-type": "text/event-stream"},
content=body(),
)
return httpx.MockTransport(handler)
@pytest.mark.asyncio
@respx.mock
async def test_send_message_streams_and_persists(app, client, seeded_user):
stub_auth_validate(respx.mock, seeded_user)
headers = {"Authorization": "Bearer valid-token"}
create = await client.post("/api/conversations", json={}, headers=headers)
convo_id = create.json()["id"]
app.state.inference_http_client = httpx.AsyncClient(
transport=_build_inference_mock(["hel", "lo", " world"])
)
try:
resp = await client.post(
f"/api/conversations/{convo_id}/messages",
json={"content": "hi there"},
headers=headers,
)
assert resp.status_code == 200
body = resp.text
assert '"token": "hel"' in body or '"token":"hel"' in body
assert '"done": true' in body or '"done":true' in body
finally:
await app.state.inference_http_client.aclose()
fetched = await client.get(
f"/api/conversations/{convo_id}", headers=headers
)
assert fetched.status_code == 200
payload = fetched.json()
messages = payload["messages"]
roles = [m["role"] for m in messages]
assert roles == ["user", "assistant"]
assert messages[0]["content"] == "hi there"
assert messages[1]["content"] == "hello world"
assert messages[1]["token_count"] == 3
assert messages[1]["inference_time_ms"] >= 0
# First message should have auto-populated the title
assert payload["title"] == "hi there"
@pytest.mark.asyncio
@respx.mock
async def test_send_message_rejected_on_foreign_conversation(
app, client, seeded_user, other_user
):
stub_auth_validate(respx.mock, seeded_user, token="alice-token")
stub_auth_validate(respx.mock, other_user, token="bob-token")
alice_headers = {"Authorization": "Bearer alice-token"}
bob_headers = {"Authorization": "Bearer bob-token"}
create = await client.post("/api/conversations", json={}, headers=alice_headers)
convo_id = create.json()["id"]
app.state.inference_http_client = httpx.AsyncClient(
transport=_build_inference_mock(["x"])
)
try:
resp = await client.post(
f"/api/conversations/{convo_id}/messages",
json={"content": "steal me"},
headers=bob_headers,
)
finally:
await app.state.inference_http_client.aclose()
assert resp.status_code == 404
@pytest.mark.asyncio
@respx.mock
async def test_send_message_returns_404_for_missing_conversation(
app, client, seeded_user
):
stub_auth_validate(respx.mock, seeded_user)
headers = {"Authorization": "Bearer valid-token"}
missing_id = str(uuid.uuid4())
resp = await client.post(
f"/api/conversations/{missing_id}/messages",
json={"content": "hello"},
headers=headers,
)
assert resp.status_code == 404
@pytest.mark.asyncio
@respx.mock
async def test_regenerate_drops_last_assistant_message(app, client, seeded_user):
stub_auth_validate(respx.mock, seeded_user)
headers = {"Authorization": "Bearer valid-token"}
create = await client.post("/api/conversations", json={}, headers=headers)
convo_id = create.json()["id"]
app.state.inference_http_client = httpx.AsyncClient(
transport=_build_inference_mock(["first"])
)
try:
first = await client.post(
f"/api/conversations/{convo_id}/messages",
json={"content": "hi"},
headers=headers,
)
assert first.status_code == 200
app.state.inference_http_client = httpx.AsyncClient(
transport=_build_inference_mock(["second", " reply"])
)
regen = await client.post(
f"/api/conversations/{convo_id}/regenerate",
json={},
headers=headers,
)
assert regen.status_code == 200
finally:
await app.state.inference_http_client.aclose()
fetched = await client.get(
f"/api/conversations/{convo_id}", headers=headers
)
messages = fetched.json()["messages"]
assistant_messages = [m for m in messages if m["role"] == "assistant"]
assert len(assistant_messages) == 1
assert assistant_messages[0]["content"] == "second reply"

View File

@ -0,0 +1,95 @@
"""Tests for /api/models proxy routes."""
from __future__ import annotations
import uuid
import httpx
import pytest
import respx
import sqlalchemy as sa
from .conftest import stub_auth_validate
def _inference_mock(models_response: dict) -> httpx.MockTransport:
async def handler(request: httpx.Request) -> httpx.Response:
if request.url.path == "/models":
return httpx.Response(200, json=models_response)
if request.url.path == "/models/swap":
return httpx.Response(200, json={"status": "ok", "current_model": "new-model"})
return httpx.Response(404)
return httpx.MockTransport(handler)
@pytest.mark.asyncio
@respx.mock
async def test_list_models_proxies_to_inference(app, client, seeded_user):
stub_auth_validate(respx.mock, seeded_user)
headers = {"Authorization": "Bearer valid-token"}
app.state.inference_http_client = httpx.AsyncClient(
transport=_inference_mock({"current_model": "m1", "models": ["m1", "m2"]})
)
try:
resp = await client.get("/api/models", headers=headers)
assert resp.status_code == 200
assert resp.json() == {"current_model": "m1", "models": ["m1", "m2"]}
finally:
await app.state.inference_http_client.aclose()
@pytest.mark.asyncio
@respx.mock
async def test_swap_model_requires_admin(app, client, seeded_user):
stub_auth_validate(respx.mock, seeded_user)
headers = {"Authorization": "Bearer valid-token"}
app.state.inference_http_client = httpx.AsyncClient(transport=_inference_mock({}))
try:
resp = await client.post(
"/api/models/swap",
json={"model_tag": "new-model"},
headers=headers,
)
finally:
await app.state.inference_http_client.aclose()
assert resp.status_code == 403
@pytest.mark.asyncio
@respx.mock
async def test_swap_model_succeeds_for_admin(app, client, session_factory):
admin_id = str(uuid.uuid4())
async with session_factory() as session:
await session.execute(
sa.text(
"INSERT INTO users (id, email, name, is_admin) "
"VALUES (:id, :email, :name, :is_admin)"
),
{"id": admin_id, "email": "root@example.com", "name": "Root", "is_admin": 1},
)
await session.commit()
admin_user = {
"id": admin_id,
"email": "root@example.com",
"name": "Root",
"is_admin": True,
}
stub_auth_validate(respx.mock, admin_user)
headers = {"Authorization": "Bearer valid-token"}
app.state.inference_http_client = httpx.AsyncClient(
transport=_inference_mock({"current_model": "new-model", "models": ["new-model"]})
)
try:
resp = await client.post(
"/api/models/swap",
json={"model_tag": "new-model"},
headers=headers,
)
assert resp.status_code == 200
assert resp.json()["current_model"] == "new-model"
finally:
await app.state.inference_http_client.aclose()

View File

@ -0,0 +1,73 @@
"""Unit tests for the inference SSE stream proxy."""
from __future__ import annotations
import json
import httpx
import pytest
from src.services.stream_proxy import StreamResult, proxy_inference_stream
def _make_response(lines: list[str]) -> httpx.Response:
async def body():
for line in lines:
yield f"{line}\n".encode("utf-8")
return httpx.Response(
200,
headers={"content-type": "text/event-stream"},
content=body(),
)
@pytest.mark.asyncio
async def test_stream_proxy_accumulates_tokens_and_signals_done():
resp = _make_response(
[
"data: " + json.dumps({"token": "hel", "gpu": 0}),
"",
"data: " + json.dumps({"token": "lo", "gpu": 0}),
"",
"data: " + json.dumps({"done": True}),
"",
]
)
captured: dict[str, StreamResult] = {}
def on_complete(result: StreamResult) -> None:
captured["result"] = result
events = []
async for event in proxy_inference_stream(resp, on_complete=on_complete):
events.append(event)
assert [json.loads(e["data"]) for e in events] == [
{"token": "hel", "gpu": 0},
{"token": "lo", "gpu": 0},
{"done": True},
]
result = captured["result"]
assert result.content == "hello"
assert result.token_count == 2
assert result.completed is True
assert result.inference_time_ms >= 0
@pytest.mark.asyncio
async def test_stream_proxy_surfaces_error_status_codes():
resp = httpx.Response(502, content=b"upstream down")
captured: dict[str, StreamResult] = {}
def on_complete(result: StreamResult) -> None:
captured["result"] = result
events = []
async for event in proxy_inference_stream(resp, on_complete=on_complete):
events.append(event)
assert any("error" in e["data"] for e in events)
assert any('"done": true' in e["data"] or '"done":true' in e["data"] for e in events)
assert captured["result"].completed is False