Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ Subclass `BenchmarkService` and implement its abstract methods. On instantiation
- `validate_task_ids(task_ids, dataset)` — raise `ValueError` if any ID is not in the dataset
- `list_tasks(dataset)` — return `list[V1Task]` (id, question, timeout) for the lab-facing `GET /v1/datasets/{dataset}/tasks` endpoint. Must be overridden before exposing task listing; the base implementation fails closed to avoid leaking evaluator-only data.
- `check_auth(headers)` — legacy boolean auth hook. Override for custom auth that does not need tenant or dataset awareness.
- `resolve_tenant(headers)` — validate request authorization and return a tenant ID, `"_legacy"` for legacy auth, or `None` to reject.
- `resolve_tenant(headers)` — validate request authorization and return an `AuthResult`. Use `AuthResult(tenant=...)` on success or `AuthResult(failure=...)` to reject with a specific reason.
- `check_dataset_access(tenant, dataset)` — return whether a resolved tenant may access a dataset.
- `get_service_version()` — optional benchmark-owned service version override. If it returns `None`, `/version` falls back to the installed benchmark package version.
- `get_dataset_version(dataset)` — optional dataset release/version hook. The value is returned on the dataset task-list response after auth and dataset access checks.
Expand Down Expand Up @@ -253,17 +253,17 @@ class MyBenchmarkService(BenchmarkService):
# ... other abstract methods
```

For tenant-aware custom authentication, override `resolve_tenant()` directly and return a tenant ID. Override `check_dataset_access()` if your service needs dataset rules that differ from the configured allowlist:
For tenant-aware custom authentication, override `resolve_tenant()` directly and return an `AuthResult`. Override `check_dataset_access()` if your service needs dataset rules that differ from the configured allowlist:

```python
from benchmark_service import BenchmarkService
from benchmark_service import AuthFailure, AuthResult, BenchmarkService

class MyBenchmarkService(BenchmarkService):
async def resolve_tenant(self, headers: dict[str, str]) -> str | None:
async def resolve_tenant(self, headers: dict[str, str]) -> AuthResult:
token = headers.get("authorization")
if token == "Bearer internal-token":
return "internal"
return None
return AuthResult(tenant="internal")
return AuthResult(failure=AuthFailure.INVALID_KEY)

async def check_dataset_access(self, tenant: str, dataset: str | None) -> bool:
return tenant == "internal" and (dataset or "default") in {"default", "validation"}
Expand Down
3 changes: 3 additions & 0 deletions src/benchmark_service/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from benchmark_service._version import __version__
from benchmark_service.app import BenchmarkServiceApp
from benchmark_service.auth import AuthFailure, AuthResult
from benchmark_service.base import BenchmarkService
from benchmark_service.client import BenchmarkServiceClient, BenchmarkServiceError, BenchmarkServiceUnauthenticatedError
from benchmark_service.inflight import InflightMiddleware
Expand All @@ -26,6 +27,8 @@
)

