diff --git a/README.md b/README.md index e821611..0081507 100644 --- a/README.md +++ b/README.md @@ -83,7 +83,7 @@ Subclass `BenchmarkService` and implement its abstract methods. On instantiation - `validate_task_ids(task_ids, dataset)` — raise `ValueError` if any ID is not in the dataset - `list_tasks(dataset)` — return `list[V1Task]` (id, question, timeout) for the lab-facing `GET /v1/datasets/{dataset}/tasks` endpoint. Must be overridden before exposing task listing; the base implementation fails closed to avoid leaking evaluator-only data. - `check_auth(headers)` — legacy boolean auth hook. Override for custom auth that does not need tenant or dataset awareness. -- `resolve_tenant(headers)` — validate request authorization and return a tenant ID, `"_legacy"` for legacy auth, or `None` to reject. +- `resolve_tenant(headers)` — validate request authorization and return an `AuthResult`. Use `AuthResult(tenant=...)` on success or `AuthResult(failure=...)` to reject with a specific reason. - `check_dataset_access(tenant, dataset)` — return whether a resolved tenant may access a dataset. - `get_service_version()` — optional benchmark-owned service version override. If it returns `None`, `/version` falls back to the installed benchmark package version. - `get_dataset_version(dataset)` — optional dataset release/version hook. The value is returned on the dataset task-list response after auth and dataset access checks. @@ -253,17 +253,17 @@ class MyBenchmarkService(BenchmarkService): # ... other abstract methods ``` -For tenant-aware custom authentication, override `resolve_tenant()` directly and return a tenant ID. Override `check_dataset_access()` if your service needs dataset rules that differ from the configured allowlist: +For tenant-aware custom authentication, override `resolve_tenant()` directly and return an `AuthResult`. Override `check_dataset_access()` if your service needs dataset rules that differ from the configured allowlist: ```python -from benchmark_service import BenchmarkService +from benchmark_service import AuthFailure, AuthResult, BenchmarkService class MyBenchmarkService(BenchmarkService): - async def resolve_tenant(self, headers: dict[str, str]) -> str | None: + async def resolve_tenant(self, headers: dict[str, str]) -> AuthResult: token = headers.get("authorization") if token == "Bearer internal-token": - return "internal" - return None + return AuthResult(tenant="internal") + return AuthResult(failure=AuthFailure.INVALID_KEY) async def check_dataset_access(self, tenant: str, dataset: str | None) -> bool: return tenant == "internal" and (dataset or "default") in {"default", "validation"} diff --git a/src/benchmark_service/__init__.py b/src/benchmark_service/__init__.py index 1a4af7a..bbc1693 100644 --- a/src/benchmark_service/__init__.py +++ b/src/benchmark_service/__init__.py @@ -2,6 +2,7 @@ from benchmark_service._version import __version__ from benchmark_service.app import BenchmarkServiceApp +from benchmark_service.auth import AuthFailure, AuthResult from benchmark_service.base import BenchmarkService from benchmark_service.client import BenchmarkServiceClient, BenchmarkServiceError, BenchmarkServiceUnauthenticatedError from benchmark_service.inflight import InflightMiddleware @@ -26,6 +27,8 @@ ) __all__ = [ + "AuthFailure", + "AuthResult", "BenchmarkServiceApp", "BenchmarkService", "BenchmarkServiceUnauthenticatedError", diff --git a/src/benchmark_service/app.py b/src/benchmark_service/app.py index 84ea413..785b63a 100644 --- a/src/benchmark_service/app.py +++ b/src/benchmark_service/app.py @@ -17,7 +17,14 @@ from websockets.exceptions import ConnectionClosed from benchmark_service._version import __version__ as _framework_version -from benchmark_service.auth import LEGACY_TENANT_SENTINEL, get_auth_settings, get_tenant_config, load_allowlist +from benchmark_service.auth import ( + AuthFailure, + AuthResult, + LEGACY_TENANT_SENTINEL, + get_auth_settings, + get_tenant_config, + load_allowlist, +) from benchmark_service.base import BenchmarkService from benchmark_service.inflight import InflightMiddleware from benchmark_service.schemas import ( @@ -53,6 +60,22 @@ logger = logging.getLogger(__name__) +_AUTH_FAILURE_RESPONSES: dict[AuthFailure, tuple[int, str]] = { + AuthFailure.NO_KEY: (401, "Missing x-descope-api-key header"), + AuthFailure.INVALID_KEY: (401, "Invalid or expired access key"), + AuthFailure.MULTI_TENANT: (401, "Access key must be scoped to exactly one tenant"), + AuthFailure.LEGACY_TENANT: (401, "Legacy bearer auth is not accepted on this endpoint"), + AuthFailure.NOT_ALLOWLISTED: ( + 403, + "Your tenant is authenticated but not allowlisted for this service; contact the service operator to request access", + ), + AuthFailure.REJECTED: (401, "Unauthorized"), +} + + +def _failure_response(result: AuthResult) -> tuple[int, str]: + return _AUTH_FAILURE_RESPONSES[result.failure] # type: ignore[index] + async def send_json_if_connected(websocket: WebSocket, payload: dict[str, Any]) -> bool: try: @@ -144,7 +167,13 @@ def __init__(self, service_cls: type[BenchmarkService]) -> None: @asynccontextmanager async def lifespan(_app: FastAPI) -> AsyncGenerator[None, None]: - if get_auth_settings().auth_required: + settings = get_auth_settings() + if settings.auth_required: + if not settings.descope_project_id: + raise RuntimeError( + "AUTH_REQUIRED is true but DESCOPE_PROJECT_ID is not set; " + "configure DESCOPE_PROJECT_ID before starting the server" + ) load_allowlist() self.service = await service_cls.create() yield @@ -159,11 +188,12 @@ def _register_routes(self) -> None: async def _check_auth(request: Request, call_next): # type: ignore[no-untyped-def] if request.url.path in _PUBLIC_PATHS: return await call_next(request) # type: ignore[reportUnknownVariableType] - tenant = await self.service.resolve_tenant(dict(request.headers)) - if tenant is None: - return JSONResponse(status_code=401, content={"detail": "Unauthorized"}) - request.state.tenant = tenant - if _is_trial_tenant(tenant) and not _trial_tenant_may_access_path(request.url.path): + result = await self.service.resolve_tenant(dict(request.headers)) + if not result.ok: + status, detail = _failure_response(result) + return JSONResponse(status_code=status, content={"detail": detail}) + request.state.tenant = result.tenant + if _is_trial_tenant(result.tenant) and not _trial_tenant_may_access_path(request.url.path): return JSONResponse( status_code=403, content={"detail": "Trial tenants may only access approved /v1 endpoints (/v1/*)"}, @@ -193,10 +223,14 @@ async def _check_auth(request: Request, call_next): # type: ignore[no-untyped-d async def _value_error_handler(self, _request: Request, exc: Exception) -> Response: raise HTTPException(status_code=400, detail=str(exc)) from exc - async def _exception_handler(self, _request: Request, exc: Exception) -> Response: + async def _exception_handler(self, request: Request, exc: Exception) -> Response: logger.error(f"Error: {exc}") logger.error(traceback.format_exc()) - return JSONResponse(status_code=500, content={"detail": "Internal server error"}) + content: dict[str, Any] = {"detail": "Evaluation failed"} + tenant = getattr(request.state, "tenant", None) + if tenant is not None and not _is_trial_tenant(tenant): + content["errors"] = [str(exc)] + return JSONResponse(status_code=500, content=content) async def _health_check(self) -> HealthCheckResponse: return HealthCheckResponse(status="ok") @@ -214,15 +248,16 @@ async def _version(self) -> VersionResponse: async def _authorize_websocket(self, websocket: WebSocket) -> str | None: """Authenticate a WebSocket caller. Returns tenant id, or None after closing 1008.""" - tenant = await self.service.resolve_tenant(dict(websocket.headers)) - if tenant is None: - await websocket.close(code=1008, reason="Unauthorized") + result = await self.service.resolve_tenant(dict(websocket.headers)) + if not result.ok: + _status, detail = _failure_response(result) + await websocket.close(code=1008, reason=detail) return None - if _is_trial_tenant(tenant): + if _is_trial_tenant(result.tenant): await websocket.close(code=1008, reason="Trial tenants may only access /v1/*") return None - websocket.state.tenant = tenant - return tenant + websocket.state.tenant = result.tenant + return result.tenant async def _verify_task_ids( self, diff --git a/src/benchmark_service/auth.py b/src/benchmark_service/auth.py index 14f1913..ce2469d 100644 --- a/src/benchmark_service/auth.py +++ b/src/benchmark_service/auth.py @@ -9,6 +9,7 @@ import os from collections.abc import Mapping from dataclasses import dataclass +from enum import Enum from functools import lru_cache from pathlib import Path from typing import Any @@ -26,6 +27,25 @@ LEGACY_TENANT_SENTINEL = "_legacy" +class AuthFailure(Enum): + NO_KEY = "no_key" + INVALID_KEY = "invalid_key" + MULTI_TENANT = "multi_tenant" + LEGACY_TENANT = "legacy_tenant" + NOT_ALLOWLISTED = "not_allowlisted" + REJECTED = "rejected" + + +@dataclass(frozen=True) +class AuthResult: + tenant: str | None = None + failure: AuthFailure | None = None + + @property + def ok(self) -> bool: + return self.failure is None + + class TenantConfig(BaseModel): """Per-tenant access rules within a benchmark service.""" @@ -189,55 +209,56 @@ def _check_legacy_benchmark_api_key(headers: Mapping[str, str], settings: AuthSe return hmac.compare_digest(authorization, expected) -async def resolve_descope_tenant(headers: Mapping[str, str]) -> str | None: +async def resolve_descope_tenant(headers: Mapping[str, str]) -> AuthResult: """Validate a Descope access key and resolve a single allowlisted tenant.""" settings = get_auth_settings() - if not settings.descope_project_id: - logger.warning("AUTH_REQUIRED is true but DESCOPE_PROJECT_ID is not configured") - return None access_key = headers.get(DESCOPE_API_KEY_HEADER) if not access_key: - return None + return AuthResult(failure=AuthFailure.NO_KEY) cache_key = (settings.descope_project_id, access_key) cached = _auth_cache.get(cache_key) if cached is not None: - return cached + return AuthResult(tenant=cached) try: jwt_response = await _exchange_descope_access_key(settings.descope_project_id, access_key) except Exception: logger.warning("Failed to exchange Descope access key", exc_info=True) - return None + return AuthResult(failure=AuthFailure.INVALID_KEY) tenants = list(jwt_response.get("tenants", {}).keys()) if len(tenants) != 1: logger.warning("Descope access key must be scoped to exactly one tenant, got %s", len(tenants)) - return None + return AuthResult(failure=AuthFailure.MULTI_TENANT) tenant = tenants[0] if tenant == LEGACY_TENANT_SENTINEL: logger.info("Descope tenant %s is reserved for legacy auth compatibility", tenant) - return None + return AuthResult(failure=AuthFailure.LEGACY_TENANT) allowlist = load_allowlist() if tenant not in allowlist.tenants: logger.info("Descope tenant %s is not in the service allowlist", tenant) - return None + return AuthResult(failure=AuthFailure.NOT_ALLOWLISTED) _auth_cache[cache_key] = tenant - return tenant + return AuthResult(tenant=tenant) -async def resolve_caller_tenant(headers: Mapping[str, str]) -> str | None: - """Return the caller tenant id, "_legacy" sentinel, or None to reject.""" +async def resolve_caller_tenant(headers: Mapping[str, str]) -> AuthResult: + """Return the caller AuthResult: tenant id, "_legacy" sentinel, or a failure reason.""" settings = get_auth_settings() if settings.auth_required: return await resolve_descope_tenant(headers) - return LEGACY_TENANT_SENTINEL if _check_legacy_benchmark_api_key(headers, settings) else None + return ( + AuthResult(tenant=LEGACY_TENANT_SENTINEL) + if _check_legacy_benchmark_api_key(headers, settings) + else AuthResult(failure=AuthFailure.REJECTED) + ) async def check_benchmark_service_auth(headers: Mapping[str, str]) -> bool: """Validate benchmark-service auth headers using the configured auth mode.""" - return await resolve_caller_tenant(headers) is not None + return (await resolve_caller_tenant(headers)).ok diff --git a/src/benchmark_service/base.py b/src/benchmark_service/base.py index 296efea..3145933 100644 --- a/src/benchmark_service/base.py +++ b/src/benchmark_service/base.py @@ -10,6 +10,8 @@ from typing import Any, Self from benchmark_service.auth import ( + AuthFailure, + AuthResult, LEGACY_TENANT_SENTINEL, check_benchmark_service_auth, load_allowlist, @@ -72,16 +74,16 @@ async def check_auth(self, headers: dict[str, str]) -> bool: """ return await check_benchmark_service_auth(headers) - async def resolve_tenant(self, headers: dict[str, str]) -> str | None: - """Authenticate the caller and return their tenant id, or None to reject. + async def resolve_tenant(self, headers: dict[str, str]) -> AuthResult: + """Authenticate the caller and return an AuthResult. Subclasses with a legacy `check_auth` override keep their existing boolean - behavior. A successful legacy check returns the "_legacy" sentinel, which - skips dataset-level allowlist enforcement. + behavior: True maps to the "_legacy" sentinel (skips allowlist enforcement); + False maps to a generic failure with no specific AuthFailure reason. """ if type(self).check_auth is not BenchmarkService.check_auth: ok = await self.check_auth(headers) - return LEGACY_TENANT_SENTINEL if ok else None + return AuthResult(tenant=LEGACY_TENANT_SENTINEL) if ok else AuthResult(failure=AuthFailure.REJECTED) return await resolve_caller_tenant(headers) async def check_dataset_access(self, tenant: str, dataset: str | None) -> bool: diff --git a/tests/test_app.py b/tests/test_app.py index e4e2b04..2ca76ea 100644 --- a/tests/test_app.py +++ b/tests/test_app.py @@ -6,7 +6,7 @@ from unittest.mock import patch import pytest -from fastapi import WebSocket +from fastapi import Request, WebSocket from fastapi.testclient import TestClient from starlette.websockets import WebSocketDisconnect @@ -155,7 +155,34 @@ def test_evaluate_response_invalid_task() -> None: with TestClient(BenchmarkServiceApp(StubBenchmark), raise_server_exceptions=False) as c: response = c.post("/evaluate-response/", json={"task_id": "nonexistent", "response": "2"}) assert response.status_code == 500 - assert response.json() == {"detail": "Internal server error"} + body = response.json() + assert body["detail"] == "Evaluation failed" + assert body["errors"] and isinstance(body["errors"], list) + + +async def test_exception_handler_withholds_errors_when_tenant_unset() -> None: + """Fail-closed: request.state.tenant unset must not produce errors[]. + + When the exception handler fires and no tenant has been resolved (e.g. on a + public path or before auth runs), errors[] must be withheld — the caller is + unknown, so internal detail must not leak. + """ + from unittest.mock import MagicMock + + from starlette.datastructures import State + + app = BenchmarkServiceApp(StubBenchmark) + + request = MagicMock(spec=Request) + request.state = State() # no tenant attribute set + + exc = RuntimeError("internal detail that must not leak") + response = await app._exception_handler(request, exc) # type: ignore[attr-defined] + body = json.loads(response.body) # type: ignore[attr-defined] + + assert response.status_code == 500 + assert body["detail"] == "Evaluation failed" + assert "errors" not in body @pytest.mark.parametrize( @@ -518,3 +545,99 @@ def test_setup_task_ws_close_for_disallowed_dataset(auth_client: TestClient) -> ws.receive_json() assert exc_info.value.code == 1008 assert exc_info.value.reason == "Dataset not allowed" + + +class TestAuthFailureMessages: + """HTTP-level assertions for distinct auth-failure status codes and detail messages.""" + + PROJECT_ID = "descope-project" + + @pytest.fixture + def descope_client(self, monkeypatch: pytest.MonkeyPatch) -> Generator[TestClient, None, None]: + monkeypatch.setenv("AUTH_REQUIRED", "true") + monkeypatch.setenv("DESCOPE_PROJECT_ID", self.PROJECT_ID) + monkeypatch.setenv( + "DESCOPE_TENANT_ALLOWLIST_JSON", + json.dumps({"tenants": {"tenant-a": {"datasets": ["default"]}}}), + ) + monkeypatch.delenv("BENCHMARK_API_KEY", raising=False) + auth_module.clear_auth_cache() + auth_module.clear_allowlist_cache() + with TestClient(BenchmarkServiceApp(StubBenchmark)) as c: + yield c + auth_module.clear_auth_cache() + auth_module.clear_allowlist_cache() + + def _patch_exchange(self, tenants: list[str]) -> Any: + return patch.object( + auth_module, + "_exchange_descope_access_key", + return_value={"tenants": {t: {} for t in tenants}}, + ) + + def _patch_exchange_raises(self) -> Any: + return patch.object( + auth_module, + "_exchange_descope_access_key", + side_effect=RuntimeError("rejected"), + ) + + def test_missing_key_returns_401_with_message(self, descope_client: TestClient) -> None: + response = descope_client.get("/verify-task-ids") + assert response.status_code == 401 + assert response.json()["detail"] == "Missing x-descope-api-key header" + + def test_invalid_key_returns_401_with_message(self, descope_client: TestClient) -> None: + with self._patch_exchange_raises(): + response = descope_client.get("/verify-task-ids", headers={"x-descope-api-key": "bad"}) + assert response.status_code == 401 + assert response.json()["detail"] == "Invalid or expired access key" + + def test_multi_tenant_key_returns_401_with_message(self, descope_client: TestClient) -> None: + with self._patch_exchange(["tenant-a", "tenant-b"]): + response = descope_client.get("/verify-task-ids", headers={"x-descope-api-key": "multi"}) + assert response.status_code == 401 + assert response.json()["detail"] == "Access key must be scoped to exactly one tenant" + + def test_not_allowlisted_returns_403_with_message(self, descope_client: TestClient) -> None: + with self._patch_exchange(["unknown-org"]): + response = descope_client.get("/verify-task-ids", headers={"x-descope-api-key": "rogue"}) + assert response.status_code == 403 + detail = response.json()["detail"] + assert "allowlist" in detail.lower() + assert "service operator" in detail + + def test_valid_key_returns_200(self, descope_client: TestClient) -> None: + with self._patch_exchange(["tenant-a"]): + response = descope_client.get("/verify-task-ids", headers={"x-descope-api-key": "valid"}) + assert response.status_code == 200 + + +class TestDescopeStartupCheck: + """Server refuses to start when auth_required=true and DESCOPE_PROJECT_ID is missing.""" + + def test_missing_project_id_raises_at_startup(self, monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setenv("AUTH_REQUIRED", "true") + monkeypatch.delenv("DESCOPE_PROJECT_ID", raising=False) + auth_module.clear_allowlist_cache() + with pytest.raises(RuntimeError, match="DESCOPE_PROJECT_ID"): + with TestClient(BenchmarkServiceApp(StubBenchmark)): + pass + + def test_startup_succeeds_when_project_id_set(self, monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setenv("AUTH_REQUIRED", "true") + monkeypatch.setenv("DESCOPE_PROJECT_ID", "P_test") + monkeypatch.setenv( + "DESCOPE_TENANT_ALLOWLIST_JSON", + json.dumps({"tenants": {}}), + ) + auth_module.clear_allowlist_cache() + with TestClient(BenchmarkServiceApp(StubBenchmark)): + pass + + def test_startup_succeeds_when_auth_not_required(self, monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setenv("AUTH_REQUIRED", "false") + monkeypatch.delenv("DESCOPE_PROJECT_ID", raising=False) + auth_module.clear_allowlist_cache() + with TestClient(BenchmarkServiceApp(StubBenchmark)): + pass diff --git a/tests/test_auth.py b/tests/test_auth.py index 971ae70..f156c5b 100644 --- a/tests/test_auth.py +++ b/tests/test_auth.py @@ -12,6 +12,8 @@ from benchmark_service import Sandbox from benchmark_service import auth as auth_module from benchmark_service.auth import ( + AuthFailure, + AuthResult, clear_allowlist_cache, clear_auth_cache, resolve_caller_tenant, @@ -57,58 +59,51 @@ def _mock_jwt_response(tenants: list[str]) -> dict[str, Any]: return {"tenants": {t: {} for t in tenants}} +# Table-driven: every AuthFailure reason including REJECTED -> expected AuthResult +@pytest.mark.parametrize( + ("headers", "exchange_tenants", "exchange_raises", "expected_failure"), + [ + ({}, None, False, AuthFailure.NO_KEY), + ({"x-descope-api-key": "k"}, None, True, AuthFailure.INVALID_KEY), + ({"x-descope-api-key": "k"}, ["t1", "t2"], False, AuthFailure.MULTI_TENANT), + ({"x-descope-api-key": "k"}, ["_legacy"], False, AuthFailure.LEGACY_TENANT), + ({"x-descope-api-key": "k"}, ["unlisted-org"], False, AuthFailure.NOT_ALLOWLISTED), + ], +) @pytest.mark.usefixtures("descope_env") -async def test_resolve_descope_tenant_returns_tenant_when_in_allowlist() -> None: - headers = {"x-descope-api-key": "key-acme"} - with patch.object( - auth_module, - "_exchange_descope_access_key", - return_value=_mock_jwt_response(["acme-corp"]), - ): - tenant = await resolve_descope_tenant(headers) - assert tenant == "acme-corp" - - -@pytest.mark.usefixtures("descope_env") -async def test_resolve_descope_tenant_returns_none_when_tenant_not_in_allowlist() -> None: - headers = {"x-descope-api-key": "key-rogue"} - with patch.object( - auth_module, - "_exchange_descope_access_key", - return_value=_mock_jwt_response(["unknown-org"]), - ): - tenant = await resolve_descope_tenant(headers) - assert tenant is None - - -@pytest.mark.usefixtures("descope_env") -async def test_resolve_descope_tenant_rejects_multi_tenant_jwt() -> None: - headers = {"x-descope-api-key": "key-multi"} - with patch.object( - auth_module, - "_exchange_descope_access_key", - return_value=_mock_jwt_response(["acme-corp", "vals-internal"]), - ): - tenant = await resolve_descope_tenant(headers) - assert tenant is None +async def test_resolve_descope_tenant_failure_table( + headers: dict[str, str], + exchange_tenants: list[str] | None, + exchange_raises: bool, + expected_failure: AuthFailure, +) -> None: + if exchange_tenants is not None or exchange_raises: + side_effect = RuntimeError("bad") if exchange_raises else None + return_value = _mock_jwt_response(exchange_tenants or []) if not exchange_raises else None + with patch.object( + auth_module, + "_exchange_descope_access_key", + side_effect=side_effect, + return_value=return_value, + ): + result = await resolve_descope_tenant(headers) + else: + result = await resolve_descope_tenant(headers) + assert result == AuthResult(failure=expected_failure) + assert not result.ok @pytest.mark.usefixtures("descope_env") -async def test_resolve_descope_tenant_rejects_reserved_legacy_tenant( - monkeypatch: pytest.MonkeyPatch, -) -> None: - monkeypatch.setenv( - "DESCOPE_TENANT_ALLOWLIST_JSON", - _allowlist_env({"tenants": {"_legacy": {"datasets": ["secret"]}}}), - ) - headers = {"x-descope-api-key": "key-reserved"} +async def test_resolve_descope_tenant_success() -> None: + headers = {"x-descope-api-key": "key-acme"} with patch.object( auth_module, "_exchange_descope_access_key", - return_value=_mock_jwt_response(["_legacy"]), + return_value=_mock_jwt_response(["acme-corp"]), ): - tenant = await resolve_descope_tenant(headers) - assert tenant is None + result = await resolve_descope_tenant(headers) + assert result == AuthResult(tenant="acme-corp") + assert result.ok async def test_resolve_caller_tenant_legacy_no_api_key_required( @@ -116,8 +111,9 @@ async def test_resolve_caller_tenant_legacy_no_api_key_required( ) -> None: monkeypatch.setenv("AUTH_REQUIRED", "false") monkeypatch.delenv("BENCHMARK_API_KEY", raising=False) - tenant = await resolve_caller_tenant({}) - assert tenant == "_legacy" + result = await resolve_caller_tenant({}) + assert result == AuthResult(tenant="_legacy") + assert result.ok async def test_resolve_caller_tenant_legacy_correct_api_key( @@ -125,8 +121,9 @@ async def test_resolve_caller_tenant_legacy_correct_api_key( ) -> None: monkeypatch.setenv("AUTH_REQUIRED", "false") monkeypatch.setenv("BENCHMARK_API_KEY", "secret123") - tenant = await resolve_caller_tenant({"authorization": "Bearer secret123"}) - assert tenant == "_legacy" + result = await resolve_caller_tenant({"authorization": "Bearer secret123"}) + assert result == AuthResult(tenant="_legacy") + assert result.ok async def test_resolve_caller_tenant_legacy_wrong_api_key( @@ -134,8 +131,9 @@ async def test_resolve_caller_tenant_legacy_wrong_api_key( ) -> None: monkeypatch.setenv("AUTH_REQUIRED", "false") monkeypatch.setenv("BENCHMARK_API_KEY", "secret123") - tenant = await resolve_caller_tenant({"authorization": "Bearer wrong"}) - assert tenant is None + result = await resolve_caller_tenant({"authorization": "Bearer wrong"}) + assert result == AuthResult(failure=AuthFailure.REJECTED) + assert not result.ok class _BareBenchmark(BenchmarkService): @@ -175,17 +173,28 @@ async def check_auth(self, headers: dict[str, str]) -> bool: async def test_resolve_tenant_legacy_override_returns_sentinel_on_true() -> None: service = _LegacyOverrideBenchmark(allow=True) - tenant = await service.resolve_tenant({}) - assert tenant == "_legacy" + result = await service.resolve_tenant({}) + assert result == AuthResult(tenant="_legacy") + assert result.ok async def test_resolve_tenant_legacy_override_returns_none_on_false() -> None: service = _LegacyOverrideBenchmark(allow=False) - tenant = await service.resolve_tenant({}) - assert tenant is None + result = await service.resolve_tenant({}) + assert result == AuthResult(failure=AuthFailure.REJECTED) + assert not result.ok async def test_check_dataset_access_legacy_sentinel_always_allowed() -> None: service = _BareBenchmark() assert await service.check_dataset_access("_legacy", "anything") is True assert await service.check_dataset_access("_legacy", None) is True + + +def test_auth_types_exported_from_package() -> None: + import benchmark_service + + assert hasattr(benchmark_service, "AuthResult") + assert hasattr(benchmark_service, "AuthFailure") + assert benchmark_service.AuthResult is AuthResult + assert benchmark_service.AuthFailure is AuthFailure diff --git a/tests/test_trial.py b/tests/test_trial.py index e2b0842..1c8790e 100644 --- a/tests/test_trial.py +++ b/tests/test_trial.py @@ -278,8 +278,11 @@ def test_internal_error_does_not_leak_traceback(trial_client: TestClient) -> Non headers={"x-descope-api-key": "trial-key"}, ) assert resp.status_code == 500 - assert resp.json()["detail"] == "Internal server error" - assert "Traceback" not in resp.text and "rubric.py" not in resp.text + body = resp.json() + assert body == {"detail": "Evaluation failed"} + assert "errors" not in body + assert "rubric.py" not in resp.text + assert "Traceback" not in resp.text def test_trial_tenant_blocked_on_internal_endpoint(trial_client: TestClient) -> None: