mirror of
https://github.com/karpathy/nanochat.git
synced 2026-05-09 01:10:10 +00:00
Merge pull request #16 from manmohan659/feat/auth-service
feat(auth): OAuth2 + JWT auth service with Alembic migrations (#5 #7)
This commit is contained in:
commit
4297817cfb
41
db/alembic.ini
Normal file
41
db/alembic.ini
Normal file
|
|
@ -0,0 +1,41 @@
|
|||
[alembic]
|
||||
script_location = %(here)s/migrations
|
||||
prepend_sys_path = .
|
||||
version_path_separator = os
|
||||
sqlalchemy.url = driver://user:pass@localhost/dbname
|
||||
|
||||
[post_write_hooks]
|
||||
|
||||
[loggers]
|
||||
keys = root,sqlalchemy,alembic
|
||||
|
||||
[handlers]
|
||||
keys = console
|
||||
|
||||
[formatters]
|
||||
keys = generic
|
||||
|
||||
[logger_root]
|
||||
level = WARN
|
||||
handlers = console
|
||||
qualname =
|
||||
|
||||
[logger_sqlalchemy]
|
||||
level = WARN
|
||||
handlers =
|
||||
qualname = sqlalchemy.engine
|
||||
|
||||
[logger_alembic]
|
||||
level = INFO
|
||||
handlers =
|
||||
qualname = alembic
|
||||
|
||||
[handler_console]
|
||||
class = StreamHandler
|
||||
args = (sys.stderr,)
|
||||
level = NOTSET
|
||||
formatter = generic
|
||||
|
||||
[formatter_generic]
|
||||
format = %(levelname)-5.5s [%(name)s] %(message)s
|
||||
datefmt = %H:%M:%S
|
||||
|
|
@ -1,42 +0,0 @@
|
|||
CREATE EXTENSION IF NOT EXISTS pgcrypto;
|
||||
|
||||
CREATE TABLE IF NOT EXISTS users (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
email TEXT NOT NULL UNIQUE,
|
||||
name TEXT NOT NULL,
|
||||
avatar_url TEXT,
|
||||
provider TEXT NOT NULL,
|
||||
provider_id TEXT NOT NULL,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
last_login_at TIMESTAMPTZ
|
||||
);
|
||||
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS users_provider_lookup_idx
|
||||
ON users (provider, provider_id);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS conversations (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
user_id UUID NOT NULL REFERENCES users(id) ON DELETE CASCADE,
|
||||
title TEXT NOT NULL,
|
||||
model_tag TEXT NOT NULL,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS conversations_user_id_idx
|
||||
ON conversations (user_id);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS messages (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
conversation_id UUID NOT NULL REFERENCES conversations(id) ON DELETE CASCADE,
|
||||
role TEXT NOT NULL,
|
||||
content TEXT NOT NULL,
|
||||
token_count INTEGER NOT NULL DEFAULT 0,
|
||||
model_tag TEXT NOT NULL,
|
||||
inference_time_ms INTEGER NOT NULL DEFAULT 0,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS messages_conversation_id_idx
|
||||
ON messages (conversation_id, created_at);
|
||||
67
db/migrations/env.py
Normal file
67
db/migrations/env.py
Normal file
|
|
@ -0,0 +1,67 @@
|
|||
"""Alembic environment configuration for async PostgreSQL."""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
from logging.config import fileConfig
|
||||
|
||||
from alembic import context
|
||||
from sqlalchemy import pool
|
||||
from sqlalchemy.engine import Connection
|
||||
from sqlalchemy.ext.asyncio import async_engine_from_config
|
||||
|
||||
config = context.config
|
||||
|
||||
if config.config_file_name is not None:
|
||||
fileConfig(config.config_file_name)
|
||||
|
||||
target_metadata = None
|
||||
|
||||
|
||||
def _database_url() -> str:
|
||||
url = os.environ.get("DATABASE_URL")
|
||||
if not url:
|
||||
raise RuntimeError("DATABASE_URL environment variable is required for Alembic")
|
||||
if url.startswith("postgresql://"):
|
||||
url = url.replace("postgresql://", "postgresql+asyncpg://", 1)
|
||||
return url
|
||||
|
||||
|
||||
def run_migrations_offline() -> None:
|
||||
context.configure(
|
||||
url=_database_url(),
|
||||
target_metadata=target_metadata,
|
||||
literal_binds=True,
|
||||
dialect_opts={"paramstyle": "named"},
|
||||
)
|
||||
with context.begin_transaction():
|
||||
context.run_migrations()
|
||||
|
||||
|
||||
def do_run_migrations(connection: Connection) -> None:
|
||||
context.configure(connection=connection, target_metadata=target_metadata)
|
||||
with context.begin_transaction():
|
||||
context.run_migrations()
|
||||
|
||||
|
||||
async def run_async_migrations() -> None:
|
||||
configuration = config.get_section(config.config_ini_section) or {}
|
||||
configuration["sqlalchemy.url"] = _database_url()
|
||||
connectable = async_engine_from_config(
|
||||
configuration,
|
||||
prefix="sqlalchemy.",
|
||||
poolclass=pool.NullPool,
|
||||
)
|
||||
async with connectable.connect() as connection:
|
||||
await connection.run_sync(do_run_migrations)
|
||||
await connectable.dispose()
|
||||
|
||||
|
||||
def run_migrations_online() -> None:
|
||||
asyncio.run(run_async_migrations())
|
||||
|
||||
|
||||
if context.is_offline_mode():
|
||||
run_migrations_offline()
|
||||
else:
|
||||
run_migrations_online()
|
||||
27
db/migrations/script.py.mako
Normal file
27
db/migrations/script.py.mako
Normal file
|
|
@ -0,0 +1,27 @@
|
|||
"""${message}
|
||||
|
||||
Revision ID: ${up_revision}
|
||||
Revises: ${down_revision | comma,n}
|
||||
Create Date: ${create_date}
|
||||
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
${imports if imports else ""}
|
||||
|
||||
revision: str = ${repr(up_revision)}
|
||||
down_revision: Union[str, None] = ${repr(down_revision)}
|
||||
branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)}
|
||||
depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)}
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
${upgrades if upgrades else "pass"}
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
${downgrades if downgrades else "pass"}
|
||||
57
db/migrations/versions/001_create_users.py
Normal file
57
db/migrations/versions/001_create_users.py
Normal file
|
|
@ -0,0 +1,57 @@
|
|||
"""create users table
|
||||
|
||||
Revision ID: 001_create_users
|
||||
Revises:
|
||||
Create Date: 2026-04-16
|
||||
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
revision: str = "001_create_users"
|
||||
down_revision: Union[str, None] = None
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.execute("CREATE EXTENSION IF NOT EXISTS pgcrypto")
|
||||
op.create_table(
|
||||
"users",
|
||||
sa.Column(
|
||||
"id",
|
||||
postgresql.UUID(as_uuid=True),
|
||||
primary_key=True,
|
||||
server_default=sa.text("gen_random_uuid()"),
|
||||
),
|
||||
sa.Column("email", sa.String(length=255), nullable=False, unique=True),
|
||||
sa.Column("name", sa.String(length=255), nullable=True),
|
||||
sa.Column("avatar_url", sa.Text(), nullable=True),
|
||||
sa.Column("provider", sa.String(length=50), nullable=False),
|
||||
sa.Column("provider_id", sa.String(length=255), nullable=False),
|
||||
sa.Column(
|
||||
"created_at",
|
||||
sa.TIMESTAMP(timezone=True),
|
||||
server_default=sa.text("NOW()"),
|
||||
),
|
||||
sa.Column(
|
||||
"updated_at",
|
||||
sa.TIMESTAMP(timezone=True),
|
||||
server_default=sa.text("NOW()"),
|
||||
),
|
||||
sa.Column(
|
||||
"last_login_at",
|
||||
sa.TIMESTAMP(timezone=True),
|
||||
server_default=sa.text("NOW()"),
|
||||
),
|
||||
sa.UniqueConstraint("provider", "provider_id", name="uq_users_provider_identity"),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_table("users")
|
||||
64
db/migrations/versions/002_create_conversations.py
Normal file
64
db/migrations/versions/002_create_conversations.py
Normal file
|
|
@ -0,0 +1,64 @@
|
|||
"""create conversations table
|
||||
|
||||
Revision ID: 002_create_conversations
|
||||
Revises: 001_create_users
|
||||
Create Date: 2026-04-16
|
||||
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
revision: str = "002_create_conversations"
|
||||
down_revision: Union[str, None] = "001_create_users"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_table(
|
||||
"conversations",
|
||||
sa.Column(
|
||||
"id",
|
||||
postgresql.UUID(as_uuid=True),
|
||||
primary_key=True,
|
||||
server_default=sa.text("gen_random_uuid()"),
|
||||
),
|
||||
sa.Column(
|
||||
"user_id",
|
||||
postgresql.UUID(as_uuid=True),
|
||||
sa.ForeignKey("users.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("title", sa.String(length=500), nullable=True),
|
||||
sa.Column(
|
||||
"model_tag",
|
||||
sa.String(length=100),
|
||||
nullable=True,
|
||||
server_default=sa.text("'default'"),
|
||||
),
|
||||
sa.Column(
|
||||
"created_at",
|
||||
sa.TIMESTAMP(timezone=True),
|
||||
server_default=sa.text("NOW()"),
|
||||
),
|
||||
sa.Column(
|
||||
"updated_at",
|
||||
sa.TIMESTAMP(timezone=True),
|
||||
server_default=sa.text("NOW()"),
|
||||
),
|
||||
)
|
||||
op.create_index(
|
||||
"idx_conversations_user",
|
||||
"conversations",
|
||||
["user_id", sa.text("updated_at DESC")],
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_index("idx_conversations_user", table_name="conversations")
|
||||
op.drop_table("conversations")
|
||||
61
db/migrations/versions/003_create_messages.py
Normal file
61
db/migrations/versions/003_create_messages.py
Normal file
|
|
@ -0,0 +1,61 @@
|
|||
"""create messages table
|
||||
|
||||
Revision ID: 003_create_messages
|
||||
Revises: 002_create_conversations
|
||||
Create Date: 2026-04-16
|
||||
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
revision: str = "003_create_messages"
|
||||
down_revision: Union[str, None] = "002_create_conversations"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_table(
|
||||
"messages",
|
||||
sa.Column(
|
||||
"id",
|
||||
postgresql.UUID(as_uuid=True),
|
||||
primary_key=True,
|
||||
server_default=sa.text("gen_random_uuid()"),
|
||||
),
|
||||
sa.Column(
|
||||
"conversation_id",
|
||||
postgresql.UUID(as_uuid=True),
|
||||
sa.ForeignKey("conversations.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("role", sa.String(length=20), nullable=False),
|
||||
sa.Column("content", sa.Text(), nullable=False),
|
||||
sa.Column("token_count", sa.Integer(), nullable=True),
|
||||
sa.Column("model_tag", sa.String(length=100), nullable=True),
|
||||
sa.Column("inference_time_ms", sa.Integer(), nullable=True),
|
||||
sa.Column(
|
||||
"created_at",
|
||||
sa.TIMESTAMP(timezone=True),
|
||||
server_default=sa.text("NOW()"),
|
||||
),
|
||||
sa.CheckConstraint(
|
||||
"role IN ('user','assistant','system')",
|
||||
name="ck_messages_role",
|
||||
),
|
||||
)
|
||||
op.create_index(
|
||||
"idx_messages_conversation",
|
||||
"messages",
|
||||
["conversation_id", sa.text("created_at ASC")],
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_index("idx_messages_conversation", table_name="messages")
|
||||
op.drop_table("messages")
|
||||
34
db/migrations/versions/004_add_favorited.py
Normal file
34
db/migrations/versions/004_add_favorited.py
Normal file
|
|
@ -0,0 +1,34 @@
|
|||
"""add is_favorited column to conversations (Day 2 demo)
|
||||
|
||||
Revision ID: 004_add_favorited
|
||||
Revises: 003_create_messages
|
||||
Create Date: 2026-04-16
|
||||
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
revision: str = "004_add_favorited"
|
||||
down_revision: Union[str, None] = "003_create_messages"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"conversations",
|
||||
sa.Column(
|
||||
"is_favorited",
|
||||
sa.Boolean(),
|
||||
nullable=False,
|
||||
server_default=sa.text("false"),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("conversations", "is_favorited")
|
||||
|
|
@ -1,9 +1,38 @@
|
|||
FROM python:3.12-slim
|
||||
|
||||
ENV PYTHONDONTWRITEBYTECODE=1 \
|
||||
PYTHONUNBUFFERED=1 \
|
||||
UV_SYSTEM_PYTHON=1 \
|
||||
UV_LINK_MODE=copy
|
||||
|
||||
RUN apt-get update \
|
||||
&& apt-get install -y --no-install-recommends build-essential curl \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
RUN pip install --no-cache-dir uv==0.4.30
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
COPY README.md /app/README.md
|
||||
COPY pyproject.toml README.md /app/
|
||||
|
||||
RUN uv pip install --system --no-cache \
|
||||
"fastapi>=0.117.1" \
|
||||
"uvicorn[standard]>=0.36.0" \
|
||||
"pydantic>=2.8.0" \
|
||||
"pydantic-settings>=2.4.0" \
|
||||
"sqlalchemy[asyncio]>=2.0.36" \
|
||||
"asyncpg>=0.29.0" \
|
||||
"alembic>=1.13.0" \
|
||||
"authlib>=1.3.2" \
|
||||
"httpx>=0.27.0" \
|
||||
"itsdangerous>=2.2.0" \
|
||||
"pyjwt>=2.9.0" \
|
||||
"cryptography>=43.0.0" \
|
||||
"slowapi>=0.1.9" \
|
||||
"python-multipart>=0.0.9"
|
||||
|
||||
COPY src /app/src
|
||||
|
||||
EXPOSE 8001
|
||||
|
||||
CMD ["python", "-m", "http.server", "8001", "--bind", "0.0.0.0"]
|
||||
CMD ["uvicorn", "src.main:app", "--host", "0.0.0.0", "--port", "8001"]
|
||||
|
|
|
|||
|
|
@ -1,7 +1,46 @@
|
|||
# Auth Service
|
||||
# samosaChaat Auth Service
|
||||
|
||||
Scaffold placeholder for Issue #5.
|
||||
FastAPI microservice providing OAuth2 login (Google + GitHub) and JWT session
|
||||
management for samosaChaat (Issue #5).
|
||||
|
||||
The monorepo branch provisions this directory and a minimal Docker image so
|
||||
local `docker compose up` remains viable before the real auth service
|
||||
implementation lands.
|
||||
## Endpoints
|
||||
|
||||
| Method | Path | Purpose |
|
||||
| ------ | ---- | ------- |
|
||||
| GET | `/auth/google` | Redirect to Google consent |
|
||||
| GET | `/auth/google/callback` | Complete Google flow, upsert user, issue tokens |
|
||||
| GET | `/auth/github` | Redirect to GitHub consent |
|
||||
| GET | `/auth/github/callback` | Complete GitHub flow, upsert user, issue tokens |
|
||||
| POST | `/auth/refresh` | Exchange refresh cookie for new access token |
|
||||
| GET | `/auth/me` | Current user profile (Bearer JWT) |
|
||||
| PUT | `/auth/me` | Update name / avatar (Bearer JWT) |
|
||||
| POST | `/auth/validate` | Internal JWT validation (service-to-service) |
|
||||
| GET | `/auth/health` | Liveness probe |
|
||||
|
||||
## Environment
|
||||
|
||||
```
|
||||
DATABASE_URL=postgresql+asyncpg://user:pass@host/db
|
||||
GOOGLE_CLIENT_ID=...
|
||||
GOOGLE_CLIENT_SECRET=...
|
||||
GITHUB_CLIENT_ID=...
|
||||
GITHUB_CLIENT_SECRET=...
|
||||
JWT_PRIVATE_KEY=<RS256 PEM>
|
||||
JWT_PUBLIC_KEY=<RS256 PEM>
|
||||
FRONTEND_URL=http://localhost:3000
|
||||
INTERNAL_API_KEY=<shared secret for /auth/validate>
|
||||
```
|
||||
|
||||
## Local development
|
||||
|
||||
```
|
||||
uv sync
|
||||
uv run uvicorn src.main:app --reload --port 8001
|
||||
uv run pytest
|
||||
```
|
||||
|
||||
Database schema is managed by Alembic at `db/migrations`:
|
||||
|
||||
```
|
||||
DATABASE_URL=... uv run alembic -c db/alembic.ini upgrade head
|
||||
```
|
||||
|
|
|
|||
40
services/auth/pyproject.toml
Normal file
40
services/auth/pyproject.toml
Normal file
|
|
@ -0,0 +1,40 @@
|
|||
[project]
|
||||
name = "samosachaat-auth"
|
||||
version = "0.1.0"
|
||||
description = "samosaChaat authentication service (OAuth2 + JWT)"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.12"
|
||||
dependencies = [
|
||||
"fastapi>=0.117.1",
|
||||
"uvicorn[standard]>=0.36.0",
|
||||
"pydantic>=2.8.0",
|
||||
"pydantic-settings>=2.4.0",
|
||||
"sqlalchemy[asyncio]>=2.0.36",
|
||||
"asyncpg>=0.29.0",
|
||||
"alembic>=1.13.0",
|
||||
"authlib>=1.3.2",
|
||||
"httpx>=0.27.0",
|
||||
"itsdangerous>=2.2.0",
|
||||
"pyjwt>=2.9.0",
|
||||
"cryptography>=43.0.0",
|
||||
"slowapi>=0.1.9",
|
||||
"python-multipart>=0.0.9",
|
||||
]
|
||||
|
||||
[dependency-groups]
|
||||
dev = [
|
||||
"pytest>=8.0.0",
|
||||
"pytest-asyncio>=0.24.0",
|
||||
"httpx>=0.27.0",
|
||||
"aiosqlite>=0.20.0",
|
||||
"respx>=0.21.1",
|
||||
]
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
asyncio_mode = "auto"
|
||||
testpaths = ["src/tests"]
|
||||
python_files = ["test_*.py"]
|
||||
pythonpath = ["."]
|
||||
|
||||
[tool.uv]
|
||||
package = false
|
||||
0
services/auth/src/__init__.py
Normal file
0
services/auth/src/__init__.py
Normal file
47
services/auth/src/config.py
Normal file
47
services/auth/src/config.py
Normal file
|
|
@ -0,0 +1,47 @@
|
|||
"""Runtime configuration for the auth service.
|
||||
|
||||
All configuration is loaded from environment variables using pydantic-settings.
|
||||
Private/public keys are PEM-encoded RSA material used for RS256 JWTs.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from functools import lru_cache
|
||||
|
||||
from pydantic import Field
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
model_config = SettingsConfigDict(env_file=".env", extra="ignore")
|
||||
|
||||
database_url: str = Field(default="postgresql+asyncpg://localhost/samosachaat")
|
||||
|
||||
google_client_id: str = Field(default="")
|
||||
google_client_secret: str = Field(default="")
|
||||
|
||||
github_client_id: str = Field(default="")
|
||||
github_client_secret: str = Field(default="")
|
||||
|
||||
jwt_private_key: str = Field(default="")
|
||||
jwt_public_key: str = Field(default="")
|
||||
jwt_issuer: str = Field(default="samosachaat-auth")
|
||||
jwt_access_ttl_seconds: int = Field(default=3600)
|
||||
jwt_refresh_ttl_seconds: int = Field(default=7 * 24 * 3600)
|
||||
|
||||
frontend_url: str = Field(default="http://localhost:3000")
|
||||
internal_api_key: str = Field(default="")
|
||||
|
||||
auth_base_url: str = Field(default="http://localhost:8001")
|
||||
session_secret: str = Field(default="dev-session-secret-change-me")
|
||||
|
||||
cookie_secure: bool = Field(default=False)
|
||||
cookie_domain: str | None = Field(default=None)
|
||||
|
||||
@property
|
||||
def refresh_cookie_name(self) -> str:
|
||||
return "samosachaat_refresh"
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def get_settings() -> Settings:
|
||||
return Settings()
|
||||
49
services/auth/src/database.py
Normal file
49
services/auth/src/database.py
Normal file
|
|
@ -0,0 +1,49 @@
|
|||
"""Async SQLAlchemy engine + session factory."""
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import AsyncIterator
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||
from sqlalchemy.orm import DeclarativeBase
|
||||
|
||||
from .config import get_settings
|
||||
|
||||
|
||||
class Base(DeclarativeBase):
|
||||
"""Shared declarative base for all ORM models."""
|
||||
|
||||
|
||||
_engine = None
|
||||
_session_factory: async_sessionmaker[AsyncSession] | None = None
|
||||
|
||||
|
||||
def _build_engine():
|
||||
global _engine, _session_factory
|
||||
settings = get_settings()
|
||||
_engine = create_async_engine(settings.database_url, pool_pre_ping=True)
|
||||
_session_factory = async_sessionmaker(_engine, expire_on_commit=False)
|
||||
|
||||
|
||||
def get_engine():
|
||||
if _engine is None:
|
||||
_build_engine()
|
||||
return _engine
|
||||
|
||||
|
||||
def get_session_factory() -> async_sessionmaker[AsyncSession]:
|
||||
if _session_factory is None:
|
||||
_build_engine()
|
||||
assert _session_factory is not None
|
||||
return _session_factory
|
||||
|
||||
|
||||
async def get_session() -> AsyncIterator[AsyncSession]:
|
||||
factory = get_session_factory()
|
||||
async with factory() as session:
|
||||
yield session
|
||||
|
||||
|
||||
def override_session_factory(factory: async_sessionmaker[AsyncSession]) -> None:
|
||||
"""Testing hook: swap the session factory for an in-memory engine."""
|
||||
global _session_factory
|
||||
_session_factory = factory
|
||||
50
services/auth/src/main.py
Normal file
50
services/auth/src/main.py
Normal file
|
|
@ -0,0 +1,50 @@
|
|||
"""FastAPI entrypoint for the samosaChaat auth service."""
|
||||
from __future__ import annotations
|
||||
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import JSONResponse
|
||||
from slowapi.errors import RateLimitExceeded
|
||||
from slowapi.middleware import SlowAPIMiddleware
|
||||
from starlette.middleware.sessions import SessionMiddleware
|
||||
|
||||
from .config import get_settings
|
||||
from .rate_limit import limiter
|
||||
from .routes import oauth, session, users
|
||||
|
||||
|
||||
def _rate_limit_handler(request, exc: RateLimitExceeded):
|
||||
return JSONResponse(status_code=429, content={"detail": "rate limit exceeded"})
|
||||
|
||||
|
||||
def create_app() -> FastAPI:
|
||||
settings = get_settings()
|
||||
app = FastAPI(title="samosaChaat Auth", version="0.1.0")
|
||||
|
||||
app.state.limiter = limiter
|
||||
app.add_exception_handler(RateLimitExceeded, _rate_limit_handler)
|
||||
app.add_middleware(SlowAPIMiddleware)
|
||||
|
||||
# SessionMiddleware is required by authlib for the OAuth state cookie.
|
||||
app.add_middleware(SessionMiddleware, secret_key=settings.session_secret)
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=[settings.frontend_url],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
app.include_router(oauth.router)
|
||||
app.include_router(session.router)
|
||||
app.include_router(users.router)
|
||||
|
||||
@app.get("/auth/health")
|
||||
async def health():
|
||||
return {"status": "ok"}
|
||||
|
||||
return app
|
||||
|
||||
|
||||
app = create_app()
|
||||
0
services/auth/src/middleware/__init__.py
Normal file
0
services/auth/src/middleware/__init__.py
Normal file
44
services/auth/src/middleware/auth_middleware.py
Normal file
44
services/auth/src/middleware/auth_middleware.py
Normal file
|
|
@ -0,0 +1,44 @@
|
|||
"""Bearer-token auth dependency."""
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
from fastapi import Depends, HTTPException, Request, status
|
||||
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from ..database import get_session
|
||||
from ..models.user import User
|
||||
from ..services import user_service
|
||||
from ..services.jwt_service import JWTError, JWTService
|
||||
|
||||
bearer_scheme = HTTPBearer(auto_error=False)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AuthContext:
|
||||
user: User
|
||||
payload: dict
|
||||
|
||||
|
||||
async def require_user(
|
||||
request: Request,
|
||||
credentials: HTTPAuthorizationCredentials | None = Depends(bearer_scheme),
|
||||
session: AsyncSession = Depends(get_session),
|
||||
) -> AuthContext:
|
||||
if credentials is None or credentials.scheme.lower() != "bearer":
|
||||
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "missing bearer token")
|
||||
|
||||
jwt_service = JWTService()
|
||||
try:
|
||||
payload = jwt_service.decode(credentials.credentials, expected_type="access")
|
||||
except JWTError as exc:
|
||||
raise HTTPException(status.HTTP_401_UNAUTHORIZED, str(exc)) from exc
|
||||
|
||||
user = await user_service.get_by_id(session, payload["sub"])
|
||||
if user is None:
|
||||
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "user not found")
|
||||
|
||||
ctx = AuthContext(user=user, payload=payload)
|
||||
request.state.auth = ctx
|
||||
return ctx
|
||||
7
services/auth/src/rate_limit.py
Normal file
7
services/auth/src/rate_limit.py
Normal file
|
|
@ -0,0 +1,7 @@
|
|||
"""Shared slowapi limiter for login-facing routes."""
|
||||
from __future__ import annotations
|
||||
|
||||
from slowapi import Limiter
|
||||
from slowapi.util import get_remote_address
|
||||
|
||||
limiter = Limiter(key_func=get_remote_address)
|
||||
0
services/auth/src/routes/__init__.py
Normal file
0
services/auth/src/routes/__init__.py
Normal file
130
services/auth/src/routes/oauth.py
Normal file
130
services/auth/src/routes/oauth.py
Normal file
|
|
@ -0,0 +1,130 @@
|
|||
"""OAuth start/callback routes for Google and GitHub."""
|
||||
from __future__ import annotations
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, status
|
||||
from fastapi.responses import RedirectResponse
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from ..config import get_settings
|
||||
from ..database import get_session
|
||||
from ..rate_limit import limiter
|
||||
from ..services import github_oauth as github_provider
|
||||
from ..services import google_oauth as google_provider
|
||||
from ..services import user_service
|
||||
from ..services.jwt_service import JWTService
|
||||
|
||||
router = APIRouter(prefix="/auth", tags=["oauth"])
|
||||
|
||||
|
||||
def _set_refresh_cookie(response: RedirectResponse, token: str, max_age: int) -> None:
|
||||
settings = get_settings()
|
||||
response.set_cookie(
|
||||
key=settings.refresh_cookie_name,
|
||||
value=token,
|
||||
max_age=max_age,
|
||||
httponly=True,
|
||||
secure=settings.cookie_secure,
|
||||
samesite="lax",
|
||||
domain=settings.cookie_domain,
|
||||
path="/",
|
||||
)
|
||||
|
||||
|
||||
def _google_oauth(request: Request):
|
||||
client = getattr(request.app.state, "google_oauth", None)
|
||||
if client is None:
|
||||
client = google_provider.build_google_client()
|
||||
request.app.state.google_oauth = client
|
||||
return client.google
|
||||
|
||||
|
||||
def _github_oauth(request: Request):
|
||||
client = getattr(request.app.state, "github_oauth", None)
|
||||
if client is None:
|
||||
client = github_provider.build_github_client()
|
||||
request.app.state.github_oauth = client
|
||||
return client.github
|
||||
|
||||
|
||||
@router.get("/google")
|
||||
@limiter.limit("10/minute")
|
||||
async def google_start(request: Request):
|
||||
settings = get_settings()
|
||||
redirect_uri = f"{settings.auth_base_url.rstrip('/')}/auth/google/callback"
|
||||
return await _google_oauth(request).authorize_redirect(request, redirect_uri)
|
||||
|
||||
|
||||
@router.get("/google/callback")
|
||||
@limiter.limit("10/minute")
|
||||
async def google_callback(
|
||||
request: Request,
|
||||
session: AsyncSession = Depends(get_session),
|
||||
):
|
||||
oauth = _google_oauth(request)
|
||||
try:
|
||||
token = await oauth.authorize_access_token(request)
|
||||
except Exception as exc:
|
||||
raise HTTPException(status.HTTP_400_BAD_REQUEST, f"google oauth failed: {exc}") from exc
|
||||
|
||||
userinfo = token.get("userinfo")
|
||||
if not userinfo:
|
||||
userinfo = await oauth.userinfo(token=token)
|
||||
|
||||
profile = google_provider.profile_from_userinfo(dict(userinfo))
|
||||
user = await user_service.upsert_from_oauth(session, profile)
|
||||
|
||||
jwt_service = JWTService()
|
||||
pair = jwt_service.issue_pair(user_id=str(user.id), email=user.email, name=user.name)
|
||||
|
||||
settings = get_settings()
|
||||
redirect = RedirectResponse(
|
||||
url=f"{settings.frontend_url.rstrip('/')}/chat?access_token={pair.access_token}",
|
||||
status_code=status.HTTP_302_FOUND,
|
||||
)
|
||||
_set_refresh_cookie(redirect, pair.refresh_token, pair.refresh_expires_in)
|
||||
return redirect
|
||||
|
||||
|
||||
@router.get("/github")
|
||||
@limiter.limit("10/minute")
|
||||
async def github_start(request: Request):
|
||||
settings = get_settings()
|
||||
redirect_uri = f"{settings.auth_base_url.rstrip('/')}/auth/github/callback"
|
||||
return await _github_oauth(request).authorize_redirect(request, redirect_uri)
|
||||
|
||||
|
||||
@router.get("/github/callback")
|
||||
@limiter.limit("10/minute")
|
||||
async def github_callback(
|
||||
request: Request,
|
||||
session: AsyncSession = Depends(get_session),
|
||||
):
|
||||
oauth = _github_oauth(request)
|
||||
try:
|
||||
token = await oauth.authorize_access_token(request)
|
||||
except Exception as exc:
|
||||
raise HTTPException(status.HTTP_400_BAD_REQUEST, f"github oauth failed: {exc}") from exc
|
||||
|
||||
user_resp = await oauth.get("user", token=token)
|
||||
user_resp.raise_for_status()
|
||||
userinfo = user_resp.json()
|
||||
|
||||
emails: list[dict] | None = None
|
||||
if not userinfo.get("email"):
|
||||
emails_resp = await oauth.get("user/emails", token=token)
|
||||
if emails_resp.status_code == 200:
|
||||
emails = emails_resp.json()
|
||||
|
||||
profile = github_provider.profile_from_userinfo(dict(userinfo), emails)
|
||||
user = await user_service.upsert_from_oauth(session, profile)
|
||||
|
||||
jwt_service = JWTService()
|
||||
pair = jwt_service.issue_pair(user_id=str(user.id), email=user.email, name=user.name)
|
||||
|
||||
settings = get_settings()
|
||||
redirect = RedirectResponse(
|
||||
url=f"{settings.frontend_url.rstrip('/')}/chat?access_token={pair.access_token}",
|
||||
status_code=status.HTTP_302_FOUND,
|
||||
)
|
||||
_set_refresh_cookie(redirect, pair.refresh_token, pair.refresh_expires_in)
|
||||
return redirect
|
||||
80
services/auth/src/routes/session.py
Normal file
80
services/auth/src/routes/session.py
Normal file
|
|
@ -0,0 +1,80 @@
|
|||
"""Session/token refresh routes + internal JWT validation."""
|
||||
from __future__ import annotations
|
||||
|
||||
from fastapi import APIRouter, Depends, Header, HTTPException, Request, status
|
||||
from fastapi.responses import JSONResponse
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from ..config import get_settings
|
||||
from ..database import get_session
|
||||
from ..rate_limit import limiter
|
||||
from ..services import user_service
|
||||
from ..services.jwt_service import JWTError, JWTService
|
||||
|
||||
router = APIRouter(prefix="/auth", tags=["session"])
|
||||
|
||||
|
||||
class RefreshResponse(BaseModel):
|
||||
access_token: str
|
||||
expires_in: int
|
||||
|
||||
|
||||
class ValidateRequest(BaseModel):
|
||||
token: str
|
||||
|
||||
|
||||
@router.post("/refresh", response_model=RefreshResponse)
|
||||
@limiter.limit("30/minute")
|
||||
async def refresh(
|
||||
request: Request,
|
||||
session: AsyncSession = Depends(get_session),
|
||||
):
|
||||
settings = get_settings()
|
||||
cookie = request.cookies.get(settings.refresh_cookie_name)
|
||||
if not cookie:
|
||||
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "missing refresh cookie")
|
||||
|
||||
jwt_service = JWTService()
|
||||
try:
|
||||
payload = jwt_service.decode(cookie, expected_type="refresh")
|
||||
except JWTError as exc:
|
||||
raise HTTPException(status.HTTP_401_UNAUTHORIZED, str(exc)) from exc
|
||||
|
||||
user = await user_service.get_by_id(session, payload["sub"])
|
||||
if user is None:
|
||||
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "user not found")
|
||||
|
||||
access, ttl = jwt_service.issue_access_token(
|
||||
user_id=str(user.id), email=user.email, name=user.name
|
||||
)
|
||||
return RefreshResponse(access_token=access, expires_in=ttl)
|
||||
|
||||
|
||||
@router.post("/validate")
|
||||
async def validate(
|
||||
payload: ValidateRequest,
|
||||
session: AsyncSession = Depends(get_session),
|
||||
x_internal_api_key: str | None = Header(default=None, alias="X-Internal-API-Key"),
|
||||
):
|
||||
settings = get_settings()
|
||||
if not settings.internal_api_key or x_internal_api_key != settings.internal_api_key:
|
||||
raise HTTPException(status.HTTP_403_FORBIDDEN, "invalid internal api key")
|
||||
|
||||
jwt_service = JWTService()
|
||||
try:
|
||||
claims = jwt_service.decode(payload.token, expected_type="access")
|
||||
except JWTError as exc:
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
content={"valid": False, "reason": str(exc)},
|
||||
)
|
||||
|
||||
user = await user_service.get_by_id(session, claims["sub"])
|
||||
if user is None:
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
content={"valid": False, "reason": "user not found"},
|
||||
)
|
||||
|
||||
return {"valid": True, "user": user.to_dict(), "claims": claims}
|
||||
46
services/auth/src/routes/users.py
Normal file
46
services/auth/src/routes/users.py
Normal file
|
|
@ -0,0 +1,46 @@
|
|||
"""User profile routes (GET /auth/me, PUT /auth/me)."""
|
||||
from __future__ import annotations
|
||||
|
||||
from fastapi import APIRouter, Depends
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from ..database import get_session
|
||||
from ..middleware.auth_middleware import AuthContext, require_user
|
||||
from ..services import user_service
|
||||
|
||||
router = APIRouter(prefix="/auth", tags=["users"])
|
||||
|
||||
|
||||
class UserProfile(BaseModel):
|
||||
id: str
|
||||
email: str
|
||||
name: str | None
|
||||
avatar_url: str | None
|
||||
provider: str
|
||||
provider_id: str
|
||||
created_at: str | None
|
||||
updated_at: str | None
|
||||
last_login_at: str | None
|
||||
|
||||
|
||||
class ProfileUpdate(BaseModel):
|
||||
name: str | None = Field(default=None, max_length=255)
|
||||
avatar_url: str | None = Field(default=None, max_length=2048)
|
||||
|
||||
|
||||
@router.get("/me", response_model=UserProfile)
|
||||
async def me(ctx: AuthContext = Depends(require_user)) -> UserProfile:
|
||||
return UserProfile(**ctx.user.to_dict())
|
||||
|
||||
|
||||
@router.put("/me", response_model=UserProfile)
|
||||
async def update_me(
|
||||
payload: ProfileUpdate,
|
||||
ctx: AuthContext = Depends(require_user),
|
||||
session: AsyncSession = Depends(get_session),
|
||||
) -> UserProfile:
|
||||
user = await user_service.update_profile(
|
||||
session, ctx.user, name=payload.name, avatar_url=payload.avatar_url
|
||||
)
|
||||
return UserProfile(**user.to_dict())
|
||||
0
services/auth/src/services/__init__.py
Normal file
0
services/auth/src/services/__init__.py
Normal file
56
services/auth/src/services/github_oauth.py
Normal file
56
services/auth/src/services/github_oauth.py
Normal file
|
|
@ -0,0 +1,56 @@
|
|||
"""GitHub OAuth provider via authlib.
|
||||
|
||||
Authorization URL: https://github.com/login/oauth/authorize
|
||||
Token URL: https://github.com/login/oauth/access_token
|
||||
User API: https://api.github.com/user
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from authlib.integrations.starlette_client import OAuth
|
||||
|
||||
from ..config import Settings, get_settings
|
||||
from .google_oauth import OAuthProfile
|
||||
|
||||
AUTHORIZE_URL = "https://github.com/login/oauth/authorize"
|
||||
TOKEN_URL = "https://github.com/login/oauth/access_token"
|
||||
USERINFO_URL = "https://api.github.com/user"
|
||||
EMAILS_URL = "https://api.github.com/user/emails"
|
||||
|
||||
|
||||
def build_github_client(settings: Settings | None = None) -> OAuth:
|
||||
settings = settings or get_settings()
|
||||
oauth = OAuth()
|
||||
oauth.register(
|
||||
name="github",
|
||||
client_id=settings.github_client_id,
|
||||
client_secret=settings.github_client_secret,
|
||||
access_token_url=TOKEN_URL,
|
||||
authorize_url=AUTHORIZE_URL,
|
||||
api_base_url="https://api.github.com/",
|
||||
client_kwargs={"scope": "read:user user:email"},
|
||||
)
|
||||
return oauth
|
||||
|
||||
|
||||
def profile_from_userinfo(userinfo: dict[str, Any], emails: list[dict[str, Any]] | None) -> OAuthProfile:
|
||||
provider_id = userinfo.get("id")
|
||||
if provider_id is None:
|
||||
raise ValueError("github userinfo missing id")
|
||||
|
||||
email = userinfo.get("email")
|
||||
if not email and emails:
|
||||
primary = next((e for e in emails if e.get("primary") and e.get("verified")), None)
|
||||
if primary:
|
||||
email = primary.get("email")
|
||||
if not email:
|
||||
raise ValueError("github userinfo missing verified email")
|
||||
|
||||
return OAuthProfile(
|
||||
provider="github",
|
||||
provider_id=str(provider_id),
|
||||
email=str(email),
|
||||
name=userinfo.get("name") or userinfo.get("login"),
|
||||
avatar_url=userinfo.get("avatar_url"),
|
||||
)
|
||||
50
services/auth/src/services/google_oauth.py
Normal file
50
services/auth/src/services/google_oauth.py
Normal file
|
|
@ -0,0 +1,50 @@
|
|||
"""Google OAuth provider via authlib.
|
||||
|
||||
Discovery URL: https://accounts.google.com/.well-known/openid-configuration
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
from authlib.integrations.starlette_client import OAuth
|
||||
|
||||
from ..config import Settings, get_settings
|
||||
|
||||
DISCOVERY_URL = "https://accounts.google.com/.well-known/openid-configuration"
|
||||
|
||||
|
||||
@dataclass
|
||||
class OAuthProfile:
|
||||
provider: str
|
||||
provider_id: str
|
||||
email: str
|
||||
name: str | None
|
||||
avatar_url: str | None
|
||||
|
||||
|
||||
def build_google_client(settings: Settings | None = None) -> OAuth:
|
||||
settings = settings or get_settings()
|
||||
oauth = OAuth()
|
||||
oauth.register(
|
||||
name="google",
|
||||
client_id=settings.google_client_id,
|
||||
client_secret=settings.google_client_secret,
|
||||
server_metadata_url=DISCOVERY_URL,
|
||||
client_kwargs={"scope": "openid email profile"},
|
||||
)
|
||||
return oauth
|
||||
|
||||
|
||||
def profile_from_userinfo(userinfo: dict[str, Any]) -> OAuthProfile:
|
||||
provider_id = userinfo.get("sub") or userinfo.get("id")
|
||||
email = userinfo.get("email")
|
||||
if not provider_id or not email:
|
||||
raise ValueError("google userinfo missing sub/email")
|
||||
return OAuthProfile(
|
||||
provider="google",
|
||||
provider_id=str(provider_id),
|
||||
email=str(email),
|
||||
name=userinfo.get("name"),
|
||||
avatar_url=userinfo.get("picture"),
|
||||
)
|
||||
87
services/auth/src/services/jwt_service.py
Normal file
87
services/auth/src/services/jwt_service.py
Normal file
|
|
@ -0,0 +1,87 @@
|
|||
"""RS256 JWT issuance + validation.
|
||||
|
||||
Access tokens (1h) are returned to the client for Bearer auth.
|
||||
Refresh tokens (7d) are stored in an httpOnly cookie.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Any
|
||||
|
||||
import jwt
|
||||
|
||||
from ..config import Settings, get_settings
|
||||
|
||||
|
||||
class JWTError(Exception):
|
||||
"""Raised when a token fails validation."""
|
||||
|
||||
|
||||
@dataclass
|
||||
class TokenPair:
|
||||
access_token: str
|
||||
refresh_token: str
|
||||
access_expires_in: int
|
||||
refresh_expires_in: int
|
||||
|
||||
|
||||
class JWTService:
|
||||
def __init__(self, settings: Settings | None = None) -> None:
|
||||
self._settings = settings or get_settings()
|
||||
|
||||
def _now(self) -> datetime:
|
||||
return datetime.now(timezone.utc)
|
||||
|
||||
def issue_access_token(self, *, user_id: str, email: str, name: str | None) -> tuple[str, int]:
|
||||
now = self._now()
|
||||
exp = now + timedelta(seconds=self._settings.jwt_access_ttl_seconds)
|
||||
payload = {
|
||||
"sub": user_id,
|
||||
"email": email,
|
||||
"name": name,
|
||||
"iat": int(now.timestamp()),
|
||||
"exp": int(exp.timestamp()),
|
||||
"iss": self._settings.jwt_issuer,
|
||||
"type": "access",
|
||||
}
|
||||
token = jwt.encode(payload, self._settings.jwt_private_key, algorithm="RS256")
|
||||
return token, self._settings.jwt_access_ttl_seconds
|
||||
|
||||
def issue_refresh_token(self, *, user_id: str) -> tuple[str, int]:
|
||||
now = self._now()
|
||||
exp = now + timedelta(seconds=self._settings.jwt_refresh_ttl_seconds)
|
||||
payload = {
|
||||
"sub": user_id,
|
||||
"iat": int(now.timestamp()),
|
||||
"exp": int(exp.timestamp()),
|
||||
"iss": self._settings.jwt_issuer,
|
||||
"type": "refresh",
|
||||
"jti": uuid.uuid4().hex,
|
||||
}
|
||||
token = jwt.encode(payload, self._settings.jwt_private_key, algorithm="RS256")
|
||||
return token, self._settings.jwt_refresh_ttl_seconds
|
||||
|
||||
def issue_pair(self, *, user_id: str, email: str, name: str | None) -> TokenPair:
|
||||
access, access_ttl = self.issue_access_token(user_id=user_id, email=email, name=name)
|
||||
refresh, refresh_ttl = self.issue_refresh_token(user_id=user_id)
|
||||
return TokenPair(access, refresh, access_ttl, refresh_ttl)
|
||||
|
||||
def decode(self, token: str, *, expected_type: str | None = None) -> dict[str, Any]:
|
||||
try:
|
||||
payload = jwt.decode(
|
||||
token,
|
||||
self._settings.jwt_public_key,
|
||||
algorithms=["RS256"],
|
||||
issuer=self._settings.jwt_issuer,
|
||||
options={"require": ["exp", "iat", "sub", "iss"]},
|
||||
)
|
||||
except jwt.ExpiredSignatureError as exc:
|
||||
raise JWTError("token expired") from exc
|
||||
except jwt.InvalidTokenError as exc:
|
||||
raise JWTError(f"invalid token: {exc}") from exc
|
||||
|
||||
if expected_type and payload.get("type") != expected_type:
|
||||
raise JWTError(f"expected {expected_type} token, got {payload.get('type')!r}")
|
||||
return payload
|
||||
75
services/auth/src/services/user_service.py
Normal file
75
services/auth/src/services/user_service.py
Normal file
|
|
@ -0,0 +1,75 @@
|
|||
"""User persistence helpers (upsert on OAuth callback, fetch by id)."""
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from ..models.user import User
|
||||
from .google_oauth import OAuthProfile
|
||||
|
||||
|
||||
async def upsert_from_oauth(session: AsyncSession, profile: OAuthProfile) -> User:
|
||||
"""Insert a new user, or update last_login_at on an existing one."""
|
||||
stmt = select(User).where(
|
||||
User.provider == profile.provider,
|
||||
User.provider_id == profile.provider_id,
|
||||
)
|
||||
existing = (await session.execute(stmt)).scalar_one_or_none()
|
||||
|
||||
now = datetime.now(timezone.utc)
|
||||
if existing is None:
|
||||
user = User(
|
||||
email=profile.email,
|
||||
name=profile.name,
|
||||
avatar_url=profile.avatar_url,
|
||||
provider=profile.provider,
|
||||
provider_id=profile.provider_id,
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
last_login_at=now,
|
||||
)
|
||||
session.add(user)
|
||||
await session.commit()
|
||||
await session.refresh(user)
|
||||
return user
|
||||
|
||||
existing.last_login_at = now
|
||||
existing.updated_at = now
|
||||
if profile.email:
|
||||
existing.email = profile.email
|
||||
if profile.name is not None:
|
||||
existing.name = profile.name
|
||||
if profile.avatar_url is not None:
|
||||
existing.avatar_url = profile.avatar_url
|
||||
await session.commit()
|
||||
await session.refresh(existing)
|
||||
return existing
|
||||
|
||||
|
||||
async def get_by_id(session: AsyncSession, user_id: str | uuid.UUID) -> User | None:
|
||||
if isinstance(user_id, str):
|
||||
try:
|
||||
user_id = uuid.UUID(user_id)
|
||||
except ValueError:
|
||||
return None
|
||||
return await session.get(User, user_id)
|
||||
|
||||
|
||||
async def update_profile(
|
||||
session: AsyncSession,
|
||||
user: User,
|
||||
*,
|
||||
name: str | None,
|
||||
avatar_url: str | None,
|
||||
) -> User:
|
||||
if name is not None:
|
||||
user.name = name
|
||||
if avatar_url is not None:
|
||||
user.avatar_url = avatar_url
|
||||
user.updated_at = datetime.now(timezone.utc)
|
||||
await session.commit()
|
||||
await session.refresh(user)
|
||||
return user
|
||||
0
services/auth/src/tests/__init__.py
Normal file
0
services/auth/src/tests/__init__.py
Normal file
84
services/auth/src/tests/conftest.py
Normal file
84
services/auth/src/tests/conftest.py
Normal file
|
|
@ -0,0 +1,84 @@
|
|||
"""Shared pytest fixtures: RSA keys, in-memory DB, FastAPI test client."""
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import uuid
|
||||
from collections.abc import AsyncIterator
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from cryptography.hazmat.primitives import serialization
|
||||
from cryptography.hazmat.primitives.asymmetric import rsa
|
||||
from httpx import ASGITransport, AsyncClient
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||
|
||||
|
||||
def _make_rsa_pair() -> tuple[str, str]:
|
||||
key = rsa.generate_private_key(public_exponent=65537, key_size=2048)
|
||||
private = key.private_bytes(
|
||||
encoding=serialization.Encoding.PEM,
|
||||
format=serialization.PrivateFormat.PKCS8,
|
||||
encryption_algorithm=serialization.NoEncryption(),
|
||||
).decode()
|
||||
public = key.public_key().public_bytes(
|
||||
encoding=serialization.Encoding.PEM,
|
||||
format=serialization.PublicFormat.SubjectPublicKeyInfo,
|
||||
).decode()
|
||||
return private, public
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
def _test_environment():
|
||||
private, public = _make_rsa_pair()
|
||||
os.environ["JWT_PRIVATE_KEY"] = private
|
||||
os.environ["JWT_PUBLIC_KEY"] = public
|
||||
os.environ["INTERNAL_API_KEY"] = "test-internal-key"
|
||||
os.environ["FRONTEND_URL"] = "http://localhost:3000"
|
||||
os.environ["AUTH_BASE_URL"] = "http://localhost:8001"
|
||||
os.environ["SESSION_SECRET"] = "test-session-secret"
|
||||
os.environ["DATABASE_URL"] = "sqlite+aiosqlite:///:memory:"
|
||||
os.environ["GOOGLE_CLIENT_ID"] = "g-id"
|
||||
os.environ["GOOGLE_CLIENT_SECRET"] = "g-secret"
|
||||
os.environ["GITHUB_CLIENT_ID"] = "gh-id"
|
||||
os.environ["GITHUB_CLIENT_SECRET"] = "gh-secret"
|
||||
|
||||
from src.config import get_settings
|
||||
|
||||
get_settings.cache_clear()
|
||||
yield
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def session_factory() -> AsyncIterator[async_sessionmaker[AsyncSession]]:
|
||||
from src.database import Base, override_session_factory
|
||||
from src.models.user import User # noqa: F401 — register on Base
|
||||
|
||||
engine = create_async_engine(
|
||||
f"sqlite+aiosqlite:///file:test_{uuid.uuid4().hex}?mode=memory&cache=shared&uri=true",
|
||||
connect_args={"uri": True},
|
||||
)
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
|
||||
factory = async_sessionmaker(engine, expire_on_commit=False)
|
||||
override_session_factory(factory)
|
||||
try:
|
||||
yield factory
|
||||
finally:
|
||||
await engine.dispose()
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def db_session(session_factory) -> AsyncIterator[AsyncSession]:
|
||||
async with session_factory() as session:
|
||||
yield session
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def client(session_factory) -> AsyncIterator[AsyncClient]:
|
||||
from src.main import create_app
|
||||
|
||||
app = create_app()
|
||||
transport = ASGITransport(app=app)
|
||||
async with AsyncClient(transport=transport, base_url="http://testserver") as ac:
|
||||
yield ac
|
||||
59
services/auth/src/tests/test_me_endpoint.py
Normal file
59
services/auth/src/tests/test_me_endpoint.py
Normal file
|
|
@ -0,0 +1,59 @@
|
|||
"""GET /auth/me and PUT /auth/me via Bearer JWT."""
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from src.services import user_service
|
||||
from src.services.google_oauth import OAuthProfile
|
||||
from src.services.jwt_service import JWTService
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_me_returns_profile(client, db_session):
|
||||
profile = OAuthProfile(
|
||||
provider="github", provider_id="gh-1", email="me@x.co", name="Me", avatar_url=None
|
||||
)
|
||||
user = await user_service.upsert_from_oauth(db_session, profile)
|
||||
token, _ = JWTService().issue_access_token(
|
||||
user_id=str(user.id), email=user.email, name=user.name
|
||||
)
|
||||
|
||||
resp = await client.get("/auth/me", headers={"Authorization": f"Bearer {token}"})
|
||||
assert resp.status_code == 200
|
||||
body = resp.json()
|
||||
assert body["email"] == "me@x.co"
|
||||
assert body["provider"] == "github"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_me_requires_bearer(client):
|
||||
resp = await client.get("/auth/me")
|
||||
assert resp.status_code == 401
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_me(client, db_session):
|
||||
profile = OAuthProfile(
|
||||
provider="google", provider_id="g-9", email="x@y.co", name="Old", avatar_url=None
|
||||
)
|
||||
user = await user_service.upsert_from_oauth(db_session, profile)
|
||||
token, _ = JWTService().issue_access_token(
|
||||
user_id=str(user.id), email=user.email, name=user.name
|
||||
)
|
||||
|
||||
resp = await client.put(
|
||||
"/auth/me",
|
||||
json={"name": "New Name", "avatar_url": "https://img/x.png"},
|
||||
headers={"Authorization": f"Bearer {token}"},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
body = resp.json()
|
||||
assert body["name"] == "New Name"
|
||||
assert body["avatar_url"] == "https://img/x.png"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_health(client):
|
||||
resp = await client.get("/auth/health")
|
||||
assert resp.status_code == 200
|
||||
assert resp.json() == {"status": "ok"}
|
||||
97
services/auth/src/tests/test_oauth_flow.py
Normal file
97
services/auth/src/tests/test_oauth_flow.py
Normal file
|
|
@ -0,0 +1,97 @@
|
|||
"""OAuth callback end-to-end with mocked authlib providers."""
|
||||
from __future__ import annotations
|
||||
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
|
||||
from src.services import user_service
|
||||
from src.services.google_oauth import OAuthProfile
|
||||
|
||||
|
||||
class _MockGoogleClient:
|
||||
"""Stands in for authlib's StarletteOAuth2App."""
|
||||
|
||||
def __init__(self, userinfo):
|
||||
self._userinfo = userinfo
|
||||
|
||||
async def authorize_access_token(self, request):
|
||||
return {"access_token": "fake", "userinfo": self._userinfo}
|
||||
|
||||
async def userinfo(self, token=None):
|
||||
return self._userinfo
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_google_callback_creates_user_and_sets_refresh_cookie(client, db_session):
|
||||
userinfo = {
|
||||
"sub": "google-42",
|
||||
"email": "new@user.co",
|
||||
"name": "New User",
|
||||
"picture": "https://img/new.png",
|
||||
}
|
||||
app = client._transport.app # type: ignore[attr-defined]
|
||||
app.state.google_oauth = SimpleNamespace(google=_MockGoogleClient(userinfo))
|
||||
|
||||
resp = await client.get("/auth/google/callback", follow_redirects=False)
|
||||
assert resp.status_code == 302
|
||||
assert "access_token=" in resp.headers["location"]
|
||||
|
||||
from src.config import get_settings
|
||||
|
||||
cookie_name = get_settings().refresh_cookie_name
|
||||
assert cookie_name in resp.cookies
|
||||
|
||||
# Verify user was persisted.
|
||||
from sqlalchemy import select
|
||||
|
||||
from src.models.user import User
|
||||
|
||||
async with db_session.bind._async_engine.connect() if False else db_session as s: # type: ignore[attr-defined]
|
||||
pass
|
||||
# Simpler: just query through a fresh session.
|
||||
from src.database import get_session_factory
|
||||
|
||||
async with get_session_factory()() as s:
|
||||
user = (
|
||||
await s.execute(select(User).where(User.provider_id == "google-42"))
|
||||
).scalar_one()
|
||||
assert user.email == "new@user.co"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_google_callback_updates_existing(client, db_session):
|
||||
# Seed an existing user.
|
||||
profile = OAuthProfile(
|
||||
provider="google",
|
||||
provider_id="google-99",
|
||||
email="old@user.co",
|
||||
name="Old",
|
||||
avatar_url=None,
|
||||
)
|
||||
existing = await user_service.upsert_from_oauth(db_session, profile)
|
||||
original_login = existing.last_login_at
|
||||
|
||||
userinfo = {
|
||||
"sub": "google-99",
|
||||
"email": "old@user.co",
|
||||
"name": "Updated Name",
|
||||
"picture": "https://img/u.png",
|
||||
}
|
||||
app = client._transport.app # type: ignore[attr-defined]
|
||||
app.state.google_oauth = SimpleNamespace(google=_MockGoogleClient(userinfo))
|
||||
|
||||
resp = await client.get("/auth/google/callback", follow_redirects=False)
|
||||
assert resp.status_code == 302
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
from src.database import get_session_factory
|
||||
from src.models.user import User
|
||||
|
||||
async with get_session_factory()() as s:
|
||||
refreshed = (
|
||||
await s.execute(select(User).where(User.id == existing.id))
|
||||
).scalar_one()
|
||||
assert refreshed.name == "Updated Name"
|
||||
assert refreshed.last_login_at >= original_login
|
||||
29
services/auth/src/tests/test_rate_limit.py
Normal file
29
services/auth/src/tests/test_rate_limit.py
Normal file
|
|
@ -0,0 +1,29 @@
|
|||
"""Rate limiter applies 10/min to OAuth start routes."""
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_google_start_rate_limited(client, monkeypatch):
|
||||
# Replace the OAuth client with a stub so /auth/google returns immediately.
|
||||
async def _stub_redirect(request, redirect_uri):
|
||||
from fastapi.responses import RedirectResponse
|
||||
|
||||
return RedirectResponse(url=redirect_uri, status_code=302)
|
||||
|
||||
class _StubProvider:
|
||||
authorize_redirect = staticmethod(_stub_redirect)
|
||||
|
||||
class _StubClient:
|
||||
google = _StubProvider()
|
||||
|
||||
app = client._transport.app # type: ignore[attr-defined]
|
||||
app.state.google_oauth = _StubClient()
|
||||
|
||||
# First 10 calls allowed, 11th should be rate-limited.
|
||||
codes = []
|
||||
for _ in range(11):
|
||||
resp = await client.get("/auth/google", follow_redirects=False)
|
||||
codes.append(resp.status_code)
|
||||
assert codes.count(429) >= 1
|
||||
68
services/auth/src/tests/test_user_upsert.py
Normal file
68
services/auth/src/tests/test_user_upsert.py
Normal file
|
|
@ -0,0 +1,68 @@
|
|||
"""User upsert flow invoked by OAuth callbacks."""
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from src.services import user_service
|
||||
from src.services.google_oauth import OAuthProfile
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_upsert_creates_new_user(db_session):
|
||||
profile = OAuthProfile(
|
||||
provider="google",
|
||||
provider_id="g-12345",
|
||||
email="alice@example.com",
|
||||
name="Alice",
|
||||
avatar_url="https://img/alice.png",
|
||||
)
|
||||
user = await user_service.upsert_from_oauth(db_session, profile)
|
||||
assert user.id is not None
|
||||
assert user.email == "alice@example.com"
|
||||
assert user.provider == "google"
|
||||
assert user.last_login_at is not None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_upsert_updates_existing_user(db_session):
|
||||
first = OAuthProfile(
|
||||
provider="google",
|
||||
provider_id="g-12345",
|
||||
email="alice@example.com",
|
||||
name="Alice",
|
||||
avatar_url=None,
|
||||
)
|
||||
u1 = await user_service.upsert_from_oauth(db_session, first)
|
||||
original_login = u1.last_login_at
|
||||
|
||||
# Second login with updated display name + avatar — same provider identity.
|
||||
second = OAuthProfile(
|
||||
provider="google",
|
||||
provider_id="g-12345",
|
||||
email="alice@example.com",
|
||||
name="Alice Smith",
|
||||
avatar_url="https://img/alice2.png",
|
||||
)
|
||||
u2 = await user_service.upsert_from_oauth(db_session, second)
|
||||
|
||||
assert u2.id == u1.id # same row
|
||||
assert u2.name == "Alice Smith"
|
||||
assert u2.avatar_url == "https://img/alice2.png"
|
||||
assert u2.last_login_at >= original_login
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_profile(db_session):
|
||||
profile = OAuthProfile(
|
||||
provider="github",
|
||||
provider_id="42",
|
||||
email="bob@example.com",
|
||||
name="Bob",
|
||||
avatar_url=None,
|
||||
)
|
||||
user = await user_service.upsert_from_oauth(db_session, profile)
|
||||
updated = await user_service.update_profile(
|
||||
db_session, user, name="Robert", avatar_url="https://img/bob.png"
|
||||
)
|
||||
assert updated.name == "Robert"
|
||||
assert updated.avatar_url == "https://img/bob.png"
|
||||
62
services/auth/src/tests/test_validate_endpoint.py
Normal file
62
services/auth/src/tests/test_validate_endpoint.py
Normal file
|
|
@ -0,0 +1,62 @@
|
|||
"""Internal /auth/validate endpoint used by the Chat API service."""
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from src.services import user_service
|
||||
from src.services.google_oauth import OAuthProfile
|
||||
from src.services.jwt_service import JWTService
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_validate_requires_internal_key(client, db_session):
|
||||
profile = OAuthProfile(
|
||||
provider="google", provider_id="123", email="v@x.co", name="V", avatar_url=None
|
||||
)
|
||||
user = await user_service.upsert_from_oauth(db_session, profile)
|
||||
token, _ = JWTService().issue_access_token(
|
||||
user_id=str(user.id), email=user.email, name=user.name
|
||||
)
|
||||
|
||||
missing = await client.post("/auth/validate", json={"token": token})
|
||||
assert missing.status_code == 403
|
||||
|
||||
wrong = await client.post(
|
||||
"/auth/validate",
|
||||
json={"token": token},
|
||||
headers={"X-Internal-API-Key": "nope"},
|
||||
)
|
||||
assert wrong.status_code == 403
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_validate_returns_user_for_valid_token(client, db_session):
|
||||
profile = OAuthProfile(
|
||||
provider="google", provider_id="456", email="v2@x.co", name="V2", avatar_url=None
|
||||
)
|
||||
user = await user_service.upsert_from_oauth(db_session, profile)
|
||||
token, _ = JWTService().issue_access_token(
|
||||
user_id=str(user.id), email=user.email, name=user.name
|
||||
)
|
||||
|
||||
resp = await client.post(
|
||||
"/auth/validate",
|
||||
json={"token": token},
|
||||
headers={"X-Internal-API-Key": "test-internal-key"},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
body = resp.json()
|
||||
assert body["valid"] is True
|
||||
assert body["user"]["email"] == "v2@x.co"
|
||||
assert body["claims"]["sub"] == str(user.id)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_validate_rejects_tampered_token(client):
|
||||
resp = await client.post(
|
||||
"/auth/validate",
|
||||
json={"token": "not-a-jwt"},
|
||||
headers={"X-Internal-API-Key": "test-internal-key"},
|
||||
)
|
||||
assert resp.status_code == 401
|
||||
assert resp.json()["valid"] is False
|
||||
Loading…
Reference in New Issue
Block a user