__all__ = [
"AuthFailure",
"AuthResult",
"BenchmarkServiceApp",
"BenchmarkService",
"BenchmarkServiceUnauthenticatedError",
Expand Down
65 changes: 50 additions & 15 deletions src/benchmark_service/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,14 @@
from websockets.exceptions import ConnectionClosed

from benchmark_service._version import __version__ as _framework_version
from benchmark_service.auth import LEGACY_TENANT_SENTINEL, get_auth_settings, get_tenant_config, load_allowlist
from benchmark_service.auth import (
AuthFailure,
AuthResult,
LEGACY_TENANT_SENTINEL,
get_auth_settings,
get_tenant_config,
load_allowlist,
)
from benchmark_service.base import BenchmarkService
from benchmark_service.inflight import InflightMiddleware
from benchmark_service.schemas import (
Expand Down Expand Up @@ -53,6 +60,22 @@

logger = logging.getLogger(__name__)

_AUTH_FAILURE_RESPONSES: dict[AuthFailure, tuple[int, str]] = {
AuthFailure.NO_KEY: (401, "Missing x-descope-api-key header"),
AuthFailure.INVALID_KEY: (401, "Invalid or expired access key"),
AuthFailure.MULTI_TENANT: (401, "Access key must be scoped to exactly one tenant"),
AuthFailure.LEGACY_TENANT: (401, "Legacy bearer auth is not accepted on this endpoint"),
AuthFailure.NOT_ALLOWLISTED: (
403,
"Your tenant is authenticated but not allowlisted for this service; contact the service operator to request access",
),
AuthFailure.REJECTED: (401, "Unauthorized"),
}


def _failure_response(result: AuthResult) -> tuple[int, str]:
return _AUTH_FAILURE_RESPONSES[result.failure] # type: ignore[index]


async def send_json_if_connected(websocket: WebSocket, payload: dict[str, Any]) -> bool:
try:
Expand Down Expand Up @@ -144,7 +167,13 @@ def __init__(self, service_cls: type[BenchmarkService]) -> None:

@asynccontextmanager
async def lifespan(_app: FastAPI) -> AsyncGenerator[None, None]:
if get_auth_settings().auth_required:
settings = get_auth_settings()
if settings.auth_required:
if not settings.descope_project_id:
raise RuntimeError(
"AUTH_REQUIRED is true but DESCOPE_PROJECT_ID is not set; "
"configure DESCOPE_PROJECT_ID before starting the server"
)
load_allowlist()
self.service = await service_cls.create()
yield
Expand All @@ -159,11 +188,12 @@ def _register_routes(self) -> None:
async def _check_auth(request: Request, call_next): # type: ignore[no-untyped-def]
if request.url.path in _PUBLIC_PATHS:
return await call_next(request) # type: ignore[reportUnknownVariableType]
tenant = await self.service.resolve_tenant(dict(request.headers))
if tenant is None:
return JSONResponse(status_code=401, content={"detail": "Unauthorized"})
request.state.tenant = tenant
if _is_trial_tenant(tenant) and not _trial_tenant_may_access_path(request.url.path):
result = await self.service.resolve_tenant(dict(request.headers))
if not result.ok:
status, detail = _failure_response(result)
return JSONResponse(status_code=status, content={"detail": detail})
request.state.tenant = result.tenant
if _is_trial_tenant(result.tenant) and not _trial_tenant_may_access_path(request.url.path):
return JSONResponse(
status_code=403,
content={"detail": "Trial tenants may only access approved /v1 endpoints (/v1/*)"},
Expand Down Expand Up @@ -193,10 +223,14 @@ async def _check_auth(request: Request, call_next): # type: ignore[no-untyped-d
async def _value_error_handler(self, _request: Request, exc: Exception) -> Response:
raise HTTPException(status_code=400, detail=str(exc)) from exc

async def _exception_handler(self, _request: Request, exc: Exception) -> Response:
async def _exception_handler(self, request: Request, exc: Exception) -> Response:
logger.error(f"Error: {exc}")
logger.error(traceback.format_exc())
return JSONResponse(status_code=500, content={"detail": "Internal server error"})
content: dict[str, Any] = {"detail": "Evaluation failed"}
tenant = getattr(request.state, "tenant", None)
if tenant is not None and not _is_trial_tenant(tenant):
content["errors"] = [str(exc)]
return JSONResponse(status_code=500, content=content)

async def _health_check(self) -> HealthCheckResponse:
return HealthCheckResponse(status="ok")
Expand All @@ -214,15 +248,16 @@ async def _version(self) -> VersionResponse:

async def _authorize_websocket(self, websocket: WebSocket) -> str | None:
"""Authenticate a WebSocket caller. Returns tenant id, or None after closing 1008."""
tenant = await self.service.resolve_tenant(dict(websocket.headers))
if tenant is None:
await websocket.close(code=1008, reason="Unauthorized")
result = await self.service.resolve_tenant(dict(websocket.headers))
if not result.ok:
_status, detail = _failure_response(result)
await websocket.close(code=1008, reason=detail)
return None
if _is_trial_tenant(tenant):
if _is_trial_tenant(result.tenant):
await websocket.close(code=1008, reason="Trial tenants may only access /v1/*")
return None
websocket.state.tenant = tenant
return tenant
websocket.state.tenant = result.tenant
return result.tenant

async def _verify_task_ids(
self,
Expand Down
51 changes: 36 additions & 15 deletions src/benchmark_service/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import os
from collections.abc import Mapping
from dataclasses import dataclass
from enum import Enum
from functools import lru_cache
from pathlib import Path
from typing import Any
Expand All @@ -26,6 +27,25 @@
LEGACY_TENANT_SENTINEL = "_legacy"


class AuthFailure(Enum):
NO_KEY = "no_key"
INVALID_KEY = "invalid_key"
MULTI_TENANT = "multi_tenant"
LEGACY_TENANT = "legacy_tenant"
NOT_ALLOWLISTED = "not_allowlisted"
REJECTED = "rejected"


@dataclass(frozen=True)
class AuthResult:
tenant: str | None = None
failure: AuthFailure | None = None

@property
def ok(self) -> bool:
return self.failure is None


class TenantConfig(BaseModel):
"""Per-tenant access rules within a benchmark service."""

Expand Down Expand Up @@ -189,55 +209,56 @@ def _check_legacy_benchmark_api_key(headers: Mapping[str, str], settings: AuthSe
return hmac.compare_digest(authorization, expected)


async def resolve_descope_tenant(headers: Mapping[str, str]) -> str | None:
async def resolve_descope_tenant(headers: Mapping[str, str]) -> AuthResult:
"""Validate a Descope access key and resolve a single allowlisted tenant."""
settings = get_auth_settings()
if not settings.descope_project_id:
logger.warning("AUTH_REQUIRED is true but DESCOPE_PROJECT_ID is not configured")
return None

access_key = headers.get(DESCOPE_API_KEY_HEADER)
if not access_key:
return None
return AuthResult(failure=AuthFailure.NO_KEY)

cache_key = (settings.descope_project_id, access_key)
cached = _auth_cache.get(cache_key)
if cached is not None:
return cached
return AuthResult(tenant=cached)

try:
jwt_response = await _exchange_descope_access_key(settings.descope_project_id, access_key)
except Exception:
logger.warning("Failed to exchange Descope access key", exc_info=True)
return None
return AuthResult(failure=AuthFailure.INVALID_KEY)

tenants = list(jwt_response.get("tenants", {}).keys())
if len(tenants) != 1:
logger.warning("Descope access key must be scoped to exactly one tenant, got %s", len(tenants))
return None
return AuthResult(failure=AuthFailure.MULTI_TENANT)

tenant = tenants[0]
if tenant == LEGACY_TENANT_SENTINEL:
logger.info("Descope tenant %s is reserved for legacy auth compatibility", tenant)
return None
return AuthResult(failure=AuthFailure.LEGACY_TENANT)

allowlist = load_allowlist()
if tenant not in allowlist.tenants:
logger.info("Descope tenant %s is not in the service allowlist", tenant)
return None
return AuthResult(failure=AuthFailure.NOT_ALLOWLISTED)

_auth_cache[cache_key] = tenant
return tenant
return AuthResult(tenant=tenant)


async def resolve_caller_tenant(headers: Mapping[str, str]) -> str | None:
"""Return the caller tenant id, "_legacy" sentinel, or None to reject."""
async def resolve_caller_tenant(headers: Mapping[str, str]) -> AuthResult:
"""Return the caller AuthResult: tenant id, "_legacy" sentinel, or a failure reason."""
settings = get_auth_settings()
if settings.auth_required:
return await resolve_descope_tenant(headers)
return LEGACY_TENANT_SENTINEL if _check_legacy_benchmark_api_key(headers, settings) else None
return (
AuthResult(tenant=LEGACY_TENANT_SENTINEL)
if _check_legacy_benchmark_api_key(headers, settings)
else AuthResult(failure=AuthFailure.REJECTED)
)


async def check_benchmark_service_auth(headers: Mapping[str, str]) -> bool:
"""Validate benchmark-service auth headers using the configured auth mode."""
return await resolve_caller_tenant(headers) is not None
return (await resolve_caller_tenant(headers)).ok
12 changes: 7 additions & 5 deletions src/benchmark_service/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
from typing import Any, Self

from benchmark_service.auth import (
AuthFailure,
AuthResult,
LEGACY_TENANT_SENTINEL,
check_benchmark_service_auth,
load_allowlist,
Expand Down Expand Up @@ -72,16 +74,16 @@ async def check_auth(self, headers: dict[str, str]) -> bool:
"""
return await check_benchmark_service_auth(headers)

async def resolve_tenant(self, headers: dict[str, str]) -> str | None:
"""Authenticate the caller and return their tenant id, or None to reject.
async def resolve_tenant(self, headers: dict[str, str]) -> AuthResult:
"""Authenticate the caller and return an AuthResult.

Subclasses with a legacy `check_auth` override keep their existing boolean
behavior. A successful legacy check returns the "_legacy" sentinel, which
skips dataset-level allowlist enforcement.
behavior: True maps to the "_legacy" sentinel (skips allowlist enforcement);
False maps to a generic failure with no specific AuthFailure reason.
"""
if type(self).check_auth is not BenchmarkService.check_auth:
ok = await self.check_auth(headers)
return LEGACY_TENANT_SENTINEL if ok else None
return AuthResult(tenant=LEGACY_TENANT_SENTINEL) if ok else AuthResult(failure=AuthFailure.REJECTED)
return await resolve_caller_tenant(headers)

async def check_dataset_access(self, tenant: str, dataset: str | None) -> bool:
Expand Down
Loading
Loading