Merge pull request #16 from manmohan659/feat/auth-service

feat(auth): OAuth2 + JWT auth service with Alembic migrations (#5 #7)
This commit is contained in:
Manmohan 2026-04-16 14:56:51 -04:00 committed by GitHub
commit 4297817cfb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
34 changed files with 1586 additions and 49 deletions

41
db/alembic.ini Normal file
View File

@ -0,0 +1,41 @@
[alembic]
script_location = %(here)s/migrations
prepend_sys_path = .
version_path_separator = os
sqlalchemy.url = driver://user:pass@localhost/dbname
[post_write_hooks]
[loggers]
keys = root,sqlalchemy,alembic
[handlers]
keys = console
[formatters]
keys = generic
[logger_root]
level = WARN
handlers = console
qualname =
[logger_sqlalchemy]
level = WARN
handlers =
qualname = sqlalchemy.engine
[logger_alembic]
level = INFO
handlers =
qualname = alembic
[handler_console]
class = StreamHandler
args = (sys.stderr,)
level = NOTSET
formatter = generic
[formatter_generic]
format = %(levelname)-5.5s [%(name)s] %(message)s
datefmt = %H:%M:%S

View File

@ -1,42 +0,0 @@
CREATE EXTENSION IF NOT EXISTS pgcrypto;
CREATE TABLE IF NOT EXISTS users (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
email TEXT NOT NULL UNIQUE,
name TEXT NOT NULL,
avatar_url TEXT,
provider TEXT NOT NULL,
provider_id TEXT NOT NULL,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
last_login_at TIMESTAMPTZ
);
CREATE UNIQUE INDEX IF NOT EXISTS users_provider_lookup_idx
ON users (provider, provider_id);
CREATE TABLE IF NOT EXISTS conversations (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
user_id UUID NOT NULL REFERENCES users(id) ON DELETE CASCADE,
title TEXT NOT NULL,
model_tag TEXT NOT NULL,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
CREATE INDEX IF NOT EXISTS conversations_user_id_idx
ON conversations (user_id);
CREATE TABLE IF NOT EXISTS messages (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
conversation_id UUID NOT NULL REFERENCES conversations(id) ON DELETE CASCADE,
role TEXT NOT NULL,
content TEXT NOT NULL,
token_count INTEGER NOT NULL DEFAULT 0,
model_tag TEXT NOT NULL,
inference_time_ms INTEGER NOT NULL DEFAULT 0,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
CREATE INDEX IF NOT EXISTS messages_conversation_id_idx
ON messages (conversation_id, created_at);

67
db/migrations/env.py Normal file
View File

@ -0,0 +1,67 @@
"""Alembic environment configuration for async PostgreSQL."""
from __future__ import annotations
import asyncio
import os
from logging.config import fileConfig
from alembic import context
from sqlalchemy import pool
from sqlalchemy.engine import Connection
from sqlalchemy.ext.asyncio import async_engine_from_config
config = context.config
if config.config_file_name is not None:
fileConfig(config.config_file_name)
target_metadata = None
def _database_url() -> str:
url = os.environ.get("DATABASE_URL")
if not url:
raise RuntimeError("DATABASE_URL environment variable is required for Alembic")
if url.startswith("postgresql://"):
url = url.replace("postgresql://", "postgresql+asyncpg://", 1)
return url
def run_migrations_offline() -> None:
context.configure(
url=_database_url(),
target_metadata=target_metadata,
literal_binds=True,
dialect_opts={"paramstyle": "named"},
)
with context.begin_transaction():
context.run_migrations()
def do_run_migrations(connection: Connection) -> None:
context.configure(connection=connection, target_metadata=target_metadata)
with context.begin_transaction():
context.run_migrations()
async def run_async_migrations() -> None:
configuration = config.get_section(config.config_ini_section) or {}
configuration["sqlalchemy.url"] = _database_url()
connectable = async_engine_from_config(
configuration,
prefix="sqlalchemy.",
poolclass=pool.NullPool,
)
async with connectable.connect() as connection:
await connection.run_sync(do_run_migrations)
await connectable.dispose()
def run_migrations_online() -> None:
asyncio.run(run_async_migrations())
if context.is_offline_mode():
run_migrations_offline()
else:
run_migrations_online()

View File

@ -0,0 +1,27 @@
"""${message}
Revision ID: ${up_revision}
Revises: ${down_revision | comma,n}
Create Date: ${create_date}
"""
from __future__ import annotations
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
${imports if imports else ""}
revision: str = ${repr(up_revision)}
down_revision: Union[str, None] = ${repr(down_revision)}
branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)}
depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)}
def upgrade() -> None:
${upgrades if upgrades else "pass"}
def downgrade() -> None:
${downgrades if downgrades else "pass"}

View File

@ -0,0 +1,57 @@
"""create users table
Revision ID: 001_create_users
Revises:
Create Date: 2026-04-16
"""
from __future__ import annotations
from typing import Sequence, Union
import sqlalchemy as sa
from alembic import op
from sqlalchemy.dialects import postgresql
revision: str = "001_create_users"
down_revision: Union[str, None] = None
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
op.execute("CREATE EXTENSION IF NOT EXISTS pgcrypto")
op.create_table(
"users",
sa.Column(
"id",
postgresql.UUID(as_uuid=True),
primary_key=True,
server_default=sa.text("gen_random_uuid()"),
),
sa.Column("email", sa.String(length=255), nullable=False, unique=True),
sa.Column("name", sa.String(length=255), nullable=True),
sa.Column("avatar_url", sa.Text(), nullable=True),
sa.Column("provider", sa.String(length=50), nullable=False),
sa.Column("provider_id", sa.String(length=255), nullable=False),
sa.Column(
"created_at",
sa.TIMESTAMP(timezone=True),
server_default=sa.text("NOW()"),
),
sa.Column(
"updated_at",
sa.TIMESTAMP(timezone=True),
server_default=sa.text("NOW()"),
),
sa.Column(
"last_login_at",
sa.TIMESTAMP(timezone=True),
server_default=sa.text("NOW()"),
),
sa.UniqueConstraint("provider", "provider_id", name="uq_users_provider_identity"),
)
def downgrade() -> None:
op.drop_table("users")

