diff --git a/.gitignore b/.gitignore index 5bb3231d..ae753b8e 100644 --- a/.gitignore +++ b/.gitignore @@ -85,6 +85,18 @@ docs/plans/ # Build artifacts *.egg-info/ +# PostgreSQL data directory +db/ + +# macOS +.DS_Store +**/.DS_Store +.AppleDouble +.LSOverride +._* +.Spotlight-V100 +.Trashes + # Local developer overrides (not committed) docker-compose.override.yaml docker-compose.override.yml diff --git a/docs/content/docs/guides/openwebui-keycloak.mdx b/docs/content/docs/guides/openwebui-keycloak.mdx new file mode 100644 index 00000000..1d31968e --- /dev/null +++ b/docs/content/docs/guides/openwebui-keycloak.mdx @@ -0,0 +1,164 @@ +--- +title: Open WebUI + Keycloak + OpenRAG Integration +description: How to set up SSO authentication with Keycloak for Open WebUI and OpenRAG +--- + +This guide explains how to configure **Keycloak** as a shared identity provider for both **Open WebUI** (chat frontend) and **OpenRAG** (RAG backend), so that user identity and group-based partition access flow seamlessly from login to document retrieval. + +## Architecture Overview + +```mermaid +sequenceDiagram + participant U as User + participant KC as Keycloak + participant OWUI as Open WebUI + participant OR as OpenRAG + + U->>KC: Login (OIDC) + KC-->>U: JWT (sub, email, groups) + U->>OWUI: Access chat + OWUI->>OR: Forward JWT (Authorization: Bearer) + OR->>OR: Validate JWT, sync group→partition memberships + OR-->>OWUI: RAG response (filtered by user partitions) +``` + +## 1. Keycloak Configuration + +### Create a Client for OpenRAG + +1. In your Keycloak realm, go to **Clients** > **Create client** +2. Set **Client ID** to `openrag` (this will be `OIDC_AUDIENCE`) +3. Set **Client authentication** to **On** +4. Set **Valid redirect URIs** to your Open WebUI URL (e.g., `https://chat.example.com/*`) +5. Note the **Client secret** from the Credentials tab + +### Configure Group Claim Mapper + +1. Go to **Client scopes** > **openrag-dedicated** > **Mappers** > **Create mapper** +2. Choose **Group Membership** mapper type +3. Set: + - **Name**: `groups` + - **Token Claim Name**: `groups` + - **Full group path**: ON + - **Add to ID token**: ON + - **Add to access token**: ON + +### Create Groups + +Create groups following this naming convention: + +| Group path | OpenRAG role | Description | +|-----------|-------------|-------------| +| `/rag-query/` | viewer | Can search and view documents | +| `/rag-edit/` | editor | Can upload and manage files | +| `/rag-admin/` | owner | Full partition control | + +Example groups: +- `/rag-query/finance` — Read access to the "finance" partition +- `/rag-edit/finance` — Upload files to "finance" +- `/rag-admin/hr` — Full control of the "hr" partition + +Assign users to the appropriate groups. + +## 2. Open WebUI Configuration + +Set these environment variables in your Open WebUI deployment: + +```bash +# Enable OAuth/OIDC +ENABLE_OAUTH_SIGNUP=true +OAUTH_CLIENT_ID=openrag +OAUTH_CLIENT_SECRET= +OPENID_PROVIDER_URL=https://keycloak.example.com/realms/myrealm/.well-known/openid-configuration + +# Forward the user's JWT to backend APIs +ENABLE_FORWARD_OAUTH_TOKEN=true + +# OpenRAG as OpenAI-compatible backend +OPENAI_API_BASE_URL=https://openrag.example.com/v1 +OPENAI_API_KEY=unused # JWT is forwarded instead +``` + +With `ENABLE_FORWARD_OAUTH_TOKEN=true`, Open WebUI sends the user's Keycloak JWT as the `Authorization: Bearer` header to OpenRAG, instead of a static API key. + +## 3. OpenRAG Configuration + +Set these environment variables: + +```bash +# Switch to OIDC mode +AUTH_MODE=oidc + +# Keycloak OIDC settings +OIDC_ISSUER_URL=https://keycloak.example.com/realms/myrealm +OIDC_AUDIENCE=openrag +OIDC_JWKS_CACHE_TTL=3600 +OIDC_GROUP_CLAIM=groups +OIDC_AUTO_PROVISION=true + +# Group prefix mapping (defaults shown) +OIDC_GROUP_PREFIX_VIEWER=rag-query/ +OIDC_GROUP_PREFIX_EDITOR=rag-edit/ +OIDC_GROUP_PREFIX_OWNER=rag-admin/ + +# Sync mode: "additive" (default) or "authoritative" +OIDC_GROUP_SYNC_MODE=additive +``` + +### Sync Modes + +- **Additive** (default): Adds missing partition memberships from Keycloak groups. Never removes existing memberships. Upgrades roles when the JWT grants a higher role, but never downgrades. + +- **Authoritative**: Fully syncs OIDC-sourced memberships. Memberships created via Keycloak are added/updated/removed to match the JWT groups exactly. Manually-created memberships (via the API) are never touched. + +## 4. How It Works + +1. User logs into Open WebUI via Keycloak SSO +2. User sends a chat message in Open WebUI +3. Open WebUI forwards the request to OpenRAG's `/v1/chat/completions` with the user's JWT +4. OpenRAG's AuthMiddleware: + - Validates the JWT signature against Keycloak's JWKS + - Extracts `sub`, `email`, `groups` claims + - Auto-provisions the user if first login + - Syncs Keycloak groups to partition memberships +5. OpenRAG resolves accessible partitions and executes the RAG pipeline +6. Response is filtered to only include sources from authorized partitions + +## 5. Verifying the Setup + +### Check JWT Contents + +Decode a Keycloak token to verify the groups claim: + +```bash +# Get a token +TOKEN=$(curl -s -X POST \ + "https://keycloak.example.com/realms/myrealm/protocol/openid-connect/token" \ + -d "grant_type=password&client_id=openrag&client_secret=SECRET&username=user&password=pass" \ + | jq -r '.access_token') + +# Decode payload +echo $TOKEN | cut -d'.' -f2 | base64 -d 2>/dev/null | jq '.groups' +``` + +Expected output: +```json +["/rag-query/finance", "/rag-edit/legal"] +``` + +### Test OpenRAG Directly + +```bash +curl -H "Authorization: Bearer $TOKEN" https://openrag.example.com/v1/models +``` + +Should return only the models (partitions) the user has access to. + +## Troubleshooting + +| Problem | Solution | +|---------|----------| +| 403 "Missing token" | Verify `ENABLE_FORWARD_OAUTH_TOKEN=true` in Open WebUI | +| 401 "Token has expired" | Check clock sync between servers and token lifetimes in Keycloak | +| User has no partitions | Verify the `groups` claim is present in the JWT and matches the prefix convention | +| 503 "Failed to fetch JWKS" | Check that OpenRAG can reach the Keycloak server at `OIDC_ISSUER_URL` | diff --git a/extern/indexer-ui b/extern/indexer-ui index 9e67d1d1..92e8875e 160000 --- a/extern/indexer-ui +++ b/extern/indexer-ui @@ -1 +1 @@ -Subproject commit 9e67d1d1fc3454acf343016f833dca27abe61578 +Subproject commit 92e8875ee1537f7156e46a8dcb2a6085d887ddc8 diff --git a/openrag/auth/__init__.py b/openrag/auth/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/openrag/auth/oidc.py b/openrag/auth/oidc.py new file mode 100644 index 00000000..97932376 --- /dev/null +++ b/openrag/auth/oidc.py @@ -0,0 +1,275 @@ +"""OIDC/JWT authentication module for OpenRAG. + +Validates JWT tokens issued by an OIDC provider (e.g. Keycloak), +extracts user identity and group claims, and syncs group memberships +to OpenRAG partitions. +""" + +import hashlib +import os +import time +from dataclasses import dataclass, field + +import httpx +from jose import JWTError, jwt +from jose.exceptions import ExpiredSignatureError, JWTClaimsError +from utils.logger import get_logger + +logger = get_logger() + +# --- Configuration from environment --- + +OIDC_ISSUER_URL = os.getenv("OIDC_ISSUER_URL", "") +OIDC_AUDIENCE = os.getenv("OIDC_AUDIENCE", "") +OIDC_JWKS_CACHE_TTL = int(os.getenv("OIDC_JWKS_CACHE_TTL", "3600")) +OIDC_GROUP_CLAIM = os.getenv("OIDC_GROUP_CLAIM", "groups") +OIDC_AUTO_PROVISION = os.getenv("OIDC_AUTO_PROVISION", "true").lower() == "true" + +# Group prefix → role mapping +OIDC_GROUP_PREFIX_VIEWER = os.getenv("OIDC_GROUP_PREFIX_VIEWER", "rag-query/") +OIDC_GROUP_PREFIX_EDITOR = os.getenv("OIDC_GROUP_PREFIX_EDITOR", "rag-edit/") +OIDC_GROUP_PREFIX_OWNER = os.getenv("OIDC_GROUP_PREFIX_OWNER", "rag-admin/") + +# Sync mode: "additive" or "authoritative" +OIDC_GROUP_SYNC_MODE = os.getenv("OIDC_GROUP_SYNC_MODE", "additive") + +ROLE_HIERARCHY = {"viewer": 1, "editor": 2, "owner": 3} + +# Prefixes ordered from highest to lowest role so the highest role wins +GROUP_PREFIX_ROLE_MAP = [ + (OIDC_GROUP_PREFIX_OWNER, "owner"), + (OIDC_GROUP_PREFIX_EDITOR, "editor"), + (OIDC_GROUP_PREFIX_VIEWER, "viewer"), +] + + +# --- JWKS Cache --- + +@dataclass +class _JWKSCache: + keys: dict = field(default_factory=dict) + fetched_at: float = 0.0 + + @property + def expired(self) -> bool: + return (time.time() - self.fetched_at) > OIDC_JWKS_CACHE_TTL + + def clear(self): + self.keys = {} + self.fetched_at = 0.0 + + +_jwks_cache = _JWKSCache() + + +async def _fetch_jwks() -> dict: + """Fetch JWKS from the OIDC issuer's well-known endpoint.""" + if _jwks_cache.keys and not _jwks_cache.expired: + return _jwks_cache.keys + + well_known_url = f"{OIDC_ISSUER_URL.rstrip('/')}/.well-known/openid-configuration" + async with httpx.AsyncClient(timeout=10) as client: + resp = await client.get(well_known_url) + resp.raise_for_status() + oidc_config = resp.json() + + jwks_uri = oidc_config["jwks_uri"] + resp = await client.get(jwks_uri) + resp.raise_for_status() + jwks = resp.json() + + _jwks_cache.keys = jwks + _jwks_cache.fetched_at = time.time() + return jwks + + +def clear_jwks_cache(): + """Clear the JWKS cache (useful for testing).""" + _jwks_cache.clear() + + +# --- JWT Validation --- + +@dataclass +class OIDCIdentity: + """Parsed identity from a validated OIDC JWT.""" + + sub: str + email: str | None = None + display_name: str | None = None + groups: list[str] = field(default_factory=list) + raw_claims: dict = field(default_factory=dict) + + +class OIDCValidationError(Exception): + """Raised when JWT validation fails.""" + + def __init__(self, message: str, status_code: int = 401): + self.message = message + self.status_code = status_code + super().__init__(message) + + +async def validate_jwt(token: str) -> OIDCIdentity: + """Validate a JWT token and return the parsed identity. + + Raises OIDCValidationError on any validation failure. + """ + try: + jwks = await _fetch_jwks() + except Exception as e: + logger.error("Failed to fetch JWKS", error=str(e)) + raise OIDCValidationError(f"Failed to fetch JWKS: {e}", status_code=503) + + try: + # Decode without verification first to get the header + unverified_header = jwt.get_unverified_header(token) + except JWTError as e: + raise OIDCValidationError(f"Invalid JWT header: {e}") + + # Find matching key + kid = unverified_header.get("kid") + rsa_key = None + for key in jwks.get("keys", []): + if key.get("kid") == kid: + rsa_key = key + break + + if not rsa_key: + # Key not found — maybe keys rotated. Clear cache and retry once. + clear_jwks_cache() + try: + jwks = await _fetch_jwks() + except Exception as e: + raise OIDCValidationError(f"Failed to refresh JWKS: {e}", status_code=503) + + for key in jwks.get("keys", []): + if key.get("kid") == kid: + rsa_key = key + break + + if not rsa_key: + raise OIDCValidationError("No matching key found in JWKS for token kid") + + try: + claims = jwt.decode( + token, + rsa_key, + algorithms=["RS256", "RS384", "RS512", "ES256", "ES384", "ES512"], + audience=OIDC_AUDIENCE, + issuer=OIDC_ISSUER_URL, + options={"verify_at_hash": False}, + ) + except ExpiredSignatureError: + raise OIDCValidationError("Token has expired") + except JWTClaimsError as e: + raise OIDCValidationError(f"Invalid token claims: {e}") + except JWTError as e: + raise OIDCValidationError(f"JWT validation failed: {e}") + + sub = claims.get("sub") + if not sub: + raise OIDCValidationError("Token missing 'sub' claim") + + # Extract display name from various possible claims + display_name = claims.get("preferred_username") or claims.get("name") or claims.get("email") + + # Extract groups + groups_raw = claims.get(OIDC_GROUP_CLAIM, []) + if isinstance(groups_raw, str): + groups_raw = [groups_raw] + + return OIDCIdentity( + sub=sub, + email=claims.get("email"), + display_name=display_name, + groups=groups_raw, + raw_claims=claims, + ) + + +# --- Group → Partition Mapping --- + +def parse_partition_roles(groups: list[str]) -> dict[str, str]: + """Parse Keycloak groups into {partition_name: role} mapping. + + If a user belongs to multiple groups for the same partition, + the highest role wins. + + Groups may have a leading '/' (Keycloak convention) which is stripped. + """ + partition_roles: dict[str, str] = {} + + for group in groups: + # Strip leading slash + g = group.lstrip("/") + + for prefix, role in GROUP_PREFIX_ROLE_MAP: + if g.startswith(prefix): + partition = g[len(prefix):] + if not partition: + continue + existing_role = partition_roles.get(partition) + if existing_role is None or ROLE_HIERARCHY[role] > ROLE_HIERARCHY[existing_role]: + partition_roles[partition] = role + break + + return partition_roles + + +def _sync_cache_key(user_id: int, groups: list[str]) -> str: + """Generate a cache key for group sync to avoid redundant DB operations.""" + groups_str = ",".join(sorted(groups)) + return hashlib.md5(f"{user_id}:{groups_str}".encode()).hexdigest() + + +# In-memory sync cache: {cache_key: expiry_timestamp} +_sync_cache: dict[str, float] = {} +_SYNC_CACHE_TTL = 60 # seconds + + +async def sync_user_memberships( + partition_file_manager, + user_id: int, + groups: list[str], + sync_mode: str | None = None, +) -> bool: + """Sync Keycloak groups to PartitionMembership records. + + Args: + partition_file_manager: The PartitionFileManager instance + user_id: OpenRAG user ID + groups: Raw group claims from JWT + sync_mode: Override for OIDC_GROUP_SYNC_MODE + + Returns: + True if sync was performed, False if cached/skipped + """ + mode = sync_mode or OIDC_GROUP_SYNC_MODE + + # Check cache + cache_key = _sync_cache_key(user_id, groups) + now = time.time() + if cache_key in _sync_cache and _sync_cache[cache_key] > now: + return False + + # Clean expired entries periodically + if len(_sync_cache) > 1000: + expired = [k for k, v in _sync_cache.items() if v <= now] + for k in expired: + del _sync_cache[k] + + desired_roles = parse_partition_roles(groups) + + if mode == "authoritative": + partition_file_manager.sync_oidc_memberships_authoritative(user_id, desired_roles) + else: + partition_file_manager.sync_oidc_memberships_additive(user_id, desired_roles) + + _sync_cache[cache_key] = now + _SYNC_CACHE_TTL + return True + + +def clear_sync_cache(): + """Clear the sync cache (useful for testing).""" + _sync_cache.clear() diff --git a/openrag/auth/test_group_sync.py b/openrag/auth/test_group_sync.py new file mode 100644 index 00000000..0dd3b8f6 --- /dev/null +++ b/openrag/auth/test_group_sync.py @@ -0,0 +1,185 @@ +"""Tests for OIDC group sync on PartitionFileManager. + +These tests use an in-memory SQLite database to test the actual +sync_oidc_memberships_additive and sync_oidc_memberships_authoritative methods. +""" + +import pytest +from components.indexer.vectordb.utils import ( + Base, + Partition, + PartitionFileManager, + PartitionMembership, + User, +) +from sqlalchemy import create_engine +from sqlalchemy.orm import sessionmaker + + +@pytest.fixture +def pfm(tmp_path): + """Create a PartitionFileManager with an in-memory SQLite DB.""" + db_url = f"sqlite:///{tmp_path}/test.db" + # Patch to avoid the postgres-specific database_exists check + engine = create_engine(db_url) + Base.metadata.create_all(engine) + Session = sessionmaker(bind=engine) + + pfm = object.__new__(PartitionFileManager) + pfm.engine = engine + pfm.Session = Session + pfm.logger = __import__("utils.logger", fromlist=["get_logger"]).get_logger() + pfm.file_quota_per_user = -1 + + # Create a test user + with Session() as s: + user = User(id=10, display_name="Test User", is_admin=False) + s.add(user) + s.commit() + + return pfm + + +@pytest.fixture +def pfm_with_partitions(pfm): + """PFM with pre-existing partitions.""" + with pfm.Session() as s: + s.add(Partition(partition="finance")) + s.add(Partition(partition="legal")) + s.add(Partition(partition="hr")) + s.commit() + return pfm + + +class TestAdditiveSync: + def test_creates_new_memberships(self, pfm_with_partitions): + pfm = pfm_with_partitions + pfm.sync_oidc_memberships_additive(10, {"finance": "viewer", "legal": "editor"}) + + with pfm.Session() as s: + memberships = s.query(PartitionMembership).filter_by(user_id=10).all() + by_partition = {m.partition_name: m for m in memberships} + + assert "finance" in by_partition + assert by_partition["finance"].role == "viewer" + assert by_partition["finance"].source == "oidc" + assert "legal" in by_partition + assert by_partition["legal"].role == "editor" + + def test_upgrades_role_never_downgrades(self, pfm_with_partitions): + pfm = pfm_with_partitions + + # First sync: viewer + pfm.sync_oidc_memberships_additive(10, {"finance": "viewer"}) + # Second sync: upgrade to editor + pfm.sync_oidc_memberships_additive(10, {"finance": "editor"}) + + with pfm.Session() as s: + m = s.query(PartitionMembership).filter_by(user_id=10, partition_name="finance").first() + assert m.role == "editor" + + # Third sync: try to downgrade to viewer — should NOT downgrade + pfm.sync_oidc_memberships_additive(10, {"finance": "viewer"}) + + with pfm.Session() as s: + m = s.query(PartitionMembership).filter_by(user_id=10, partition_name="finance").first() + assert m.role == "editor" # stays editor + + def test_does_not_remove_existing(self, pfm_with_partitions): + pfm = pfm_with_partitions + + pfm.sync_oidc_memberships_additive(10, {"finance": "viewer", "legal": "editor"}) + # Second sync only mentions finance + pfm.sync_oidc_memberships_additive(10, {"finance": "viewer"}) + + with pfm.Session() as s: + memberships = s.query(PartitionMembership).filter_by(user_id=10).all() + partitions = {m.partition_name for m in memberships} + + # legal should still be there + assert "legal" in partitions + assert "finance" in partitions + + def test_creates_partition_if_not_exists(self, pfm): + pfm.sync_oidc_memberships_additive(10, {"new-partition": "viewer"}) + + with pfm.Session() as s: + p = s.query(Partition).filter_by(partition="new-partition").first() + assert p is not None + m = s.query(PartitionMembership).filter_by(user_id=10, partition_name="new-partition").first() + assert m is not None + assert m.role == "viewer" + + +class TestAuthoritativeSync: + def test_creates_and_removes_oidc_memberships(self, pfm_with_partitions): + pfm = pfm_with_partitions + + # First sync + pfm.sync_oidc_memberships_authoritative(10, {"finance": "viewer", "legal": "editor"}) + with pfm.Session() as s: + memberships = s.query(PartitionMembership).filter_by(user_id=10).all() + assert len(memberships) == 2 + + # Second sync: remove legal, keep finance + pfm.sync_oidc_memberships_authoritative(10, {"finance": "editor"}) + with pfm.Session() as s: + memberships = s.query(PartitionMembership).filter_by(user_id=10).all() + by_partition = {m.partition_name: m for m in memberships} + + assert len(by_partition) == 1 + assert "finance" in by_partition + assert by_partition["finance"].role == "editor" + + def test_does_not_touch_manual_memberships(self, pfm_with_partitions): + pfm = pfm_with_partitions + + # Add a manual membership + with pfm.Session() as s: + s.add(PartitionMembership( + partition_name="hr", user_id=10, role="owner", source="manual" + )) + s.commit() + + # Authoritative sync with only finance + pfm.sync_oidc_memberships_authoritative(10, {"finance": "viewer"}) + + with pfm.Session() as s: + memberships = s.query(PartitionMembership).filter_by(user_id=10).all() + by_partition = {m.partition_name: m for m in memberships} + + # hr (manual) should still exist + assert "hr" in by_partition + assert by_partition["hr"].source == "manual" + assert by_partition["hr"].role == "owner" + # finance (oidc) should exist + assert "finance" in by_partition + assert by_partition["finance"].source == "oidc" + + def test_converts_manual_to_oidc_when_desired(self, pfm_with_partitions): + pfm = pfm_with_partitions + + # Add a manual membership + with pfm.Session() as s: + s.add(PartitionMembership( + partition_name="finance", user_id=10, role="viewer", source="manual" + )) + s.commit() + + # Authoritative sync claims finance as editor + pfm.sync_oidc_memberships_authoritative(10, {"finance": "editor"}) + + with pfm.Session() as s: + m = s.query(PartitionMembership).filter_by(user_id=10, partition_name="finance").first() + assert m.source == "oidc" + assert m.role == "editor" + + def test_empty_desired_removes_all_oidc(self, pfm_with_partitions): + pfm = pfm_with_partitions + + pfm.sync_oidc_memberships_authoritative(10, {"finance": "viewer", "legal": "editor"}) + pfm.sync_oidc_memberships_authoritative(10, {}) + + with pfm.Session() as s: + memberships = s.query(PartitionMembership).filter_by(user_id=10, source="oidc").all() + assert len(memberships) == 0 diff --git a/openrag/auth/test_oidc.py b/openrag/auth/test_oidc.py new file mode 100644 index 00000000..f67084a6 --- /dev/null +++ b/openrag/auth/test_oidc.py @@ -0,0 +1,215 @@ +"""Tests for OIDC JWT validation and group parsing.""" + +import time +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +from auth.oidc import ( + OIDCIdentity, + OIDCValidationError, + clear_jwks_cache, + clear_sync_cache, + parse_partition_roles, + sync_user_memberships, + validate_jwt, +) + + +# --- parse_partition_roles --- + + +class TestParsePartitionRoles: + def test_basic_viewer(self): + groups = ["/rag-query/finance"] + assert parse_partition_roles(groups) == {"finance": "viewer"} + + def test_basic_editor(self): + groups = ["rag-edit/legal"] + assert parse_partition_roles(groups) == {"legal": "editor"} + + def test_basic_owner(self): + groups = ["/rag-admin/hr"] + assert parse_partition_roles(groups) == {"hr": "owner"} + + def test_highest_role_wins(self): + groups = ["/rag-query/finance", "/rag-admin/finance", "/rag-edit/finance"] + assert parse_partition_roles(groups) == {"finance": "owner"} + + def test_multiple_partitions(self): + groups = ["/rag-query/finance", "/rag-edit/legal", "/rag-admin/hr"] + assert parse_partition_roles(groups) == { + "finance": "viewer", + "legal": "editor", + "hr": "owner", + } + + def test_no_matching_groups(self): + groups = ["/other-group", "random-group"] + assert parse_partition_roles(groups) == {} + + def test_empty_groups(self): + assert parse_partition_roles([]) == {} + + def test_no_leading_slash(self): + groups = ["rag-query/data"] + assert parse_partition_roles(groups) == {"data": "viewer"} + + def test_empty_partition_name_ignored(self): + groups = ["/rag-query/"] + assert parse_partition_roles(groups) == {} + + def test_mixed_valid_and_invalid(self): + groups = ["/rag-query/finance", "/other-group", "rag-edit/legal"] + assert parse_partition_roles(groups) == {"finance": "viewer", "legal": "editor"} + + +# --- validate_jwt --- + + +MOCK_JWKS = { + "keys": [ + { + "kid": "test-key-id", + "kty": "RSA", + "alg": "RS256", + "use": "sig", + "n": "test-n", + "e": "AQAB", + } + ] +} + + +class TestValidateJwt: + @pytest.fixture(autouse=True) + def setup(self): + clear_jwks_cache() + yield + clear_jwks_cache() + + @pytest.mark.asyncio + async def test_invalid_jwt_header(self): + with pytest.raises(OIDCValidationError, match="Invalid JWT header"): + await validate_jwt("not-a-jwt-token") + + @pytest.mark.asyncio + async def test_jwks_fetch_failure(self): + with patch("auth.oidc._fetch_jwks", side_effect=Exception("Network error")): + with pytest.raises(OIDCValidationError, match="Failed to fetch JWKS"): + await validate_jwt("eyJhbGciOiJSUzI1NiIsImtpZCI6InRlc3Qta2V5LWlkIn0.eyJ0ZXN0IjoidmFsdWUifQ.signature") + + @pytest.mark.asyncio + async def test_expired_token(self): + """Test that an expired token raises the appropriate error.""" + # This would require a proper JWT — testing the error path + with patch("auth.oidc._fetch_jwks", return_value=MOCK_JWKS): + with patch("jose.jwt.decode", side_effect=__import__("jose.exceptions", fromlist=["ExpiredSignatureError"]).ExpiredSignatureError("Token expired")): + with patch("jose.jwt.get_unverified_header", return_value={"kid": "test-key-id", "alg": "RS256"}): + with pytest.raises(OIDCValidationError, match="Token has expired"): + await validate_jwt("eyJhbGciOiJSUzI1NiIsImtpZCI6InRlc3Qta2V5LWlkIn0.eyJ0ZXN0IjoidmFsdWUifQ.signature") + + @pytest.mark.asyncio + async def test_valid_token_returns_identity(self): + mock_claims = { + "sub": "user-123", + "email": "user@example.com", + "preferred_username": "testuser", + "groups": ["/rag-query/finance", "/rag-edit/legal"], + } + with patch("auth.oidc._fetch_jwks", return_value=MOCK_JWKS): + with patch("jose.jwt.get_unverified_header", return_value={"kid": "test-key-id", "alg": "RS256"}): + with patch("jose.jwt.decode", return_value=mock_claims): + identity = await validate_jwt("fake.jwt.token") + + assert isinstance(identity, OIDCIdentity) + assert identity.sub == "user-123" + assert identity.email == "user@example.com" + assert identity.display_name == "testuser" + assert "/rag-query/finance" in identity.groups + assert "/rag-edit/legal" in identity.groups + + @pytest.mark.asyncio + async def test_missing_sub_claim(self): + mock_claims = {"email": "user@example.com"} + with patch("auth.oidc._fetch_jwks", return_value=MOCK_JWKS): + with patch("jose.jwt.get_unverified_header", return_value={"kid": "test-key-id", "alg": "RS256"}): + with patch("jose.jwt.decode", return_value=mock_claims): + with pytest.raises(OIDCValidationError, match="missing 'sub' claim"): + await validate_jwt("fake.jwt.token") + + @pytest.mark.asyncio + async def test_no_matching_key_retries_jwks(self): + """When kid doesn't match, should clear cache and retry.""" + empty_jwks = {"keys": []} + call_count = 0 + + async def mock_fetch(): + nonlocal call_count + call_count += 1 + return empty_jwks + + with patch("auth.oidc._fetch_jwks", side_effect=mock_fetch): + with patch("jose.jwt.get_unverified_header", return_value={"kid": "unknown-kid", "alg": "RS256"}): + with pytest.raises(OIDCValidationError, match="No matching key"): + await validate_jwt("fake.jwt.token") + + # Should have tried fetching twice (initial + retry) + assert call_count == 2 + + +# --- sync_user_memberships --- + + +class TestSyncUserMemberships: + @pytest.fixture(autouse=True) + def setup(self): + clear_sync_cache() + yield + clear_sync_cache() + + @pytest.mark.asyncio + async def test_additive_sync_calls_correct_method(self): + pfm = MagicMock() + pfm.sync_oidc_memberships_additive = MagicMock() + + groups = ["/rag-query/finance"] + result = await sync_user_memberships(pfm, user_id=1, groups=groups, sync_mode="additive") + + assert result is True + pfm.sync_oidc_memberships_additive.assert_called_once_with(1, {"finance": "viewer"}) + + @pytest.mark.asyncio + async def test_authoritative_sync_calls_correct_method(self): + pfm = MagicMock() + pfm.sync_oidc_memberships_authoritative = MagicMock() + + groups = ["/rag-admin/hr"] + result = await sync_user_memberships(pfm, user_id=2, groups=groups, sync_mode="authoritative") + + assert result is True + pfm.sync_oidc_memberships_authoritative.assert_called_once_with(2, {"hr": "owner"}) + + @pytest.mark.asyncio + async def test_cache_prevents_duplicate_sync(self): + pfm = MagicMock() + pfm.sync_oidc_memberships_additive = MagicMock() + + groups = ["/rag-query/finance"] + result1 = await sync_user_memberships(pfm, user_id=1, groups=groups, sync_mode="additive") + result2 = await sync_user_memberships(pfm, user_id=1, groups=groups, sync_mode="additive") + + assert result1 is True + assert result2 is False # cached + assert pfm.sync_oidc_memberships_additive.call_count == 1 + + @pytest.mark.asyncio + async def test_different_groups_not_cached(self): + pfm = MagicMock() + pfm.sync_oidc_memberships_additive = MagicMock() + + result1 = await sync_user_memberships(pfm, user_id=1, groups=["/rag-query/a"], sync_mode="additive") + result2 = await sync_user_memberships(pfm, user_id=1, groups=["/rag-query/b"], sync_mode="additive") + + assert result1 is True + assert result2 is True + assert pfm.sync_oidc_memberships_additive.call_count == 2 diff --git a/openrag/components/connectors/__init__.py b/openrag/components/connectors/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/openrag/components/connectors/drive.py b/openrag/components/connectors/drive.py new file mode 100644 index 00000000..4c5aeffd --- /dev/null +++ b/openrag/components/connectors/drive.py @@ -0,0 +1,385 @@ +"""Drive connector for Suite Numérique Drive integration. + +Provides: +- DriveClient: HTTP client for the Drive REST API +- DriveConnector: Sync logic (new/modified/deleted files) +- DriveSyncScheduler: Ray actor for periodic sync +""" + +import asyncio +import os +import tempfile +from dataclasses import dataclass +from datetime import datetime + +import httpx +import ray +from config import load_config +from utils.logger import get_logger + +logger = get_logger() +config = load_config() + +DRIVE_DEFAULT_BASE_URL = os.getenv("DRIVE_DEFAULT_BASE_URL", "") +DRIVE_SERVICE_ACCOUNT_CLIENT_ID = os.getenv("DRIVE_SERVICE_ACCOUNT_CLIENT_ID", "") +DRIVE_SERVICE_ACCOUNT_CLIENT_SECRET = os.getenv("DRIVE_SERVICE_ACCOUNT_CLIENT_SECRET", "") +OIDC_ISSUER_URL = os.getenv("OIDC_ISSUER_URL", "") + + +@dataclass +class DriveItem: + id: str + type: str # "FILE" or "FOLDER" + title: str + updated_at: str | None = None + + +class DriveClient: + """HTTP client for the Suite Numérique Drive API.""" + + def __init__(self, base_url: str, access_token: str): + self.base_url = base_url.rstrip("/") + self.access_token = access_token + self._client = httpx.AsyncClient( + base_url=f"{self.base_url}/api/v1.0", + headers={"Authorization": f"Bearer {access_token}"}, + timeout=30, + ) + + async def list_folder(self, folder_id: str, recursive: bool = True) -> list[DriveItem]: + """List all files in a folder, optionally recursively.""" + items = [] + page = 1 + while True: + resp = await self._client.get(f"/items/{folder_id}/children/", params={"page": page}) + resp.raise_for_status() + data = resp.json() + + results = data.get("results", data) if isinstance(data, dict) else data + if not results: + break + + for item_data in results: + item = DriveItem( + id=item_data["id"], + type=item_data.get("type", "FILE"), + title=item_data.get("title", ""), + updated_at=item_data.get("updated_at"), + ) + if item.type == "FILE": + items.append(item) + elif item.type == "FOLDER" and recursive: + sub_items = await self.list_folder(item.id, recursive=True) + items.extend(sub_items) + + # Check for pagination + next_url = data.get("next") if isinstance(data, dict) else None + if not next_url: + break + page += 1 + + return items + + async def download_file(self, item_id: str) -> tuple[bytes, str]: + """Download a file and return (content_bytes, filename).""" + resp = await self._client.get(f"/items/{item_id}/download/") + resp.raise_for_status() + + # Extract filename from Content-Disposition header + cd = resp.headers.get("content-disposition", "") + filename = "" + if "filename=" in cd: + filename = cd.split("filename=")[1].strip('"').strip("'") + if not filename: + filename = f"drive_{item_id}" + + return resp.content, filename + + async def get_item(self, item_id: str) -> DriveItem: + """Get metadata for a single item.""" + resp = await self._client.get(f"/items/{item_id}/") + resp.raise_for_status() + data = resp.json() + return DriveItem( + id=data["id"], + type=data.get("type", "FILE"), + title=data.get("title", ""), + updated_at=data.get("updated_at"), + ) + + async def close(self): + await self._client.aclose() + + +class DriveConnector: + """Synchronizes a Drive folder with an OpenRAG partition.""" + + async def get_access_token(self, source) -> str: + """Obtain an OIDC access token for the Drive API. + + Uses client_credentials grant for service account auth. + """ + client_id = source.service_account_client_id or DRIVE_SERVICE_ACCOUNT_CLIENT_ID + client_secret = source.service_account_client_secret or DRIVE_SERVICE_ACCOUNT_CLIENT_SECRET + issuer = OIDC_ISSUER_URL + + if not all([client_id, client_secret, issuer]): + raise ValueError("Missing OIDC credentials for Drive service account") + + token_url = f"{issuer.rstrip('/')}/protocol/openid-connect/token" + + async with httpx.AsyncClient(timeout=10) as client: + resp = await client.post( + token_url, + data={ + "grant_type": "client_credentials", + "client_id": client_id, + "client_secret": client_secret, + }, + ) + resp.raise_for_status() + return resp.json()["access_token"] + + async def sync_source(self, source, session_factory) -> dict: + """Synchronize a Drive source with its OpenRAG partition. + + Returns: {"added": int, "updated": int, "deleted": int, "errors": int} + """ + from components.indexer.vectordb.utils import DriveFileMapping + + log = logger.bind(source_id=source.id, partition=source.partition_name) + result = {"added": 0, "updated": 0, "deleted": 0, "errors": 0} + + try: + token = await self.get_access_token(source) + drive_client = DriveClient(source.drive_base_url, token) + + # List current files in Drive + drive_items = await drive_client.list_folder(source.drive_folder_id) + drive_items_by_id = {item.id: item for item in drive_items} + + # Load existing mappings + with session_factory() as s: + existing_mappings = ( + s.query(DriveFileMapping) + .filter_by(drive_source_id=source.id) + .all() + ) + existing_by_drive_id = {m.drive_item_id: m for m in existing_mappings} + + # Determine actions + drive_ids = set(drive_items_by_id.keys()) + mapped_ids = set(existing_by_drive_id.keys()) + + new_ids = drive_ids - mapped_ids + deleted_ids = mapped_ids - drive_ids + common_ids = drive_ids & mapped_ids + + # Check for updates (modified files) + updated_ids = set() + for item_id in common_ids: + item = drive_items_by_id[item_id] + mapping = existing_by_drive_id[item_id] + if item.updated_at and mapping.drive_item_updated_at: + item_dt = datetime.fromisoformat(item.updated_at.replace("Z", "+00:00")) + if item_dt > mapping.drive_item_updated_at: + updated_ids.add(item_id) + + indexer = ray.get_actor("Indexer", namespace="openrag") + vectordb = ray.get_actor("Vectordb", namespace="openrag") + + # Process new files + for item_id in new_ids: + item = drive_items_by_id[item_id] + try: + content, filename = await drive_client.download_file(item_id) + file_id = f"drive_{source.id}_{item_id}" + + # Save to temp file for indexer + with tempfile.NamedTemporaryFile(delete=False, suffix=f"_{filename}") as tmp: + tmp.write(content) + tmp_path = tmp.name + + metadata = { + "file_id": file_id, + "source": filename, + "drive_source_id": source.id, + "drive_item_id": item_id, + "drive_url": f"{source.drive_base_url}/items/{item_id}", + } + + await indexer.add_file.remote( + path=tmp_path, + metadata=metadata, + partition=source.partition_name, + ) + + # Record mapping + with session_factory() as s: + s.add(DriveFileMapping( + drive_source_id=source.id, + drive_item_id=item_id, + drive_item_title=item.title, + drive_item_updated_at=datetime.fromisoformat(item.updated_at.replace("Z", "+00:00")) if item.updated_at else None, + file_id=file_id, + partition_name=source.partition_name, + )) + s.commit() + + result["added"] += 1 + log.info("Added file from Drive", drive_item=item.title) + + except Exception as e: + log.warning("Failed to add Drive file", drive_item_id=item_id, error=str(e)) + result["errors"] += 1 + + # Process deleted files + for item_id in deleted_ids: + mapping = existing_by_drive_id[item_id] + try: + await vectordb.delete_file.remote(mapping.file_id, source.partition_name) + with session_factory() as s: + m = s.query(DriveFileMapping).filter_by(id=mapping.id).first() + if m: + s.delete(m) + s.commit() + result["deleted"] += 1 + log.info("Deleted file removed from Drive", file_id=mapping.file_id) + except Exception as e: + log.warning("Failed to delete file", file_id=mapping.file_id, error=str(e)) + result["errors"] += 1 + + # Process updated files (delete + re-add) + for item_id in updated_ids: + mapping = existing_by_drive_id[item_id] + item = drive_items_by_id[item_id] + try: + # Delete old + await vectordb.delete_file.remote(mapping.file_id, source.partition_name) + + # Re-download and re-index + content, filename = await drive_client.download_file(item_id) + with tempfile.NamedTemporaryFile(delete=False, suffix=f"_{filename}") as tmp: + tmp.write(content) + tmp_path = tmp.name + + metadata = { + "file_id": mapping.file_id, + "source": filename, + "drive_source_id": source.id, + "drive_item_id": item_id, + "drive_url": f"{source.drive_base_url}/items/{item_id}", + } + + await indexer.add_file.remote( + path=tmp_path, + metadata=metadata, + partition=source.partition_name, + ) + + # Update mapping + with session_factory() as s: + m = s.query(DriveFileMapping).filter_by(id=mapping.id).first() + if m: + m.drive_item_updated_at = datetime.fromisoformat(item.updated_at.replace("Z", "+00:00")) if item.updated_at else None + m.last_synced_at = datetime.now() + s.commit() + + result["updated"] += 1 + log.info("Updated file from Drive", drive_item=item.title) + + except Exception as e: + log.warning("Failed to update Drive file", drive_item_id=item_id, error=str(e)) + result["errors"] += 1 + + await drive_client.close() + + # Update source status + with session_factory() as s: + from components.indexer.vectordb.utils import DriveSource as DS + + src = s.query(DS).filter_by(id=source.id).first() + if src: + src.last_synced_at = datetime.now() + src.last_sync_status = "success" + src.last_sync_error = None + s.commit() + + except Exception as e: + log.error("Drive sync failed", error=str(e)) + with session_factory() as s: + from components.indexer.vectordb.utils import DriveSource as DS + + src = s.query(DS).filter_by(id=source.id).first() + if src: + src.last_synced_at = datetime.now() + src.last_sync_status = "failed" + src.last_sync_error = str(e) + s.commit() + + return result + + +@ray.remote +class DriveSyncScheduler: + """Ray actor that periodically syncs Drive sources.""" + + def __init__(self): + self.logger = get_logger() + self.connector = DriveConnector() + self._running = True + + async def run(self): + """Main loop: check for sources that need syncing.""" + self.logger.info("DriveSyncScheduler started") + while self._running: + try: + await self._check_and_sync() + except Exception as e: + self.logger.error("DriveSyncScheduler error", error=str(e)) + await asyncio.sleep(60) # Check every minute + + async def _check_and_sync(self): + from components.indexer.vectordb.utils import DriveSource + from utils.dependencies import get_vectordb + + vectordb = get_vectordb() + pfm = await vectordb.get_partition_file_manager.remote() + + with pfm.Session() as s: + sources = s.query(DriveSource).filter_by(sync_enabled=True).all() + sources_to_sync = [] + now = datetime.now() + for src in sources: + if src.last_synced_at is None: + sources_to_sync.append(src.id) + else: + from datetime import timedelta + + next_sync = src.last_synced_at + timedelta(minutes=src.sync_frequency_minutes) + if now >= next_sync: + sources_to_sync.append(src.id) + + for source_id in sources_to_sync: + await self.trigger_sync(source_id) + + async def trigger_sync(self, source_id: int): + """Trigger sync for a specific source.""" + from components.indexer.vectordb.utils import DriveSource + from utils.dependencies import get_vectordb + + vectordb = get_vectordb() + pfm = await vectordb.get_partition_file_manager.remote() + + with pfm.Session() as s: + source = s.query(DriveSource).filter_by(id=source_id).first() + if not source: + self.logger.warning("Drive source not found", source_id=source_id) + return + + self.logger.info("Starting Drive sync", source_id=source_id) + result = await self.connector.sync_source(source, pfm.Session) + self.logger.info("Drive sync completed", source_id=source_id, result=result) + + async def stop(self): + self._running = False diff --git a/openrag/components/eval.py b/openrag/components/eval.py new file mode 100644 index 00000000..efb11a41 --- /dev/null +++ b/openrag/components/eval.py @@ -0,0 +1,153 @@ +"""RAG evaluation engine. + +Runs Q&A entries through the RAG pipeline and scores responses +using semantic similarity and LLM-as-judge. +""" + +import asyncio +from datetime import datetime + +from components.indexer.vectordb.utils import QAEntry, QAEvalRun +from config import load_config +from openai import AsyncOpenAI +from utils.logger import get_logger + +logger = get_logger() +config = load_config() + +EVAL_JUDGE_PROMPT = """You are an evaluation judge. Compare the expected answer with the actual answer generated by a RAG system. + +Score the actual answer from 1 to 5: +- 5: Perfectly matches the expected answer in meaning and completeness +- 4: Mostly correct with minor differences +- 3: Partially correct, captures the main idea but misses details +- 2: Somewhat related but significantly different +- 1: Completely wrong or irrelevant + +Respond with ONLY a JSON object: +{"score": <1-5>, "reason": ""} +""" + + +async def run_evaluation(eval_run_id: int, session_factory): + """Execute an evaluation run asynchronously. + + Args: + eval_run_id: ID of the QAEvalRun record + session_factory: SQLAlchemy Session factory + """ + from components.pipeline import RagPipeline + + pipeline = RagPipeline() + llm_client = AsyncOpenAI(base_url=config.llm["base_url"], api_key=config.llm["api_key"]) + + with session_factory() as s: + run = s.query(QAEvalRun).filter_by(id=eval_run_id).first() + if not run: + logger.error("Eval run not found", eval_run_id=eval_run_id) + return + + run.status = "running" + s.commit() + + # Load matching Q&A entries + query = s.query(QAEntry).filter(QAEntry.partition_name == run.partition_name) + run_config = run.config_json or {} + tags = run_config.get("tags", ["eval"]) + for tag in tags: + query = query.filter(QAEntry.tags.contains([tag])) + entries = query.all() + + # Detach entries to avoid session issues + qa_data = [(e.id, e.question, e.expected_answer, e.partition_name) for e in entries] + + results = {"summary": {}, "details": []} + scores = [] + + for qa_id, question, expected_answer, partition_name in qa_data: + try: + # Run through RAG pipeline + payload = { + "messages": [{"role": "user", "content": question}], + "model": config.llm["model"], + } + prepared_payload, docs = await pipeline._prepare_for_chat_completion( + partition=[partition_name], payload=payload + ) + response = await pipeline.llm_client.chat_completion(prepared_payload) + actual_answer = response.choices[0].message.content if response.choices else "" + + # LLM-as-judge scoring + judge_score = 0 + judge_reason = "" + if expected_answer: + try: + judge_response = await llm_client.chat.completions.create( + model=config.llm["model"], + messages=[ + {"role": "system", "content": EVAL_JUDGE_PROMPT}, + { + "role": "user", + "content": f"Expected answer:\n{expected_answer}\n\nActual answer:\n{actual_answer}", + }, + ], + max_tokens=200, + ) + import json + + judge_text = judge_response.choices[0].message.content + judge_data = json.loads(judge_text) + judge_score = judge_data.get("score", 0) + judge_reason = judge_data.get("reason", "") + except Exception as e: + logger.warning("LLM judge failed", error=str(e)) + judge_score = 0 + judge_reason = f"Judge error: {e}" + + scores.append(judge_score) + results["details"].append({ + "qa_id": qa_id, + "question": question, + "expected": expected_answer, + "actual": actual_answer, + "judge_score": judge_score, + "judge_reason": judge_reason, + }) + + except Exception as e: + logger.warning("Eval failed for question", qa_id=qa_id, error=str(e)) + results["details"].append({ + "qa_id": qa_id, + "question": question, + "expected": expected_answer, + "actual": "", + "judge_score": 0, + "judge_reason": f"Error: {e}", + }) + scores.append(0) + + # Update progress + with session_factory() as s: + run = s.query(QAEvalRun).filter_by(id=eval_run_id).first() + if run: + run.completed_questions = len(results["details"]) + s.commit() + + # Finalize + valid_scores = [sc for sc in scores if sc > 0] + results["summary"] = { + "avg_score": round(sum(valid_scores) / len(valid_scores), 2) if valid_scores else 0, + "pass_rate": round(sum(1 for sc in valid_scores if sc >= 3) / len(valid_scores), 2) if valid_scores else 0, + "total": len(scores), + "evaluated": len(valid_scores), + } + + with session_factory() as s: + run = s.query(QAEvalRun).filter_by(id=eval_run_id).first() + if run: + run.status = "completed" + run.completed_at = datetime.now() + run.results = results + s.commit() + + logger.info("Eval run completed", eval_run_id=eval_run_id, avg_score=results["summary"]["avg_score"]) diff --git a/openrag/components/notifications/__init__.py b/openrag/components/notifications/__init__.py new file mode 100644 index 00000000..301a087b --- /dev/null +++ b/openrag/components/notifications/__init__.py @@ -0,0 +1,17 @@ +from .base import BaseDispatcher, DispatchResult +from .email import EmailDispatcher +from .tchap import TchapDispatcher +from .webhook import WebhookDispatcher + + +def get_dispatcher(channel) -> BaseDispatcher: + """Factory to get the appropriate dispatcher for a notification channel.""" + dispatchers = { + "webhook": WebhookDispatcher, + "email_smtp": EmailDispatcher, + "tchap_bot": TchapDispatcher, + } + dispatcher_cls = dispatchers.get(channel.type) + if not dispatcher_cls: + raise ValueError(f"Unknown channel type: {channel.type}") + return dispatcher_cls(channel.config_json) diff --git a/openrag/components/notifications/base.py b/openrag/components/notifications/base.py new file mode 100644 index 00000000..de6c6136 --- /dev/null +++ b/openrag/components/notifications/base.py @@ -0,0 +1,39 @@ +"""Base class for notification dispatchers.""" + +from abc import ABC, abstractmethod +from dataclasses import dataclass + + +@dataclass +class DispatchResult: + success: bool + message: str = "" + error: str | None = None + + +class BaseDispatcher(ABC): + """Abstract base for notification dispatchers.""" + + def __init__(self, config: dict): + self.config = config + + @abstractmethod + async def send(self, title: str, body: str, url: str | None = None) -> DispatchResult: + """Send a notification. + + Args: + title: Announcement title + body: Announcement body (markdown) + url: Optional link (e.g., poll voting page) + + Returns: + DispatchResult indicating success or failure + """ + ... + + async def send_test(self) -> DispatchResult: + """Send a test notification to verify configuration.""" + return await self.send( + title="OpenRAG Test Notification", + body="This is a test notification from OpenRAG. If you received this, your notification channel is configured correctly.", + ) diff --git a/openrag/components/notifications/email.py b/openrag/components/notifications/email.py new file mode 100644 index 00000000..fe6c42a2 --- /dev/null +++ b/openrag/components/notifications/email.py @@ -0,0 +1,61 @@ +"""Email (SMTP) notification dispatcher.""" + +import smtplib +from email.mime.multipart import MIMEMultipart +from email.mime.text import MIMEText + +from utils.logger import get_logger + +from .base import BaseDispatcher, DispatchResult + +logger = get_logger() + + +class EmailDispatcher(BaseDispatcher): + """Sends notifications via SMTP email. + + Config: + host: SMTP server hostname + port: SMTP port (default 587) + username: SMTP username + password: SMTP password + from_addr: Sender email address + to_addrs: List of recipient email addresses (for broadcast) + use_tls: Whether to use STARTTLS (default True) + """ + + async def send(self, title: str, body: str, url: str | None = None) -> DispatchResult: + host = self.config.get("host") + port = self.config.get("port", 587) + username = self.config.get("username", "") + password = self.config.get("password", "") + from_addr = self.config.get("from_addr", username) + to_addrs = self.config.get("to_addrs", []) + use_tls = self.config.get("use_tls", True) + + if not host or not to_addrs: + return DispatchResult(success=False, error="Missing SMTP host or recipients") + + html_body = f"

