diff --git a/db/alembic.ini b/db/alembic.ini new file mode 100644 index 00000000..d8d1bbc1 --- /dev/null +++ b/db/alembic.ini @@ -0,0 +1,41 @@ +[alembic] +script_location = %(here)s/migrations +prepend_sys_path = . +version_path_separator = os +sqlalchemy.url = driver://user:pass@localhost/dbname + +[post_write_hooks] + +[loggers] +keys = root,sqlalchemy,alembic + +[handlers] +keys = console + +[formatters] +keys = generic + +[logger_root] +level = WARN +handlers = console +qualname = + +[logger_sqlalchemy] +level = WARN +handlers = +qualname = sqlalchemy.engine + +[logger_alembic] +level = INFO +handlers = +qualname = alembic + +[handler_console] +class = StreamHandler +args = (sys.stderr,) +level = NOTSET +formatter = generic + +[formatter_generic] +format = %(levelname)-5.5s [%(name)s] %(message)s +datefmt = %H:%M:%S diff --git a/db/migrations/0001_initial_schema.sql b/db/migrations/0001_initial_schema.sql deleted file mode 100644 index bb40f32e..00000000 --- a/db/migrations/0001_initial_schema.sql +++ /dev/null @@ -1,42 +0,0 @@ -CREATE EXTENSION IF NOT EXISTS pgcrypto; - -CREATE TABLE IF NOT EXISTS users ( - id UUID PRIMARY KEY DEFAULT gen_random_uuid(), - email TEXT NOT NULL UNIQUE, - name TEXT NOT NULL, - avatar_url TEXT, - provider TEXT NOT NULL, - provider_id TEXT NOT NULL, - created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), - updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), - last_login_at TIMESTAMPTZ -); - -CREATE UNIQUE INDEX IF NOT EXISTS users_provider_lookup_idx - ON users (provider, provider_id); - -CREATE TABLE IF NOT EXISTS conversations ( - id UUID PRIMARY KEY DEFAULT gen_random_uuid(), - user_id UUID NOT NULL REFERENCES users(id) ON DELETE CASCADE, - title TEXT NOT NULL, - model_tag TEXT NOT NULL, - created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), - updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW() -); - -CREATE INDEX IF NOT EXISTS conversations_user_id_idx - ON conversations (user_id); - -CREATE TABLE IF NOT EXISTS messages ( - id UUID PRIMARY KEY DEFAULT gen_random_uuid(), - conversation_id UUID NOT NULL REFERENCES conversations(id) ON DELETE CASCADE, - role TEXT NOT NULL, - content TEXT NOT NULL, - token_count INTEGER NOT NULL DEFAULT 0, - model_tag TEXT NOT NULL, - inference_time_ms INTEGER NOT NULL DEFAULT 0, - created_at TIMESTAMPTZ NOT NULL DEFAULT NOW() -); - -CREATE INDEX IF NOT EXISTS messages_conversation_id_idx - ON messages (conversation_id, created_at); diff --git a/db/migrations/env.py b/db/migrations/env.py new file mode 100644 index 00000000..b035b347 --- /dev/null +++ b/db/migrations/env.py @@ -0,0 +1,67 @@ +"""Alembic environment configuration for async PostgreSQL.""" +from __future__ import annotations + +import asyncio +import os +from logging.config import fileConfig + +from alembic import context +from sqlalchemy import pool +from sqlalchemy.engine import Connection +from sqlalchemy.ext.asyncio import async_engine_from_config + +config = context.config + +if config.config_file_name is not None: + fileConfig(config.config_file_name) + +target_metadata = None + + +def _database_url() -> str: + url = os.environ.get("DATABASE_URL") + if not url: + raise RuntimeError("DATABASE_URL environment variable is required for Alembic") + if url.startswith("postgresql://"): + url = url.replace("postgresql://", "postgresql+asyncpg://", 1) + return url + + +def run_migrations_offline() -> None: + context.configure( + url=_database_url(), + target_metadata=target_metadata, + literal_binds=True, + dialect_opts={"paramstyle": "named"}, + ) + with context.begin_transaction(): + context.run_migrations() + + +def do_run_migrations(connection: Connection) -> None: + context.configure(connection=connection, target_metadata=target_metadata) + with context.begin_transaction(): + context.run_migrations() + + +async def run_async_migrations() -> None: + configuration = config.get_section(config.config_ini_section) or {} + configuration["sqlalchemy.url"] = _database_url() + connectable = async_engine_from_config( + configuration, + prefix="sqlalchemy.", + poolclass=pool.NullPool, + ) + async with connectable.connect() as connection: + await connection.run_sync(do_run_migrations) + await connectable.dispose() + + +def run_migrations_online() -> None: + asyncio.run(run_async_migrations()) + + +if context.is_offline_mode(): + run_migrations_offline() +else: + run_migrations_online() diff --git a/db/migrations/script.py.mako b/db/migrations/script.py.mako new file mode 100644 index 00000000..b1f8b89a --- /dev/null +++ b/db/migrations/script.py.mako @@ -0,0 +1,27 @@ +"""${message} + +Revision ID: ${up_revision} +Revises: ${down_revision | comma,n} +Create Date: ${create_date} + +""" +from __future__ import annotations + +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa +${imports if imports else ""} + +revision: str = ${repr(up_revision)} +down_revision: Union[str, None] = ${repr(down_revision)} +branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)} +depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)} + + +def upgrade() -> None: + ${upgrades if upgrades else "pass"} + + +def downgrade() -> None: + ${downgrades if downgrades else "pass"} diff --git a/db/migrations/versions/001_create_users.py b/db/migrations/versions/001_create_users.py new file mode 100644 index 00000000..d280b4c7 --- /dev/null +++ b/db/migrations/versions/001_create_users.py @@ -0,0 +1,57 @@ +"""create users table + +Revision ID: 001_create_users +Revises: +Create Date: 2026-04-16 + +""" +from __future__ import annotations + +from typing import Sequence, Union + +import sqlalchemy as sa +from alembic import op +from sqlalchemy.dialects import postgresql + +revision: str = "001_create_users" +down_revision: Union[str, None] = None +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + op.execute("CREATE EXTENSION IF NOT EXISTS pgcrypto") + op.create_table( + "users", + sa.Column( + "id", + postgresql.UUID(as_uuid=True), + primary_key=True, + server_default=sa.text("gen_random_uuid()"), + ), + sa.Column("email", sa.String(length=255), nullable=False, unique=True), + sa.Column("name", sa.String(length=255), nullable=True), + sa.Column("avatar_url", sa.Text(), nullable=True), + sa.Column("provider", sa.String(length=50), nullable=False), + sa.Column("provider_id", sa.String(length=255), nullable=False), + sa.Column( + "created_at", + sa.TIMESTAMP(timezone=True), + server_default=sa.text("NOW()"), + ), + sa.Column( + "updated_at", + sa.TIMESTAMP(timezone=True), + server_default=sa.text("NOW()"), + ), + sa.Column( + "last_login_at", + sa.TIMESTAMP(timezone=True), + server_default=sa.text("NOW()"), + ), + sa.UniqueConstraint("provider", "provider_id", name="uq_users_provider_identity"), + ) + + +def downgrade() -> None: + op.drop_table("users") diff --git a/db/migrations/versions/002_create_conversations.py b/db/migrations/versions/002_create_conversations.py new file mode 100644 index 00000000..79bc5c14 --- /dev/null +++ b/db/migrations/versions/002_create_conversations.py @@ -0,0 +1,64 @@ +"""create conversations table + +Revision ID: 002_create_conversations +Revises: 001_create_users +Create Date: 2026-04-16 + +""" +from __future__ import annotations + +from typing import Sequence, Union + +import sqlalchemy as sa +from alembic import op +from sqlalchemy.dialects import postgresql + +revision: str = "002_create_conversations" +down_revision: Union[str, None] = "001_create_users" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + op.create_table( + "conversations", + sa.Column( + "id", + postgresql.UUID(as_uuid=True), + primary_key=True, + server_default=sa.text("gen_random_uuid()"), + ), + sa.Column( + "user_id", + postgresql.UUID(as_uuid=True), + sa.ForeignKey("users.id", ondelete="CASCADE"), + nullable=False, + ), + sa.Column("title", sa.String(length=500), nullable=True), + sa.Column( + "model_tag", + sa.String(length=100), + nullable=True, + server_default=sa.text("'default'"), + ), + sa.Column( + "created_at", + sa.TIMESTAMP(timezone=True), + server_default=sa.text("NOW()"), + ), + sa.Column( + "updated_at", + sa.TIMESTAMP(timezone=True), + server_default=sa.text("NOW()"), + ), + ) + op.create_index( + "idx_conversations_user", + "conversations", + ["user_id", sa.text("updated_at DESC")], + ) + + +def downgrade() -> None: + op.drop_index("idx_conversations_user", table_name="conversations") + op.drop_table("conversations") diff --git a/db/migrations/versions/003_create_messages.py b/db/migrations/versions/003_create_messages.py new file mode 100644 index 00000000..8e1cfbef --- /dev/null +++ b/db/migrations/versions/003_create_messages.py @@ -0,0 +1,61 @@ +"""create messages table + +Revision ID: 003_create_messages +Revises: 002_create_conversations +Create Date: 2026-04-16 + +""" +from __future__ import annotations + +from typing import Sequence, Union + +import sqlalchemy as sa +from alembic import op +from sqlalchemy.dialects import postgresql + +revision: str = "003_create_messages" +down_revision: Union[str, None] = "002_create_conversations" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + op.create_table( + "messages", + sa.Column( + "id", + postgresql.UUID(as_uuid=True), + primary_key=True, + server_default=sa.text("gen_random_uuid()"), + ), + sa.Column( + "conversation_id", + postgresql.UUID(as_uuid=True), + sa.ForeignKey("conversations.id", ondelete="CASCADE"), + nullable=False, + ), + sa.Column("role", sa.String(length=20), nullable=False), + sa.Column("content", sa.Text(), nullable=False), + sa.Column("token_count", sa.Integer(), nullable=True), + sa.Column("model_tag", sa.String(length=100), nullable=True), + sa.Column("inference_time_ms", sa.Integer(), nullable=True), + sa.Column( + "created_at", + sa.TIMESTAMP(timezone=True), + server_default=sa.text("NOW()"), + ), + sa.CheckConstraint( + "role IN ('user','assistant','system')", + name="ck_messages_role", + ), + ) + op.create_index( + "idx_messages_conversation", + "messages", + ["conversation_id", sa.text("created_at ASC")], + ) + + +def downgrade() -> None: + op.drop_index("idx_messages_conversation", table_name="messages") + op.drop_table("messages") diff --git a/db/migrations/versions/004_add_favorited.py b/db/migrations/versions/004_add_favorited.py new file mode 100644 index 00000000..4c0c3dcd --- /dev/null +++ b/db/migrations/versions/004_add_favorited.py @@ -0,0 +1,34 @@ +"""add is_favorited column to conversations (Day 2 demo) + +Revision ID: 004_add_favorited +Revises: 003_create_messages +Create Date: 2026-04-16 + +""" +from __future__ import annotations + +from typing import Sequence, Union + +import sqlalchemy as sa +from alembic import op + +revision: str = "004_add_favorited" +down_revision: Union[str, None] = "003_create_messages" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + op.add_column( + "conversations", + sa.Column( + "is_favorited", + sa.Boolean(), + nullable=False, + server_default=sa.text("false"), + ), + ) + + +def downgrade() -> None: + op.drop_column("conversations", "is_favorited") diff --git a/services/auth/Dockerfile b/services/auth/Dockerfile index bce37ff9..0d968753 100644 --- a/services/auth/Dockerfile +++ b/services/auth/Dockerfile @@ -1,9 +1,38 @@ FROM python:3.12-slim +ENV PYTHONDONTWRITEBYTECODE=1 \ + PYTHONUNBUFFERED=1 \ + UV_SYSTEM_PYTHON=1 \ + UV_LINK_MODE=copy + +RUN apt-get update \ + && apt-get install -y --no-install-recommends build-essential curl \ + && rm -rf /var/lib/apt/lists/* + +RUN pip install --no-cache-dir uv==0.4.30 + WORKDIR /app -COPY README.md /app/README.md +COPY pyproject.toml README.md /app/ + +RUN uv pip install --system --no-cache \ + "fastapi>=0.117.1" \ + "uvicorn[standard]>=0.36.0" \ + "pydantic>=2.8.0" \ + "pydantic-settings>=2.4.0" \ + "sqlalchemy[asyncio]>=2.0.36" \ + "asyncpg>=0.29.0" \ + "alembic>=1.13.0" \ + "authlib>=1.3.2" \ + "httpx>=0.27.0" \ + "itsdangerous>=2.2.0" \ + "pyjwt>=2.9.0" \ + "cryptography>=43.0.0" \ + "slowapi>=0.1.9" \ + "python-multipart>=0.0.9" + +COPY src /app/src EXPOSE 8001 -CMD ["python", "-m", "http.server", "8001", "--bind", "0.0.0.0"] +CMD ["uvicorn", "src.main:app", "--host", "0.0.0.0", "--port", "8001"] diff --git a/services/auth/README.md b/services/auth/README.md index 97433de7..c34e2249 100644 --- a/services/auth/README.md +++ b/services/auth/README.md @@ -1,7 +1,46 @@ -# Auth Service +# samosaChaat Auth Service -Scaffold placeholder for Issue #5. +FastAPI microservice providing OAuth2 login (Google + GitHub) and JWT session +management for samosaChaat (Issue #5). -The monorepo branch provisions this directory and a minimal Docker image so -local `docker compose up` remains viable before the real auth service -implementation lands. +## Endpoints + +| Method | Path | Purpose | +| ------ | ---- | ------- | +| GET | `/auth/google` | Redirect to Google consent | +| GET | `/auth/google/callback` | Complete Google flow, upsert user, issue tokens | +| GET | `/auth/github` | Redirect to GitHub consent | +| GET | `/auth/github/callback` | Complete GitHub flow, upsert user, issue tokens | +| POST | `/auth/refresh` | Exchange refresh cookie for new access token | +| GET | `/auth/me` | Current user profile (Bearer JWT) | +| PUT | `/auth/me` | Update name / avatar (Bearer JWT) | +| POST | `/auth/validate` | Internal JWT validation (service-to-service) | +| GET | `/auth/health` | Liveness probe | + +## Environment + +``` +DATABASE_URL=postgresql+asyncpg://user:pass@host/db +GOOGLE_CLIENT_ID=... +GOOGLE_CLIENT_SECRET=... +GITHUB_CLIENT_ID=... +GITHUB_CLIENT_SECRET=... +JWT_PRIVATE_KEY= +JWT_PUBLIC_KEY= +FRONTEND_URL=http://localhost:3000 +INTERNAL_API_KEY= +``` + +## Local development + +``` +uv sync +uv run uvicorn src.main:app --reload --port 8001 +uv run pytest +``` + +Database schema is managed by Alembic at `db/migrations`: + +``` +DATABASE_URL=... uv run alembic -c db/alembic.ini upgrade head +``` diff --git a/services/auth/pyproject.toml b/services/auth/pyproject.toml new file mode 100644 index 00000000..c42bcc68 --- /dev/null +++ b/services/auth/pyproject.toml @@ -0,0 +1,40 @@ +[project] +name = "samosachaat-auth" +version = "0.1.0" +description = "samosaChaat authentication service (OAuth2 + JWT)" +readme = "README.md" +requires-python = ">=3.12" +dependencies = [ + "fastapi>=0.117.1", + "uvicorn[standard]>=0.36.0", + "pydantic>=2.8.0", + "pydantic-settings>=2.4.0", + "sqlalchemy[asyncio]>=2.0.36", + "asyncpg>=0.29.0", + "alembic>=1.13.0", + "authlib>=1.3.2", + "httpx>=0.27.0", + "itsdangerous>=2.2.0", + "pyjwt>=2.9.0", + "cryptography>=43.0.0", + "slowapi>=0.1.9", + "python-multipart>=0.0.9", +] + +[dependency-groups] +dev = [ + "pytest>=8.0.0", + "pytest-asyncio>=0.24.0", + "httpx>=0.27.0", + "aiosqlite>=0.20.0", + "respx>=0.21.1", +] + +[tool.pytest.ini_options] +asyncio_mode = "auto" +testpaths = ["src/tests"] +python_files = ["test_*.py"] +pythonpath = ["."] + +[tool.uv] +package = false diff --git a/services/auth/src/__init__.py b/services/auth/src/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/services/auth/src/config.py b/services/auth/src/config.py new file mode 100644 index 00000000..636b7f39 --- /dev/null +++ b/services/auth/src/config.py @@ -0,0 +1,47 @@ +"""Runtime configuration for the auth service. + +All configuration is loaded from environment variables using pydantic-settings. +Private/public keys are PEM-encoded RSA material used for RS256 JWTs. +""" +from __future__ import annotations + +from functools import lru_cache + +from pydantic import Field +from pydantic_settings import BaseSettings, SettingsConfigDict + + +class Settings(BaseSettings): + model_config = SettingsConfigDict(env_file=".env", extra="ignore") + + database_url: str = Field(default="postgresql+asyncpg://localhost/samosachaat") + + google_client_id: str = Field(default="") + google_client_secret: str = Field(default="") + + github_client_id: str = Field(default="") + github_client_secret: str = Field(default="") + + jwt_private_key: str = Field(default="") + jwt_public_key: str = Field(default="") + jwt_issuer: str = Field(default="samosachaat-auth") + jwt_access_ttl_seconds: int = Field(default=3600) + jwt_refresh_ttl_seconds: int = Field(default=7 * 24 * 3600) + + frontend_url: str = Field(default="http://localhost:3000") + internal_api_key: str = Field(default="") + + auth_base_url: str = Field(default="http://localhost:8001") + session_secret: str = Field(default="dev-session-secret-change-me") + + cookie_secure: bool = Field(default=False) + cookie_domain: str | None = Field(default=None) + + @property + def refresh_cookie_name(self) -> str: + return "samosachaat_refresh" + + +@lru_cache(maxsize=1) +def get_settings() -> Settings: + return Settings() diff --git a/services/auth/src/database.py b/services/auth/src/database.py new file mode 100644 index 00000000..42352331 --- /dev/null +++ b/services/auth/src/database.py @@ -0,0 +1,49 @@ +"""Async SQLAlchemy engine + session factory.""" +from __future__ import annotations + +from collections.abc import AsyncIterator + +from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine +from sqlalchemy.orm import DeclarativeBase + +from .config import get_settings + + +class Base(DeclarativeBase): + """Shared declarative base for all ORM models.""" + + +_engine = None +_session_factory: async_sessionmaker[AsyncSession] | None = None + + +def _build_engine(): + global _engine, _session_factory + settings = get_settings() + _engine = create_async_engine(settings.database_url, pool_pre_ping=True) + _session_factory = async_sessionmaker(_engine, expire_on_commit=False) + + +def get_engine(): + if _engine is None: + _build_engine() + return _engine + + +def get_session_factory() -> async_sessionmaker[AsyncSession]: + if _session_factory is None: + _build_engine() + assert _session_factory is not None + return _session_factory + + +async def get_session() -> AsyncIterator[AsyncSession]: + factory = get_session_factory() + async with factory() as session: + yield session + + +def override_session_factory(factory: async_sessionmaker[AsyncSession]) -> None: + """Testing hook: swap the session factory for an in-memory engine.""" + global _session_factory + _session_factory = factory diff --git a/services/auth/src/main.py b/services/auth/src/main.py new file mode 100644 index 00000000..0d3d86c4 --- /dev/null +++ b/services/auth/src/main.py @@ -0,0 +1,50 @@ +"""FastAPI entrypoint for the samosaChaat auth service.""" +from __future__ import annotations + +from fastapi import FastAPI +from fastapi.middleware.cors import CORSMiddleware +from fastapi.responses import JSONResponse +from slowapi.errors import RateLimitExceeded +from slowapi.middleware import SlowAPIMiddleware +from starlette.middleware.sessions import SessionMiddleware + +from .config import get_settings +from .rate_limit import limiter +from .routes import oauth, session, users + + +def _rate_limit_handler(request, exc: RateLimitExceeded): + return JSONResponse(status_code=429, content={"detail": "rate limit exceeded"}) + + +def create_app() -> FastAPI: + settings = get_settings() + app = FastAPI(title="samosaChaat Auth", version="0.1.0") + + app.state.limiter = limiter + app.add_exception_handler(RateLimitExceeded, _rate_limit_handler) + app.add_middleware(SlowAPIMiddleware) + + # SessionMiddleware is required by authlib for the OAuth state cookie. + app.add_middleware(SessionMiddleware, secret_key=settings.session_secret) + + app.add_middleware( + CORSMiddleware, + allow_origins=[settings.frontend_url], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], + ) + + app.include_router(oauth.router) + app.include_router(session.router) + app.include_router(users.router) + + @app.get("/auth/health") + async def health(): + return {"status": "ok"} + + return app + + +app = create_app() diff --git a/services/auth/src/middleware/__init__.py b/services/auth/src/middleware/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/services/auth/src/middleware/auth_middleware.py b/services/auth/src/middleware/auth_middleware.py new file mode 100644 index 00000000..9dbe0994 --- /dev/null +++ b/services/auth/src/middleware/auth_middleware.py @@ -0,0 +1,44 @@ +"""Bearer-token auth dependency.""" +from __future__ import annotations + +from dataclasses import dataclass + +from fastapi import Depends, HTTPException, Request, status +from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer +from sqlalchemy.ext.asyncio import AsyncSession + +from ..database import get_session +from ..models.user import User +from ..services import user_service +from ..services.jwt_service import JWTError, JWTService + +bearer_scheme = HTTPBearer(auto_error=False) + + +@dataclass +class AuthContext: + user: User + payload: dict + + +async def require_user( + request: Request, + credentials: HTTPAuthorizationCredentials | None = Depends(bearer_scheme), + session: AsyncSession = Depends(get_session), +) -> AuthContext: + if credentials is None or credentials.scheme.lower() != "bearer": + raise HTTPException(status.HTTP_401_UNAUTHORIZED, "missing bearer token") + + jwt_service = JWTService() + try: + payload = jwt_service.decode(credentials.credentials, expected_type="access") + except JWTError as exc: + raise HTTPException(status.HTTP_401_UNAUTHORIZED, str(exc)) from exc + + user = await user_service.get_by_id(session, payload["sub"]) + if user is None: + raise HTTPException(status.HTTP_401_UNAUTHORIZED, "user not found") + + ctx = AuthContext(user=user, payload=payload) + request.state.auth = ctx + return ctx diff --git a/services/auth/src/rate_limit.py b/services/auth/src/rate_limit.py new file mode 100644 index 00000000..3957854c --- /dev/null +++ b/services/auth/src/rate_limit.py @@ -0,0 +1,7 @@ +"""Shared slowapi limiter for login-facing routes.""" +from __future__ import annotations + +from slowapi import Limiter +from slowapi.util import get_remote_address + +limiter = Limiter(key_func=get_remote_address) diff --git a/services/auth/src/routes/__init__.py b/services/auth/src/routes/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/services/auth/src/routes/oauth.py b/services/auth/src/routes/oauth.py new file mode 100644 index 00000000..3295fed3 --- /dev/null +++ b/services/auth/src/routes/oauth.py @@ -0,0 +1,130 @@ +"""OAuth start/callback routes for Google and GitHub.""" +from __future__ import annotations + +from fastapi import APIRouter, Depends, HTTPException, Request, status +from fastapi.responses import RedirectResponse +from sqlalchemy.ext.asyncio import AsyncSession + +from ..config import get_settings +from ..database import get_session +from ..rate_limit import limiter +from ..services import github_oauth as github_provider +from ..services import google_oauth as google_provider +from ..services import user_service +from ..services.jwt_service import JWTService + +router = APIRouter(prefix="/auth", tags=["oauth"]) + + +def _set_refresh_cookie(response: RedirectResponse, token: str, max_age: int) -> None: + settings = get_settings() + response.set_cookie( + key=settings.refresh_cookie_name, + value=token, + max_age=max_age, + httponly=True, + secure=settings.cookie_secure, + samesite="lax", + domain=settings.cookie_domain, + path="/", + ) + + +def _google_oauth(request: Request): + client = getattr(request.app.state, "google_oauth", None) + if client is None: + client = google_provider.build_google_client() + request.app.state.google_oauth = client + return client.google + + +def _github_oauth(request: Request): + client = getattr(request.app.state, "github_oauth", None) + if client is None: + client = github_provider.build_github_client() + request.app.state.github_oauth = client + return client.github + + +@router.get("/google") +@limiter.limit("10/minute") +async def google_start(request: Request): + settings = get_settings() + redirect_uri = f"{settings.auth_base_url.rstrip('/')}/auth/google/callback" + return await _google_oauth(request).authorize_redirect(request, redirect_uri) + + +@router.get("/google/callback") +@limiter.limit("10/minute") +async def google_callback( + request: Request, + session: AsyncSession = Depends(get_session), +): + oauth = _google_oauth(request) + try: + token = await oauth.authorize_access_token(request) + except Exception as exc: + raise HTTPException(status.HTTP_400_BAD_REQUEST, f"google oauth failed: {exc}") from exc + + userinfo = token.get("userinfo") + if not userinfo: + userinfo = await oauth.userinfo(token=token) + + profile = google_provider.profile_from_userinfo(dict(userinfo)) + user = await user_service.upsert_from_oauth(session, profile) + + jwt_service = JWTService() + pair = jwt_service.issue_pair(user_id=str(user.id), email=user.email, name=user.name) + + settings = get_settings() + redirect = RedirectResponse( + url=f"{settings.frontend_url.rstrip('/')}/chat?access_token={pair.access_token}", + status_code=status.HTTP_302_FOUND, + ) + _set_refresh_cookie(redirect, pair.refresh_token, pair.refresh_expires_in) + return redirect + + +@router.get("/github") +@limiter.limit("10/minute") +async def github_start(request: Request): + settings = get_settings() + redirect_uri = f"{settings.auth_base_url.rstrip('/')}/auth/github/callback" + return await _github_oauth(request).authorize_redirect(request, redirect_uri) + + +@router.get("/github/callback") +@limiter.limit("10/minute") +async def github_callback( + request: Request, + session: AsyncSession = Depends(get_session), +): + oauth = _github_oauth(request) + try: + token = await oauth.authorize_access_token(request) + except Exception as exc: + raise HTTPException(status.HTTP_400_BAD_REQUEST, f"github oauth failed: {exc}") from exc + + user_resp = await oauth.get("user", token=token) + user_resp.raise_for_status() + userinfo = user_resp.json() + + emails: list[dict] | None = None + if not userinfo.get("email"): + emails_resp = await oauth.get("user/emails", token=token) + if emails_resp.status_code == 200: + emails = emails_resp.json() + + profile = github_provider.profile_from_userinfo(dict(userinfo), emails) + user = await user_service.upsert_from_oauth(session, profile) + + jwt_service = JWTService() + pair = jwt_service.issue_pair(user_id=str(user.id), email=user.email, name=user.name) + + settings = get_settings() + redirect = RedirectResponse( + url=f"{settings.frontend_url.rstrip('/')}/chat?access_token={pair.access_token}", + status_code=status.HTTP_302_FOUND, + ) + _set_refresh_cookie(redirect, pair.refresh_token, pair.refresh_expires_in) + return redirect diff --git a/services/auth/src/routes/session.py b/services/auth/src/routes/session.py new file mode 100644 index 00000000..bb7eed11 --- /dev/null +++ b/services/auth/src/routes/session.py @@ -0,0 +1,80 @@ +"""Session/token refresh routes + internal JWT validation.""" +from __future__ import annotations + +from fastapi import APIRouter, Depends, Header, HTTPException, Request, status +from fastapi.responses import JSONResponse +from pydantic import BaseModel +from sqlalchemy.ext.asyncio import AsyncSession + +from ..config import get_settings +from ..database import get_session +from ..rate_limit import limiter +from ..services import user_service +from ..services.jwt_service import JWTError, JWTService + +router = APIRouter(prefix="/auth", tags=["session"]) + + +class RefreshResponse(BaseModel): + access_token: str + expires_in: int + + +class ValidateRequest(BaseModel): + token: str + + +@router.post("/refresh", response_model=RefreshResponse) +@limiter.limit("30/minute") +async def refresh( + request: Request, + session: AsyncSession = Depends(get_session), +): + settings = get_settings() + cookie = request.cookies.get(settings.refresh_cookie_name) + if not cookie: + raise HTTPException(status.HTTP_401_UNAUTHORIZED, "missing refresh cookie") + + jwt_service = JWTService() + try: + payload = jwt_service.decode(cookie, expected_type="refresh") + except JWTError as exc: + raise HTTPException(status.HTTP_401_UNAUTHORIZED, str(exc)) from exc + + user = await user_service.get_by_id(session, payload["sub"]) + if user is None: + raise HTTPException(status.HTTP_401_UNAUTHORIZED, "user not found") + + access, ttl = jwt_service.issue_access_token( + user_id=str(user.id), email=user.email, name=user.name + ) + return RefreshResponse(access_token=access, expires_in=ttl) + + +@router.post("/validate") +async def validate( + payload: ValidateRequest, + session: AsyncSession = Depends(get_session), + x_internal_api_key: str | None = Header(default=None, alias="X-Internal-API-Key"), +): + settings = get_settings() + if not settings.internal_api_key or x_internal_api_key != settings.internal_api_key: + raise HTTPException(status.HTTP_403_FORBIDDEN, "invalid internal api key") + + jwt_service = JWTService() + try: + claims = jwt_service.decode(payload.token, expected_type="access") + except JWTError as exc: + return JSONResponse( + status_code=status.HTTP_401_UNAUTHORIZED, + content={"valid": False, "reason": str(exc)}, + ) + + user = await user_service.get_by_id(session, claims["sub"]) + if user is None: + return JSONResponse( + status_code=status.HTTP_401_UNAUTHORIZED, + content={"valid": False, "reason": "user not found"}, + ) + + return {"valid": True, "user": user.to_dict(), "claims": claims} diff --git a/services/auth/src/routes/users.py b/services/auth/src/routes/users.py new file mode 100644 index 00000000..d8a35795 --- /dev/null +++ b/services/auth/src/routes/users.py @@ -0,0 +1,46 @@ +"""User profile routes (GET /auth/me, PUT /auth/me).""" +from __future__ import annotations + +from fastapi import APIRouter, Depends +from pydantic import BaseModel, Field +from sqlalchemy.ext.asyncio import AsyncSession + +from ..database import get_session +from ..middleware.auth_middleware import AuthContext, require_user +from ..services import user_service + +router = APIRouter(prefix="/auth", tags=["users"]) + + +class UserProfile(BaseModel): + id: str + email: str + name: str | None + avatar_url: str | None + provider: str + provider_id: str + created_at: str | None + updated_at: str | None + last_login_at: str | None + + +class ProfileUpdate(BaseModel): + name: str | None = Field(default=None, max_length=255) + avatar_url: str | None = Field(default=None, max_length=2048) + + +@router.get("/me", response_model=UserProfile) +async def me(ctx: AuthContext = Depends(require_user)) -> UserProfile: + return UserProfile(**ctx.user.to_dict()) + + +@router.put("/me", response_model=UserProfile) +async def update_me( + payload: ProfileUpdate, + ctx: AuthContext = Depends(require_user), + session: AsyncSession = Depends(get_session), +) -> UserProfile: + user = await user_service.update_profile( + session, ctx.user, name=payload.name, avatar_url=payload.avatar_url + ) + return UserProfile(**user.to_dict()) diff --git a/services/auth/src/services/__init__.py b/services/auth/src/services/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/services/auth/src/services/github_oauth.py b/services/auth/src/services/github_oauth.py new file mode 100644 index 00000000..595f2453 --- /dev/null +++ b/services/auth/src/services/github_oauth.py @@ -0,0 +1,56 @@ +"""GitHub OAuth provider via authlib. + +Authorization URL: https://github.com/login/oauth/authorize +Token URL: https://github.com/login/oauth/access_token +User API: https://api.github.com/user +""" +from __future__ import annotations + +from typing import Any + +from authlib.integrations.starlette_client import OAuth + +from ..config import Settings, get_settings +from .google_oauth import OAuthProfile + +AUTHORIZE_URL = "https://github.com/login/oauth/authorize" +TOKEN_URL = "https://github.com/login/oauth/access_token" +USERINFO_URL = "https://api.github.com/user" +EMAILS_URL = "https://api.github.com/user/emails" + + +def build_github_client(settings: Settings | None = None) -> OAuth: + settings = settings or get_settings() + oauth = OAuth() + oauth.register( + name="github", + client_id=settings.github_client_id, + client_secret=settings.github_client_secret, + access_token_url=TOKEN_URL, + authorize_url=AUTHORIZE_URL, + api_base_url="https://api.github.com/", + client_kwargs={"scope": "read:user user:email"}, + ) + return oauth + + +def profile_from_userinfo(userinfo: dict[str, Any], emails: list[dict[str, Any]] | None) -> OAuthProfile: + provider_id = userinfo.get("id") + if provider_id is None: + raise ValueError("github userinfo missing id") + + email = userinfo.get("email") + if not email and emails: + primary = next((e for e in emails if e.get("primary") and e.get("verified")), None) + if primary: + email = primary.get("email") + if not email: + raise ValueError("github userinfo missing verified email") + + return OAuthProfile( + provider="github", + provider_id=str(provider_id), + email=str(email), + name=userinfo.get("name") or userinfo.get("login"), + avatar_url=userinfo.get("avatar_url"), + ) diff --git a/services/auth/src/services/google_oauth.py b/services/auth/src/services/google_oauth.py new file mode 100644 index 00000000..3feccffb --- /dev/null +++ b/services/auth/src/services/google_oauth.py @@ -0,0 +1,50 @@ +"""Google OAuth provider via authlib. + +Discovery URL: https://accounts.google.com/.well-known/openid-configuration +""" +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any + +from authlib.integrations.starlette_client import OAuth + +from ..config import Settings, get_settings + +DISCOVERY_URL = "https://accounts.google.com/.well-known/openid-configuration" + + +@dataclass +class OAuthProfile: + provider: str + provider_id: str + email: str + name: str | None + avatar_url: str | None + + +def build_google_client(settings: Settings | None = None) -> OAuth: + settings = settings or get_settings() + oauth = OAuth() + oauth.register( + name="google", + client_id=settings.google_client_id, + client_secret=settings.google_client_secret, + server_metadata_url=DISCOVERY_URL, + client_kwargs={"scope": "openid email profile"}, + ) + return oauth + + +def profile_from_userinfo(userinfo: dict[str, Any]) -> OAuthProfile: + provider_id = userinfo.get("sub") or userinfo.get("id") + email = userinfo.get("email") + if not provider_id or not email: + raise ValueError("google userinfo missing sub/email") + return OAuthProfile( + provider="google", + provider_id=str(provider_id), + email=str(email), + name=userinfo.get("name"), + avatar_url=userinfo.get("picture"), + ) diff --git a/services/auth/src/services/jwt_service.py b/services/auth/src/services/jwt_service.py new file mode 100644 index 00000000..0d6484f1 --- /dev/null +++ b/services/auth/src/services/jwt_service.py @@ -0,0 +1,87 @@ +"""RS256 JWT issuance + validation. + +Access tokens (1h) are returned to the client for Bearer auth. +Refresh tokens (7d) are stored in an httpOnly cookie. +""" +from __future__ import annotations + +import uuid +from dataclasses import dataclass +from datetime import datetime, timedelta, timezone +from typing import Any + +import jwt + +from ..config import Settings, get_settings + + +class JWTError(Exception): + """Raised when a token fails validation.""" + + +@dataclass +class TokenPair: + access_token: str + refresh_token: str + access_expires_in: int + refresh_expires_in: int + + +class JWTService: + def __init__(self, settings: Settings | None = None) -> None: + self._settings = settings or get_settings() + + def _now(self) -> datetime: + return datetime.now(timezone.utc) + + def issue_access_token(self, *, user_id: str, email: str, name: str | None) -> tuple[str, int]: + now = self._now() + exp = now + timedelta(seconds=self._settings.jwt_access_ttl_seconds) + payload = { + "sub": user_id, + "email": email, + "name": name, + "iat": int(now.timestamp()), + "exp": int(exp.timestamp()), + "iss": self._settings.jwt_issuer, + "type": "access", + } + token = jwt.encode(payload, self._settings.jwt_private_key, algorithm="RS256") + return token, self._settings.jwt_access_ttl_seconds + + def issue_refresh_token(self, *, user_id: str) -> tuple[str, int]: + now = self._now() + exp = now + timedelta(seconds=self._settings.jwt_refresh_ttl_seconds) + payload = { + "sub": user_id, + "iat": int(now.timestamp()), + "exp": int(exp.timestamp()), + "iss": self._settings.jwt_issuer, + "type": "refresh", + "jti": uuid.uuid4().hex, + } + token = jwt.encode(payload, self._settings.jwt_private_key, algorithm="RS256") + return token, self._settings.jwt_refresh_ttl_seconds + + def issue_pair(self, *, user_id: str, email: str, name: str | None) -> TokenPair: + access, access_ttl = self.issue_access_token(user_id=user_id, email=email, name=name) + refresh, refresh_ttl = self.issue_refresh_token(user_id=user_id) + return TokenPair(access, refresh, access_ttl, refresh_ttl) + + def decode(self, token: str, *, expected_type: str | None = None) -> dict[str, Any]: + try: + payload = jwt.decode( + token, + self._settings.jwt_public_key, + algorithms=["RS256"], + issuer=self._settings.jwt_issuer, + options={"require": ["exp", "iat", "sub", "iss"]}, + ) + except jwt.ExpiredSignatureError as exc: + raise JWTError("token expired") from exc + except jwt.InvalidTokenError as exc: + raise JWTError(f"invalid token: {exc}") from exc + + if expected_type and payload.get("type") != expected_type: + raise JWTError(f"expected {expected_type} token, got {payload.get('type')!r}") + return payload diff --git a/services/auth/src/services/user_service.py b/services/auth/src/services/user_service.py new file mode 100644 index 00000000..f07114de --- /dev/null +++ b/services/auth/src/services/user_service.py @@ -0,0 +1,75 @@ +"""User persistence helpers (upsert on OAuth callback, fetch by id).""" +from __future__ import annotations + +import uuid +from datetime import datetime, timezone + +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +from ..models.user import User +from .google_oauth import OAuthProfile + + +async def upsert_from_oauth(session: AsyncSession, profile: OAuthProfile) -> User: + """Insert a new user, or update last_login_at on an existing one.""" + stmt = select(User).where( + User.provider == profile.provider, + User.provider_id == profile.provider_id, + ) + existing = (await session.execute(stmt)).scalar_one_or_none() + + now = datetime.now(timezone.utc) + if existing is None: + user = User( + email=profile.email, + name=profile.name, + avatar_url=profile.avatar_url, + provider=profile.provider, + provider_id=profile.provider_id, + created_at=now, + updated_at=now, + last_login_at=now, + ) + session.add(user) + await session.commit() + await session.refresh(user) + return user + + existing.last_login_at = now + existing.updated_at = now + if profile.email: + existing.email = profile.email + if profile.name is not None: + existing.name = profile.name + if profile.avatar_url is not None: + existing.avatar_url = profile.avatar_url + await session.commit() + await session.refresh(existing) + return existing + + +async def get_by_id(session: AsyncSession, user_id: str | uuid.UUID) -> User | None: + if isinstance(user_id, str): + try: + user_id = uuid.UUID(user_id) + except ValueError: + return None + return await session.get(User, user_id) + + +async def update_profile( + session: AsyncSession, + user: User, + *, + name: str | None, + avatar_url: str | None, +) -> User: + if name is not None: + user.name = name + if avatar_url is not None: + user.avatar_url = avatar_url + user.updated_at = datetime.now(timezone.utc) + await session.commit() + await session.refresh(user) + return user diff --git a/services/auth/src/tests/__init__.py b/services/auth/src/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/services/auth/src/tests/conftest.py b/services/auth/src/tests/conftest.py new file mode 100644 index 00000000..553976ef --- /dev/null +++ b/services/auth/src/tests/conftest.py @@ -0,0 +1,84 @@ +"""Shared pytest fixtures: RSA keys, in-memory DB, FastAPI test client.""" +from __future__ import annotations + +import os +import uuid +from collections.abc import AsyncIterator + +import pytest +import pytest_asyncio +from cryptography.hazmat.primitives import serialization +from cryptography.hazmat.primitives.asymmetric import rsa +from httpx import ASGITransport, AsyncClient +from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine + + +def _make_rsa_pair() -> tuple[str, str]: + key = rsa.generate_private_key(public_exponent=65537, key_size=2048) + private = key.private_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.PKCS8, + encryption_algorithm=serialization.NoEncryption(), + ).decode() + public = key.public_key().public_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PublicFormat.SubjectPublicKeyInfo, + ).decode() + return private, public + + +@pytest.fixture(scope="session", autouse=True) +def _test_environment(): + private, public = _make_rsa_pair() + os.environ["JWT_PRIVATE_KEY"] = private + os.environ["JWT_PUBLIC_KEY"] = public + os.environ["INTERNAL_API_KEY"] = "test-internal-key" + os.environ["FRONTEND_URL"] = "http://localhost:3000" + os.environ["AUTH_BASE_URL"] = "http://localhost:8001" + os.environ["SESSION_SECRET"] = "test-session-secret" + os.environ["DATABASE_URL"] = "sqlite+aiosqlite:///:memory:" + os.environ["GOOGLE_CLIENT_ID"] = "g-id" + os.environ["GOOGLE_CLIENT_SECRET"] = "g-secret" + os.environ["GITHUB_CLIENT_ID"] = "gh-id" + os.environ["GITHUB_CLIENT_SECRET"] = "gh-secret" + + from src.config import get_settings + + get_settings.cache_clear() + yield + + +@pytest_asyncio.fixture +async def session_factory() -> AsyncIterator[async_sessionmaker[AsyncSession]]: + from src.database import Base, override_session_factory + from src.models.user import User # noqa: F401 — register on Base + + engine = create_async_engine( + f"sqlite+aiosqlite:///file:test_{uuid.uuid4().hex}?mode=memory&cache=shared&uri=true", + connect_args={"uri": True}, + ) + async with engine.begin() as conn: + await conn.run_sync(Base.metadata.create_all) + + factory = async_sessionmaker(engine, expire_on_commit=False) + override_session_factory(factory) + try: + yield factory + finally: + await engine.dispose() + + +@pytest_asyncio.fixture +async def db_session(session_factory) -> AsyncIterator[AsyncSession]: + async with session_factory() as session: + yield session + + +@pytest_asyncio.fixture +async def client(session_factory) -> AsyncIterator[AsyncClient]: + from src.main import create_app + + app = create_app() + transport = ASGITransport(app=app) + async with AsyncClient(transport=transport, base_url="http://testserver") as ac: + yield ac diff --git a/services/auth/src/tests/test_me_endpoint.py b/services/auth/src/tests/test_me_endpoint.py new file mode 100644 index 00000000..c45aabc8 --- /dev/null +++ b/services/auth/src/tests/test_me_endpoint.py @@ -0,0 +1,59 @@ +"""GET /auth/me and PUT /auth/me via Bearer JWT.""" +from __future__ import annotations + +import pytest + +from src.services import user_service +from src.services.google_oauth import OAuthProfile +from src.services.jwt_service import JWTService + + +@pytest.mark.asyncio +async def test_me_returns_profile(client, db_session): + profile = OAuthProfile( + provider="github", provider_id="gh-1", email="me@x.co", name="Me", avatar_url=None + ) + user = await user_service.upsert_from_oauth(db_session, profile) + token, _ = JWTService().issue_access_token( + user_id=str(user.id), email=user.email, name=user.name + ) + + resp = await client.get("/auth/me", headers={"Authorization": f"Bearer {token}"}) + assert resp.status_code == 200 + body = resp.json() + assert body["email"] == "me@x.co" + assert body["provider"] == "github" + + +@pytest.mark.asyncio +async def test_me_requires_bearer(client): + resp = await client.get("/auth/me") + assert resp.status_code == 401 + + +@pytest.mark.asyncio +async def test_update_me(client, db_session): + profile = OAuthProfile( + provider="google", provider_id="g-9", email="x@y.co", name="Old", avatar_url=None + ) + user = await user_service.upsert_from_oauth(db_session, profile) + token, _ = JWTService().issue_access_token( + user_id=str(user.id), email=user.email, name=user.name + ) + + resp = await client.put( + "/auth/me", + json={"name": "New Name", "avatar_url": "https://img/x.png"}, + headers={"Authorization": f"Bearer {token}"}, + ) + assert resp.status_code == 200 + body = resp.json() + assert body["name"] == "New Name" + assert body["avatar_url"] == "https://img/x.png" + + +@pytest.mark.asyncio +async def test_health(client): + resp = await client.get("/auth/health") + assert resp.status_code == 200 + assert resp.json() == {"status": "ok"} diff --git a/services/auth/src/tests/test_oauth_flow.py b/services/auth/src/tests/test_oauth_flow.py new file mode 100644 index 00000000..c46caa72 --- /dev/null +++ b/services/auth/src/tests/test_oauth_flow.py @@ -0,0 +1,97 @@ +"""OAuth callback end-to-end with mocked authlib providers.""" +from __future__ import annotations + +from types import SimpleNamespace + +import pytest + +from src.services import user_service +from src.services.google_oauth import OAuthProfile + + +class _MockGoogleClient: + """Stands in for authlib's StarletteOAuth2App.""" + + def __init__(self, userinfo): + self._userinfo = userinfo + + async def authorize_access_token(self, request): + return {"access_token": "fake", "userinfo": self._userinfo} + + async def userinfo(self, token=None): + return self._userinfo + + +@pytest.mark.asyncio +async def test_google_callback_creates_user_and_sets_refresh_cookie(client, db_session): + userinfo = { + "sub": "google-42", + "email": "new@user.co", + "name": "New User", + "picture": "https://img/new.png", + } + app = client._transport.app # type: ignore[attr-defined] + app.state.google_oauth = SimpleNamespace(google=_MockGoogleClient(userinfo)) + + resp = await client.get("/auth/google/callback", follow_redirects=False) + assert resp.status_code == 302 + assert "access_token=" in resp.headers["location"] + + from src.config import get_settings + + cookie_name = get_settings().refresh_cookie_name + assert cookie_name in resp.cookies + + # Verify user was persisted. + from sqlalchemy import select + + from src.models.user import User + + async with db_session.bind._async_engine.connect() if False else db_session as s: # type: ignore[attr-defined] + pass + # Simpler: just query through a fresh session. + from src.database import get_session_factory + + async with get_session_factory()() as s: + user = ( + await s.execute(select(User).where(User.provider_id == "google-42")) + ).scalar_one() + assert user.email == "new@user.co" + + +@pytest.mark.asyncio +async def test_google_callback_updates_existing(client, db_session): + # Seed an existing user. + profile = OAuthProfile( + provider="google", + provider_id="google-99", + email="old@user.co", + name="Old", + avatar_url=None, + ) + existing = await user_service.upsert_from_oauth(db_session, profile) + original_login = existing.last_login_at + + userinfo = { + "sub": "google-99", + "email": "old@user.co", + "name": "Updated Name", + "picture": "https://img/u.png", + } + app = client._transport.app # type: ignore[attr-defined] + app.state.google_oauth = SimpleNamespace(google=_MockGoogleClient(userinfo)) + + resp = await client.get("/auth/google/callback", follow_redirects=False) + assert resp.status_code == 302 + + from sqlalchemy import select + + from src.database import get_session_factory + from src.models.user import User + + async with get_session_factory()() as s: + refreshed = ( + await s.execute(select(User).where(User.id == existing.id)) + ).scalar_one() + assert refreshed.name == "Updated Name" + assert refreshed.last_login_at >= original_login diff --git a/services/auth/src/tests/test_rate_limit.py b/services/auth/src/tests/test_rate_limit.py new file mode 100644 index 00000000..b3f5994c --- /dev/null +++ b/services/auth/src/tests/test_rate_limit.py @@ -0,0 +1,29 @@ +"""Rate limiter applies 10/min to OAuth start routes.""" +from __future__ import annotations + +import pytest + + +@pytest.mark.asyncio +async def test_google_start_rate_limited(client, monkeypatch): + # Replace the OAuth client with a stub so /auth/google returns immediately. + async def _stub_redirect(request, redirect_uri): + from fastapi.responses import RedirectResponse + + return RedirectResponse(url=redirect_uri, status_code=302) + + class _StubProvider: + authorize_redirect = staticmethod(_stub_redirect) + + class _StubClient: + google = _StubProvider() + + app = client._transport.app # type: ignore[attr-defined] + app.state.google_oauth = _StubClient() + + # First 10 calls allowed, 11th should be rate-limited. + codes = [] + for _ in range(11): + resp = await client.get("/auth/google", follow_redirects=False) + codes.append(resp.status_code) + assert codes.count(429) >= 1 diff --git a/services/auth/src/tests/test_user_upsert.py b/services/auth/src/tests/test_user_upsert.py new file mode 100644 index 00000000..8e9f5860 --- /dev/null +++ b/services/auth/src/tests/test_user_upsert.py @@ -0,0 +1,68 @@ +"""User upsert flow invoked by OAuth callbacks.""" +from __future__ import annotations + +import pytest + +from src.services import user_service +from src.services.google_oauth import OAuthProfile + + +@pytest.mark.asyncio +async def test_upsert_creates_new_user(db_session): + profile = OAuthProfile( + provider="google", + provider_id="g-12345", + email="alice@example.com", + name="Alice", + avatar_url="https://img/alice.png", + ) + user = await user_service.upsert_from_oauth(db_session, profile) + assert user.id is not None + assert user.email == "alice@example.com" + assert user.provider == "google" + assert user.last_login_at is not None + + +@pytest.mark.asyncio +async def test_upsert_updates_existing_user(db_session): + first = OAuthProfile( + provider="google", + provider_id="g-12345", + email="alice@example.com", + name="Alice", + avatar_url=None, + ) + u1 = await user_service.upsert_from_oauth(db_session, first) + original_login = u1.last_login_at + + # Second login with updated display name + avatar — same provider identity. + second = OAuthProfile( + provider="google", + provider_id="g-12345", + email="alice@example.com", + name="Alice Smith", + avatar_url="https://img/alice2.png", + ) + u2 = await user_service.upsert_from_oauth(db_session, second) + + assert u2.id == u1.id # same row + assert u2.name == "Alice Smith" + assert u2.avatar_url == "https://img/alice2.png" + assert u2.last_login_at >= original_login + + +@pytest.mark.asyncio +async def test_update_profile(db_session): + profile = OAuthProfile( + provider="github", + provider_id="42", + email="bob@example.com", + name="Bob", + avatar_url=None, + ) + user = await user_service.upsert_from_oauth(db_session, profile) + updated = await user_service.update_profile( + db_session, user, name="Robert", avatar_url="https://img/bob.png" + ) + assert updated.name == "Robert" + assert updated.avatar_url == "https://img/bob.png" diff --git a/services/auth/src/tests/test_validate_endpoint.py b/services/auth/src/tests/test_validate_endpoint.py new file mode 100644 index 00000000..1851f18f --- /dev/null +++ b/services/auth/src/tests/test_validate_endpoint.py @@ -0,0 +1,62 @@ +"""Internal /auth/validate endpoint used by the Chat API service.""" +from __future__ import annotations + +import pytest + +from src.services import user_service +from src.services.google_oauth import OAuthProfile +from src.services.jwt_service import JWTService + + +@pytest.mark.asyncio +async def test_validate_requires_internal_key(client, db_session): + profile = OAuthProfile( + provider="google", provider_id="123", email="v@x.co", name="V", avatar_url=None + ) + user = await user_service.upsert_from_oauth(db_session, profile) + token, _ = JWTService().issue_access_token( + user_id=str(user.id), email=user.email, name=user.name + ) + + missing = await client.post("/auth/validate", json={"token": token}) + assert missing.status_code == 403 + + wrong = await client.post( + "/auth/validate", + json={"token": token}, + headers={"X-Internal-API-Key": "nope"}, + ) + assert wrong.status_code == 403 + + +@pytest.mark.asyncio +async def test_validate_returns_user_for_valid_token(client, db_session): + profile = OAuthProfile( + provider="google", provider_id="456", email="v2@x.co", name="V2", avatar_url=None + ) + user = await user_service.upsert_from_oauth(db_session, profile) + token, _ = JWTService().issue_access_token( + user_id=str(user.id), email=user.email, name=user.name + ) + + resp = await client.post( + "/auth/validate", + json={"token": token}, + headers={"X-Internal-API-Key": "test-internal-key"}, + ) + assert resp.status_code == 200 + body = resp.json() + assert body["valid"] is True + assert body["user"]["email"] == "v2@x.co" + assert body["claims"]["sub"] == str(user.id) + + +@pytest.mark.asyncio +async def test_validate_rejects_tampered_token(client): + resp = await client.post( + "/auth/validate", + json={"token": "not-a-jwt"}, + headers={"X-Internal-API-Key": "test-internal-key"}, + ) + assert resp.status_code == 401 + assert resp.json()["valid"] is False