View File

@ -0,0 +1,64 @@
"""create conversations table
Revision ID: 002_create_conversations
Revises: 001_create_users
Create Date: 2026-04-16
"""
from __future__ import annotations
from typing import Sequence, Union
import sqlalchemy as sa
from alembic import op
from sqlalchemy.dialects import postgresql
revision: str = "002_create_conversations"
down_revision: Union[str, None] = "001_create_users"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
op.create_table(
"conversations",
sa.Column(
"id",
postgresql.UUID(as_uuid=True),
primary_key=True,
server_default=sa.text("gen_random_uuid()"),
),
sa.Column(
"user_id",
postgresql.UUID(as_uuid=True),
sa.ForeignKey("users.id", ondelete="CASCADE"),
nullable=False,
),
sa.Column("title", sa.String(length=500), nullable=True),
sa.Column(
"model_tag",
sa.String(length=100),
nullable=True,
server_default=sa.text("'default'"),
),
sa.Column(
"created_at",
sa.TIMESTAMP(timezone=True),
server_default=sa.text("NOW()"),
),
sa.Column(
"updated_at",
sa.TIMESTAMP(timezone=True),
server_default=sa.text("NOW()"),
),
)
op.create_index(
"idx_conversations_user",
"conversations",
["user_id", sa.text("updated_at DESC")],
)
def downgrade() -> None:
op.drop_index("idx_conversations_user", table_name="conversations")
op.drop_table("conversations")

View File

@ -0,0 +1,61 @@
"""create messages table
Revision ID: 003_create_messages
Revises: 002_create_conversations
Create Date: 2026-04-16
"""
from __future__ import annotations
from typing import Sequence, Union
import sqlalchemy as sa
from alembic import op
from sqlalchemy.dialects import postgresql
revision: str = "003_create_messages"
down_revision: Union[str, None] = "002_create_conversations"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
op.create_table(
"messages",
sa.Column(
"id",
postgresql.UUID(as_uuid=True),
primary_key=True,
server_default=sa.text("gen_random_uuid()"),
),
sa.Column(
"conversation_id",
postgresql.UUID(as_uuid=True),
sa.ForeignKey("conversations.id", ondelete="CASCADE"),
nullable=False,
),
sa.Column("role", sa.String(length=20), nullable=False),
sa.Column("content", sa.Text(), nullable=False),
sa.Column("token_count", sa.Integer(), nullable=True),
sa.Column("model_tag", sa.String(length=100), nullable=True),
sa.Column("inference_time_ms", sa.Integer(), nullable=True),
sa.Column(
"created_at",
sa.TIMESTAMP(timezone=True),
server_default=sa.text("NOW()"),
),
sa.CheckConstraint(
"role IN ('user','assistant','system')",
name="ck_messages_role",
),
)
op.create_index(
"idx_messages_conversation",
"messages",
["conversation_id", sa.text("created_at ASC")],
)
def downgrade() -> None:
op.drop_index("idx_messages_conversation", table_name="messages")
op.drop_table("messages")

View File

@ -0,0 +1,34 @@
"""add is_favorited column to conversations (Day 2 demo)
Revision ID: 004_add_favorited
Revises: 003_create_messages
Create Date: 2026-04-16
"""
from __future__ import annotations
from typing import Sequence, Union
import sqlalchemy as sa
from alembic import op
revision: str = "004_add_favorited"
down_revision: Union[str, None] = "003_create_messages"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
op.add_column(
"conversations",
sa.Column(
"is_favorited",
sa.Boolean(),
nullable=False,
server_default=sa.text("false"),
),
)
def downgrade() -> None:
op.drop_column("conversations", "is_favorited")

View File