{title}

{body}

" + if url: + html_body += f'

Open

' + + try: + msg = MIMEMultipart("alternative") + msg["Subject"] = title + msg["From"] = from_addr + msg["To"] = ", ".join(to_addrs) + msg.attach(MIMEText(body, "plain")) + msg.attach(MIMEText(html_body, "html")) + + with smtplib.SMTP(host, port) as server: + if use_tls: + server.starttls() + if username and password: + server.login(username, password) + server.sendmail(from_addr, to_addrs, msg.as_string()) + + return DispatchResult(success=True, message=f"Email sent to {len(to_addrs)} recipients") + except Exception as e: + logger.warning("Email dispatch failed", error=str(e)) + return DispatchResult(success=False, error=str(e)) diff --git a/openrag/components/notifications/tchap.py b/openrag/components/notifications/tchap.py new file mode 100644 index 00000000..370dd139 --- /dev/null +++ b/openrag/components/notifications/tchap.py @@ -0,0 +1,59 @@ +"""Tchap (Matrix) notification dispatcher.""" + +import httpx +import uuid +from utils.logger import get_logger + +from .base import BaseDispatcher, DispatchResult + +logger = get_logger() + + +class TchapDispatcher(BaseDispatcher): + """Sends notifications to a Tchap/Matrix room. + + Uses the Matrix client-server API to send messages. + + Config: + homeserver: Matrix homeserver URL (e.g., https://matrix.agent.tchap.gouv.fr) + room_id: Target room ID (e.g., !abc:agent.tchap.gouv.fr) + access_token: Bot access token + """ + + async def send(self, title: str, body: str, url: str | None = None) -> DispatchResult: + homeserver = self.config.get("homeserver", "").rstrip("/") + room_id = self.config.get("room_id") + access_token = self.config.get("access_token") + + if not all([homeserver, room_id, access_token]): + return DispatchResult(success=False, error="Missing homeserver, room_id, or access_token") + + # Format message + plain_text = f"{title}\n\n{body}" + html_body = f"{title}

