diff --git a/services/chat-api/Dockerfile b/services/chat-api/Dockerfile index 1cbef5b5..ed11ce41 100644 --- a/services/chat-api/Dockerfile +++ b/services/chat-api/Dockerfile @@ -1,9 +1,34 @@ FROM python:3.12-slim +ENV PYTHONDONTWRITEBYTECODE=1 \ + PYTHONUNBUFFERED=1 \ + UV_SYSTEM_PYTHON=1 \ + UV_LINK_MODE=copy + +RUN apt-get update \ + && apt-get install -y --no-install-recommends build-essential curl \ + && rm -rf /var/lib/apt/lists/* + +RUN pip install --no-cache-dir uv==0.4.30 + WORKDIR /app -COPY README.md /app/README.md +COPY pyproject.toml README.md /app/ + +RUN uv pip install --system --no-cache \ + "fastapi>=0.117.1" \ + "uvicorn[standard]>=0.36.0" \ + "pydantic>=2.8.0" \ + "pydantic-settings>=2.4.0" \ + "sqlalchemy[asyncio]>=2.0.36" \ + "asyncpg>=0.29.0" \ + "httpx>=0.27.0" \ + "sse-starlette>=2.1.3" \ + "structlog>=24.4.0" \ + "cachetools>=5.5.0" + +COPY src /app/src EXPOSE 8002 -CMD ["python", "-m", "http.server", "8002", "--bind", "0.0.0.0"] +CMD ["uvicorn", "src.main:app", "--host", "0.0.0.0", "--port", "8002"] diff --git a/services/chat-api/README.md b/services/chat-api/README.md index a671d66b..d71be29f 100644 --- a/services/chat-api/README.md +++ b/services/chat-api/README.md @@ -1,7 +1,54 @@ # Chat API Service -Scaffold placeholder for Issue #6. +Orchestration layer for samosaChaat conversations. Manages conversation state +in PostgreSQL, authenticates every request via the auth service, and proxies +streaming inference requests via Server-Sent Events. -This container is intentionally minimal on the monorepo scaffold branch. It -keeps the compose topology stable while the dedicated chat API implementation -is developed. +## Endpoints + +| Method | Path | Description | +| --- | --- | --- | +| GET | `/api/health` | Liveness probe (unauthenticated) | +| GET | `/api/conversations` | List the authenticated user's conversations, grouped by date | +| POST | `/api/conversations` | Create a new conversation | +| GET | `/api/conversations/{id}` | Fetch a conversation + full message history | +| PUT | `/api/conversations/{id}` | Update the conversation title | +| DELETE | `/api/conversations/{id}` | Delete a conversation (cascade deletes messages) | +| POST | `/api/conversations/{id}/messages` | Append a user message and stream the assistant response | +| POST | `/api/conversations/{id}/regenerate` | Delete the last assistant message and regenerate it | +| GET | `/api/models` | Proxy to inference `GET /models` | +| POST | `/api/models/swap` | Proxy to inference `POST /models/swap` (admin only) | + +All authenticated endpoints expect `Authorization: Bearer `. The chat API +validates the token by calling the auth service `POST /auth/validate` with the +shared `X-Internal-API-Key` header and caches the result for 5 minutes. + +## Environment + +| Variable | Default | Purpose | +| --- | --- | --- | +| `DATABASE_URL` | `postgresql+asyncpg://localhost/samosachaat` | PostgreSQL connection string | +| `AUTH_SERVICE_URL` | `http://auth:8001` | Base URL of the auth service | +| `INFERENCE_SERVICE_URL` | `http://inference:8000` | Base URL of the inference service | +| `INTERNAL_API_KEY` | — | Shared key for internal service auth | +| `MAX_CONVERSATION_HISTORY` | `50` | Max messages included in each inference call | +| `MAX_TOKEN_BUDGET` | `6000` | Character budget proxy for the above | +| `FRONTEND_URL` | `http://localhost:3000` | Origin allowed by CORS | +| `LOG_LEVEL` | `INFO` | Python log level | + +## Running locally + +``` +uv pip install -e ".[dev]" +uvicorn src.main:app --reload --port 8002 +``` + +## Running tests + +``` +cd services/chat-api +pytest +``` + +Tests use SQLite + aiosqlite for a throwaway database, respx to mock the auth +service, and hand-crafted httpx mocks for the inference SSE stream. diff --git a/services/chat-api/pyproject.toml b/services/chat-api/pyproject.toml new file mode 100644 index 00000000..07df9eb6 --- /dev/null +++ b/services/chat-api/pyproject.toml @@ -0,0 +1,38 @@ +[project] +name = "samosachaat-chat-api" +version = "0.1.0" +description = "samosaChaat chat API orchestration service" +readme = "README.md" +requires-python = ">=3.12" +dependencies = [ + "fastapi>=0.117.1", + "uvicorn[standard]>=0.36.0", + "pydantic>=2.8.0", + "pydantic-settings>=2.4.0", + "sqlalchemy[asyncio]>=2.0.36", + "asyncpg>=0.29.0", + "httpx>=0.27.0", + "sse-starlette>=2.1.3", + "structlog>=24.4.0", + "cachetools>=5.5.0", +] + +[dependency-groups] +dev = [ + "pytest>=8.0.0", + "pytest-asyncio>=0.24.0", + "aiosqlite>=0.20.0", + "respx>=0.21.1", +] + +[tool.pytest.ini_options] +asyncio_mode = "auto" +testpaths = ["src/tests"] +python_files = ["test_*.py"] + +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[tool.hatch.build.targets.wheel] +packages = ["src"] diff --git a/services/chat-api/src/__init__.py b/services/chat-api/src/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/services/chat-api/src/config.py b/services/chat-api/src/config.py new file mode 100644 index 00000000..bc3ae51a --- /dev/null +++ b/services/chat-api/src/config.py @@ -0,0 +1,36 @@ +"""Runtime configuration for the chat API service.""" +from __future__ import annotations + +from functools import lru_cache + +from pydantic import Field +from pydantic_settings import BaseSettings, SettingsConfigDict + + +class Settings(BaseSettings): + model_config = SettingsConfigDict(env_file=".env", extra="ignore") + + database_url: str = Field(default="postgresql+asyncpg://localhost/samosachaat") + + auth_service_url: str = Field(default="http://auth:8001") + inference_service_url: str = Field(default="http://inference:8000") + internal_api_key: str = Field(default="") + + max_conversation_history: int = Field(default=50) + max_token_budget: int = Field(default=6000) + + auth_cache_ttl_seconds: int = Field(default=300) + auth_cache_max_size: int = Field(default=1024) + + inference_default_temperature: float = Field(default=0.8) + inference_default_max_tokens: int = Field(default=512) + inference_default_top_k: int = Field(default=50) + + frontend_url: str = Field(default="http://localhost:3000") + + log_level: str = Field(default="INFO") + + +@lru_cache(maxsize=1) +def get_settings() -> Settings: + return Settings() diff --git a/services/chat-api/src/database.py b/services/chat-api/src/database.py new file mode 100644 index 00000000..7079c352 --- /dev/null +++ b/services/chat-api/src/database.py @@ -0,0 +1,49 @@ +"""Async SQLAlchemy engine and session factory.""" +from __future__ import annotations + +from collections.abc import AsyncIterator + +from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine +from sqlalchemy.orm import DeclarativeBase + +from .config import get_settings + + +class Base(DeclarativeBase): + """Shared declarative base for all chat-api ORM models.""" + + +_engine = None +_session_factory: async_sessionmaker[AsyncSession] | None = None + + +def _build_engine() -> None: + global _engine, _session_factory + settings = get_settings() + _engine = create_async_engine(settings.database_url, pool_pre_ping=True) + _session_factory = async_sessionmaker(_engine, expire_on_commit=False) + + +def get_engine(): + if _engine is None: + _build_engine() + return _engine + + +def get_session_factory() -> async_sessionmaker[AsyncSession]: + if _session_factory is None: + _build_engine() + assert _session_factory is not None + return _session_factory + + +async def get_session() -> AsyncIterator[AsyncSession]: + factory = get_session_factory() + async with factory() as session: + yield session + + +def override_session_factory(factory: async_sessionmaker[AsyncSession]) -> None: + """Testing hook: swap the session factory for an in-memory engine.""" + global _session_factory + _session_factory = factory diff --git a/services/chat-api/src/logging_setup.py b/services/chat-api/src/logging_setup.py new file mode 100644 index 00000000..888ae5b8 --- /dev/null +++ b/services/chat-api/src/logging_setup.py @@ -0,0 +1,74 @@ +"""Structured JSON logging for the chat API service.""" +from __future__ import annotations + +import logging +import sys +import uuid +from contextvars import ContextVar + +import structlog + +from .config import get_settings + +_trace_id_ctx: ContextVar[str | None] = ContextVar("trace_id", default=None) +_user_id_ctx: ContextVar[str | None] = ContextVar("user_id", default=None) + + +def set_trace_id(trace_id: str | None) -> None: + _trace_id_ctx.set(trace_id) + + +def set_user_id(user_id: str | None) -> None: + _user_id_ctx.set(user_id) + + +def get_trace_id() -> str | None: + return _trace_id_ctx.get() + + +def get_user_id() -> str | None: + return _user_id_ctx.get() + + +def new_trace_id() -> str: + return uuid.uuid4().hex + + +def _inject_context(_logger, _method, event_dict): + event_dict.setdefault("service", "chat-api") + trace_id = _trace_id_ctx.get() + if trace_id is not None: + event_dict.setdefault("trace_id", trace_id) + user_id = _user_id_ctx.get() + if user_id is not None: + event_dict.setdefault("user_id", user_id) + return event_dict + + +def configure_logging() -> None: + settings = get_settings() + level = getattr(logging, settings.log_level.upper(), logging.INFO) + + logging.basicConfig( + format="%(message)s", + stream=sys.stdout, + level=level, + force=True, + ) + + structlog.configure( + processors=[ + structlog.contextvars.merge_contextvars, + structlog.processors.add_log_level, + structlog.processors.TimeStamper(fmt="iso", utc=True), + _inject_context, + structlog.processors.JSONRenderer(), + ], + wrapper_class=structlog.make_filtering_bound_logger(level), + logger_factory=structlog.stdlib.LoggerFactory(), + cache_logger_on_first_use=True, + ) + + +def get_logger(name: str | None = None): + return structlog.get_logger(name) diff --git a/services/chat-api/src/main.py b/services/chat-api/src/main.py new file mode 100644 index 00000000..bdb2608f --- /dev/null +++ b/services/chat-api/src/main.py @@ -0,0 +1,84 @@ +"""FastAPI entrypoint for the samosaChaat chat API service.""" +from __future__ import annotations + +from contextlib import asynccontextmanager + +import httpx +from fastapi import FastAPI, Request, Response +from fastapi.middleware.cors import CORSMiddleware + +from .config import get_settings +from .logging_setup import ( + configure_logging, + get_logger, + new_trace_id, + set_trace_id, + set_user_id, +) +from .routes import conversations, messages, models + + +@asynccontextmanager +async def lifespan(app: FastAPI): + app.state.auth_http_client = httpx.AsyncClient( + timeout=httpx.Timeout(5.0, connect=2.0) + ) + app.state.inference_http_client = httpx.AsyncClient( + timeout=httpx.Timeout(60.0, connect=5.0) + ) + try: + yield + finally: + await app.state.auth_http_client.aclose() + await app.state.inference_http_client.aclose() + + +def create_app() -> FastAPI: + configure_logging() + settings = get_settings() + logger = get_logger(__name__) + + app = FastAPI(title="samosaChaat Chat API", version="0.1.0", lifespan=lifespan) + + app.add_middleware( + CORSMiddleware, + allow_origins=[settings.frontend_url], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], + ) + + @app.middleware("http") + async def request_context(request: Request, call_next) -> Response: + incoming = request.headers.get("x-trace-id") or request.headers.get("x-request-id") + trace_id = incoming or new_trace_id() + set_trace_id(trace_id) + set_user_id(None) + + logger.info( + "request_start", + method=request.method, + path=request.url.path, + ) + response = await call_next(request) + response.headers["x-trace-id"] = trace_id + logger.info( + "request_end", + method=request.method, + path=request.url.path, + status_code=response.status_code, + ) + return response + + app.include_router(conversations.router) + app.include_router(messages.router) + app.include_router(models.router) + + @app.get("/api/health") + async def health(): + return {"status": "ok", "ready": True, "service": "chat-api"} + + return app + + +app = create_app() diff --git a/services/chat-api/src/middleware/__init__.py b/services/chat-api/src/middleware/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/services/chat-api/src/middleware/auth_guard.py b/services/chat-api/src/middleware/auth_guard.py new file mode 100644 index 00000000..849696cb --- /dev/null +++ b/services/chat-api/src/middleware/auth_guard.py @@ -0,0 +1,154 @@ +"""Authentication guard that validates JWTs via the auth service. + +Successful validations are cached in an in-memory TTL cache keyed by the raw +JWT string so that a burst of requests from the same user does not fan out to +the auth service on every call. +""" +from __future__ import annotations + +import asyncio +from dataclasses import dataclass +from typing import Annotated, Any + +import httpx +from cachetools import TTLCache +from fastapi import Depends, Header, HTTPException, Request, status + +from ..config import Settings, get_settings +from ..logging_setup import get_logger, set_user_id + +logger = get_logger(__name__) + + +@dataclass +class AuthenticatedUser: + id: str + email: str + name: str | None + raw: dict[str, Any] + + @classmethod + def from_validate_response(cls, payload: dict[str, Any]) -> "AuthenticatedUser": + user = payload.get("user") or {} + return cls( + id=str(user["id"]), + email=user.get("email", ""), + name=user.get("name"), + raw=user, + ) + + +class AuthCache: + """Thread-safe TTL cache for validated JWTs.""" + + def __init__(self, ttl_seconds: int, max_size: int) -> None: + self._cache: TTLCache[str, AuthenticatedUser] = TTLCache( + maxsize=max_size, ttl=ttl_seconds + ) + self._lock = asyncio.Lock() + + async def get(self, token: str) -> AuthenticatedUser | None: + async with self._lock: + return self._cache.get(token) + + async def set(self, token: str, user: AuthenticatedUser) -> None: + async with self._lock: + self._cache[token] = user + + async def clear(self) -> None: + async with self._lock: + self._cache.clear() + + +_auth_cache: AuthCache | None = None + + +def get_auth_cache() -> AuthCache: + global _auth_cache + if _auth_cache is None: + settings = get_settings() + _auth_cache = AuthCache( + ttl_seconds=settings.auth_cache_ttl_seconds, + max_size=settings.auth_cache_max_size, + ) + return _auth_cache + + +def reset_auth_cache() -> None: + """Testing hook: drop the cached singleton so a fresh one is built.""" + global _auth_cache + _auth_cache = None + + +async def _validate_with_auth_service( + token: str, settings: Settings, http_client: httpx.AsyncClient | None = None +) -> AuthenticatedUser: + owns_client = http_client is None + client = http_client or httpx.AsyncClient(timeout=5.0) + try: + response = await client.post( + f"{settings.auth_service_url.rstrip('/')}/auth/validate", + headers={"X-Internal-API-Key": settings.internal_api_key}, + json={"token": token}, + ) + except httpx.HTTPError as exc: + logger.error("auth_service_unreachable", error=str(exc)) + raise HTTPException( + status_code=status.HTTP_503_SERVICE_UNAVAILABLE, + detail="auth service unreachable", + ) from exc + finally: + if owns_client: + await client.aclose() + + if response.status_code == 401: + raise HTTPException(status_code=401, detail="invalid or expired token") + if response.status_code == 403: + raise HTTPException(status_code=500, detail="internal api key rejected by auth") + if response.status_code >= 400: + logger.error( + "auth_validate_failed", + status_code=response.status_code, + body=response.text[:200], + ) + raise HTTPException(status_code=401, detail="token validation failed") + + data = response.json() + if not data.get("valid"): + raise HTTPException(status_code=401, detail=data.get("reason", "invalid token")) + + return AuthenticatedUser.from_validate_response(data) + + +def _extract_bearer(authorization: str | None) -> str: + if not authorization: + raise HTTPException(status_code=401, detail="missing authorization header") + parts = authorization.split(" ", 1) + if len(parts) != 2 or parts[0].lower() != "bearer" or not parts[1].strip(): + raise HTTPException(status_code=401, detail="invalid authorization scheme") + return parts[1].strip() + + +async def require_user( + request: Request, + authorization: Annotated[str | None, Header()] = None, + settings: Annotated[Settings, Depends(get_settings)] = None, # type: ignore[assignment] +) -> AuthenticatedUser: + """FastAPI dependency that yields the authenticated user for the request.""" + token = _extract_bearer(authorization) + cache = get_auth_cache() + + cached = await cache.get(token) + if cached is not None: + set_user_id(cached.id) + request.state.user = cached + return cached + + http_client: httpx.AsyncClient | None = getattr( + request.app.state, "auth_http_client", None + ) + user = await _validate_with_auth_service(token, settings, http_client=http_client) + await cache.set(token, user) + set_user_id(user.id) + request.state.user = user + return user diff --git a/services/chat-api/src/routes/__init__.py b/services/chat-api/src/routes/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/services/chat-api/src/routes/conversations.py b/services/chat-api/src/routes/conversations.py new file mode 100644 index 00000000..b1f62a4e --- /dev/null +++ b/services/chat-api/src/routes/conversations.py @@ -0,0 +1,130 @@ +"""CRUD routes for conversations, scoped to the authenticated user.""" +from __future__ import annotations + +import uuid +from typing import Annotated + +from fastapi import APIRouter, Depends, HTTPException, Query, status +from pydantic import BaseModel, Field +from sqlalchemy.ext.asyncio import AsyncSession + +from ..database import get_session +from ..logging_setup import get_logger +from ..middleware.auth_guard import AuthenticatedUser, require_user +from ..services import conversation_service + +logger = get_logger(__name__) + +router = APIRouter(prefix="/api/conversations", tags=["conversations"]) + + +class CreateConversationRequest(BaseModel): + title: str | None = Field(default=None, max_length=500) + model_tag: str | None = Field(default=None, max_length=100) + + +class UpdateConversationRequest(BaseModel): + title: str = Field(..., min_length=1, max_length=500) + + +def _parse_uuid(raw: str) -> uuid.UUID: + try: + return uuid.UUID(raw) + except (ValueError, TypeError) as exc: + raise HTTPException(status_code=404, detail="conversation not found") from exc + + +@router.get("") +async def list_conversations( + user: Annotated[AuthenticatedUser, Depends(require_user)], + session: Annotated[AsyncSession, Depends(get_session)], + limit: int = Query(default=50, ge=1, le=200), + offset: int = Query(default=0, ge=0), +): + user_uuid = uuid.UUID(user.id) + conversations = await conversation_service.list_conversations( + session, user_id=user_uuid, limit=limit, offset=offset + ) + grouped = conversation_service.group_by_date(conversations) + return { + "items": [c.to_dict() for c in conversations], + "grouped": grouped, + "limit": limit, + "offset": offset, + } + + +@router.post("", status_code=status.HTTP_201_CREATED) +async def create_conversation( + body: CreateConversationRequest, + user: Annotated[AuthenticatedUser, Depends(require_user)], + session: Annotated[AsyncSession, Depends(get_session)], +): + user_uuid = uuid.UUID(user.id) + convo = await conversation_service.create_conversation( + session, + user_id=user_uuid, + title=body.title, + model_tag=body.model_tag, + ) + logger.info("conversation_created", conversation_id=str(convo.id)) + return convo.to_dict() + + +@router.get("/{conversation_id}") +async def get_conversation( + conversation_id: str, + user: Annotated[AuthenticatedUser, Depends(require_user)], + session: Annotated[AsyncSession, Depends(get_session)], +): + conv_uuid = _parse_uuid(conversation_id) + user_uuid = uuid.UUID(user.id) + convo = await conversation_service.get_user_conversation( + session, user_id=user_uuid, conversation_id=conv_uuid + ) + if convo is None: + raise HTTPException(status_code=404, detail="conversation not found") + + messages = await conversation_service.get_conversation_messages( + session, conversation_id=conv_uuid + ) + payload = convo.to_dict() + payload["messages"] = [m.to_dict() for m in messages] + return payload + + +@router.put("/{conversation_id}") +async def update_conversation( + conversation_id: str, + body: UpdateConversationRequest, + user: Annotated[AuthenticatedUser, Depends(require_user)], + session: Annotated[AsyncSession, Depends(get_session)], +): + conv_uuid = _parse_uuid(conversation_id) + user_uuid = uuid.UUID(user.id) + convo = await conversation_service.update_conversation_title( + session, + user_id=user_uuid, + conversation_id=conv_uuid, + title=body.title, + ) + if convo is None: + raise HTTPException(status_code=404, detail="conversation not found") + return convo.to_dict() + + +@router.delete("/{conversation_id}", status_code=status.HTTP_204_NO_CONTENT) +async def delete_conversation( + conversation_id: str, + user: Annotated[AuthenticatedUser, Depends(require_user)], + session: Annotated[AsyncSession, Depends(get_session)], +): + conv_uuid = _parse_uuid(conversation_id) + user_uuid = uuid.UUID(user.id) + deleted = await conversation_service.delete_conversation( + session, user_id=user_uuid, conversation_id=conv_uuid + ) + if not deleted: + raise HTTPException(status_code=404, detail="conversation not found") + logger.info("conversation_deleted", conversation_id=str(conv_uuid)) + return None diff --git a/services/chat-api/src/routes/messages.py b/services/chat-api/src/routes/messages.py new file mode 100644 index 00000000..088d9d68 --- /dev/null +++ b/services/chat-api/src/routes/messages.py @@ -0,0 +1,257 @@ +"""The main chat route: send a message and stream an assistant response.""" +from __future__ import annotations + +import uuid +from typing import Annotated, AsyncIterator + +from fastapi import APIRouter, Depends, HTTPException, Request, status +from pydantic import BaseModel, Field +from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker +from sse_starlette.sse import EventSourceResponse + +from ..config import Settings, get_settings +from ..database import get_session_factory +from ..logging_setup import get_logger +from ..middleware.auth_guard import AuthenticatedUser, require_user +from ..services import conversation_service +from ..services.inference_client import InferenceClient +from ..services.stream_proxy import StreamResult, proxy_inference_stream + +logger = get_logger(__name__) + +router = APIRouter(prefix="/api/conversations", tags=["messages"]) + + +class SendMessageRequest(BaseModel): + content: str = Field(..., min_length=1, max_length=8000) + temperature: float | None = Field(default=None, ge=0.0, le=2.0) + max_tokens: int | None = Field(default=None, ge=1, le=4096) + top_k: int | None = Field(default=None, ge=0, le=200) + + +class RegenerateRequest(BaseModel): + temperature: float | None = Field(default=None, ge=0.0, le=2.0) + max_tokens: int | None = Field(default=None, ge=1, le=4096) + top_k: int | None = Field(default=None, ge=0, le=200) + + +def _parse_uuid(raw: str) -> uuid.UUID: + try: + return uuid.UUID(raw) + except (ValueError, TypeError) as exc: + raise HTTPException(status_code=404, detail="conversation not found") from exc + + +async def _ensure_ownership( + session: AsyncSession, + *, + user_id: uuid.UUID, + conversation_id: uuid.UUID, +): + convo = await conversation_service.get_user_conversation( + session, user_id=user_id, conversation_id=conversation_id + ) + if convo is None: + raise HTTPException(status_code=404, detail="conversation not found") + return convo + + +async def _stream_and_persist( + *, + request: Request, + user_id: uuid.UUID, + conversation_id: uuid.UUID, + history: list[dict[str, str]], + temperature: float | None, + max_tokens: int | None, + top_k: int | None, + model_tag: str, + first_message: bool, + first_message_preview: str | None, + settings: Settings, +) -> AsyncIterator[dict]: + """Generator that streams inference SSE events to the client and, after the + stream closes, persists the full assistant message in a fresh DB session. + """ + http_client = getattr(request.app.state, "inference_http_client", None) + inference = InferenceClient(settings=settings, http_client=http_client) + + accumulated: dict[str, StreamResult] = {} + + def _capture(result: StreamResult) -> None: + accumulated["result"] = result + + try: + async with inference.stream_generate( + messages=history, + temperature=temperature, + max_tokens=max_tokens, + top_k=top_k, + ) as response: + async for event in proxy_inference_stream(response, on_complete=_capture): + yield event + except Exception as exc: # pragma: no cover - defensive path + logger.error( + "inference_stream_failed", + conversation_id=str(conversation_id), + error=str(exc), + ) + yield {"data": '{"error":"inference_stream_failed"}'} + yield {"data": '{"done":true}'} + return + + result = accumulated.get("result") + if result is None or not result.completed or not result.content: + logger.warning( + "assistant_message_not_persisted", + conversation_id=str(conversation_id), + completed=bool(result and result.completed), + content_len=len(result.content) if result else 0, + ) + return + + factory: async_sessionmaker[AsyncSession] = get_session_factory() + try: + async with factory() as persist_session: + await conversation_service.append_message( + persist_session, + conversation_id=conversation_id, + role="assistant", + content=result.content, + token_count=result.token_count, + model_tag=model_tag, + inference_time_ms=result.inference_time_ms, + ) + if first_message and first_message_preview is not None: + await conversation_service.update_conversation_title( + persist_session, + user_id=user_id, + conversation_id=conversation_id, + title=first_message_preview, + ) + logger.info( + "assistant_message_persisted", + conversation_id=str(conversation_id), + token_count=result.token_count, + inference_time_ms=result.inference_time_ms, + ) + except Exception as exc: # pragma: no cover - defensive path + logger.error( + "assistant_message_persist_failed", + conversation_id=str(conversation_id), + error=str(exc), + ) + + +@router.post("/{conversation_id}/messages") +async def send_message( + conversation_id: str, + body: SendMessageRequest, + request: Request, + user: Annotated[AuthenticatedUser, Depends(require_user)], + settings: Annotated[Settings, Depends(get_settings)] = None, # type: ignore[assignment] +): + # We own our DB sessions explicitly here so the background task that + # persists the streamed assistant response can open its own session once + # the request scope has already closed. + conv_uuid = _parse_uuid(conversation_id) + user_uuid = uuid.UUID(user.id) + session_factory = get_session_factory() + + async with session_factory() as db_session: + convo = await _ensure_ownership( + db_session, user_id=user_uuid, conversation_id=conv_uuid + ) + model_tag = convo.model_tag or "default" + + existing = await conversation_service.get_conversation_messages( + db_session, conversation_id=conv_uuid, limit=1 + ) + first_message = len(existing) == 0 + first_preview = body.content[:80] if first_message else None + + await conversation_service.append_message( + db_session, + conversation_id=conv_uuid, + role="user", + content=body.content, + token_count=None, + model_tag=model_tag, + ) + history = await conversation_service.build_history_for_inference( + db_session, conversation_id=conv_uuid + ) + + logger.info( + "send_message", + conversation_id=str(conv_uuid), + history_len=len(history), + model_tag=model_tag, + ) + + generator = _stream_and_persist( + request=request, + user_id=user_uuid, + conversation_id=conv_uuid, + history=history, + temperature=body.temperature, + max_tokens=body.max_tokens, + top_k=body.top_k, + model_tag=model_tag, + first_message=first_message, + first_message_preview=first_preview, + settings=settings, + ) + return EventSourceResponse(generator, media_type="text/event-stream") + + +@router.post("/{conversation_id}/regenerate") +async def regenerate( + conversation_id: str, + body: RegenerateRequest, + request: Request, + user: Annotated[AuthenticatedUser, Depends(require_user)], + settings: Annotated[Settings, Depends(get_settings)] = None, # type: ignore[assignment] +): + conv_uuid = _parse_uuid(conversation_id) + user_uuid = uuid.UUID(user.id) + session_factory = get_session_factory() + + async with session_factory() as db_session: + convo = await _ensure_ownership( + db_session, user_id=user_uuid, conversation_id=conv_uuid + ) + model_tag = convo.model_tag or "default" + await conversation_service.delete_last_assistant_message( + db_session, conversation_id=conv_uuid + ) + history = await conversation_service.build_history_for_inference( + db_session, conversation_id=conv_uuid + ) + + if not history: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="conversation has no user messages to regenerate from", + ) + + logger.info( + "regenerate_message", + conversation_id=str(conv_uuid), + history_len=len(history), + ) + + generator = _stream_and_persist( + request=request, + user_id=user_uuid, + conversation_id=conv_uuid, + history=history, + temperature=body.temperature, + max_tokens=body.max_tokens, + top_k=body.top_k, + model_tag=model_tag, + first_message=False, + first_message_preview=None, + settings=settings, + ) + return EventSourceResponse(generator, media_type="text/event-stream") diff --git a/services/chat-api/src/routes/models.py b/services/chat-api/src/routes/models.py new file mode 100644 index 00000000..9093476e --- /dev/null +++ b/services/chat-api/src/routes/models.py @@ -0,0 +1,74 @@ +"""Proxy routes that forward model management calls to the inference service.""" +from __future__ import annotations + +from typing import Annotated + +import httpx +from fastapi import APIRouter, Depends, HTTPException, Request, status +from pydantic import BaseModel, Field + +from ..config import Settings, get_settings +from ..logging_setup import get_logger +from ..middleware.auth_guard import AuthenticatedUser, require_user +from ..services.inference_client import InferenceClient + +logger = get_logger(__name__) + +router = APIRouter(prefix="/api/models", tags=["models"]) + + +class SwapModelRequest(BaseModel): + model_tag: str = Field(..., min_length=1, max_length=100) + + +def _client_for(request: Request, settings: Settings) -> InferenceClient: + http_client = getattr(request.app.state, "inference_http_client", None) + return InferenceClient(settings=settings, http_client=http_client) + + +@router.get("") +async def list_models( + request: Request, + user: Annotated[AuthenticatedUser, Depends(require_user)], + settings: Annotated[Settings, Depends(get_settings)] = None, # type: ignore[assignment] +): + client = _client_for(request, settings) + try: + return await client.list_models() + except httpx.HTTPStatusError as exc: + logger.error("list_models_proxy_failed", status_code=exc.response.status_code) + raise HTTPException( + status_code=status.HTTP_502_BAD_GATEWAY, + detail="inference service error", + ) from exc + except httpx.HTTPError as exc: + raise HTTPException( + status_code=status.HTTP_503_SERVICE_UNAVAILABLE, + detail="inference service unreachable", + ) from exc + + +@router.post("/swap") +async def swap_model( + body: SwapModelRequest, + request: Request, + user: Annotated[AuthenticatedUser, Depends(require_user)], + settings: Annotated[Settings, Depends(get_settings)] = None, # type: ignore[assignment] +): + if not user.raw.get("is_admin"): + raise HTTPException(status_code=403, detail="admin privilege required") + + client = _client_for(request, settings) + try: + return await client.swap_model(body.model_tag) + except httpx.HTTPStatusError as exc: + logger.error("swap_model_proxy_failed", status_code=exc.response.status_code) + raise HTTPException( + status_code=status.HTTP_502_BAD_GATEWAY, + detail="inference service rejected swap", + ) from exc + except httpx.HTTPError as exc: + raise HTTPException( + status_code=status.HTTP_503_SERVICE_UNAVAILABLE, + detail="inference service unreachable", + ) from exc diff --git a/services/chat-api/src/services/__init__.py b/services/chat-api/src/services/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/services/chat-api/src/services/conversation_service.py b/services/chat-api/src/services/conversation_service.py new file mode 100644 index 00000000..0a19b231 --- /dev/null +++ b/services/chat-api/src/services/conversation_service.py @@ -0,0 +1,209 @@ +"""Business logic for conversations and messages, scoped to a single user. + +Every query in this module filters by `user_id` — that scoping is the only +thing preventing one user from reading or mutating another user's data. +""" +from __future__ import annotations + +import uuid +from collections import defaultdict +from datetime import date +from typing import Iterable + +import sqlalchemy as sa +from sqlalchemy.ext.asyncio import AsyncSession + +from ..config import get_settings +from ..models import Conversation, Message + + +async def create_conversation( + session: AsyncSession, + *, + user_id: uuid.UUID, + title: str | None = None, + model_tag: str | None = None, +) -> Conversation: + conversation = Conversation( + user_id=user_id, + title=title or "New conversation", + model_tag=model_tag or "default", + ) + session.add(conversation) + await session.commit() + await session.refresh(conversation) + return conversation + + +async def list_conversations( + session: AsyncSession, + *, + user_id: uuid.UUID, + limit: int = 50, + offset: int = 0, +) -> list[Conversation]: + stmt = ( + sa.select(Conversation) + .where(Conversation.user_id == user_id) + .order_by(Conversation.updated_at.desc()) + .limit(limit) + .offset(offset) + ) + result = await session.execute(stmt) + return list(result.scalars().all()) + + +def group_by_date(conversations: Iterable[Conversation]) -> dict[str, list[dict]]: + """Group conversations by YYYY-MM-DD (UTC) of `updated_at`.""" + buckets: dict[str, list[dict]] = defaultdict(list) + for convo in conversations: + bucket_key: str + if convo.updated_at is None: + bucket_key = date.today().isoformat() + else: + bucket_key = convo.updated_at.date().isoformat() + buckets[bucket_key].append(convo.to_dict()) + return dict(buckets) + + +async def get_user_conversation( + session: AsyncSession, + *, + user_id: uuid.UUID, + conversation_id: uuid.UUID, +) -> Conversation | None: + stmt = sa.select(Conversation).where( + Conversation.id == conversation_id, + Conversation.user_id == user_id, + ) + result = await session.execute(stmt) + return result.scalar_one_or_none() + + +async def get_conversation_messages( + session: AsyncSession, + *, + conversation_id: uuid.UUID, + limit: int | None = None, +) -> list[Message]: + stmt = ( + sa.select(Message) + .where(Message.conversation_id == conversation_id) + .order_by(Message.created_at.asc(), Message.id.asc()) + ) + if limit is not None: + stmt = stmt.limit(limit) + result = await session.execute(stmt) + return list(result.scalars().all()) + + +async def update_conversation_title( + session: AsyncSession, + *, + user_id: uuid.UUID, + conversation_id: uuid.UUID, + title: str, +) -> Conversation | None: + convo = await get_user_conversation( + session, user_id=user_id, conversation_id=conversation_id + ) + if convo is None: + return None + convo.title = title + await session.commit() + await session.refresh(convo) + return convo + + +async def delete_conversation( + session: AsyncSession, + *, + user_id: uuid.UUID, + conversation_id: uuid.UUID, +) -> bool: + convo = await get_user_conversation( + session, user_id=user_id, conversation_id=conversation_id + ) + if convo is None: + return False + await session.delete(convo) + await session.commit() + return True + + +async def append_message( + session: AsyncSession, + *, + conversation_id: uuid.UUID, + role: str, + content: str, + token_count: int | None = None, + model_tag: str | None = None, + inference_time_ms: int | None = None, +) -> Message: + message = Message( + conversation_id=conversation_id, + role=role, + content=content, + token_count=token_count, + model_tag=model_tag, + inference_time_ms=inference_time_ms, + ) + session.add(message) + await session.commit() + await session.refresh(message) + return message + + +async def build_history_for_inference( + session: AsyncSession, + *, + conversation_id: uuid.UUID, +) -> list[dict[str, str]]: + """Return the trailing slice of history that fits the configured budgets.""" + settings = get_settings() + max_history = settings.max_conversation_history + max_budget = settings.max_token_budget + + stmt = ( + sa.select(Message) + .where(Message.conversation_id == conversation_id) + .order_by(Message.created_at.desc(), Message.id.desc()) + .limit(max_history) + ) + result = await session.execute(stmt) + newest_first = list(result.scalars().all()) + + selected: list[Message] = [] + budget = 0 + for message in newest_first: + budget += len(message.content or "") + if budget > max_budget and selected: + break + selected.append(message) + + selected.reverse() + return [{"role": m.role, "content": m.content} for m in selected] + + +async def delete_last_assistant_message( + session: AsyncSession, + *, + conversation_id: uuid.UUID, +) -> bool: + stmt = ( + sa.select(Message) + .where( + Message.conversation_id == conversation_id, + Message.role == "assistant", + ) + .order_by(Message.created_at.desc(), Message.id.desc()) + .limit(1) + ) + result = await session.execute(stmt) + last = result.scalar_one_or_none() + if last is None: + return False + await session.delete(last) + await session.commit() + return True diff --git a/services/chat-api/src/services/inference_client.py b/services/chat-api/src/services/inference_client.py new file mode 100644 index 00000000..235c41c1 --- /dev/null +++ b/services/chat-api/src/services/inference_client.py @@ -0,0 +1,93 @@ +"""HTTP client wrapper for the inference service.""" +from __future__ import annotations + +from contextlib import asynccontextmanager +from typing import AsyncIterator + +import httpx + +from ..config import Settings, get_settings + + +class InferenceClient: + """Thin async wrapper around the inference service HTTP contract.""" + + def __init__( + self, + settings: Settings | None = None, + http_client: httpx.AsyncClient | None = None, + ) -> None: + self._settings = settings or get_settings() + self._client = http_client + self._owns_client = http_client is None + + @property + def base_url(self) -> str: + return self._settings.inference_service_url.rstrip("/") + + @property + def headers(self) -> dict[str, str]: + return {"X-Internal-API-Key": self._settings.internal_api_key} + + def _get_client(self) -> httpx.AsyncClient: + if self._client is None: + self._client = httpx.AsyncClient(timeout=httpx.Timeout(60.0, connect=5.0)) + return self._client + + async def aclose(self) -> None: + if self._client is not None and self._owns_client: + await self._client.aclose() + self._client = None + + async def list_models(self) -> dict: + client = self._get_client() + resp = await client.get(f"{self.base_url}/models", headers=self.headers) + resp.raise_for_status() + return resp.json() + + async def swap_model(self, model_tag: str) -> dict: + client = self._get_client() + resp = await client.post( + f"{self.base_url}/models/swap", + headers=self.headers, + json={"model_tag": model_tag}, + ) + resp.raise_for_status() + return resp.json() + + @asynccontextmanager + async def stream_generate( + self, + *, + messages: list[dict[str, str]], + temperature: float | None = None, + max_tokens: int | None = None, + top_k: int | None = None, + ) -> AsyncIterator[httpx.Response]: + temperature = ( + temperature + if temperature is not None + else self._settings.inference_default_temperature + ) + max_tokens = ( + max_tokens + if max_tokens is not None + else self._settings.inference_default_max_tokens + ) + top_k = top_k if top_k is not None else self._settings.inference_default_top_k + + payload = { + "messages": messages, + "temperature": temperature, + "max_tokens": max_tokens, + "top_k": top_k, + } + + client = self._get_client() + async with client.stream( + "POST", + f"{self.base_url}/generate", + headers=self.headers, + json=payload, + ) as response: + yield response diff --git a/services/chat-api/src/services/stream_proxy.py b/services/chat-api/src/services/stream_proxy.py new file mode 100644 index 00000000..c979778f --- /dev/null +++ b/services/chat-api/src/services/stream_proxy.py @@ -0,0 +1,114 @@ +"""Proxy the inference SSE stream to the client while accumulating tokens. + +The inference service emits lines like `data: {"token": "...", "gpu": 0}` and +terminates with `data: {"done": true}`. We forward each event unchanged to the +client, collect assistant tokens into a buffer, and report the buffer plus +timing info once the stream closes. +""" +from __future__ import annotations + +import json +import time +from dataclasses import dataclass +from typing import AsyncIterator, Callable + +import httpx + +from ..logging_setup import get_logger + +logger = get_logger(__name__) + + +@dataclass +class StreamResult: + content: str + token_count: int + inference_time_ms: int + completed: bool + + +def _parse_sse_data(raw_line: str) -> dict | None: + line = raw_line.strip() + if not line or not line.startswith("data:"): + return None + body = line[len("data:"):].strip() + if not body: + return None + try: + return json.loads(body) + except json.JSONDecodeError: + logger.warning("stream_proxy_bad_sse_payload", payload=body[:200]) + return None + + +async def proxy_inference_stream( + response: httpx.Response, + *, + on_complete: Callable[[StreamResult], None] | None = None, +) -> AsyncIterator[dict]: + """Async generator that yields SSE events as dicts with ``data`` keys. + + Each yielded event is shaped as ``{"data": ""}`` so callers can + pass it straight into sse-starlette's ``EventSourceResponse``. + """ + started = time.perf_counter() + buffer: list[str] = [] + token_count = 0 + completed = False + + if response.status_code >= 400: + body = await response.aread() + logger.error( + "inference_error_response", + status_code=response.status_code, + body=body.decode("utf-8", errors="replace")[:200], + ) + error_payload = json.dumps( + { + "error": "inference service returned an error", + "status_code": response.status_code, + } + ) + yield {"data": error_payload} + yield {"data": json.dumps({"done": True})} + if on_complete is not None: + on_complete( + StreamResult( + content="", + token_count=0, + inference_time_ms=int((time.perf_counter() - started) * 1000), + completed=False, + ) + ) + return + + try: + async for raw_line in response.aiter_lines(): + if not raw_line: + continue + parsed = _parse_sse_data(raw_line) + if parsed is None: + continue + + if parsed.get("done"): + completed = True + yield {"data": json.dumps(parsed)} + break + + token = parsed.get("token") + if isinstance(token, str) and token: + buffer.append(token) + token_count += 1 + + yield {"data": json.dumps(parsed)} + finally: + elapsed_ms = int((time.perf_counter() - started) * 1000) + if on_complete is not None: + on_complete( + StreamResult( + content="".join(buffer), + token_count=token_count, + inference_time_ms=elapsed_ms, + completed=completed, + ) + ) diff --git a/services/chat-api/src/tests/__init__.py b/services/chat-api/src/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/services/chat-api/src/tests/conftest.py b/services/chat-api/src/tests/conftest.py new file mode 100644 index 00000000..338073c4 --- /dev/null +++ b/services/chat-api/src/tests/conftest.py @@ -0,0 +1,160 @@ +"""Shared pytest fixtures for the chat API test suite. + +Tests run against a throwaway SQLite database (via aiosqlite) with a stub +`users` table so that the foreign key from `conversations.user_id` validates. +The auth service is mocked with respx; the inference service SSE responses +are mocked with hand-rolled httpx MockTransports. +""" +from __future__ import annotations + +import uuid +from typing import AsyncIterator + +import pytest +import pytest_asyncio +import sqlalchemy as sa +from httpx import ASGITransport, AsyncClient +from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine + +from src import database as db_module +from src.database import Base +from src.middleware import auth_guard +from src.models import Conversation, Message # noqa: F401 (registers tables) + + +# Attach a stub `users` table to the shared metadata so the FK on +# `conversations.user_id` can resolve during SQLite-based test setup. +if "users" not in Base.metadata.tables: + sa.Table( + "users", + Base.metadata, + sa.Column("id", sa.Uuid(as_uuid=True), primary_key=True), + sa.Column("email", sa.String(255)), + sa.Column("name", sa.String(255)), + sa.Column("is_admin", sa.Integer(), default=0), + ) + + +@pytest.fixture(autouse=True) +def _reset_caches(monkeypatch): + # Settings and the auth cache are module-level singletons; wipe them + # between tests so environment overrides take effect cleanly. + from src.config import get_settings + + get_settings.cache_clear() + auth_guard.reset_auth_cache() + _TOKEN_REGISTRY.clear() + yield + get_settings.cache_clear() + auth_guard.reset_auth_cache() + _TOKEN_REGISTRY.clear() + + +@pytest.fixture(autouse=True) +def _env(monkeypatch): + monkeypatch.setenv("DATABASE_URL", "sqlite+aiosqlite:///:memory:") + monkeypatch.setenv("AUTH_SERVICE_URL", "http://auth.test") + monkeypatch.setenv("INFERENCE_SERVICE_URL", "http://inference.test") + monkeypatch.setenv("INTERNAL_API_KEY", "test-internal-key") + monkeypatch.setenv("LOG_LEVEL", "WARNING") + from src.config import get_settings + + get_settings.cache_clear() + + +@pytest_asyncio.fixture +async def engine(_env) -> AsyncIterator: + engine = create_async_engine("sqlite+aiosqlite:///:memory:") + async with engine.begin() as conn: + await conn.run_sync(Base.metadata.create_all) + yield engine + await engine.dispose() + + +@pytest_asyncio.fixture +async def session_factory(engine) -> async_sessionmaker[AsyncSession]: + factory = async_sessionmaker(engine, expire_on_commit=False) + db_module.override_session_factory(factory) + return factory + + +@pytest_asyncio.fixture +async def seeded_user(session_factory) -> dict: + user_id = str(uuid.uuid4()) + async with session_factory() as session: + await session.execute( + sa.text( + "INSERT INTO users (id, email, name, is_admin) " + "VALUES (:id, :email, :name, :is_admin)" + ), + { + "id": user_id, + "email": "alice@example.com", + "name": "Alice", + "is_admin": 0, + }, + ) + await session.commit() + return {"id": user_id, "email": "alice@example.com", "name": "Alice"} + + +@pytest_asyncio.fixture +async def other_user(session_factory) -> dict: + user_id = str(uuid.uuid4()) + async with session_factory() as session: + await session.execute( + sa.text( + "INSERT INTO users (id, email, name, is_admin) " + "VALUES (:id, :email, :name, :is_admin)" + ), + {"id": user_id, "email": "bob@example.com", "name": "Bob", "is_admin": 0}, + ) + await session.commit() + return {"id": user_id, "email": "bob@example.com", "name": "Bob"} + + +@pytest.fixture +def app(): + from src.main import create_app + + return create_app() + + +@pytest_asyncio.fixture +async def client(app, session_factory) -> AsyncIterator[AsyncClient]: + transport = ASGITransport(app=app) + async with AsyncClient(transport=transport, base_url="http://test") as ac: + yield ac + + +_TOKEN_REGISTRY: dict[str, dict] = {} + + +def _dispatch_auth_validate(request): + import json + + import httpx + + payload = json.loads(request.read()) + token = payload.get("token") + if request.headers.get("X-Internal-API-Key") != "test-internal-key": + return httpx.Response(403, json={"detail": "invalid internal api key"}) + user = _TOKEN_REGISTRY.get(token) + if user is None: + return httpx.Response(401, json={"valid": False, "reason": "invalid token"}) + return httpx.Response( + 200, + json={"valid": True, "user": user, "claims": {"sub": user["id"]}}, + ) + + +def stub_auth_validate(respx_mock, user: dict, token: str = "valid-token"): + """Register a respx mock that returns the given user for the given token. + + Multiple calls within a single test accumulate token→user mappings so + several users can authenticate in the same scenario. + """ + _TOKEN_REGISTRY[token] = user + respx_mock.post("http://auth.test/auth/validate").mock( + side_effect=_dispatch_auth_validate + ) diff --git a/services/chat-api/src/tests/test_conversations.py b/services/chat-api/src/tests/test_conversations.py new file mode 100644 index 00000000..e6f175ac --- /dev/null +++ b/services/chat-api/src/tests/test_conversations.py @@ -0,0 +1,110 @@ +import pytest +import respx + +from .conftest import stub_auth_validate + + +@pytest.mark.asyncio +@respx.mock +async def test_requires_authorization_header(client): + response = await client.get("/api/conversations") + assert response.status_code == 401 + + +@pytest.mark.asyncio +@respx.mock +async def test_create_and_list_conversation(client, seeded_user): + stub_auth_validate(respx.mock, seeded_user) + headers = {"Authorization": "Bearer valid-token"} + + create = await client.post( + "/api/conversations", json={"title": "my first chat"}, headers=headers + ) + assert create.status_code == 201 + convo = create.json() + assert convo["title"] == "my first chat" + assert convo["user_id"] == seeded_user["id"] + + listed = await client.get("/api/conversations", headers=headers) + assert listed.status_code == 200 + payload = listed.json() + assert len(payload["items"]) == 1 + assert payload["items"][0]["id"] == convo["id"] + assert payload["grouped"] # at least one date bucket + + +@pytest.mark.asyncio +@respx.mock +async def test_update_conversation_title(client, seeded_user): + stub_auth_validate(respx.mock, seeded_user) + headers = {"Authorization": "Bearer valid-token"} + + create = await client.post("/api/conversations", json={}, headers=headers) + convo_id = create.json()["id"] + + updated = await client.put( + f"/api/conversations/{convo_id}", + json={"title": "Renamed"}, + headers=headers, + ) + assert updated.status_code == 200 + assert updated.json()["title"] == "Renamed" + + +@pytest.mark.asyncio +@respx.mock +async def test_delete_conversation(client, seeded_user): + stub_auth_validate(respx.mock, seeded_user) + headers = {"Authorization": "Bearer valid-token"} + + create = await client.post("/api/conversations", json={}, headers=headers) + convo_id = create.json()["id"] + + deleted = await client.delete(f"/api/conversations/{convo_id}", headers=headers) + assert deleted.status_code == 204 + + missing = await client.get(f"/api/conversations/{convo_id}", headers=headers) + assert missing.status_code == 404 + + +@pytest.mark.asyncio +@respx.mock +async def test_user_scoping_prevents_cross_user_access( + client, seeded_user, other_user +): + stub_auth_validate(respx.mock, seeded_user, token="alice-token") + stub_auth_validate(respx.mock, other_user, token="bob-token") + + alice_headers = {"Authorization": "Bearer alice-token"} + bob_headers = {"Authorization": "Bearer bob-token"} + + alice_convo = await client.post( + "/api/conversations", + json={"title": "alice only"}, + headers=alice_headers, + ) + convo_id = alice_convo.json()["id"] + + bob_view = await client.get( + f"/api/conversations/{convo_id}", headers=bob_headers + ) + assert bob_view.status_code == 404 + + bob_delete = await client.delete( + f"/api/conversations/{convo_id}", headers=bob_headers + ) + assert bob_delete.status_code == 404 + + bob_list = await client.get("/api/conversations", headers=bob_headers) + assert bob_list.json()["items"] == [] + + +@pytest.mark.asyncio +@respx.mock +async def test_invalid_token_is_rejected(client, seeded_user): + stub_auth_validate(respx.mock, seeded_user, token="valid-token") + response = await client.get( + "/api/conversations", + headers={"Authorization": "Bearer wrong-token"}, + ) + assert response.status_code == 401 diff --git a/services/chat-api/src/tests/test_health.py b/services/chat-api/src/tests/test_health.py new file mode 100644 index 00000000..13803b29 --- /dev/null +++ b/services/chat-api/src/tests/test_health.py @@ -0,0 +1,11 @@ +import pytest + + +@pytest.mark.asyncio +async def test_health_is_unauthenticated(client): + response = await client.get("/api/health") + assert response.status_code == 200 + body = response.json() + assert body["status"] == "ok" + assert body["ready"] is True + assert body["service"] == "chat-api" diff --git a/services/chat-api/src/tests/test_messages.py b/services/chat-api/src/tests/test_messages.py new file mode 100644 index 00000000..17fbda3d --- /dev/null +++ b/services/chat-api/src/tests/test_messages.py @@ -0,0 +1,166 @@ +"""Tests for the streaming message send + regenerate flows.""" +from __future__ import annotations + +import json +import uuid + +import httpx +import pytest +import respx + +from .conftest import stub_auth_validate + + +def _build_inference_mock(tokens: list[str]) -> httpx.MockTransport: + """Build an httpx mock transport that streams an SSE response.""" + sse_lines: list[bytes] = [] + for token in tokens: + sse_lines.append( + f"data: {json.dumps({'token': token, 'gpu': 0})}\n\n".encode("utf-8") + ) + sse_lines.append(f"data: {json.dumps({'done': True})}\n\n".encode("utf-8")) + + async def handler(request: httpx.Request) -> httpx.Response: + if request.url.path != "/generate": + return httpx.Response(404) + + async def body(): + for chunk in sse_lines: + yield chunk + + return httpx.Response( + 200, + headers={"content-type": "text/event-stream"}, + content=body(), + ) + + return httpx.MockTransport(handler) + + +@pytest.mark.asyncio +@respx.mock +async def test_send_message_streams_and_persists(app, client, seeded_user): + stub_auth_validate(respx.mock, seeded_user) + headers = {"Authorization": "Bearer valid-token"} + + create = await client.post("/api/conversations", json={}, headers=headers) + convo_id = create.json()["id"] + + app.state.inference_http_client = httpx.AsyncClient( + transport=_build_inference_mock(["hel", "lo", " world"]) + ) + try: + resp = await client.post( + f"/api/conversations/{convo_id}/messages", + json={"content": "hi there"}, + headers=headers, + ) + assert resp.status_code == 200 + body = resp.text + assert '"token": "hel"' in body or '"token":"hel"' in body + assert '"done": true' in body or '"done":true' in body + finally: + await app.state.inference_http_client.aclose() + + fetched = await client.get( + f"/api/conversations/{convo_id}", headers=headers + ) + assert fetched.status_code == 200 + payload = fetched.json() + messages = payload["messages"] + + roles = [m["role"] for m in messages] + assert roles == ["user", "assistant"] + assert messages[0]["content"] == "hi there" + assert messages[1]["content"] == "hello world" + assert messages[1]["token_count"] == 3 + assert messages[1]["inference_time_ms"] >= 0 + + # First message should have auto-populated the title + assert payload["title"] == "hi there" + + +@pytest.mark.asyncio +@respx.mock +async def test_send_message_rejected_on_foreign_conversation( + app, client, seeded_user, other_user +): + stub_auth_validate(respx.mock, seeded_user, token="alice-token") + stub_auth_validate(respx.mock, other_user, token="bob-token") + + alice_headers = {"Authorization": "Bearer alice-token"} + bob_headers = {"Authorization": "Bearer bob-token"} + + create = await client.post("/api/conversations", json={}, headers=alice_headers) + convo_id = create.json()["id"] + + app.state.inference_http_client = httpx.AsyncClient( + transport=_build_inference_mock(["x"]) + ) + try: + resp = await client.post( + f"/api/conversations/{convo_id}/messages", + json={"content": "steal me"}, + headers=bob_headers, + ) + finally: + await app.state.inference_http_client.aclose() + assert resp.status_code == 404 + + +@pytest.mark.asyncio +@respx.mock +async def test_send_message_returns_404_for_missing_conversation( + app, client, seeded_user +): + stub_auth_validate(respx.mock, seeded_user) + headers = {"Authorization": "Bearer valid-token"} + + missing_id = str(uuid.uuid4()) + resp = await client.post( + f"/api/conversations/{missing_id}/messages", + json={"content": "hello"}, + headers=headers, + ) + assert resp.status_code == 404 + + +@pytest.mark.asyncio +@respx.mock +async def test_regenerate_drops_last_assistant_message(app, client, seeded_user): + stub_auth_validate(respx.mock, seeded_user) + headers = {"Authorization": "Bearer valid-token"} + + create = await client.post("/api/conversations", json={}, headers=headers) + convo_id = create.json()["id"] + + app.state.inference_http_client = httpx.AsyncClient( + transport=_build_inference_mock(["first"]) + ) + try: + first = await client.post( + f"/api/conversations/{convo_id}/messages", + json={"content": "hi"}, + headers=headers, + ) + assert first.status_code == 200 + + app.state.inference_http_client = httpx.AsyncClient( + transport=_build_inference_mock(["second", " reply"]) + ) + regen = await client.post( + f"/api/conversations/{convo_id}/regenerate", + json={}, + headers=headers, + ) + assert regen.status_code == 200 + finally: + await app.state.inference_http_client.aclose() + + fetched = await client.get( + f"/api/conversations/{convo_id}", headers=headers + ) + messages = fetched.json()["messages"] + assistant_messages = [m for m in messages if m["role"] == "assistant"] + assert len(assistant_messages) == 1 + assert assistant_messages[0]["content"] == "second reply" diff --git a/services/chat-api/src/tests/test_models_proxy.py b/services/chat-api/src/tests/test_models_proxy.py new file mode 100644 index 00000000..c8ad7f1b --- /dev/null +++ b/services/chat-api/src/tests/test_models_proxy.py @@ -0,0 +1,95 @@ +"""Tests for /api/models proxy routes.""" +from __future__ import annotations + +import uuid + +import httpx +import pytest +import respx +import sqlalchemy as sa + +from .conftest import stub_auth_validate + + +def _inference_mock(models_response: dict) -> httpx.MockTransport: + async def handler(request: httpx.Request) -> httpx.Response: + if request.url.path == "/models": + return httpx.Response(200, json=models_response) + if request.url.path == "/models/swap": + return httpx.Response(200, json={"status": "ok", "current_model": "new-model"}) + return httpx.Response(404) + + return httpx.MockTransport(handler) + + +@pytest.mark.asyncio +@respx.mock +async def test_list_models_proxies_to_inference(app, client, seeded_user): + stub_auth_validate(respx.mock, seeded_user) + headers = {"Authorization": "Bearer valid-token"} + + app.state.inference_http_client = httpx.AsyncClient( + transport=_inference_mock({"current_model": "m1", "models": ["m1", "m2"]}) + ) + try: + resp = await client.get("/api/models", headers=headers) + assert resp.status_code == 200 + assert resp.json() == {"current_model": "m1", "models": ["m1", "m2"]} + finally: + await app.state.inference_http_client.aclose() + + +@pytest.mark.asyncio +@respx.mock +async def test_swap_model_requires_admin(app, client, seeded_user): + stub_auth_validate(respx.mock, seeded_user) + headers = {"Authorization": "Bearer valid-token"} + + app.state.inference_http_client = httpx.AsyncClient(transport=_inference_mock({})) + try: + resp = await client.post( + "/api/models/swap", + json={"model_tag": "new-model"}, + headers=headers, + ) + finally: + await app.state.inference_http_client.aclose() + assert resp.status_code == 403 + + +@pytest.mark.asyncio +@respx.mock +async def test_swap_model_succeeds_for_admin(app, client, session_factory): + admin_id = str(uuid.uuid4()) + async with session_factory() as session: + await session.execute( + sa.text( + "INSERT INTO users (id, email, name, is_admin) " + "VALUES (:id, :email, :name, :is_admin)" + ), + {"id": admin_id, "email": "root@example.com", "name": "Root", "is_admin": 1}, + ) + await session.commit() + + admin_user = { + "id": admin_id, + "email": "root@example.com", + "name": "Root", + "is_admin": True, + } + stub_auth_validate(respx.mock, admin_user) + headers = {"Authorization": "Bearer valid-token"} + + app.state.inference_http_client = httpx.AsyncClient( + transport=_inference_mock({"current_model": "new-model", "models": ["new-model"]}) + ) + try: + resp = await client.post( + "/api/models/swap", + json={"model_tag": "new-model"}, + headers=headers, + ) + assert resp.status_code == 200 + assert resp.json()["current_model"] == "new-model" + finally: + await app.state.inference_http_client.aclose() diff --git a/services/chat-api/src/tests/test_stream_proxy.py b/services/chat-api/src/tests/test_stream_proxy.py new file mode 100644 index 00000000..72a1706f --- /dev/null +++ b/services/chat-api/src/tests/test_stream_proxy.py @@ -0,0 +1,73 @@ +"""Unit tests for the inference SSE stream proxy.""" +from __future__ import annotations + +import json + +import httpx +import pytest + +from src.services.stream_proxy import StreamResult, proxy_inference_stream + + +def _make_response(lines: list[str]) -> httpx.Response: + async def body(): + for line in lines: + yield f"{line}\n".encode("utf-8") + + return httpx.Response( + 200, + headers={"content-type": "text/event-stream"}, + content=body(), + ) + + +@pytest.mark.asyncio +async def test_stream_proxy_accumulates_tokens_and_signals_done(): + resp = _make_response( + [ + "data: " + json.dumps({"token": "hel", "gpu": 0}), + "", + "data: " + json.dumps({"token": "lo", "gpu": 0}), + "", + "data: " + json.dumps({"done": True}), + "", + ] + ) + + captured: dict[str, StreamResult] = {} + + def on_complete(result: StreamResult) -> None: + captured["result"] = result + + events = [] + async for event in proxy_inference_stream(resp, on_complete=on_complete): + events.append(event) + + assert [json.loads(e["data"]) for e in events] == [ + {"token": "hel", "gpu": 0}, + {"token": "lo", "gpu": 0}, + {"done": True}, + ] + + result = captured["result"] + assert result.content == "hello" + assert result.token_count == 2 + assert result.completed is True + assert result.inference_time_ms >= 0 + + +@pytest.mark.asyncio +async def test_stream_proxy_surfaces_error_status_codes(): + resp = httpx.Response(502, content=b"upstream down") + captured: dict[str, StreamResult] = {} + + def on_complete(result: StreamResult) -> None: + captured["result"] = result + + events = [] + async for event in proxy_inference_stream(resp, on_complete=on_complete): + events.append(event) + + assert any("error" in e["data"] for e in events) + assert any('"done": true' in e["data"] or '"done":true' in e["data"] for e in events) + assert captured["result"].completed is False