@ -1,9 +1,38 @@
FROM python:3.12-slim
ENV PYTHONDONTWRITEBYTECODE=1 \
PYTHONUNBUFFERED=1 \
UV_SYSTEM_PYTHON=1 \
UV_LINK_MODE=copy
RUN apt-get update \
&& apt-get install -y --no-install-recommends build-essential curl \
&& rm -rf /var/lib/apt/lists/*
RUN pip install --no-cache-dir uv==0.4.30
WORKDIR /app
COPY README.md /app/README.md
COPY pyproject.toml README.md /app/
RUN uv pip install --system --no-cache \
"fastapi>=0.117.1" \
"uvicorn[standard]>=0.36.0" \
"pydantic>=2.8.0" \
"pydantic-settings>=2.4.0" \
"sqlalchemy[asyncio]>=2.0.36" \
"asyncpg>=0.29.0" \
"alembic>=1.13.0" \
"authlib>=1.3.2" \
"httpx>=0.27.0" \
"itsdangerous>=2.2.0" \
"pyjwt>=2.9.0" \
"cryptography>=43.0.0" \
"slowapi>=0.1.9" \
"python-multipart>=0.0.9"
COPY src /app/src
EXPOSE 8001
CMD ["python", "-m", "http.server", "8001", "--bind", "0.0.0.0"]
CMD ["uvicorn", "src.main:app", "--host", "0.0.0.0", "--port", "8001"]

View File

@ -1,7 +1,46 @@
# Auth Service
# samosaChaat Auth Service
Scaffold placeholder for Issue #5.
FastAPI microservice providing OAuth2 login (Google + GitHub) and JWT session
management for samosaChaat (Issue #5).
The monorepo branch provisions this directory and a minimal Docker image so
local `docker compose up` remains viable before the real auth service
implementation lands.
## Endpoints
| Method | Path | Purpose |
| ------ | ---- | ------- |
| GET | `/auth/google` | Redirect to Google consent |
| GET | `/auth/google/callback` | Complete Google flow, upsert user, issue tokens |
| GET | `/auth/github` | Redirect to GitHub consent |
| GET | `/auth/github/callback` | Complete GitHub flow, upsert user, issue tokens |
| POST | `/auth/refresh` | Exchange refresh cookie for new access token |
| GET | `/auth/me` | Current user profile (Bearer JWT) |
| PUT | `/auth/me` | Update name / avatar (Bearer JWT) |
| POST | `/auth/validate` | Internal JWT validation (service-to-service) |
| GET | `/auth/health` | Liveness probe |
## Environment
```
DATABASE_URL=postgresql+asyncpg://user:pass@host/db
GOOGLE_CLIENT_ID=...
GOOGLE_CLIENT_SECRET=...
GITHUB_CLIENT_ID=...
GITHUB_CLIENT_SECRET=...
JWT_PRIVATE_KEY=<RS256 PEM>
JWT_PUBLIC_KEY=<RS256 PEM>
FRONTEND_URL=http://localhost:3000
INTERNAL_API_KEY=<shared secret for /auth/validate>
```
## Local development
```
uv sync
uv run uvicorn src.main:app --reload --port 8001
uv run pytest
```
Database schema is managed by Alembic at `db/migrations`:
```
DATABASE_URL=... uv run alembic -c db/alembic.ini upgrade head
```

View File

@ -0,0 +1,40 @@
[project]
name = "samosachaat-auth"
version = "0.1.0"
description = "samosaChaat authentication service (OAuth2 + JWT)"
readme = "README.md"
requires-python = ">=3.12"
dependencies = [
"fastapi>=0.117.1",
"uvicorn[standard]>=0.36.0",
"pydantic>=2.8.0",
"pydantic-settings>=2.4.0",
"sqlalchemy[asyncio]>=2.0.36",
"asyncpg>=0.29.0",
"alembic>=1.13.0",
"authlib>=1.3.2",
"httpx>=0.27.0",
"itsdangerous>=2.2.0",
"pyjwt>=2.9.0",
"cryptography>=43.0.0",
"slowapi>=0.1.9",
"python-multipart>=0.0.9",
]
[dependency-groups]
dev = [
"pytest>=8.0.0",
"pytest-asyncio>=0.24.0",
"httpx>=0.27.0",
"aiosqlite>=0.20.0",
"respx>=0.21.1",
]
[tool.pytest.ini_options]
asyncio_mode = "auto"
testpaths = ["src/tests"]
python_files = ["test_*.py"]
pythonpath = ["."]
[tool.uv]
package = false

View File

View File

@ -0,0 +1,47 @@
"""Runtime configuration for the auth service.
All configuration is loaded from environment variables using pydantic-settings.
Private/public keys are PEM-encoded RSA material used for RS256 JWTs.
"""
from __future__ import annotations
from functools import lru_cache
from pydantic import Field
from pydantic_settings import BaseSettings, SettingsConfigDict
class Settings(BaseSettings):
model_config = SettingsConfigDict(env_file=".env", extra="ignore")
database_url: str = Field(default="postgresql+asyncpg://localhost/samosachaat")
google_client_id: str = Field(default="")
google_client_secret: str = Field(default="")
github_client_id: str = Field(default="")
github_client_secret: str = Field(default="")
jwt_private_key: str = Field(default="")
jwt_public_key: str = Field(default="")
jwt_issuer: str = Field(default="samosachaat-auth")
jwt_access_ttl_seconds: int = Field(default=3600)
jwt_refresh_ttl_seconds: int = Field(default=7 * 24 * 3600)
frontend_url: str = Field(default="http://localhost:3000")
internal_api_key: str = Field(default="")
auth_base_url: str = Field(default="http://localhost:8001")
session_secret: str = Field(default="dev-session-secret-change-me")
cookie_secure: bool = Field(default=False)
cookie_domain: str | None = Field(default=None)
@property
def refresh_cookie_name(self) -> str:
return "samosachaat_refresh"
@lru_cache(maxsize=1)
def get_settings() -> Settings:
return Settings()

View File

@ -0,0 +1,49 @@
"""Async SQLAlchemy engine + session factory."""
from __future__ import annotations
from collections.abc import AsyncIterator
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from sqlalchemy.orm import DeclarativeBase
from .config import get_settings
class Base(DeclarativeBase):
"""Shared declarative base for all ORM models."""
_engine = None
_session_factory: async_sessionmaker[AsyncSession] | None = None
def _build_engine():
global _engine, _session_factory
settings = get_settings()
_engine = create_async_engine(settings.database_url, pool_pre_ping=True)
_session_factory = async_sessionmaker(_engine, expire_on_commit=False)
def get_engine():
if _engine is None:
_build_engine()
return _engine
def get_session_factory() -> async_sessionmaker[AsyncSession]:
if _session_factory is None:
_build_engine()
assert _session_factory is not None
return _session_factory
async def get_session() -> AsyncIterator[AsyncSession]:
factory = get_session_factory()
async with factory() as session:
yield session
def override_session_factory(factory: async_sessionmaker[AsyncSession]) -> None:
"""Testing hook: swap the session factory for an in-memory engine."""
global _session_factory
_session_factory = factory

50
services/auth/src/main.py Normal file
View File

@ -0,0 +1,50 @@
"""FastAPI entrypoint for the samosaChaat auth service."""
from __future__ import annotations
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from slowapi.errors import RateLimitExceeded
from slowapi.middleware import SlowAPIMiddleware
from starlette.middleware.sessions import SessionMiddleware
from .config import get_settings
from .rate_limit import limiter
from .routes import oauth, session, users
def _rate_limit_handler(request, exc: RateLimitExceeded):
return JSONResponse(status_code=429, content={"detail": "rate limit exceeded"})
def create_app() -> FastAPI:
settings = get_settings()
app = FastAPI(title="samosaChaat Auth", version="0.1.0")
app.state.limiter = limiter
app.add_exception_handler(RateLimitExceeded, _rate_limit_handler)
app.add_middleware(SlowAPIMiddleware)
# SessionMiddleware is required by authlib for the OAuth state cookie.
app.add_middleware(SessionMiddleware, secret_key=settings.session_secret)
app.add_middleware(
CORSMiddleware,
allow_origins=[settings.frontend_url],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
app.include_router(oauth.router)
app.include_router(session.router)
app.include_router(users.router)
@app.get("/auth/health")
async def health():
return {"status": "ok"}
return app
app = create_app()

View File

View File

@ -0,0 +1,44 @@
"""Bearer-token auth dependency."""
from __future__ import annotations
from dataclasses import dataclass
from fastapi import Depends, HTTPException, Request, status
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from sqlalchemy.ext.asyncio import AsyncSession
from ..database import get_session
from ..models.user import User
from ..services import user_service
from ..services.jwt_service import JWTError, JWTService
bearer_scheme = HTTPBearer(auto_error=False)
@dataclass
class AuthContext:
user: User
payload: dict
async def require_user(
request: Request,
credentials: HTTPAuthorizationCredentials | None = Depends(bearer_scheme),
session: AsyncSession = Depends(get_session),
) -> AuthContext:
if credentials is None or credentials.scheme.lower() != "bearer":
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "missing bearer token")
jwt_service = JWTService()
try:
payload = jwt_service.decode(credentials.credentials, expected_type="access")
except JWTError as exc:
raise HTTPException(status.HTTP_401_UNAUTHORIZED, str(exc)) from exc
user = await user_service.get_by_id(session, payload["sub"])
if user is None:
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "user not found")
ctx = AuthContext(user=user, payload=payload)
request.state.auth = ctx
return ctx

View File

@ -0,0 +1,7 @@
"""Shared slowapi limiter for login-facing routes."""
from __future__ import annotations
from slowapi import Limiter
from slowapi.util import get_remote_address
limiter = Limiter(key_func=get_remote_address)

View File

View File

@ -0,0 +1,130 @@
"""OAuth start/callback routes for Google and GitHub."""
from __future__ import annotations
from fastapi import APIRouter, Depends, HTTPException, Request, status
from fastapi.responses import RedirectResponse
from sqlalchemy.ext.asyncio import AsyncSession
from ..config import get_settings
from ..database import get_session
from ..rate_limit import limiter
from ..services import github_oauth as github_provider
from ..services import google_oauth as google_provider
from ..services import user_service
from ..services.jwt_service import JWTService
router = APIRouter(prefix="/auth", tags=["oauth"])
def _set_refresh_cookie(response: RedirectResponse, token: str, max_age: int) -> None:
settings = get_settings()
response.set_cookie(
key=settings.refresh_cookie_name,
value=token,
max_age=max_age,
httponly=True,
secure=settings.cookie_secure,
samesite="lax",
domain=settings.cookie_domain,
path="/",
)
def _google_oauth(request: Request):
client = getattr(request.app.state, "google_oauth", None)
if client is None:
client = google_provider.build_google_client()
request.app.state.google_oauth = client
return client.google
def _github_oauth(request: Request):
client = getattr(request.app.state, "github_oauth", None)
if client is None:
client = github_provider.build_github_client()
request.app.state.github_oauth = client
return client.github
@router.get("/google")
@limiter.limit("10/minute")
async def google_start(request: Request):
settings = get_settings()
redirect_uri = f"{settings.auth_base_url.rstrip('/')}/auth/google/callback"
return await _google_oauth(request).authorize_redirect(request, redirect_uri)
@router.get("/google/callback")
@limiter.limit("10/minute")
async def google_callback(
request: Request,
session: AsyncSession = Depends(get_session),
):
oauth = _google_oauth(request)
try:
token = await oauth.authorize_access_token(request)
except Exception as exc:
raise HTTPException(status.HTTP_400_BAD_REQUEST, f"google oauth failed: {exc}") from exc
userinfo = token.get("userinfo")
if not userinfo:
userinfo = await oauth.userinfo(token=token)
profile = google_provider.profile_from_userinfo(dict(userinfo))
user = await user_service.upsert_from_oauth(session, profile)
jwt_service = JWTService()
pair = jwt_service.issue_pair(user_id=str(user.id), email=user.email, name=user.name)
settings = get_settings()
redirect = RedirectResponse(
url=f"{settings.frontend_url.rstrip('/')}/chat?access_token={pair.access_token}",
status_code=status.HTTP_302_FOUND,
)
_set_refresh_cookie(redirect, pair.refresh_token, pair.refresh_expires_in)
return redirect
@router.get("/github")
@limiter.limit("10/minute")
async def github_start(request: Request):
settings = get_settings()
redirect_uri = f"{settings.auth_base_url.rstrip('/')}/auth/github/callback"
return await _github_oauth(request).authorize_redirect(request, redirect_uri)
@router.get("/github/callback")
@limiter.limit("10/minute")
async def github_callback(
request: Request,
session: AsyncSession = Depends(get_session),
):
oauth = _github_oauth(request)
try:
token = await oauth.authorize_access_token(request)
except Exception as exc:
raise HTTPException(status.HTTP_400_BAD_REQUEST, f"github oauth failed: {exc}") from exc
user_resp = await oauth.get("user", token=token)
user_resp.raise_for_status()
userinfo = user_resp.json()
emails: list[dict] | None = None
if not userinfo.get("email"):
emails_resp = await oauth.get("user/emails", token=token)
if emails_resp.status_code == 200:
emails = emails_resp.json()
profile = github_provider.profile_from_userinfo(dict(userinfo), emails)
user = await user_service.upsert_from_oauth(session, profile)
jwt_service = JWTService()
pair = jwt_service.issue_pair(user_id=str(user.id), email=user.email, name=user.name)
settings = get_settings()
redirect = RedirectResponse(
url=f"{settings.frontend_url.rstrip('/')}/chat?access_token={pair.access_token}",
status_code=status.HTTP_302_FOUND,
)
_set_refresh_cookie(redirect, pair.refresh_token, pair.refresh_expires_in)
return redirect

View File

@ -0,0 +1,80 @@
"""Session/token refresh routes + internal JWT validation."""
from __future__ import annotations
from fastapi import APIRouter, Depends, Header, HTTPException, Request, status
from fastapi.responses import JSONResponse
from pydantic import BaseModel
from sqlalchemy.ext.asyncio import AsyncSession
from ..config import get_settings
from ..database import get_session
from ..rate_limit import limiter
from ..services import user_service
from ..services.jwt_service import JWTError, JWTService
router = APIRouter(prefix="/auth", tags=["session"])
class RefreshResponse(BaseModel):
access_token: str
expires_in: int
class ValidateRequest(BaseModel):
token: str
@router.post("/refresh", response_model=RefreshResponse)
@limiter.limit("30/minute")
async def refresh(
request: Request,
session: AsyncSession = Depends(get_session),
):
settings = get_settings()
cookie = request.cookies.get(settings.refresh_cookie_name)
if not cookie:
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "missing refresh cookie")
jwt_service = JWTService()
try:
payload = jwt_service.decode(cookie, expected_type="refresh")
except JWTError as exc:
raise HTTPException(status.HTTP_401_UNAUTHORIZED, str(exc)) from exc
user = await user_service.get_by_id(session, payload["sub"])
if user is None:
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "user not found")
access, ttl = jwt_service.issue_access_token(
user_id=str(user.id), email=user.email, name=user.name
)
return RefreshResponse(access_token=access, expires_in=ttl)
@router.post("/validate")
async def validate(
payload: ValidateRequest,
session: AsyncSession = Depends(get_session),
x_internal_api_key: str | None = Header(default=None, alias="X-Internal-API-Key"),
):
settings = get_settings()
if not settings.internal_api_key or x_internal_api_key != settings.internal_api_key:
raise HTTPException(status.HTTP_403_FORBIDDEN, "invalid internal api key")
jwt_service = JWTService()
try:
claims = jwt_service.decode(payload.token, expected_type="access")
except JWTError as exc:
return JSONResponse(
status_code=status.HTTP_401_UNAUTHORIZED,
content={"valid": False, "reason": str(exc)},
)
user = await user_service.get_by_id(session, claims["sub"])
if user is None:
return JSONResponse(
status_code=status.HTTP_401_UNAUTHORIZED,
content={"valid": False, "reason": "user not found"},
)
return {"valid": True, "user": user.to_dict(), "claims": claims}

View File

@ -0,0 +1,46 @@
"""User profile routes (GET /auth/me, PUT /auth/me)."""
from __future__ import annotations
from fastapi import APIRouter, Depends
from pydantic import BaseModel, Field
from sqlalchemy.ext.asyncio import AsyncSession
from ..database import get_session
from ..middleware.auth_middleware import AuthContext, require_user
from ..services import user_service
router = APIRouter(prefix="/auth", tags=["users"])
class UserProfile(BaseModel):
id: str
email: str
name: str | None
avatar_url: str | None
provider: str
provider_id: str
created_at: str | None
updated_at: str | None
last_login_at: str | None
class ProfileUpdate(BaseModel):
name: str | None = Field(default=None, max_length=255)
avatar_url: str | None = Field(default=None, max_length=2048)
@router.get("/me", response_model=UserProfile)
async def me(ctx: AuthContext = Depends(require_user)) -> UserProfile:
return UserProfile(**ctx.user.to_dict())
@router.put("/me", response_model=UserProfile)
async def update_me(
payload: ProfileUpdate,
ctx: AuthContext = Depends(require_user),
session: AsyncSession = Depends(get_session),
) -> UserProfile:
user = await user_service.update_profile(
session, ctx.user, name=payload.name, avatar_url=payload.avatar_url
)
return UserProfile(**user.to_dict())

View File

View File

@ -0,0 +1,56 @@
"""GitHub OAuth provider via authlib.
Authorization URL: https://github.com/login/oauth/authorize
Token URL: https://github.com/login/oauth/access_token
User API: https://api.github.com/user
"""
from __future__ import annotations
from typing import Any
from authlib.integrations.starlette_client import OAuth
from ..config import Settings, get_settings
from .google_oauth import OAuthProfile
AUTHORIZE_URL = "https://github.com/login/oauth/authorize"
TOKEN_URL = "https://github.com/login/oauth/access_token"
USERINFO_URL = "https://api.github.com/user"
EMAILS_URL = "https://api.github.com/user/emails"
def build_github_client(settings: Settings | None = None) -> OAuth:
settings = settings or get_settings()
oauth = OAuth()
oauth.register(
name="github",
client_id=settings.github_client_id,
client_secret=settings.github_client_secret,
access_token_url=TOKEN_URL,
authorize_url=AUTHORIZE_URL,
api_base_url="https://api.github.com/",
client_kwargs={"scope": "read:user user:email"},
)
return oauth
def profile_from_userinfo(userinfo: dict[str, Any], emails: list[dict[str, Any]] | None) -> OAuthProfile:
provider_id = userinfo.get("id")
if provider_id is None:
raise ValueError("github userinfo missing id")
email = userinfo.get("email")
if not email and emails:
primary = next((e for e in emails if e.get("primary") and e.get("verified")), None)
if primary:
email = primary.get("email")
if not email:
raise ValueError("github userinfo missing verified email")
return OAuthProfile(
provider="github",
provider_id=str(provider_id),
email=str(email),
name=userinfo.get("name") or userinfo.get("login"),
avatar_url=userinfo.get("avatar_url"),
)

View File

@ -0,0 +1,50 @@
"""Google OAuth provider via authlib.
Discovery URL: https://accounts.google.com/.well-known/openid-configuration
"""
from __future__ import annotations
from dataclasses import dataclass
from typing import Any
from authlib.integrations.starlette_client import OAuth
from ..config import Settings, get_settings
DISCOVERY_URL = "https://accounts.google.com/.well-known/openid-configuration"
@dataclass
class OAuthProfile:
provider: str
provider_id: str
email: str
name: str | None
avatar_url: str | None
def build_google_client(settings: Settings | None = None) -> OAuth:
settings = settings or get_settings()
oauth = OAuth()
oauth.register(
name="google",
client_id=settings.google_client_id,
client_secret=settings.google_client_secret,
server_metadata_url=DISCOVERY_URL,
client_kwargs={"scope": "openid email profile"},
)
return oauth
def profile_from_userinfo(userinfo: dict[str, Any]) -> OAuthProfile:
provider_id = userinfo.get("sub") or userinfo.get("id")
email = userinfo.get("email")
if not provider_id or not email:
raise ValueError("google userinfo missing sub/email")
return OAuthProfile(
provider="google",
provider_id=str(provider_id),
email=str(email),
name=userinfo.get("name"),
avatar_url=userinfo.get("picture"),
)

View File

@ -0,0 +1,87 @@
"""RS256 JWT issuance + validation.
Access tokens (1h) are returned to the client for Bearer auth.
Refresh tokens (7d) are stored in an httpOnly cookie.
"""
from __future__ import annotations
import uuid
from dataclasses import dataclass
from datetime import datetime, timedelta, timezone
from typing import Any
import jwt
from ..config import Settings, get_settings
class JWTError(Exception):
"""Raised when a token fails validation."""
@dataclass
class TokenPair:
access_token: str
refresh_token: str
access_expires_in: int
refresh_expires_in: int
class JWTService:
def __init__(self, settings: Settings | None = None) -> None:
self._settings = settings or get_settings()
def _now(self) -> datetime:
return datetime.now(timezone.utc)
def issue_access_token(self, *, user_id: str, email: str, name: str | None) -> tuple[str, int]:
now = self._now()
exp = now + timedelta(seconds=self._settings.jwt_access_ttl_seconds)
payload = {
"sub": user_id,
"email": email,
"name": name,
"iat": int(now.timestamp()),
"exp": int(exp.timestamp()),
"iss": self._settings.jwt_issuer,
"type": "access",
}
token = jwt.encode(payload, self._settings.jwt_private_key, algorithm="RS256")
return token, self._settings.jwt_access_ttl_seconds
def issue_refresh_token(self, *, user_id: str) -> tuple[str, int]:
now = self._now()
exp = now + timedelta(seconds=self._settings.jwt_refresh_ttl_seconds)
payload = {
"sub": user_id,
"iat": int(now.timestamp()),
"exp": int(exp.timestamp()),
"iss": self._settings.jwt_issuer,
"type": "refresh",
"jti": uuid.uuid4().hex,
}
token = jwt.encode(payload, self._settings.jwt_private_key, algorithm="RS256")
return token, self._settings.jwt_refresh_ttl_seconds
def issue_pair(self, *, user_id: str, email: str, name: str | None) -> TokenPair:
access, access_ttl = self.issue_access_token(user_id=user_id, email=email, name=name)
refresh, refresh_ttl = self.issue_refresh_token(user_id=user_id)
return TokenPair(access, refresh, access_ttl, refresh_ttl)
def decode(self, token: str, *, expected_type: str | None = None) -> dict[str, Any]:
try:
payload = jwt.decode(
token,
self._settings.jwt_public_key,
algorithms=["RS256"],
issuer=self._settings.jwt_issuer,
options={"require": ["exp", "iat", "sub", "iss"]},
)
except jwt.ExpiredSignatureError as exc:
raise JWTError("token expired") from exc
except jwt.InvalidTokenError as exc:
raise JWTError(f"invalid token: {exc}") from exc
if expected_type and payload.get("type") != expected_type:
raise JWTError(f"expected {expected_type} token, got {payload.get('type')!r}")
return payload

View File

@ -0,0 +1,75 @@
"""User persistence helpers (upsert on OAuth callback, fetch by id)."""
from __future__ import annotations
import uuid
from datetime import datetime, timezone
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from ..models.user import User
from .google_oauth import OAuthProfile
async def upsert_from_oauth(session: AsyncSession, profile: OAuthProfile) -> User:
"""Insert a new user, or update last_login_at on an existing one."""
stmt = select(User).where(
User.provider == profile.provider,
User.provider_id == profile.provider_id,
)
existing = (await session.execute(stmt)).scalar_one_or_none()
now = datetime.now(timezone.utc)
if existing is None:
user = User(
email=profile.email,
name=profile.name,
avatar_url=profile.avatar_url,
provider=profile.provider,
provider_id=profile.provider_id,
created_at=now,
updated_at=now,
last_login_at=now,
)
session.add(user)
await session.commit()
await session.refresh(user)
return user
existing.last_login_at = now
existing.updated_at = now
if profile.email:
existing.email = profile.email
if profile.name is not None:
existing.name = profile.name
if profile.avatar_url is not None:
existing.avatar_url = profile.avatar_url
await session.commit()
await session.refresh(existing)
return existing
async def get_by_id(session: AsyncSession, user_id: str | uuid.UUID) -> User | None:
if isinstance(user_id, str):
try:
user_id = uuid.UUID(user_id)
except ValueError:
return None
return await session.get(User, user_id)
async def update_profile(
session: AsyncSession,
user: User,
*,
name: str | None,
avatar_url: str | None,
) -> User:
if name is not None:
user.name = name
if avatar_url is not None:
user.avatar_url = avatar_url
user.updated_at = datetime.now(timezone.utc)
await session.commit()
await session.refresh(user)
return user

View File

View File

@ -0,0 +1,84 @@
"""Shared pytest fixtures: RSA keys, in-memory DB, FastAPI test client."""
from __future__ import annotations
import os
import uuid
from collections.abc import AsyncIterator
import pytest
import pytest_asyncio
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.asymmetric import rsa
from httpx import ASGITransport, AsyncClient
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
def _make_rsa_pair() -> tuple[str, str]:
key = rsa.generate_private_key(public_exponent=65537, key_size=2048)
private = key.private_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PrivateFormat.PKCS8,
encryption_algorithm=serialization.NoEncryption(),
).decode()
public = key.public_key().public_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PublicFormat.SubjectPublicKeyInfo,
).decode()
return private, public
@pytest.fixture(scope="session", autouse=True)
def _test_environment():
private, public = _make_rsa_pair()
os.environ["JWT_PRIVATE_KEY"] = private
os.environ["JWT_PUBLIC_KEY"] = public
os.environ["INTERNAL_API_KEY"] = "test-internal-key"
os.environ["FRONTEND_URL"] = "http://localhost:3000"
os.environ["AUTH_BASE_URL"] = "http://localhost:8001"
os.environ["SESSION_SECRET"] = "test-session-secret"
os.environ["DATABASE_URL"] = "sqlite+aiosqlite:///:memory:"
os.environ["GOOGLE_CLIENT_ID"] = "g-id"
os.environ["GOOGLE_CLIENT_SECRET"] = "g-secret"
os.environ["GITHUB_CLIENT_ID"] = "gh-id"
os.environ["GITHUB_CLIENT_SECRET"] = "gh-secret"
from src.config import get_settings
get_settings.cache_clear()
yield
@pytest_asyncio.fixture
async def session_factory() -> AsyncIterator[async_sessionmaker[AsyncSession]]:
from src.database import Base, override_session_factory
from src.models.user import User # noqa: F401 — register on Base
engine = create_async_engine(
f"sqlite+aiosqlite:///file:test_{uuid.uuid4().hex}?mode=memory&cache=shared&uri=true",
connect_args={"uri": True},
)
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
factory = async_sessionmaker(engine, expire_on_commit=False)
override_session_factory(factory)
try:
yield factory
finally:
await engine.dispose()
@pytest_asyncio.fixture
async def db_session(session_factory) -> AsyncIterator[AsyncSession]:
async with session_factory() as session:
yield session
@pytest_asyncio.fixture
async def client(session_factory) -> AsyncIterator[AsyncClient]:
from src.main import create_app
app = create_app()
transport = ASGITransport(app=app)
async with AsyncClient(transport=transport, base_url="http://testserver") as ac:
yield ac

View File

@ -0,0 +1,59 @@
"""GET /auth/me and PUT /auth/me via Bearer JWT."""
from __future__ import annotations
import pytest
from src.services import user_service
from src.services.google_oauth import OAuthProfile
from src.services.jwt_service import JWTService
@pytest.mark.asyncio
async def test_me_returns_profile(client, db_session):
profile = OAuthProfile(
provider="github", provider_id="gh-1", email="me@x.co", name="Me", avatar_url=None
)
user = await user_service.upsert_from_oauth(db_session, profile)
token, _ = JWTService().issue_access_token(
user_id=str(user.id), email=user.email, name=user.name
)
resp = await client.get("/auth/me", headers={"Authorization": f"Bearer {token}"})
assert resp.status_code == 200
body = resp.json()
assert body["email"] == "me@x.co"
assert body["provider"] == "github"
@pytest.mark.asyncio
async def test_me_requires_bearer(client):
resp = await client.get("/auth/me")
assert resp.status_code == 401
@pytest.mark.asyncio
async def test_update_me(client, db_session):
profile = OAuthProfile(
provider="google", provider_id="g-9", email="x@y.co", name="Old", avatar_url=None
)
user = await user_service.upsert_from_oauth(db_session, profile)
token, _ = JWTService().issue_access_token(
user_id=str(user.id), email=user.email, name=user.name
)
resp = await client.put(
"/auth/me",
json={"name": "New Name", "avatar_url": "https://img/x.png"},
headers={"Authorization": f"Bearer {token}"},
)
assert resp.status_code == 200
body = resp.json()
assert body["name"] == "New Name"
assert body["avatar_url"] == "https://img/x.png"
@pytest.mark.asyncio
async def test_health(client):
resp = await client.get("/auth/health")
assert resp.status_code == 200
assert resp.json() == {"status": "ok"}

View File

@ -0,0 +1,97 @@
"""OAuth callback end-to-end with mocked authlib providers."""
from __future__ import annotations
from types import SimpleNamespace
import pytest
from src.services import user_service
from src.services.google_oauth import OAuthProfile
class _MockGoogleClient:
"""Stands in for authlib's StarletteOAuth2App."""
def __init__(self, userinfo):
self._userinfo = userinfo
async def authorize_access_token(self, request):
return {"access_token": "fake", "userinfo": self._userinfo}
async def userinfo(self, token=None):
return self._userinfo
@pytest.mark.asyncio
async def test_google_callback_creates_user_and_sets_refresh_cookie(client, db_session):
userinfo = {
"sub": "google-42",
"email": "new@user.co",
"name": "New User",
"picture": "https://img/new.png",
}
app = client._transport.app # type: ignore[attr-defined]
app.state.google_oauth = SimpleNamespace(google=_MockGoogleClient(userinfo))
resp = await client.get("/auth/google/callback", follow_redirects=False)
assert resp.status_code == 302
assert "access_token=" in resp.headers["location"]
from src.config import get_settings
cookie_name = get_settings().refresh_cookie_name
assert cookie_name in resp.cookies
# Verify user was persisted.
from sqlalchemy import select
from src.models.user import User
async with db_session.bind._async_engine.connect() if False else db_session as s: # type: ignore[attr-defined]
pass
# Simpler: just query through a fresh session.
from src.database import get_session_factory
async with get_session_factory()() as s:
user = (
await s.execute(select(User).where(User.provider_id == "google-42"))
).scalar_one()
assert user.email == "new@user.co"
@pytest.mark.asyncio
async def test_google_callback_updates_existing(client, db_session):
# Seed an existing user.
profile = OAuthProfile(
provider="google",
provider_id="google-99",
email="old@user.co",
name="Old",
avatar_url=None,
)
existing = await user_service.upsert_from_oauth(db_session, profile)
original_login = existing.last_login_at
userinfo = {
"sub": "google-99",
"email": "old@user.co",
"name": "Updated Name",
"picture": "https://img/u.png",
}
app = client._transport.app # type: ignore[attr-defined]
app.state.google_oauth = SimpleNamespace(google=_MockGoogleClient(userinfo))
resp = await client.get("/auth/google/callback", follow_redirects=False)
assert resp.status_code == 302
from sqlalchemy import select
from src.database import get_session_factory
from src.models.user import User
async with get_session_factory()() as s:
refreshed = (
await s.execute(select(User).where(User.id == existing.id))
).scalar_one()
assert refreshed.name == "Updated Name"
assert refreshed.last_login_at >= original_login

View File

@ -0,0 +1,29 @@
"""Rate limiter applies 10/min to OAuth start routes."""
from __future__ import annotations
import pytest
@pytest.mark.asyncio
async def test_google_start_rate_limited(client, monkeypatch):
# Replace the OAuth client with a stub so /auth/google returns immediately.
async def _stub_redirect(request, redirect_uri):
from fastapi.responses import RedirectResponse
return RedirectResponse(url=redirect_uri, status_code=302)
class _StubProvider:
authorize_redirect = staticmethod(_stub_redirect)
class _StubClient:
google = _StubProvider()
app = client._transport.app # type: ignore[attr-defined]
app.state.google_oauth = _StubClient()
# First 10 calls allowed, 11th should be rate-limited.
codes = []
for _ in range(11):
resp = await client.get("/auth/google", follow_redirects=False)
codes.append(resp.status_code)
assert codes.count(429) >= 1

View File

@ -0,0 +1,68 @@
"""User upsert flow invoked by OAuth callbacks."""
from __future__ import annotations
import pytest
from src.services import user_service
from src.services.google_oauth import OAuthProfile
@pytest.mark.asyncio
async def test_upsert_creates_new_user(db_session):
profile = OAuthProfile(
provider="google",
provider_id="g-12345",
email="alice@example.com",
name="Alice",
avatar_url="https://img/alice.png",
)
user = await user_service.upsert_from_oauth(db_session, profile)
assert user.id is not None
assert user.email == "alice@example.com"
assert user.provider == "google"
assert user.last_login_at is not None
@pytest.mark.asyncio
async def test_upsert_updates_existing_user(db_session):
first = OAuthProfile(
provider="google",
provider_id="g-12345",
email="alice@example.com",
name="Alice",
avatar_url=None,
)
u1 = await user_service.upsert_from_oauth(db_session, first)
original_login = u1.last_login_at
# Second login with updated display name + avatar — same provider identity.
second = OAuthProfile(
provider="google",
provider_id="g-12345",
email="alice@example.com",
name="Alice Smith",
avatar_url="https://img/alice2.png",
)
u2 = await user_service.upsert_from_oauth(db_session, second)
assert u2.id == u1.id # same row
assert u2.name == "Alice Smith"
assert u2.avatar_url == "https://img/alice2.png"
assert u2.last_login_at >= original_login
@pytest.mark.asyncio
async def test_update_profile(db_session):
profile = OAuthProfile(
provider="github",
provider_id="42",
email="bob@example.com",
name="Bob",
avatar_url=None,
)
user = await user_service.upsert_from_oauth(db_session, profile)
updated = await user_service.update_profile(
db_session, user, name="Robert", avatar_url="https://img/bob.png"
)
assert updated.name == "Robert"
assert updated.avatar_url == "https://img/bob.png"

View File

@ -0,0 +1,62 @@
"""Internal /auth/validate endpoint used by the Chat API service."""
from __future__ import annotations
import pytest
from src.services import user_service
from src.services.google_oauth import OAuthProfile
from src.services.jwt_service import JWTService
@pytest.mark.asyncio
async def test_validate_requires_internal_key(client, db_session):
profile = OAuthProfile(
provider="google", provider_id="123", email="v@x.co", name="V", avatar_url=None
)
user = await user_service.upsert_from_oauth(db_session, profile)
token, _ = JWTService().issue_access_token(
user_id=str(user.id), email=user.email, name=user.name
)
missing = await client.post("/auth/validate", json={"token": token})
assert missing.status_code == 403
wrong = await client.post(
"/auth/validate",
json={"token": token},
headers={"X-Internal-API-Key": "nope"},
)
assert wrong.status_code == 403
@pytest.mark.asyncio
async def test_validate_returns_user_for_valid_token(client, db_session):
profile = OAuthProfile(
provider="google", provider_id="456", email="v2@x.co", name="V2", avatar_url=None
)
user = await user_service.upsert_from_oauth(db_session, profile)
token, _ = JWTService().issue_access_token(
user_id=str(user.id), email=user.email, name=user.name
)
resp = await client.post(
"/auth/validate",
json={"token": token},
headers={"X-Internal-API-Key": "test-internal-key"},
)
assert resp.status_code == 200
body = resp.json()
assert body["valid"] is True
assert body["user"]["email"] == "v2@x.co"
assert body["claims"]["sub"] == str(user.id)
@pytest.mark.asyncio
async def test_validate_rejects_tampered_token(client):
resp = await client.post(
"/auth/validate",
json={"token": "not-a-jwt"},
headers={"X-Internal-API-Key": "test-internal-key"},
)
assert resp.status_code == 401
assert resp.json()["valid"] is False