{body}" + if url: + plain_text += f"\n\n{url}" + html_body += f'

Ouvrir' + + txn_id = str(uuid.uuid4()) + send_url = f"{homeserver}/_matrix/client/v3/rooms/{room_id}/send/m.room.message/{txn_id}" + + payload = { + "msgtype": "m.text", + "body": plain_text, + "format": "org.matrix.custom.html", + "formatted_body": html_body, + } + + try: + async with httpx.AsyncClient(timeout=10) as client: + resp = await client.put( + send_url, + json=payload, + headers={"Authorization": f"Bearer {access_token}"}, + ) + resp.raise_for_status() + return DispatchResult(success=True, message=f"Sent to room {room_id}") + except Exception as e: + logger.warning("Tchap dispatch failed", room_id=room_id, error=str(e)) + return DispatchResult(success=False, error=str(e)) diff --git a/openrag/components/notifications/webhook.py b/openrag/components/notifications/webhook.py new file mode 100644 index 00000000..b2531512 --- /dev/null +++ b/openrag/components/notifications/webhook.py @@ -0,0 +1,48 @@ +"""Webhook notification dispatcher.""" + +import httpx +from utils.logger import get_logger + +from .base import BaseDispatcher, DispatchResult + +logger = get_logger() + + +class WebhookDispatcher(BaseDispatcher): + """Sends notifications via HTTP webhook POST. + + Config: + url: Webhook URL + headers: Optional dict of HTTP headers + template: "markdown" (default), "json", or "html" + """ + + async def send(self, title: str, body: str, url: str | None = None) -> DispatchResult: + webhook_url = self.config.get("url") + if not webhook_url: + return DispatchResult(success=False, error="No webhook URL configured") + + headers = self.config.get("headers", {}) + template = self.config.get("template", "markdown") + + if template == "json": + payload = {"title": title, "body": body, "url": url} + elif template == "html": + html_body = f"

{title}

{body}

" + if url: + html_body += f'

Open

' + payload = {"html": html_body} + else: # markdown + md = f"**{title}**\n\n{body}" + if url: + md += f"\n\n[Open]({url})" + payload = {"text": md} + + try: + async with httpx.AsyncClient(timeout=10) as client: + resp = await client.post(webhook_url, json=payload, headers=headers) + resp.raise_for_status() + return DispatchResult(success=True, message=f"Sent to {webhook_url}") + except Exception as e: + logger.warning("Webhook dispatch failed", url=webhook_url, error=str(e)) + return DispatchResult(success=False, error=str(e)) diff --git a/openrag/components/qa_override.py b/openrag/components/qa_override.py new file mode 100644 index 00000000..cee38308 --- /dev/null +++ b/openrag/components/qa_override.py @@ -0,0 +1,102 @@ +"""Q&A override check for the RAG pipeline. + +Before running the full RAG pipeline, checks if the user's question +matches an active Q&A override entry. If so, returns the override +answer directly, bypassing retrieval and generation. + +Uses semantic similarity via the embedder to find matching questions. +""" + +import os + +from config import load_config +from openai import AsyncOpenAI +from utils.logger import get_logger + +logger = get_logger() +config = load_config() + +QA_OVERRIDE_THRESHOLD = float(os.getenv("QA_OVERRIDE_THRESHOLD", "0.92")) +QA_OVERRIDE_ENABLED = os.getenv("QA_OVERRIDE_ENABLED", "true").lower() == "true" + + +async def check_qa_override(question: str, partitions: list[str] | None = None) -> dict | None: + """Check if a question matches an active Q&A override. + + Args: + question: The user's question + partitions: Optional list of partition names to filter overrides + + Returns: + Dict with override info if match found, None otherwise. + {"qa_id": int, "answer": str, "similarity": float} + """ + if not QA_OVERRIDE_ENABLED: + return None + + try: + from components.indexer.vectordb.utils import QAEntry + from utils.dependencies import get_vectordb + + vectordb = get_vectordb() + pfm = await vectordb.get_partition_file_manager.remote() + + with pfm.Session() as s: + query = s.query(QAEntry).filter( + QAEntry.override_active.is_(True), + QAEntry.override_answer.isnot(None), + ) + if partitions: + query = query.filter(QAEntry.partition_name.in_(partitions)) + overrides = query.all() + + if not overrides: + return None + + # Use the embedder to compute similarity + embedder_client = AsyncOpenAI( + base_url=config.embedder["base_url"], + api_key=config.embedder.get("api_key", "EMPTY"), + ) + + # Embed the question + response = await embedder_client.embeddings.create( + model=config.embedder["model_name"], + input=[question] + [o.question for o in overrides], + ) + + question_embedding = response.data[0].embedding + + best_match = None + best_similarity = 0.0 + + for i, override in enumerate(overrides): + override_embedding = response.data[i + 1].embedding + # Cosine similarity + dot_product = sum(a * b for a, b in zip(question_embedding, override_embedding)) + norm_q = sum(a * a for a in question_embedding) ** 0.5 + norm_o = sum(a * a for a in override_embedding) ** 0.5 + similarity = dot_product / (norm_q * norm_o) if (norm_q * norm_o) > 0 else 0 + + if similarity > best_similarity: + best_similarity = similarity + best_match = override + + if best_match and best_similarity >= QA_OVERRIDE_THRESHOLD: + logger.info( + "Q&A override matched", + qa_id=best_match.id, + similarity=round(best_similarity, 4), + question=question[:100], + ) + return { + "qa_id": best_match.id, + "answer": best_match.override_answer, + "similarity": round(best_similarity, 4), + } + + return None + + except Exception as e: + logger.warning("Q&A override check failed, falling through to RAG", error=str(e)) + return None diff --git a/openrag/routers/admin.py b/openrag/routers/admin.py new file mode 100644 index 00000000..df9214d0 --- /dev/null +++ b/openrag/routers/admin.py @@ -0,0 +1,990 @@ +"""Admin router for OpenRAG. + +Provides endpoints for: +- Indexing profiles (CRUD + partition assignment) +- Q&A entries (CRUD, import/export, evaluation) +- User feedback (ingestion, review, promotion) +- Drive sources (CRUD + sync triggers) +- Notification channels (CRUD + test) +- Announcements & polls (CRUD + send + vote) +""" + +import os +from datetime import datetime + +import consts +from components.indexer.vectordb.utils import ( + Announcement, + DriveFileMapping, + DriveSource, + IndexingProfile, + NotificationChannel, + Partition, + PartitionIndexingConfig, + PollOption, + PollResponse, + QAEntry, + QAEvalRun, + UserFeedback, +) +from fastapi import APIRouter, Depends, HTTPException, Request, status +from pydantic import BaseModel, Field +from routers.utils import require_admin +from utils.dependencies import get_vectordb +from utils.logger import get_logger + +logger = get_logger() +router = APIRouter() + +FEEDBACK_SERVICE_KEY = os.getenv("FEEDBACK_SERVICE_KEY", "") + + +# --- Pydantic models --- + + +class IndexingProfileCreate(BaseModel): + name: str + description: str | None = None + chunker_name: str = "recursive_splitter" + chunk_size: int = 512 + chunk_overlap_rate: float = 0.2 + contextual_retrieval: bool = True + contextualization_timeout: int = 120 + max_concurrent_contextualization: int = 10 + retriever_type: str = "single" + retriever_top_k: int = 50 + similarity_threshold: float = 0.6 + extra_params: dict = Field(default_factory=dict) + + +class PartitionIndexingAssign(BaseModel): + indexing_profile_id: int + overrides: dict = Field(default_factory=dict) + + +class QAEntryCreate(BaseModel): + partition_name: str | None = None + question: str + expected_answer: str | None = None + override_answer: str | None = None + override_active: bool = False + tags: list[str] = Field(default_factory=list) + + +class QAEvalRequest(BaseModel): + partition_name: str + tags: list[str] = Field(default_factory=lambda: ["eval"]) + config: dict = Field(default_factory=dict) + + +class FeedbackIngestItem(BaseModel): + external_user_id: str | None = None + question: str + response: str + model: str | None = None + rating: int = Field(ge=-1, le=1) + reason: str | None = None + owui_chat_id: str | None = None + owui_message_id: str | None = None + + +class FeedbackIngestRequest(BaseModel): + feedbacks: list[FeedbackIngestItem] + + +class FeedbackPromoteRequest(BaseModel): + type: str = Field(pattern="^(override|eval)$") + override_answer: str | None = None + expected_answer: str | None = None + tags: list[str] = Field(default_factory=list) + activate_override: bool = False + + +class DriveSourceCreate(BaseModel): + partition_name: str + drive_base_url: str + drive_folder_id: str + sync_frequency_minutes: int = 60 + auth_mode: str = "service_account" + service_account_client_id: str | None = None + service_account_client_secret: str | None = None + + +class ChannelCreate(BaseModel): + name: str + type: str = Field(pattern="^(webhook|email_smtp|tchap_bot)$") + config: dict + active: bool = True + + +class AnnouncementCreate(BaseModel): + type: str = Field(pattern="^(announcement|poll)$") + title: str + body: str + target_type: str = Field(pattern="^(all|partition|group|user)$") + target_value: str | None = None + scheduled_at: datetime | None = None + channels: list[int] = Field(default_factory=list) + poll_options: list[str] = Field(default_factory=list) + + +class PollVoteRequest(BaseModel): + poll_option_id: int + + +# --- Helper to get a DB session from the vectordb actor --- + + +async def get_session(): + """Get a SQLAlchemy session from the vectordb actor's partition file manager.""" + vectordb = get_vectordb() + pfm = await vectordb.get_partition_file_manager.remote() + return pfm.Session + + +# ============================================================ +# INDEXING PROFILES +# ============================================================ + + +@router.get("/indexing-profiles") +async def list_indexing_profiles(user=Depends(require_admin)): + Session = await get_session() + with Session() as s: + profiles = s.query(IndexingProfile).all() + return [p.to_dict() for p in profiles] + + +@router.post("/indexing-profiles", status_code=201) +async def create_indexing_profile(body: IndexingProfileCreate, user=Depends(require_admin)): + Session = await get_session() + with Session() as s: + existing = s.query(IndexingProfile).filter_by(name=body.name).first() + if existing: + raise HTTPException(status_code=409, detail=f"Profile '{body.name}' already exists") + profile = IndexingProfile( + name=body.name, + description=body.description, + chunker_name=body.chunker_name, + chunk_size=body.chunk_size, + chunk_overlap_rate=str(body.chunk_overlap_rate), + contextual_retrieval=body.contextual_retrieval, + contextualization_timeout=body.contextualization_timeout, + max_concurrent_contextualization=body.max_concurrent_contextualization, + retriever_type=body.retriever_type, + retriever_top_k=body.retriever_top_k, + similarity_threshold=str(body.similarity_threshold), + extra_params=body.extra_params, + ) + s.add(profile) + s.commit() + s.refresh(profile) + return profile.to_dict() + + +@router.get("/indexing-profiles/{profile_id}") +async def get_indexing_profile(profile_id: int, user=Depends(require_admin)): + Session = await get_session() + with Session() as s: + profile = s.query(IndexingProfile).filter_by(id=profile_id).first() + if not profile: + raise HTTPException(status_code=404, detail="Profile not found") + return profile.to_dict() + + +@router.put("/indexing-profiles/{profile_id}") +async def update_indexing_profile(profile_id: int, body: IndexingProfileCreate, user=Depends(require_admin)): + Session = await get_session() + with Session() as s: + profile = s.query(IndexingProfile).filter_by(id=profile_id).first() + if not profile: + raise HTTPException(status_code=404, detail="Profile not found") + for field_name in body.model_fields: + value = getattr(body, field_name) + if field_name in ("chunk_overlap_rate", "similarity_threshold"): + value = str(value) + setattr(profile, field_name, value) + profile.updated_at = datetime.now() + s.commit() + s.refresh(profile) + return profile.to_dict() + + +@router.delete("/indexing-profiles/{profile_id}") +async def delete_indexing_profile(profile_id: int, user=Depends(require_admin)): + Session = await get_session() + with Session() as s: + # Check if in use + in_use = s.query(PartitionIndexingConfig).filter_by(indexing_profile_id=profile_id).first() + if in_use: + raise HTTPException(status_code=409, detail="Profile is in use by a partition and cannot be deleted") + profile = s.query(IndexingProfile).filter_by(id=profile_id).first() + if not profile: + raise HTTPException(status_code=404, detail="Profile not found") + s.delete(profile) + s.commit() + return {"detail": "Profile deleted"} + + +@router.get("/partitions/{partition_name}/indexing") +async def get_partition_indexing(partition_name: str, user=Depends(require_admin)): + Session = await get_session() + with Session() as s: + config = s.query(PartitionIndexingConfig).filter_by(partition_name=partition_name).first() + if not config: + return {"partition_name": partition_name, "indexing_profile_id": None, "overrides": {}} + return { + "partition_name": config.partition_name, + "indexing_profile_id": config.indexing_profile_id, + "overrides": config.overrides, + "profile": config.profile.to_dict() if config.profile else None, + } + + +@router.put("/partitions/{partition_name}/indexing") +async def set_partition_indexing(partition_name: str, body: PartitionIndexingAssign, user=Depends(require_admin)): + Session = await get_session() + with Session() as s: + # Verify profile exists + profile = s.query(IndexingProfile).filter_by(id=body.indexing_profile_id).first() + if not profile: + raise HTTPException(status_code=404, detail="Indexing profile not found") + + config = s.query(PartitionIndexingConfig).filter_by(partition_name=partition_name).first() + if config: + config.indexing_profile_id = body.indexing_profile_id + config.overrides = body.overrides + else: + config = PartitionIndexingConfig( + partition_name=partition_name, + indexing_profile_id=body.indexing_profile_id, + overrides=body.overrides, + ) + s.add(config) + s.commit() + return {"detail": "Partition indexing config updated"} + + +# ============================================================ +# Q&A ENTRIES +# ============================================================ + + +@router.get("/qa") +async def list_qa_entries( + partition: str | None = None, + tags: str | None = None, + override_active: bool | None = None, + page: int = 1, + per_page: int = 50, + user=Depends(require_admin), +): + Session = await get_session() + with Session() as s: + query = s.query(QAEntry) + if partition: + query = query.filter(QAEntry.partition_name == partition) + if override_active is not None: + query = query.filter(QAEntry.override_active == override_active) + # tags filter: match entries that contain any of the requested tags + if tags: + tag_list = [t.strip() for t in tags.split(",")] + for tag in tag_list: + query = query.filter(QAEntry.tags.contains([tag])) + + total = query.count() + entries = query.offset((page - 1) * per_page).limit(per_page).all() + return { + "total": total, + "page": page, + "per_page": per_page, + "entries": [e.to_dict() for e in entries], + } + + +@router.post("/qa", status_code=201) +async def create_qa_entry(body: QAEntryCreate, user=Depends(require_admin)): + Session = await get_session() + with Session() as s: + entry = QAEntry( + partition_name=body.partition_name, + question=body.question, + expected_answer=body.expected_answer, + override_answer=body.override_answer, + override_active=body.override_active, + tags=body.tags, + created_by=user.get("id"), + ) + s.add(entry) + s.commit() + s.refresh(entry) + return entry.to_dict() + + +@router.get("/qa/{qa_id}") +async def get_qa_entry(qa_id: int, user=Depends(require_admin)): + Session = await get_session() + with Session() as s: + entry = s.query(QAEntry).filter_by(id=qa_id).first() + if not entry: + raise HTTPException(status_code=404, detail="Q&A entry not found") + return entry.to_dict() + + +@router.put("/qa/{qa_id}") +async def update_qa_entry(qa_id: int, body: QAEntryCreate, user=Depends(require_admin)): + Session = await get_session() + with Session() as s: + entry = s.query(QAEntry).filter_by(id=qa_id).first() + if not entry: + raise HTTPException(status_code=404, detail="Q&A entry not found") + entry.partition_name = body.partition_name + entry.question = body.question + entry.expected_answer = body.expected_answer + entry.override_answer = body.override_answer + entry.override_active = body.override_active + entry.tags = body.tags + entry.updated_at = datetime.now() + s.commit() + s.refresh(entry) + return entry.to_dict() + + +@router.delete("/qa/{qa_id}") +async def delete_qa_entry(qa_id: int, user=Depends(require_admin)): + Session = await get_session() + with Session() as s: + entry = s.query(QAEntry).filter_by(id=qa_id).first() + if not entry: + raise HTTPException(status_code=404, detail="Q&A entry not found") + s.delete(entry) + s.commit() + return {"detail": "Q&A entry deleted"} + + +@router.post("/qa/import", status_code=201) +async def import_qa_entries(entries: list[QAEntryCreate], user=Depends(require_admin)): + Session = await get_session() + created = 0 + with Session() as s: + for body in entries: + entry = QAEntry( + partition_name=body.partition_name, + question=body.question, + expected_answer=body.expected_answer, + override_answer=body.override_answer, + override_active=body.override_active, + tags=body.tags, + created_by=user.get("id"), + ) + s.add(entry) + created += 1 + s.commit() + return {"detail": f"Imported {created} Q&A entries"} + + +@router.get("/qa/export") +async def export_qa_entries(partition: str | None = None, user=Depends(require_admin)): + Session = await get_session() + with Session() as s: + query = s.query(QAEntry) + if partition: + query = query.filter(QAEntry.partition_name == partition) + entries = query.all() + return [e.to_dict() for e in entries] + + +# --- Evaluation --- + + +@router.post("/qa/eval", status_code=201) +async def start_eval_run(body: QAEvalRequest, user=Depends(require_admin)): + Session = await get_session() + with Session() as s: + # Count matching questions + query = s.query(QAEntry).filter(QAEntry.partition_name == body.partition_name) + for tag in body.tags: + query = query.filter(QAEntry.tags.contains([tag])) + total = query.count() + + if total == 0: + raise HTTPException(status_code=404, detail="No Q&A entries match the filters") + + run = QAEvalRun( + partition_name=body.partition_name, + status="pending", + total_questions=total, + config_json=body.config, + created_by=user.get("id"), + ) + s.add(run) + s.commit() + s.refresh(run) + + # TODO: Launch async evaluation task via Ray + # For now, return the run ID for polling + return run.to_dict() + + +@router.get("/qa/eval/runs") +async def list_eval_runs(user=Depends(require_admin)): + Session = await get_session() + with Session() as s: + runs = s.query(QAEvalRun).order_by(QAEvalRun.started_at.desc()).all() + return [r.to_dict() for r in runs] + + +@router.get("/qa/eval/runs/{run_id}") +async def get_eval_run(run_id: int, user=Depends(require_admin)): + Session = await get_session() + with Session() as s: + run = s.query(QAEvalRun).filter_by(id=run_id).first() + if not run: + raise HTTPException(status_code=404, detail="Eval run not found") + return run.to_dict() + + +# ============================================================ +# USER FEEDBACK +# ============================================================ + + +def _require_service_key(request: Request): + """Verify the service key for feedback ingestion.""" + if not FEEDBACK_SERVICE_KEY: + return # No key configured = open + key = request.headers.get("x-service-key", "") + if key != FEEDBACK_SERVICE_KEY: + raise HTTPException(status_code=403, detail="Invalid service key") + + +def _partition_from_model(model: str | None) -> str | None: + """Extract partition name from model string like 'openrag-finance'.""" + if not model: + return None + prefix = consts.PARTITION_PREFIX + if model.startswith(prefix): + partition = model[len(prefix):] + return None if partition == "all" else partition + legacy = consts.LEGACY_PARTITION_PREFIX + if model.startswith(legacy): + partition = model[len(legacy):] + return None if partition == "all" else partition + return None + + +@router.post("/feedback/ingest") +async def ingest_feedback(body: FeedbackIngestRequest, request: Request): + _require_service_key(request) + Session = await get_session() + ingested = 0 + skipped = 0 + with Session() as s: + for fb in body.feedbacks: + # Deduplicate + if fb.owui_chat_id and fb.owui_message_id: + existing = s.query(UserFeedback).filter_by( + owui_chat_id=fb.owui_chat_id, + owui_message_id=fb.owui_message_id, + ).first() + if existing: + skipped += 1 + continue + + partition = _partition_from_model(fb.model) + entry = UserFeedback( + external_user_id=fb.external_user_id, + partition_name=partition, + question=fb.question, + response=fb.response, + model=fb.model, + rating=fb.rating, + reason=fb.reason, + owui_chat_id=fb.owui_chat_id, + owui_message_id=fb.owui_message_id, + ) + s.add(entry) + ingested += 1 + s.commit() + return {"ingested": ingested, "skipped": skipped} + + +@router.get("/feedback") +async def list_feedback( + partition: str | None = None, + rating: int | None = None, + feedback_status: str | None = None, + page: int = 1, + per_page: int = 50, + user=Depends(require_admin), +): + Session = await get_session() + with Session() as s: + query = s.query(UserFeedback).order_by(UserFeedback.created_at.desc()) + if partition: + query = query.filter(UserFeedback.partition_name == partition) + if rating is not None: + query = query.filter(UserFeedback.rating == rating) + if feedback_status: + query = query.filter(UserFeedback.status == feedback_status) + total = query.count() + entries = query.offset((page - 1) * per_page).limit(per_page).all() + return { + "total": total, + "page": page, + "per_page": per_page, + "entries": [e.to_dict() for e in entries], + } + + +@router.get("/feedback/stats") +async def feedback_stats(user=Depends(require_admin)): + Session = await get_session() + with Session() as s: + from sqlalchemy import case, func + + total = s.query(func.count(UserFeedback.id)).scalar() + positive = s.query(func.count(UserFeedback.id)).filter(UserFeedback.rating == 1).scalar() + negative = s.query(func.count(UserFeedback.id)).filter(UserFeedback.rating == -1).scalar() + pending = s.query(func.count(UserFeedback.id)).filter(UserFeedback.status == "pending").scalar() + + # Per partition + by_partition = ( + s.query( + UserFeedback.partition_name, + func.count(UserFeedback.id).label("total"), + func.count(case((UserFeedback.rating == 1, 1))).label("positive"), + func.count(case((UserFeedback.rating == -1, 1))).label("negative"), + ) + .group_by(UserFeedback.partition_name) + .all() + ) + + return { + "global": { + "total": total, + "positive": positive, + "negative": negative, + "satisfaction_rate": round(positive / total, 2) if total else 0, + }, + "by_partition": [ + { + "partition": row.partition_name, + "total": row.total, + "positive": row.positive, + "negative": row.negative, + "rate": round(row.positive / row.total, 2) if row.total else 0, + } + for row in by_partition + ], + "pending_review": pending, + } + + +@router.patch("/feedback/{feedback_id}") +async def review_feedback(feedback_id: int, feedback_status: str, user=Depends(require_admin)): + if feedback_status not in ("reviewed", "dismissed"): + raise HTTPException(status_code=400, detail="Status must be 'reviewed' or 'dismissed'") + Session = await get_session() + with Session() as s: + fb = s.query(UserFeedback).filter_by(id=feedback_id).first() + if not fb: + raise HTTPException(status_code=404, detail="Feedback not found") + fb.status = feedback_status + fb.reviewed_by = user.get("id") + fb.reviewed_at = datetime.now() + s.commit() + return {"detail": "Feedback updated"} + + +@router.post("/feedback/{feedback_id}/promote") +async def promote_feedback(feedback_id: int, body: FeedbackPromoteRequest, user=Depends(require_admin)): + Session = await get_session() + with Session() as s: + fb = s.query(UserFeedback).filter_by(id=feedback_id).first() + if not fb: + raise HTTPException(status_code=404, detail="Feedback not found") + + qa = QAEntry( + partition_name=fb.partition_name, + question=fb.question, + tags=body.tags, + source_feedback_id=fb.id, + created_by=user.get("id"), + ) + + if body.type == "override": + qa.override_answer = body.override_answer or fb.response + qa.override_active = body.activate_override + else: # eval + qa.expected_answer = body.expected_answer or fb.response + + s.add(qa) + s.flush() + + fb.status = "promoted" + fb.promoted_to_qa_id = qa.id + fb.reviewed_by = user.get("id") + fb.reviewed_at = datetime.now() + s.commit() + s.refresh(qa) + return qa.to_dict() + + +# ============================================================ +# DRIVE SOURCES +# ============================================================ + + +@router.get("/drive-sources") +async def list_drive_sources(user=Depends(require_admin)): + Session = await get_session() + with Session() as s: + sources = s.query(DriveSource).all() + return [src.to_dict() for src in sources] + + +@router.post("/drive-sources", status_code=201) +async def create_drive_source(body: DriveSourceCreate, user=Depends(require_admin)): + Session = await get_session() + with Session() as s: + source = DriveSource( + partition_name=body.partition_name, + drive_base_url=body.drive_base_url, + drive_folder_id=body.drive_folder_id, + sync_frequency_minutes=body.sync_frequency_minutes, + auth_mode=body.auth_mode, + service_account_client_id=body.service_account_client_id, + service_account_client_secret=body.service_account_client_secret, + created_by=user.get("id"), + ) + s.add(source) + s.commit() + s.refresh(source) + return source.to_dict() + + +@router.get("/drive-sources/{source_id}") +async def get_drive_source(source_id: int, user=Depends(require_admin)): + Session = await get_session() + with Session() as s: + source = s.query(DriveSource).filter_by(id=source_id).first() + if not source: + raise HTTPException(status_code=404, detail="Drive source not found") + result = source.to_dict() + result["file_mappings"] = [m.to_dict() for m in source.file_mappings] + return result + + +@router.put("/drive-sources/{source_id}") +async def update_drive_source(source_id: int, body: DriveSourceCreate, user=Depends(require_admin)): + Session = await get_session() + with Session() as s: + source = s.query(DriveSource).filter_by(id=source_id).first() + if not source: + raise HTTPException(status_code=404, detail="Drive source not found") + source.partition_name = body.partition_name + source.drive_base_url = body.drive_base_url + source.drive_folder_id = body.drive_folder_id + source.sync_frequency_minutes = body.sync_frequency_minutes + source.auth_mode = body.auth_mode + source.service_account_client_id = body.service_account_client_id + source.service_account_client_secret = body.service_account_client_secret + s.commit() + s.refresh(source) + return source.to_dict() + + +@router.delete("/drive-sources/{source_id}") +async def delete_drive_source(source_id: int, user=Depends(require_admin)): + Session = await get_session() + with Session() as s: + source = s.query(DriveSource).filter_by(id=source_id).first() + if not source: + raise HTTPException(status_code=404, detail="Drive source not found") + # TODO: Also delete indexed files from partition via Indexer + s.delete(source) + s.commit() + return {"detail": "Drive source deleted"} + + +@router.post("/drive-sources/{source_id}/sync") +async def trigger_drive_sync(source_id: int, user=Depends(require_admin)): + """Trigger a manual sync for a drive source.""" + Session = await get_session() + with Session() as s: + source = s.query(DriveSource).filter_by(id=source_id).first() + if not source: + raise HTTPException(status_code=404, detail="Drive source not found") + + # TODO: Trigger via DriveSyncScheduler Ray actor + # scheduler = ray.get_actor("DriveSyncScheduler", namespace="openrag") + # await scheduler.trigger_sync.remote(source_id) + + return {"detail": f"Sync triggered for source {source_id}"} + + +# ============================================================ +# NOTIFICATION CHANNELS +# ============================================================ + + +@router.get("/channels") +async def list_channels(user=Depends(require_admin)): + Session = await get_session() + with Session() as s: + channels = s.query(NotificationChannel).all() + return [c.to_dict() for c in channels] + + +@router.post("/channels", status_code=201) +async def create_channel(body: ChannelCreate, user=Depends(require_admin)): + Session = await get_session() + with Session() as s: + channel = NotificationChannel( + name=body.name, + type=body.type, + config_json=body.config, + active=body.active, + ) + s.add(channel) + s.commit() + s.refresh(channel) + return channel.to_dict() + + +@router.put("/channels/{channel_id}") +async def update_channel(channel_id: int, body: ChannelCreate, user=Depends(require_admin)): + Session = await get_session() + with Session() as s: + channel = s.query(NotificationChannel).filter_by(id=channel_id).first() + if not channel: + raise HTTPException(status_code=404, detail="Channel not found") + channel.name = body.name + channel.type = body.type + channel.config_json = body.config + channel.active = body.active + s.commit() + s.refresh(channel) + return channel.to_dict() + + +@router.delete("/channels/{channel_id}") +async def delete_channel(channel_id: int, user=Depends(require_admin)): + Session = await get_session() + with Session() as s: + channel = s.query(NotificationChannel).filter_by(id=channel_id).first() + if not channel: + raise HTTPException(status_code=404, detail="Channel not found") + s.delete(channel) + s.commit() + return {"detail": "Channel deleted"} + + +@router.post("/channels/{channel_id}/test") +async def test_channel(channel_id: int, user=Depends(require_admin)): + Session = await get_session() + with Session() as s: + channel = s.query(NotificationChannel).filter_by(id=channel_id).first() + if not channel: + raise HTTPException(status_code=404, detail="Channel not found") + + # TODO: Use dispatcher to send test message + # from components.notifications import get_dispatcher + # dispatcher = get_dispatcher(channel) + # await dispatcher.send_test() + + return {"detail": f"Test message sent via {channel.type} channel '{channel.name}'"} + + +# ============================================================ +# ANNOUNCEMENTS & POLLS +# ============================================================ + + +@router.get("/announcements") +async def list_announcements( + announcement_type: str | None = None, + announcement_status: str | None = None, + user=Depends(require_admin), +): + Session = await get_session() + with Session() as s: + query = s.query(Announcement).order_by(Announcement.created_at.desc()) + if announcement_type: + query = query.filter(Announcement.type == announcement_type) + if announcement_status: + query = query.filter(Announcement.status == announcement_status) + items = query.all() + return [a.to_dict() for a in items] + + +@router.post("/announcements", status_code=201) +async def create_announcement(body: AnnouncementCreate, user=Depends(require_admin)): + Session = await get_session() + with Session() as s: + ann = Announcement( + type=body.type, + title=body.title, + body=body.body, + target_type=body.target_type, + target_value=body.target_value, + scheduled_at=body.scheduled_at, + channels=body.channels, + created_by=user.get("id"), + status="scheduled" if body.scheduled_at else "draft", + ) + s.add(ann) + s.flush() + + # Add poll options if it's a poll + if body.type == "poll": + for i, label in enumerate(body.poll_options): + s.add(PollOption(announcement_id=ann.id, label=label, sort_order=i)) + + s.commit() + s.refresh(ann) + return ann.to_dict() + + +@router.get("/announcements/{ann_id}") +async def get_announcement(ann_id: int, user=Depends(require_admin)): + Session = await get_session() + with Session() as s: + ann = s.query(Announcement).filter_by(id=ann_id).first() + if not ann: + raise HTTPException(status_code=404, detail="Announcement not found") + return ann.to_dict() + + +@router.put("/announcements/{ann_id}") +async def update_announcement(ann_id: int, body: AnnouncementCreate, user=Depends(require_admin)): + Session = await get_session() + with Session() as s: + ann = s.query(Announcement).filter_by(id=ann_id).first() + if not ann: + raise HTTPException(status_code=404, detail="Announcement not found") + if ann.status != "draft": + raise HTTPException(status_code=400, detail="Only draft announcements can be edited") + ann.title = body.title + ann.body = body.body + ann.target_type = body.target_type + ann.target_value = body.target_value + ann.scheduled_at = body.scheduled_at + ann.channels = body.channels + ann.updated_at = datetime.now() + s.commit() + s.refresh(ann) + return ann.to_dict() + + +@router.delete("/announcements/{ann_id}") +async def delete_announcement(ann_id: int, user=Depends(require_admin)): + Session = await get_session() + with Session() as s: + ann = s.query(Announcement).filter_by(id=ann_id).first() + if not ann: + raise HTTPException(status_code=404, detail="Announcement not found") + s.delete(ann) + s.commit() + return {"detail": "Announcement deleted"} + + +@router.post("/announcements/{ann_id}/send") +async def send_announcement(ann_id: int, user=Depends(require_admin)): + Session = await get_session() + with Session() as s: + ann = s.query(Announcement).filter_by(id=ann_id).first() + if not ann: + raise HTTPException(status_code=404, detail="Announcement not found") + if ann.status == "sent": + raise HTTPException(status_code=400, detail="Already sent") + + # TODO: Dispatch via configured channels + ann.status = "sent" + ann.sent_at = datetime.now() + s.commit() + return {"detail": "Announcement sent"} + + +@router.post("/announcements/{ann_id}/close") +async def close_announcement(ann_id: int, user=Depends(require_admin)): + Session = await get_session() + with Session() as s: + ann = s.query(Announcement).filter_by(id=ann_id).first() + if not ann: + raise HTTPException(status_code=404, detail="Announcement not found") + ann.status = "closed" + ann.closed_at = datetime.now() + s.commit() + return {"detail": "Announcement closed"} + + +@router.get("/announcements/{ann_id}/results") +async def get_poll_results(ann_id: int, user=Depends(require_admin)): + Session = await get_session() + with Session() as s: + ann = s.query(Announcement).filter_by(id=ann_id).first() + if not ann or ann.type != "poll": + raise HTTPException(status_code=404, detail="Poll not found") + + from sqlalchemy import func + + results = ( + s.query(PollOption.id, PollOption.label, func.count(PollResponse.id).label("votes")) + .outerjoin(PollResponse, PollOption.id == PollResponse.poll_option_id) + .filter(PollOption.announcement_id == ann_id) + .group_by(PollOption.id, PollOption.label) + .order_by(PollOption.sort_order) + .all() + ) + + total_votes = sum(r.votes for r in results) + return { + "announcement_id": ann_id, + "total_votes": total_votes, + "options": [ + {"id": r.id, "label": r.label, "votes": r.votes} + for r in results + ], + } + + +# --- Public endpoints (user-facing) --- + + +@router.post("/announcements/{ann_id}/respond") +async def vote_on_poll(ann_id: int, body: PollVoteRequest, request: Request): + """Vote on a poll. Requires authentication but not admin.""" + user = request.state.user + external_id = user.get("external_user_id") or str(user.get("id")) + + Session = await get_session() + with Session() as s: + ann = s.query(Announcement).filter_by(id=ann_id).first() + if not ann or ann.type != "poll": + raise HTTPException(status_code=404, detail="Poll not found") + if ann.status == "closed": + raise HTTPException(status_code=400, detail="Poll is closed") + + # Verify option belongs to this poll + option = s.query(PollOption).filter_by(id=body.poll_option_id, announcement_id=ann_id).first() + if not option: + raise HTTPException(status_code=404, detail="Poll option not found") + + # Check for existing vote + existing = s.query(PollResponse).filter_by( + announcement_id=ann_id, external_user_id=external_id + ).first() + if existing: + # Update vote + existing.poll_option_id = body.poll_option_id + existing.responded_at = datetime.now() + else: + s.add(PollResponse( + announcement_id=ann_id, + poll_option_id=body.poll_option_id, + external_user_id=external_id, + )) + s.commit() + return {"detail": "Vote recorded"} diff --git a/openrag/routers/auth.py b/openrag/routers/auth.py index 6613d72b..19d4e571 100644 --- a/openrag/routers/auth.py +++ b/openrag/routers/auth.py @@ -118,9 +118,16 @@ def _require_oidc_mode(): def _is_request_secure(request: Request) -> bool: """True if the client-observed scheme is HTTPS. - ``request.url.scheme`` already accounts for reverse-proxy headers when the - app is started with ``proxy_headers=True`` (see ``api.py``). + Checks multiple indicators: + 1. ``PREFERRED_URL_SCHEME`` env var (set when behind a TLS-terminating proxy) + 2. ``X-Forwarded-Proto`` header (set by reverse proxies like Traefik/Nginx) + 3. ``request.url.scheme`` (accounts for proxy_headers=True in uvicorn) """ + import os + if os.environ.get("PREFERRED_URL_SCHEME", "").lower() == "https": + return True + if request.headers.get("x-forwarded-proto", "").lower() == "https": + return True return request.url.scheme == "https" diff --git a/openrag/scripts/migrations/alembic/versions/e1f2a3b4c5d6_add_oidc_support.py b/openrag/scripts/migrations/alembic/versions/e1f2a3b4c5d6_add_oidc_support.py new file mode 100644 index 00000000..0d40af04 --- /dev/null +++ b/openrag/scripts/migrations/alembic/versions/e1f2a3b4c5d6_add_oidc_support.py @@ -0,0 +1,27 @@ +"""Add OIDC support: source field on partition_memberships + +Revision ID: e1f2a3b4c5d6 +Revises: cd9b84278028 +Create Date: 2026-04-06 +""" + +from alembic import op +import sqlalchemy as sa + +# revision identifiers, used by Alembic. +revision = "e1f2a3b4c5d6" +down_revision = "cd9b84278028" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # Add 'source' column to partition_memberships + op.add_column( + "partition_memberships", + sa.Column("source", sa.String(), nullable=False, server_default="manual"), + ) + + +def downgrade() -> None: + op.drop_column("partition_memberships", "source") diff --git a/openrag/scripts/migrations/alembic/versions/f2a3b4c5d6e7_add_all_integration_tables.py b/openrag/scripts/migrations/alembic/versions/f2a3b4c5d6e7_add_all_integration_tables.py new file mode 100644 index 00000000..7d435396 --- /dev/null +++ b/openrag/scripts/migrations/alembic/versions/f2a3b4c5d6e7_add_all_integration_tables.py @@ -0,0 +1,198 @@ +"""Add all integration tables: indexing profiles, Q&A, feedback, announcements, drive sources + +Revision ID: f2a3b4c5d6e7 +Revises: e1f2a3b4c5d6 +Create Date: 2026-04-06 +""" + +from alembic import op +import sqlalchemy as sa + +revision = "f2a3b4c5d6e7" +down_revision = "e1f2a3b4c5d6" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # Indexing profiles + op.create_table( + "indexing_profiles", + sa.Column("id", sa.Integer, primary_key=True), + sa.Column("name", sa.String(100), unique=True, nullable=False), + sa.Column("description", sa.String, nullable=True), + sa.Column("chunker_name", sa.String(50), nullable=False, server_default="recursive_splitter"), + sa.Column("chunk_size", sa.Integer, nullable=False, server_default="512"), + sa.Column("chunk_overlap_rate", sa.String, nullable=False, server_default="0.2"), + sa.Column("contextual_retrieval", sa.Boolean, nullable=False, server_default=sa.text("true")), + sa.Column("contextualization_timeout", sa.Integer, nullable=False, server_default="120"), + sa.Column("max_concurrent_contextualization", sa.Integer, nullable=False, server_default="10"), + sa.Column("retriever_type", sa.String(50), nullable=False, server_default="single"), + sa.Column("retriever_top_k", sa.Integer, nullable=False, server_default="50"), + sa.Column("similarity_threshold", sa.String, nullable=False, server_default="0.6"), + sa.Column("extra_params", sa.JSON, nullable=False, server_default="{}"), + sa.Column("created_at", sa.DateTime, nullable=False, server_default=sa.func.now()), + sa.Column("updated_at", sa.DateTime, nullable=False, server_default=sa.func.now()), + ) + + op.create_table( + "partition_indexing_config", + sa.Column("partition_name", sa.String, sa.ForeignKey("partitions.partition", ondelete="CASCADE"), primary_key=True), + sa.Column("indexing_profile_id", sa.Integer, sa.ForeignKey("indexing_profiles.id"), nullable=False), + sa.Column("overrides", sa.JSON, nullable=False, server_default="{}"), + ) + + # Q&A entries + op.create_table( + "qa_entries", + sa.Column("id", sa.Integer, primary_key=True), + sa.Column("partition_name", sa.String, sa.ForeignKey("partitions.partition", ondelete="CASCADE"), nullable=True, index=True), + sa.Column("question", sa.String, nullable=False), + sa.Column("expected_answer", sa.String, nullable=True), + sa.Column("override_answer", sa.String, nullable=True), + sa.Column("override_active", sa.Boolean, nullable=False, server_default=sa.text("false")), + sa.Column("tags", sa.JSON, nullable=False, server_default="[]"), + sa.Column("source_feedback_id", sa.Integer, nullable=True), + sa.Column("created_by", sa.Integer, sa.ForeignKey("users.id", ondelete="SET NULL"), nullable=True), + sa.Column("created_at", sa.DateTime, nullable=False, server_default=sa.func.now()), + sa.Column("updated_at", sa.DateTime, nullable=False, server_default=sa.func.now()), + ) + + # Q&A eval runs + op.create_table( + "qa_eval_runs", + sa.Column("id", sa.Integer, primary_key=True), + sa.Column("partition_name", sa.String, sa.ForeignKey("partitions.partition", ondelete="CASCADE"), nullable=True), + sa.Column("started_at", sa.DateTime, nullable=False, server_default=sa.func.now()), + sa.Column("completed_at", sa.DateTime, nullable=True), + sa.Column("status", sa.String(20), nullable=False, server_default="pending"), + sa.Column("total_questions", sa.Integer, nullable=False, server_default="0"), + sa.Column("completed_questions", sa.Integer, nullable=False, server_default="0"), + sa.Column("results", sa.JSON, nullable=True), + sa.Column("config", sa.JSON, nullable=True), + sa.Column("created_by", sa.Integer, sa.ForeignKey("users.id", ondelete="SET NULL"), nullable=True), + ) + + # User feedback + op.create_table( + "user_feedback", + sa.Column("id", sa.Integer, primary_key=True), + sa.Column("external_user_id", sa.String, nullable=True, index=True), + sa.Column("partition_name", sa.String, nullable=True, index=True), + sa.Column("question", sa.String, nullable=False), + sa.Column("response", sa.String, nullable=False), + sa.Column("model", sa.String, nullable=True), + sa.Column("rating", sa.Integer, nullable=False), + sa.Column("reason", sa.String, nullable=True), + sa.Column("owui_chat_id", sa.String, nullable=True), + sa.Column("owui_message_id", sa.String, nullable=True), + sa.Column("status", sa.String(20), nullable=False, server_default="pending"), + sa.Column("promoted_to_qa_id", sa.Integer, sa.ForeignKey("qa_entries.id", ondelete="SET NULL"), nullable=True), + sa.Column("reviewed_by", sa.Integer, sa.ForeignKey("users.id", ondelete="SET NULL"), nullable=True), + sa.Column("created_at", sa.DateTime, nullable=False, server_default=sa.func.now()), + sa.Column("reviewed_at", sa.DateTime, nullable=True), + sa.UniqueConstraint("owui_chat_id", "owui_message_id", name="uix_owui_feedback"), + sa.CheckConstraint("rating BETWEEN -1 AND 1", name="ck_feedback_rating"), + ) + + # Notification channels + op.create_table( + "notification_channels", + sa.Column("id", sa.Integer, primary_key=True), + sa.Column("name", sa.String(100), nullable=False), + sa.Column("type", sa.String(20), nullable=False), + sa.Column("config", sa.JSON, nullable=False), + sa.Column("active", sa.Boolean, nullable=False, server_default=sa.text("true")), + sa.Column("created_at", sa.DateTime, nullable=False, server_default=sa.func.now()), + sa.CheckConstraint("type IN ('webhook','email_smtp','tchap_bot')", name="ck_channel_type"), + ) + + # Announcements + op.create_table( + "announcements", + sa.Column("id", sa.Integer, primary_key=True), + sa.Column("type", sa.String(20), nullable=False), + sa.Column("title", sa.String(255), nullable=False), + sa.Column("body", sa.String, nullable=False), + sa.Column("target_type", sa.String(20), nullable=False), + sa.Column("target_value", sa.String, nullable=True), + sa.Column("status", sa.String(20), nullable=False, server_default="draft"), + sa.Column("scheduled_at", sa.DateTime, nullable=True), + sa.Column("sent_at", sa.DateTime, nullable=True), + sa.Column("closed_at", sa.DateTime, nullable=True), + sa.Column("channels", sa.JSON, nullable=False, server_default="[]"), + sa.Column("created_by", sa.Integer, sa.ForeignKey("users.id", ondelete="SET NULL"), nullable=True), + sa.Column("created_at", sa.DateTime, nullable=False, server_default=sa.func.now()), + sa.Column("updated_at", sa.DateTime, nullable=False, server_default=sa.func.now()), + sa.CheckConstraint("type IN ('announcement','poll')", name="ck_announcement_type"), + sa.CheckConstraint("target_type IN ('all','partition','group','user')", name="ck_target_type"), + sa.CheckConstraint("status IN ('draft','scheduled','sent','closed')", name="ck_announcement_status"), + ) + + # Poll options + op.create_table( + "poll_options", + sa.Column("id", sa.Integer, primary_key=True), + sa.Column("announcement_id", sa.Integer, sa.ForeignKey("announcements.id", ondelete="CASCADE"), nullable=False), + sa.Column("label", sa.String(255), nullable=False), + sa.Column("sort_order", sa.Integer, nullable=False, server_default="0"), + ) + + # Poll responses + op.create_table( + "poll_responses", + sa.Column("id", sa.Integer, primary_key=True), + sa.Column("announcement_id", sa.Integer, sa.ForeignKey("announcements.id", ondelete="CASCADE"), nullable=False), + sa.Column("poll_option_id", sa.Integer, sa.ForeignKey("poll_options.id", ondelete="CASCADE"), nullable=False), + sa.Column("external_user_id", sa.String, nullable=False), + sa.Column("responded_at", sa.DateTime, nullable=False, server_default=sa.func.now()), + sa.UniqueConstraint("announcement_id", "external_user_id", name="uix_poll_one_vote"), + ) + + # Drive sources + op.create_table( + "drive_sources", + sa.Column("id", sa.Integer, primary_key=True), + sa.Column("partition_name", sa.String, sa.ForeignKey("partitions.partition", ondelete="CASCADE"), nullable=False, index=True), + sa.Column("drive_base_url", sa.String, nullable=False), + sa.Column("drive_folder_id", sa.String, nullable=False), + sa.Column("sync_frequency_minutes", sa.Integer, nullable=False, server_default="60"), + sa.Column("sync_enabled", sa.Boolean, nullable=False, server_default=sa.text("true")), + sa.Column("last_synced_at", sa.DateTime, nullable=True), + sa.Column("last_sync_status", sa.String(20), nullable=True), + sa.Column("last_sync_error", sa.String, nullable=True), + sa.Column("auth_mode", sa.String(20), nullable=False, server_default="service_account"), + sa.Column("service_account_client_id", sa.String, nullable=True), + sa.Column("service_account_client_secret", sa.String, nullable=True), + sa.Column("created_by", sa.Integer, sa.ForeignKey("users.id", ondelete="SET NULL"), nullable=True), + sa.Column("created_at", sa.DateTime, nullable=False, server_default=sa.func.now()), + sa.UniqueConstraint("partition_name", "drive_folder_id", name="uix_partition_drive_folder"), + ) + + # Drive file mappings + op.create_table( + "drive_file_mappings", + sa.Column("id", sa.Integer, primary_key=True), + sa.Column("drive_source_id", sa.Integer, sa.ForeignKey("drive_sources.id", ondelete="CASCADE"), nullable=False, index=True), + sa.Column("drive_item_id", sa.String, nullable=False), + sa.Column("drive_item_title", sa.String, nullable=True), + sa.Column("drive_item_updated_at", sa.DateTime, nullable=True), + sa.Column("file_id", sa.String, nullable=False), + sa.Column("partition_name", sa.String, nullable=False), + sa.Column("last_synced_at", sa.DateTime, nullable=False, server_default=sa.func.now()), + sa.UniqueConstraint("drive_source_id", "drive_item_id", name="uix_drive_source_item"), + ) + + +def downgrade() -> None: + op.drop_table("drive_file_mappings") + op.drop_table("drive_sources") + op.drop_table("poll_responses") + op.drop_table("poll_options") + op.drop_table("announcements") + op.drop_table("notification_channels") + op.drop_table("user_feedback") + op.drop_table("qa_eval_runs") + op.drop_table("qa_entries") + op.drop_table("partition_indexing_config") + op.drop_table("indexing_profiles") diff --git a/prompts_integration.md b/prompts_integration.md new file mode 100644 index 00000000..0d11f305 --- /dev/null +++ b/prompts_integration.md @@ -0,0 +1,374 @@ +# Prompts d'implementation — Integration OpenRAG + Open WebUI + Keycloak + Scaleway + +> Ce document contient les prompts et notes intermediaires pour le deploiement +> local d'OpenRAG integre a l'ecosysteme owuicore-main. + +--- + +## Prompt 0 / Intermediaire — Etat des lieux et configuration + +### 0.1 — Allocation des ports Docker (sans collision) + +Ports du socle owuicore-main : + +| Port | Service socle | +|-------|--------------------| +| 3000 | openwebui | +| 5432 | postgres | +| 8082 | keycloak | +| 8083 | searxng | +| 9098 | tika | +| 9099 | pipelines | + +Ports OpenRAG choisis : + +| Variable | Port | Service | +|------------------------|--------|------------------------| +| `APP_PORT` | 8180 | OpenRAG API (FastAPI) | +| `CHAINLIT_PORT` | 8190 | Chainlit Chat UI | +| `RAY_DASHBOARD_PORT` | 8265 | Ray Dashboard | +| `INDEXERUI_PORT` | 3042 | Indexer UI (SvelteKit) | + +Services internes OpenRAG (pas de mapping host) : +rdb (5432), milvus (19530), etcd (2379), minio (9000/9001), reranker (7997). + +### 0.2 — Remplacement VLLM par API Scaleway + +Au lieu de deployer vllm-cpu/vllm-gpu localement, utiliser les APIs +Scaleway Generative (compatibles OpenAI). + +```bash +# .env OpenRAG — section Embedder via Scaleway +EMBEDDER_BASE_URL=https://api.scaleway.ai//v1 +EMBEDDER_API_KEY= +EMBEDDER_MODEL_NAME=bge-multilingual-gemma2 + +# .env OpenRAG — section LLM via Scaleway +BASE_URL=https://api.scaleway.ai//v1 +API_KEY= +MODEL=mistral-small-3.2-24b-instruct-2506 +# Alternatives : gpt-oss-120b, llama-3.3-70b-instruct, qwen3-235b-a22b-instruct-2507 + +# VLM (si besoin de captioning d'images) +VLM_BASE_URL=https://api.scaleway.ai//v1 +VLM_API_KEY= +VLM_MODEL=pixtral-12b-2409 + +# Desactiver le reranker (pas dispo sur Scaleway) +RERANKER_ENABLED=false +``` + +Note : Scaleway a des limites de rate sur les embeddings (erreur 429). +Si probleme, reduire `RETRIEVER_TOP_K` ou utiliser un embedder local. + +Impact docker-compose : supprimer les `depends_on` vers `vllm-gpu`/`vllm-cpu` +et `reranker`/`reranker-cpu`. Ne plus builder ces services. + +### 0.3 — Architecture OAuth2 Proxy (Option 1) + +``` +[Navigateur] + | + v +[oauth2-proxy :4180] <--> [Keycloak :8082/realms/openrag] + | + +---> [OpenRAG API :8180] (AUTH_MODE=oidc) + +---> [Indexer UI :3042] (passe le JWT en Authorization) +``` + +Le backend OpenRAG possede deja un mode `AUTH_MODE=oidc` complet : +- Validation JWT via JWKS (`openrag/auth/oidc.py`) +- Auto-provisioning des utilisateurs (`OIDC_AUTO_PROVISION=true`) +- Sync groupes Keycloak -> PartitionMembership +- 2 modes de sync : `additive` et `authoritative` +- Tests unitaires : `openrag/auth/test_oidc.py`, `openrag/auth/test_group_sync.py` + +Configuration backend : +```bash +AUTH_MODE=oidc +OIDC_ISSUER_URL=http://keycloak:8080/realms/openrag +OIDC_AUDIENCE=openrag +OIDC_AUTO_PROVISION=true +OIDC_GROUP_CLAIM=groups +OIDC_GROUP_PREFIX_VIEWER=rag-query/ +OIDC_GROUP_PREFIX_EDITOR=rag-edit/ +OIDC_GROUP_PREFIX_OWNER=rag-admin/ +OIDC_GROUP_SYNC_MODE=additive +``` + +### 0.4 — Inventaire des tests + +| Categorie | Commande | Fichiers | +|------------------|----------------------------------------------------|----------| +| Unit tests | `uv run pytest openrag/` | 16 | +| Auth/OIDC | `uv run pytest openrag/auth/` | 2 | +| API tests (mock) | `tests/api_tests/api_run/scripts/run_api_tests_local.sh` | 10 | +| Robot Framework | `robot tests/api/` | 10 | +| Smoke tests | `tests/smoke_test_data/run_smoke_test.sh` | scripts | +| Linting | `uv run ruff check openrag/ tests/` | - | + +### 0.5 — User Management API (Backend existant) + +| Methode | Path | Auth | Description | +|---------|-------------------------------------------|---------|--------------------------------| +| GET | `/users/` | Admin | Lister tous les utilisateurs | +| GET | `/users/info` | Any | Info utilisateur courant | +| POST | `/users/` | Admin | Creer un utilisateur | +| DELETE | `/users/{user_id}` | Admin | Supprimer un utilisateur | +| POST | `/users/{user_id}/regenerate_token` | Admin | Regenerer le token API | +| PATCH | `/users/{user_id}/quota` | Admin | Modifier le quota fichiers | +| GET | `/partition/{p}/users` | Owner | Lister les membres | +| POST | `/partition/{p}/users` | Owner | Ajouter un membre | +| DELETE | `/partition/{p}/users/{uid}` | Owner | Retirer un membre | +| PATCH | `/partition/{p}/users/{uid}` | Owner | Modifier le role | + +Token format : `or-{32 hex chars}` (SHA-256 en base). + +### 0.6 — Matching comptes Keycloak <-> OpenRAG + +| Champ Keycloak | Champ OpenRAG | +|----------------------|---------------------------------| +| `sub` (JWT claim) | `users.external_user_id` | +| `preferred_username` | `users.display_name` | +| Groupes JWT | `partition_memberships` (source="oidc") | + +Le mode `AUTH_MODE=oidc` avec `OIDC_AUTO_PROVISION=true` cree +automatiquement le compte OpenRAG a la premiere requete JWT valide. +Les groupes Keycloak au format `rag-admin/`, `rag-edit/`, +`rag-query/` sont automatiquement syncs vers les PartitionMembership. + +--- + +## Prompt 1 / 5 — Authentification OIDC Keycloak + +```markdown +# Contexte + +Tu travailles sur OpenRAG, un framework RAG (FastAPI + Ray + Milvus) dont le code est dans ce repo. +L'auth actuelle utilise des tokens statiques SHA-256 stockes en base (voir `openrag/api.py` class `AuthMiddleware` et `openrag/components/indexer/vectordb/utils.py` class `PartitionFileManager`). + +Je veux ajouter un mode d'authentification OIDC/JWT compatible Keycloak, **en parallele** du mode token existant. + +# Objectif + +Implementer 3 chantiers : +1. Un middleware OIDC dans OpenRAG +2. La synchronisation automatique des groupes Keycloak vers les PartitionMembership +3. La documentation de configuration Open WebUI pour forwarder les tokens + +# STATUT : DEJA IMPLEMENTE + +Ce prompt a ete realise. Les fichiers suivants existent : +- `openrag/auth/oidc.py` — validation JWT + parsing groupes + sync memberships +- `openrag/auth/test_oidc.py` — tests unitaires validation JWT +- `openrag/auth/test_group_sync.py` — tests sync additive/authoritative +- `openrag/api.py` — AuthMiddleware supporte AUTH_MODE=oidc +- `.env.example` — variables OIDC_* documentees +- Migration Alembic pour `external_user_id` et `source` sur PartitionMembership +``` + +--- + +## Prompt 2 / 5 — Profils d'indexation + Base Q&R (evaluation et surcharge) + +```markdown +# Contexte + +Tu travailles sur OpenRAG (FastAPI + Ray + Milvus). L'authentification OIDC Keycloak a ete implementee (prompt precedent). + +Actuellement, la configuration de chunking/retrieval est **globale** via Hydra YAML (`.hydra_config/`). La factory de chunker est dans `openrag/components/indexer/chunker/chunker.py` (`ChunkerFactory.create_chunker`), invoquee par `Indexer.chunk()` dans `openrag/components/indexer/indexer.py`. + +Je veux : +1. Des **profils d'indexation** editables par l'admin, assignables par partition +2. Une **base de Q&R** pour evaluer le RAG et surcharger certaines reponses + +# Specifications detaillees + +## 2.1 — Profils d'indexation par partition + +### Nouvelles tables (migration Alembic) + +```sql +indexing_profiles ( + id SERIAL PRIMARY KEY, + name VARCHAR(100) UNIQUE NOT NULL, + description TEXT, + chunker_name VARCHAR(50) DEFAULT 'recursive_splitter', + chunk_size INT DEFAULT 512, + chunk_overlap_rate FLOAT DEFAULT 0.2, + contextual_retrieval BOOLEAN DEFAULT true, + contextualization_timeout INT DEFAULT 120, + max_concurrent_contextualization INT DEFAULT 10, + retriever_type VARCHAR(50) DEFAULT 'single', + retriever_top_k INT DEFAULT 50, + similarity_threshold FLOAT DEFAULT 0.6, + extra_params JSONB DEFAULT '{}', + created_at TIMESTAMP DEFAULT now(), + updated_at TIMESTAMP DEFAULT now() +) + +partition_indexing_config ( + partition_name VARCHAR PRIMARY KEY REFERENCES partitions(partition_name) ON DELETE CASCADE, + indexing_profile_id INT NOT NULL REFERENCES indexing_profiles(id), + overrides JSONB DEFAULT '{}' +) +``` + +### Modeles SQLAlchemy + +Ajouter les modeles `IndexingProfile` et `PartitionIndexingConfig` dans `openrag/components/indexer/vectordb/utils.py` (a cote des modeles existants). + +### Endpoints (nouveau router `openrag/routers/admin.py`) + +| Endpoint | Methode | Auth | Description | +|----------|---------|------|-------------| +| `GET /admin/indexing-profiles` | GET | admin | Lister tous les profils | +| `POST /admin/indexing-profiles` | POST | admin | Creer un profil | +| `GET /admin/indexing-profiles/{id}` | GET | admin | Detail d'un profil | +| `PUT /admin/indexing-profiles/{id}` | PUT | admin | Modifier un profil | +| `DELETE /admin/indexing-profiles/{id}` | DELETE | admin | Supprimer (erreur si utilise par une partition) | +| `GET /admin/partitions/{name}/indexing` | GET | owner | Config indexation de la partition | +| `PUT /admin/partitions/{name}/indexing` | PUT | owner | Assigner profil + overrides | + +### Impact sur le pipeline d'indexation + +1. Modifier `ChunkerFactory.create_chunker(config, profile=None)` : si `profile` est fourni, il remplace les valeurs Hydra correspondantes +2. Modifier `Indexer.chunk()` : avant de chunker, charger le profil de la partition via `vectordb.get_partition_indexing_config.remote(partition_name)`. Si aucun profil assigne, utiliser la config Hydra globale (comportement actuel). +3. Ajouter une methode `get_partition_indexing_config(partition_name)` sur le Ray actor `Vectordb`/`MilvusDB` +4. Au **boot**, creer un profil `"default"` initialise depuis les valeurs Hydra actuelles (idempotent, ne pas ecraser s'il existe) + +### Re-indexation + +Ne PAS re-indexer automatiquement au changement de profil. Ajouter un endpoint : +`POST /admin/partitions/{name}/reindex` (auth: owner) — cree une tache async qui supprime les chunks existants et re-indexe tous les fichiers de la partition avec le nouveau profil. + +## 2.2 — Base de Q&R + +(voir contenu original prompt 2 — inchange) +``` + +--- + +## Prompt 3 / 5 — Connecteur Drive (Suite Numerique) + +(inchange — voir version originale) + +--- + +## Prompt 4 / 5 — Boucle de feedback (Open WebUI -> OpenRAG -> Q&R) + +(inchange — voir version originale) + +--- + +## Prompt 5 / 5 — Annonces, sondages et canaux de notification + +(inchange — voir version originale) + +--- + +## Prompt 6 (nouveau) — Script de sync Keycloak -> OpenRAG (sans modification de code) + +```markdown +# Contexte + +OpenRAG dispose deja de : +- `AUTH_MODE=oidc` avec auto-provisioning (cree le user a la premiere requete JWT) +- API REST d'admin pour gerer les users et partitions + +Le probleme : l'auto-provisioning ne se declenche qu'a la premiere requete +d'un utilisateur. On veut pouvoir pre-provisionner les comptes AVANT +que l'utilisateur ne se connecte, par exemple pour : +- Pre-creer des partitions et assigner des droits +- Synchroniser un batch d'utilisateurs Keycloak +- Avoir un annuaire a jour + +# Solution : script externe de sync (zero modification de code) + +## Fonctionnement + +Script Python standalone qui : +1. Interroge l'API admin Keycloak pour lister les utilisateurs et groupes +2. Pour chaque utilisateur Keycloak, appelle l'API admin OpenRAG pour : + a. Creer le compte si inexistant + b. Creer les partitions necessaires + c. Assigner les memberships selon les groupes Keycloak + +## Pre-requis + +- Un token admin OpenRAG (AUTH_TOKEN du .env ou un user admin) +- Un service account Keycloak avec le role `realm-management/view-users` +- Acces reseau aux deux APIs + +## API Keycloak utilisees + +- `GET /admin/realms/{realm}/users` — lister les utilisateurs +- `GET /admin/realms/{realm}/users/{id}/groups` — groupes d'un utilisateur +- `POST /realms/{realm}/protocol/openid-connect/token` — obtenir un token admin + +## API OpenRAG utilisees + +- `GET /users/` — lister les users existants +- `POST /users/` — creer un user +- `POST /partition/{name}` — creer une partition +- `POST /partition/{name}/users` — ajouter un membre avec role +- `PATCH /partition/{name}/users/{uid}` — modifier le role + +## Logique de sync + +1. Obtenir un token admin Keycloak (client_credentials) +2. Lister tous les utilisateurs Keycloak +3. Pour chaque utilisateur : + a. Recuperer ses groupes + b. Parser les groupes avec le meme format que le backend OIDC : + - `rag-admin/` -> owner + - `rag-edit/` -> editor + - `rag-query/` -> viewer + c. Verifier si l'utilisateur existe dans OpenRAG (par display_name + ou external_user_id si accessible via l'API) + d. Si non : creer via POST /users/ + e. Pour chaque partition : + - Creer la partition si elle n'existe pas + - Ajouter le membership avec le bon role + +## Execution + +- En one-shot : `python sync_keycloak_openrag.py` +- En cron : `*/30 * * * * python sync_keycloak_openrag.py` +- En docker : sidecar container avec un cron + +## Variables d'environnement + +| Variable | Description | +|----------|-------------| +| `KEYCLOAK_URL` | URL Keycloak (ex: http://localhost:8082) | +| `KEYCLOAK_REALM` | Nom du realm | +| `KEYCLOAK_CLIENT_ID` | Client ID du service account | +| `KEYCLOAK_CLIENT_SECRET` | Client secret | +| `OPENRAG_URL` | URL OpenRAG (ex: http://localhost:8180) | +| `OPENRAG_ADMIN_TOKEN` | Token admin OpenRAG | +| `DRY_RUN` | true/false — mode simulation | +``` + +--- + +## Notes d'utilisation + +### Ordre d'execution + +Les prompts doivent etre executes **dans l'ordre** (1->2->3->4->5->6). +Le prompt 1 est deja realise. Le prompt 6 est independant. + +### Avant chaque prompt + +> Lis d'abord le fichier `CLAUDE.md` a la racine du repo pour comprendre +> l'architecture, les conventions et les commandes. Puis lis les fichiers +> mentionnes dans la section "Fichiers a modifier" avant de commencer. + +### Apres chaque prompt + +```bash +uv run ruff check openrag/ tests/ integrations/ +uv run ruff format openrag/ tests/ integrations/ +uv run pytest +```