diff --git a/api/ee/databases/postgres/migrations/core/data_migrations/applications_workflow.py b/api/ee/databases/postgres/migrations/core/data_migrations/applications_workflow.py index 0c078ec15c..2f8216603f 100644 --- a/api/ee/databases/postgres/migrations/core/data_migrations/applications_workflow.py +++ b/api/ee/databases/postgres/migrations/core/data_migrations/applications_workflow.py @@ -254,9 +254,11 @@ def check_url_safety(cls, v: Any) -> Any: # noqa: N805 return v from oss.src.dbs.postgres.git.mappings import map_dto_to_dbe - from oss.src.dbs.postgres.shared.engine import engine as db_engine + from oss.src.dbs.postgres.shared.engine import get_transactions_engine from datetime import datetime, timezone + db_engine = get_transactions_engine() + workflow_create = WorkflowCreate( **application_create.model_dump(mode="json"), ) @@ -267,7 +269,7 @@ def check_url_safety(cls, v: Any) -> Any: # noqa: N805 # Avoid slug collision with existing workflow artifacts (e.g. evaluators) artifact_slug = git_artifact_create.slug - async with db_engine.core_session() as session: + async with db_engine.session() as session: existing = ( await session.execute( select(WorkflowArtifactDBE).filter( @@ -298,7 +300,7 @@ def check_url_safety(cls, v: Any) -> Any: # noqa: N805 dto=artifact_dto, ) - async with db_engine.core_session() as session: + async with db_engine.session() as session: session.add(artifact_dbe) await session.commit() @@ -364,7 +366,7 @@ def check_url_safety(cls, v: Any) -> Any: # noqa: N805 dto=variant_dto, ) - async with db_engine.core_session() as session: + async with db_engine.session() as session: session.add(variant_dbe) await session.commit() @@ -415,7 +417,7 @@ def check_url_safety(cls, v: Any) -> Any: # noqa: N805 dto=revision_dto, ) - async with db_engine.core_session() as session: + async with db_engine.session() as session: session.add(revision_dbe) await session.commit() diff --git a/api/ee/databases/postgres/migrations/core/env.py b/api/ee/databases/postgres/migrations/core/env.py index 492b4f9a69..1faed28fe5 100644 --- a/api/ee/databases/postgres/migrations/core/env.py +++ b/api/ee/databases/postgres/migrations/core/env.py @@ -6,7 +6,7 @@ from alembic import context -from oss.src.dbs.postgres.shared.engine import engine +from oss.src.utils.env import env from oss.src.dbs.postgres.shared.base import Base # Side-effect imports: register SQLAlchemy models with Base.metadata @@ -29,7 +29,7 @@ # this is the Alembic Config object, which provides # access to the values within the .ini file in use. config = context.config -config.set_main_option("sqlalchemy.url", engine.postgres_uri_core) # type: ignore +config.set_main_option("sqlalchemy.url", env.postgres.uri_core) # Interpret the config file for Python logging. diff --git a/api/ee/databases/postgres/migrations/tracing/env.py b/api/ee/databases/postgres/migrations/tracing/env.py index 2dadd6892e..83ae21ab6b 100644 --- a/api/ee/databases/postgres/migrations/tracing/env.py +++ b/api/ee/databases/postgres/migrations/tracing/env.py @@ -7,7 +7,7 @@ from alembic import context -from oss.src.dbs.postgres.shared.engine import engine +from oss.src.utils.env import env from oss.src.dbs.postgres.shared.base import Base # Side-effect import: register SQLAlchemy model with Base.metadata @@ -19,7 +19,7 @@ # this is the Alembic Config object, which provides # access to the values within the .ini file in use. config = context.config -config.set_main_option("sqlalchemy.url", engine.postgres_uri_tracing) # type: ignore +config.set_main_option("sqlalchemy.url", env.postgres.uri_tracing) # Interpret the config file for Python logging. diff --git a/api/ee/src/core/meters/service.py b/api/ee/src/core/meters/service.py index 3a92107e4c..c3a96cd811 100644 --- a/api/ee/src/core/meters/service.py +++ b/api/ee/src/core/meters/service.py @@ -1,10 +1,9 @@ from typing import Awaitable, Tuple, Callable, List, Optional from uuid import uuid4 -import stripe - from oss.src.utils.logging import get_module_logger from oss.src.utils.env import env +from oss.src.utils.lazy import _load_stripe from ee.src.core.entitlements.types import Quota from ee.src.core.entitlements.types import Counter, Gauge, REPORTS @@ -13,13 +12,6 @@ log = get_module_logger(__name__) -# Initialize Stripe only if enabled -if env.stripe.enabled: - stripe.api_key = env.stripe.api_key - log.info("✓ Stripe enabled:", target=env.stripe.webhook_target) -else: - log.info("✗ Stripe disabled") - class MetersService: def __init__( @@ -82,6 +74,11 @@ async def report( log.warn("✗ Stripe disabled") return + stripe = _load_stripe() + if stripe is None: + log.error("[report] Failed to load Stripe module") + return + log.info("[report] ============================================") log.info("[report] Starting meter report job") log.info("[report] ============================================") diff --git a/api/ee/src/core/subscriptions/service.py b/api/ee/src/core/subscriptions/service.py index 3badcc4c58..c26ea1c88a 100644 --- a/api/ee/src/core/subscriptions/service.py +++ b/api/ee/src/core/subscriptions/service.py @@ -2,12 +2,10 @@ from uuid import getnode from datetime import datetime, timezone, timedelta - -import stripe - from oss.src.utils.logging import get_module_logger from oss.src.utils.env import env from oss.src.utils.caching import invalidate_cache +from oss.src.utils.lazy import _load_stripe from ee.src.core.subscriptions.types import ( SubscriptionDTO, @@ -24,13 +22,6 @@ log = get_module_logger(__name__) -# Initialize Stripe only if enabled -if env.stripe.enabled: - stripe.api_key = env.stripe.api_key - log.info("✓ Stripe enabled:", target=env.stripe.webhook_target) -else: - log.info("✗ Stripe disabled") - MAC_ADDRESS = ":".join(f"{(getnode() >> ele) & 0xFF:02x}" for ele in range(40, -1, -8)) @@ -83,6 +74,10 @@ async def start_reverse_trial( if not env.stripe.enabled: raise EventException("Reverse trial requires Stripe to be enabled") + stripe = _load_stripe() + if stripe is None: + raise EventException("Failed to load Stripe module") + now = datetime.now(tz=timezone.utc) anchor = now + timedelta(days=REVERSE_TRIAL_DAYS) @@ -262,6 +257,11 @@ async def process_event( log.warn("✗ Stripe disabled") return None + stripe = _load_stripe() + if stripe is None: + log.error("Failed to load Stripe module") + raise EventException("Stripe is not available for plan switching") + if subscription.plan == plan: log.warn("Subscription already on the plan: %s", plan) diff --git a/api/ee/src/dbs/postgres/meters/dao.py b/api/ee/src/dbs/postgres/meters/dao.py index 670cabaa59..0d8e490cbe 100644 --- a/api/ee/src/dbs/postgres/meters/dao.py +++ b/api/ee/src/dbs/postgres/meters/dao.py @@ -8,7 +8,10 @@ from oss.src.utils.logging import get_module_logger -from oss.src.dbs.postgres.shared.engine import engine +from oss.src.dbs.postgres.shared.engine import ( + TransactionsEngine, + get_transactions_engine, +) from ee.src.core.entitlements.types import Quota from ee.src.core.meters.types import MeterDTO @@ -22,8 +25,10 @@ class MetersDAO(MetersDAOInterface): - def __init__(self): - pass + def __init__(self, engine: TransactionsEngine = None): + if engine is None: + engine = get_transactions_engine() + self.engine = engine async def dump( self, @@ -31,7 +36,7 @@ async def dump( ) -> list[MeterDTO]: log.info(f"[report] [dump] Starting (limit={limit or 'none'})") - async with engine.core_session() as session: + async with self.engine.session() as session: try: stmt = ( select(MeterDBE) @@ -203,7 +208,7 @@ async def _bump_commit_chunk( missing_count = 0 missing_samples: list[str] = [] - async with engine.core_session() as session: + async with self.engine.session() as session: for meter in meters: stmt = ( update(MeterDBE) @@ -249,7 +254,7 @@ async def fetch( year: Optional[int] = None, month: Optional[int] = None, ) -> list[MeterDTO]: - async with engine.core_session() as session: + async with self.engine.session() as session: stmt = select(MeterDBE).filter_by( organization_id=organization_id, ) # NO RISK OF DEADLOCK @@ -288,7 +293,7 @@ async def check( year, month = compute_billing_period(anchor=anchor) meter.year, meter.month = year, month - async with engine.core_session() as session: + async with self.engine.session() as session: stmt = select(MeterDBE).filter_by( organization_id=meter.organization_id, key=meter.key, @@ -376,7 +381,7 @@ async def adjust( where = where | where_clause # 4. Build SQL statement (atomic upsert with RETURNING) - async with engine.core_session() as session: + async with self.engine.session() as session: stmt = ( insert(MeterDBE) .values( diff --git a/api/ee/src/dbs/postgres/organizations/dao.py b/api/ee/src/dbs/postgres/organizations/dao.py index 00658837db..fd2e02a44d 100644 --- a/api/ee/src/dbs/postgres/organizations/dao.py +++ b/api/ee/src/dbs/postgres/organizations/dao.py @@ -3,7 +3,10 @@ from sqlalchemy import select, and_ from sqlalchemy.ext.asyncio import AsyncSession -from oss.src.dbs.postgres.shared.engine import engine +from oss.src.dbs.postgres.shared.engine import ( + TransactionsEngine, + get_transactions_engine, +) from ee.src.dbs.postgres.organizations.dbes import ( OrganizationDomainDBE, OrganizationProviderDBE, @@ -18,8 +21,15 @@ class OrganizationDomainsDAO: 2. Without a session (creates own sessions): OrganizationDomainsDAO() """ - def __init__(self, session: Optional[AsyncSession] = None): + def __init__( + self, + session: Optional[AsyncSession] = None, + engine: Optional[TransactionsEngine] = None, + ): self.session = session + if engine is None: + engine = get_transactions_engine() + self.engine = engine async def create( self, @@ -54,7 +64,7 @@ async def create( return domain else: - async with engine.core_session() as session: + async with self.engine.session() as session: domain = OrganizationDomainDBE( organization_id=organization_id, slug=slug, @@ -92,7 +102,7 @@ async def get_by_id( return result.scalars().first() else: - async with engine.core_session() as session: + async with self.engine.session() as session: result = await session.execute( select(OrganizationDomainDBE).where( and_( @@ -125,7 +135,7 @@ async def get_by_slug( return result.scalars().first() else: - async with engine.core_session() as session: + async with self.engine.session() as session: result = await session.execute( select(OrganizationDomainDBE).where( and_( @@ -158,7 +168,7 @@ async def get_verified_by_slug( return result.scalars().first() else: - async with engine.core_session() as session: + async with self.engine.session() as session: result = await session.execute( select(OrganizationDomainDBE).where( and_( @@ -186,7 +196,7 @@ async def list_by_organization( return list(result.scalars().all()) else: - async with engine.core_session() as session: + async with self.engine.session() as session: result = await session.execute( select(OrganizationDomainDBE).where( OrganizationDomainDBE.organization_id == organization_id @@ -218,7 +228,7 @@ async def update_flags( return domain else: - async with engine.core_session() as session: + async with self.engine.session() as session: if domain: # Re-attach to new session domain = await session.get(OrganizationDomainDBE, domain_id) @@ -252,7 +262,7 @@ async def delete( return False else: - async with engine.core_session() as session: + async with self.engine.session() as session: domain = await session.get(OrganizationDomainDBE, domain_id) if domain: @@ -272,8 +282,15 @@ class OrganizationProvidersDAO: 2. Without a session (creates own sessions): OrganizationProvidersDAO() """ - def __init__(self, session: Optional[AsyncSession] = None): + def __init__( + self, + session: Optional[AsyncSession] = None, + engine: Optional[TransactionsEngine] = None, + ): self.session = session + if engine is None: + engine = get_transactions_engine() + self.engine = engine async def create( self, @@ -311,7 +328,7 @@ async def create( return provider else: - async with engine.core_session() as session: + async with self.engine.session() as session: provider = OrganizationProviderDBE( organization_id=organization_id, slug=slug, @@ -350,7 +367,7 @@ async def get_by_id( return result.scalars().first() else: - async with engine.core_session() as session: + async with self.engine.session() as session: result = await session.execute( select(OrganizationProviderDBE).where( and_( @@ -378,7 +395,7 @@ async def get_by_id_any( return result.scalars().first() else: - async with engine.core_session() as session: + async with self.engine.session() as session: result = await session.execute( select(OrganizationProviderDBE).where( OrganizationProviderDBE.id == provider_id @@ -408,7 +425,7 @@ async def get_by_slug( return result.scalars().first() else: - async with engine.core_session() as session: + async with self.engine.session() as session: result = await session.execute( select(OrganizationProviderDBE).where( and_( @@ -436,7 +453,7 @@ async def list_by_organization( return list(result.scalars().all()) else: - async with engine.core_session() as session: + async with self.engine.session() as session: result = await session.execute( select(OrganizationProviderDBE).where( OrganizationProviderDBE.organization_id == organization_id @@ -475,7 +492,7 @@ async def update( return provider else: - async with engine.core_session() as session: + async with self.engine.session() as session: provider = await session.get(OrganizationProviderDBE, provider_id) if provider: @@ -511,7 +528,7 @@ async def delete( return False else: - async with engine.core_session() as session: + async with self.engine.session() as session: provider = await session.get(OrganizationProviderDBE, provider_id) if provider: diff --git a/api/ee/src/dbs/postgres/subscriptions/dao.py b/api/ee/src/dbs/postgres/subscriptions/dao.py index 93d67dcac4..58f7f26f03 100644 --- a/api/ee/src/dbs/postgres/subscriptions/dao.py +++ b/api/ee/src/dbs/postgres/subscriptions/dao.py @@ -5,7 +5,10 @@ from ee.src.core.subscriptions.types import SubscriptionDTO from ee.src.core.subscriptions.interfaces import SubscriptionsDAOInterface -from oss.src.dbs.postgres.shared.engine import engine +from oss.src.dbs.postgres.shared.engine import ( + TransactionsEngine, + get_transactions_engine, +) from ee.src.dbs.postgres.subscriptions.dbes import SubscriptionDBE from ee.src.dbs.postgres.subscriptions.mappings import ( map_dbe_to_dto, @@ -14,15 +17,17 @@ class SubscriptionsDAO(SubscriptionsDAOInterface): - def __init__(self): - pass + def __init__(self, engine: TransactionsEngine = None): + if engine is None: + engine = get_transactions_engine() + self.engine = engine async def create( self, *, subscription: SubscriptionDTO, ) -> SubscriptionDTO: - async with engine.core_session() as session: + async with self.engine.session() as session: subscription_dbe = map_dto_to_dbe(subscription) session.add(subscription_dbe) @@ -38,7 +43,7 @@ async def read( *, organization_id: str, ) -> Optional[SubscriptionDTO]: - async with engine.core_session() as session: + async with self.engine.session() as session: result = await session.execute( select(SubscriptionDBE).where( SubscriptionDBE.organization_id == organization_id, @@ -59,7 +64,7 @@ async def update( *, subscription: SubscriptionDTO, ) -> Optional[SubscriptionDTO]: - async with engine.core_session() as session: + async with self.engine.session() as session: result = await session.execute( select(SubscriptionDBE).where( SubscriptionDBE.organization_id == subscription.organization_id, diff --git a/api/ee/src/dbs/postgres/tracing/dao.py b/api/ee/src/dbs/postgres/tracing/dao.py index a3c3b0f91b..2d7578a5e7 100644 --- a/api/ee/src/dbs/postgres/tracing/dao.py +++ b/api/ee/src/dbs/postgres/tracing/dao.py @@ -10,7 +10,12 @@ from oss.src.models.db_models import ProjectDB -from oss.src.dbs.postgres.shared.engine import engine +from oss.src.dbs.postgres.shared.engine import ( + TransactionsEngine, + AnalyticsEngine, + get_transactions_engine, + get_analytics_engine, +) from oss.src.dbs.postgres.tracing.dbes import SpanDBE from ee.src.dbs.postgres.subscriptions.dbes import SubscriptionDBE @@ -70,6 +75,18 @@ class TracingDAO: + def __init__( + self, + transactions_engine: TransactionsEngine = None, + analytics_engine: AnalyticsEngine = None, + ): + if transactions_engine is None: + transactions_engine = get_transactions_engine() + if analytics_engine is None: + analytics_engine = get_analytics_engine() + self.transactions_engine = transactions_engine + self.analytics_engine = analytics_engine + # ---------------- # # Raw-SQL versions # ---------------- # @@ -81,7 +98,7 @@ async def _fetch_projects_with_plan( project_id: Optional[UUID], max_projects: int, ) -> List[UUID]: - async with engine.core_session() as session: + async with self.transactions_engine.session() as session: result = await session.execute( CORE_PROJECTS_PAGE_SQL, { @@ -105,7 +122,7 @@ async def _delete_traces_before_cutoff( if not project_ids: return (0, 0) - async with engine.tracing_session() as session: + async with self.analytics_engine.session() as session: result = await session.execute( TRACING_DELETE_SQL, { @@ -135,7 +152,7 @@ async def fetch_projects_with_plan( project_id: Optional[UUID], max_projects: int, ) -> List[UUID]: - async with engine.core_session() as session: + async with self.transactions_engine.session() as session: stmt = ( select(ProjectDB.id) .select_from( @@ -167,7 +184,7 @@ async def delete_traces_before_cutoff( if not project_ids: return (0, 0) - async with engine.tracing_session() as session: + async with self.analytics_engine.session() as session: project_ids_param = bindparam( "project_ids", value=project_ids, diff --git a/api/ee/src/main.py b/api/ee/src/main.py index ebe123f276..14aace23c5 100644 --- a/api/ee/src/main.py +++ b/api/ee/src/main.py @@ -4,6 +4,11 @@ from oss.src.utils.env import env from oss.src.utils.logging import get_module_logger +from oss.src.dbs.postgres.shared.engine import ( + get_transactions_engine, + get_analytics_engine, +) + from ee.src.routers import ( workspace_router, organization_router as _organization_router, @@ -12,6 +17,7 @@ from ee.src.dbs.postgres.meters.dao import MetersDAO from ee.src.dbs.postgres.tracing.dao import TracingDAO from ee.src.dbs.postgres.subscriptions.dao import SubscriptionsDAO +from ee.src.dbs.postgres.organizations.dao import OrganizationDomainsDAO from ee.src.core.meters.service import MetersService from ee.src.core.tracing.service import TracingService @@ -24,11 +30,20 @@ # DBS -------------------------------------------------------------------------- -meters_dao = MetersDAO() +# Get engines from shared initialization (instantiated in routers.py) +_transactions_engine = get_transactions_engine() +_analytics_engine = get_analytics_engine() + +meters_dao = MetersDAO(engine=_transactions_engine) + +tracing_dao = TracingDAO( + transactions_engine=_transactions_engine, + analytics_engine=_analytics_engine, +) -tracing_dao = TracingDAO() +subscriptions_dao = SubscriptionsDAO(engine=_transactions_engine) -subscriptions_dao = SubscriptionsDAO() +organization_domains_dao = OrganizationDomainsDAO(engine=_transactions_engine) # CORE ------------------------------------------------------------------------- diff --git a/api/ee/src/services/admin_manager.py b/api/ee/src/services/admin_manager.py index b345f2471e..7e5ae87a2a 100644 --- a/api/ee/src/services/admin_manager.py +++ b/api/ee/src/services/admin_manager.py @@ -7,7 +7,7 @@ from oss.src.utils.logging import get_module_logger -from oss.src.dbs.postgres.shared.engine import engine +from oss.src.dbs.postgres.shared.engine import get_transactions_engine from oss.src.models.db_models import UserDB from oss.src.services.api_key_service import create_api_key @@ -142,7 +142,9 @@ class ProjectMembershipRequest(BaseModel): async def check_user( request: UserRequest, ) -> Optional[UserRequest]: - async with engine.core_session() as session: + engine = get_transactions_engine() + + async with engine.session() as session: result = await session.execute( select(UserDB).filter_by( email=request.email, @@ -159,7 +161,9 @@ async def check_user( async def create_user( request: UserRequest, ) -> Reference: - async with engine.core_session() as session: + engine = get_transactions_engine() + + async with engine.session() as session: user_db = UserDB( # id=uuid7() # use default # @@ -185,7 +189,9 @@ async def create_user( async def create_organization( request: OrganizationRequest, ) -> Reference: - async with engine.core_session() as session: + engine = get_transactions_engine() + + async with engine.session() as session: organization_db = OrganizationDB( name=request.name, description=request.description, @@ -213,7 +219,9 @@ async def create_organization( async def create_workspace( request: WorkspaceRequest, ) -> Reference: - async with engine.core_session() as session: + engine = get_transactions_engine() + + async with engine.session() as session: workspace_db = WorkspaceDB( # id=uuid7() # use default # @@ -242,7 +250,9 @@ async def create_workspace( async def create_project( request: ProjectRequest, ) -> Reference: - async with engine.core_session() as session: + engine = get_transactions_engine() + + async with engine.session() as session: project_db = ProjectDB( # id=uuid7() # use default # @@ -273,7 +283,9 @@ async def create_project( async def create_organization_membership( request: OrganizationMembershipRequest, ) -> Reference: - async with engine.core_session() as session: + engine = get_transactions_engine() + + async with engine.session() as session: membership_db = OrganizationMembershipDB( # id=uuid7() # use default # @@ -316,7 +328,9 @@ async def create_organization_membership( async def create_workspace_membership( request: WorkspaceMembershipRequest, ) -> Reference: - async with engine.core_session() as session: + engine = get_transactions_engine() + + async with engine.session() as session: workspace = await session.execute( select(WorkspaceDB).filter_by( id=request.workspace_ref.id, @@ -357,7 +371,9 @@ async def create_workspace_membership( async def create_project_membership( request: ProjectMembershipRequest, ) -> Reference: - async with engine.core_session() as session: + engine = get_transactions_engine() + + async with engine.session() as session: project = await session.execute( select(ProjectDB).filter_by( id=request.project_ref.id, diff --git a/api/ee/src/services/db_manager.py b/api/ee/src/services/db_manager.py index a97b52af6d..5e58b3fb7c 100644 --- a/api/ee/src/services/db_manager.py +++ b/api/ee/src/services/db_manager.py @@ -1,6 +1,6 @@ import uuid -from oss.src.dbs.postgres.shared.engine import engine +from oss.src.dbs.postgres.shared.engine import get_transactions_engine from oss.src.models.db_models import DeploymentDB @@ -18,7 +18,9 @@ async def create_deployment( DeploymentDB: The created deployment. """ - async with engine.core_session() as session: + engine = get_transactions_engine() + + async with engine.session() as session: try: deployment = DeploymentDB( app_id=uuid.UUID(app_id), diff --git a/api/ee/src/services/db_manager_ee.py b/api/ee/src/services/db_manager_ee.py index 8e0854676c..d91ea0346e 100644 --- a/api/ee/src/services/db_manager_ee.py +++ b/api/ee/src/services/db_manager_ee.py @@ -14,7 +14,9 @@ from oss.src.utils.logging import get_module_logger -from oss.src.dbs.postgres.shared.engine import engine +from oss.src.dbs.postgres.shared.engine import ( + get_transactions_engine, +) from oss.src.services import db_manager from ee.src.core.workspaces.types import ( UserRole, @@ -76,7 +78,9 @@ async def get_organization(organization_id: str) -> OrganizationDB: OrganizationDB: The fetched organization. """ - async with engine.core_session() as session: + engine = get_transactions_engine() + + async with engine.session() as session: result = await session.execute( select(OrganizationDB).filter_by(id=uuid.UUID(organization_id)) ) @@ -95,7 +99,9 @@ async def get_organizations_by_list_ids(organization_ids: List) -> List[Organiza List: A list of dictionaries representing the retrieved organizations. """ - async with engine.core_session() as session: + engine = get_transactions_engine() + + async with engine.session() as session: organization_uuids = [ uuid.UUID(organization_id) for organization_id in organization_ids ] @@ -115,7 +121,9 @@ async def count_organizations_by_owner(owner_id: str) -> int: Returns: int: The count of organizations owned by the user. """ - async with engine.core_session() as session: + engine = get_transactions_engine() + + async with engine.session() as session: result = await session.execute( select(func.count(OrganizationDB.id)).where( OrganizationDB.owner_id == uuid.UUID(owner_id) @@ -135,7 +143,9 @@ async def get_default_workspace_id(user_id: str) -> str: str: The default workspace ID. """ - async with engine.core_session() as session: + engine = get_transactions_engine() + + async with engine.session() as session: result = await session.execute( select(WorkspaceMemberDB) .filter_by(user_id=uuid.UUID(user_id)) @@ -180,7 +190,9 @@ async def get_organization_workspaces(organization_id: str): organization_id (str): The ID of the organization """ - async with engine.core_session() as session: + engine = get_transactions_engine() + + async with engine.session() as session: result = await session.execute( select(WorkspaceDB) .filter_by(organization_id=uuid.UUID(organization_id)) @@ -198,7 +210,9 @@ async def get_workspace_members(workspace_id: str) -> List[WorkspaceMemberDB]: Used by RBAC / admin helpers to derive roles and permissions. """ - async with engine.core_session() as session: + engine = get_transactions_engine() + + async with engine.session() as session: result = await session.execute( select(WorkspaceMemberDB).where( WorkspaceMemberDB.workspace_id == workspace_id @@ -384,7 +398,8 @@ async def _sync(db_session: AsyncSession) -> None: await _sync(session) return - async with engine.core_session() as new_session: + engine = get_transactions_engine() + async with engine.session() as new_session: await _sync(new_session) @@ -401,7 +416,9 @@ async def get_default_workspace_id_from_organization( str: The default (first) workspace ID. """ - async with engine.core_session() as session: + engine = get_transactions_engine() + + async with engine.session() as session: workspace_query = await session.execute( select(WorkspaceDB) .where( @@ -432,7 +449,9 @@ async def get_project_by_workspace( """ assert workspace_id is not None, "Workspace ID is required to retrieve project" - async with engine.core_session() as session: + engine = get_transactions_engine() + + async with engine.session() as session: stmt = select(ProjectDB).where( ProjectDB.workspace_id == uuid.UUID(workspace_id), ) @@ -493,7 +512,9 @@ async def create_project_member( async def fetch_project_memberships_by_user_id( user_id: str, ) -> List[ProjectMemberDB]: - async with engine.core_session() as session: + engine = get_transactions_engine() + + async with engine.session() as session: result = await session.execute( select(ProjectMemberDB) .filter_by(user_id=uuid.UUID(user_id)) @@ -621,7 +642,9 @@ async def create_workspace( user = await db_manager.get_user(user_uid) organization = await get_organization(organization_id) - async with engine.core_session() as session: + engine = get_transactions_engine() + + async with engine.session() as session: user_result = await session.execute(select(UserDB).filter_by(uid=user_uid)) user = user_result.scalars().first() @@ -651,7 +674,9 @@ async def update_workspace( payload (UpdateWorkspace): The data to update the workspace with. """ - async with engine.core_session() as session: + engine = get_transactions_engine() + + async with engine.session() as session: result = await session.execute(select(WorkspaceDB).filter_by(id=workspace.id)) workspace = result.scalars().first() @@ -680,7 +705,9 @@ async def check_user_in_workspace_with_email(email: str, workspace_id: str) -> b Exception: If there is an error checking if the user belongs to the workspace. """ - async with engine.core_session() as session: + engine = get_transactions_engine() + + async with engine.session() as session: result = await session.execute( select(WorkspaceMemberDB) .join(UserDB, UserDB.id == WorkspaceMemberDB.user_id) @@ -720,7 +747,9 @@ async def update_user_roles( f"No projects found for the provided workspace_id {workspace_id}" ) - async with engine.core_session() as session: + engine = get_transactions_engine() + + async with engine.session() as session: workspace_member_result = await session.execute( select(WorkspaceMemberDB).filter_by( workspace_id=uuid.UUID(workspace_id), user_id=user.id @@ -803,7 +832,9 @@ async def add_user_to_workspace_and_org( if project and str(project.workspace_id) != str(workspace.id): raise ValueError("Project does not belong to the provided workspace") - async with engine.core_session() as session: + engine = get_transactions_engine() + + async with engine.session() as session: # create joined organization for user user_organization = OrganizationMemberDB( user_id=user.id, organization_id=organization.id @@ -912,7 +943,9 @@ async def remove_user_from_workspace( ) project_ids = [project.id for project in projects] - async with engine.core_session() as session: + engine = get_transactions_engine() + + async with engine.session() as session: if not user: # User is an invited user who has not yet created an account and therefore does not have a user object pass else: @@ -1041,7 +1074,9 @@ async def create_organization( Exception: If there is an error creating the organization. """ - async with engine.core_session() as session: + engine = get_transactions_engine() + + async with engine.session() as session: create_org_data = payload.model_dump(exclude_unset=True) is_demo = create_org_data.pop("is_demo", False) @@ -1130,7 +1165,9 @@ async def update_organization( Exception: If there is an error updating the organization. """ - async with engine.core_session() as session: + engine = get_transactions_engine() + + async with engine.session() as session: result = await session.execute( select(OrganizationDB).filter_by(id=uuid.UUID(organization_id)) ) @@ -1292,7 +1329,9 @@ async def delete_organization(organization_id: str) -> bool: Raises: NoResultFound: If organization not found. """ - async with engine.core_session() as session: + engine = get_transactions_engine() + + async with engine.session() as session: result = await session.execute( select(OrganizationDB).filter_by(id=uuid.UUID(organization_id)) ) @@ -1317,7 +1356,9 @@ async def delete_invitation(invitation_id: str) -> bool: bool: True if the invitation was successfully deleted, False otherwise. """ - async with engine.core_session() as session: + engine = get_transactions_engine() + + async with engine.session() as session: result = await session.execute( select(InvitationDB).filter_by(id=uuid.UUID(invitation_id)) ) @@ -1379,7 +1420,9 @@ async def mark_invitation_as_used( HTTPException: If there is an error marking the invitation as used. """ - async with engine.core_session() as session: + engine = get_transactions_engine() + + async with engine.session() as session: result = await session.execute( select(InvitationDB).filter_by( project_id=uuid.UUID(project_id), token=invitation.token @@ -1470,7 +1513,9 @@ async def get_project_invitations(project_id: str, **kwargs): project_id (str): The ID of the project """ - async with engine.core_session() as session: + engine = get_transactions_engine() + + async with engine.session() as session: stmt = select(InvitationDB).filter( InvitationDB.project_id == uuid.UUID(project_id) ) @@ -1490,7 +1535,9 @@ async def get_all_pending_invitations(email: str): email (str): The email address of the user. """ - async with engine.core_session() as session: + engine = get_transactions_engine() + + async with engine.session() as session: result = await session.execute( select(InvitationDB).filter( InvitationDB.email == email, @@ -1515,7 +1562,9 @@ async def get_project_invitation( InvitationDB: invitation object """ - async with engine.core_session() as session: + engine = get_transactions_engine() + + async with engine.session() as session: result = await session.execute( select(InvitationDB).filter_by( project_id=uuid.UUID(project_id), token=token, email=email @@ -1532,7 +1581,9 @@ async def get_project_members(project_id: str): project_id (str): The ID of the project """ - async with engine.core_session() as session: + engine = get_transactions_engine() + + async with engine.session() as session: members_query = await session.execute( select(ProjectMemberDB) .filter(ProjectMemberDB.project_id == uuid.UUID(project_id)) @@ -1560,7 +1611,9 @@ async def project_member_exists( True if the user belongs to the project, False otherwise. """ - async with engine.core_session() as session: + engine = get_transactions_engine() + + async with engine.session() as session: stmt = select( select(ProjectMemberDB.id) .filter( @@ -1591,7 +1644,9 @@ async def workspace_member_exists( True if the user belongs to the workspace, False otherwise. """ - async with engine.core_session() as session: + engine = get_transactions_engine() + + async with engine.session() as session: stmt = select( select(WorkspaceMemberDB.id) .filter( @@ -1638,7 +1693,9 @@ async def create_org_workspace_invitation( if not project: raise Exception(f"No project found with ID {project_id}") - async with engine.core_session() as session: + engine = get_transactions_engine() + + async with engine.session() as session: invitation = InvitationDB( token=token, email=email, @@ -1681,7 +1738,9 @@ async def add_user_to_organization( role: str = "member", # is_demo: bool = False, ) -> None: - async with engine.core_session() as session: + engine = get_transactions_engine() + + async with engine.session() as session: organization_member = OrganizationMemberDB( user_id=user_id, organization_id=organization_id, @@ -1707,7 +1766,9 @@ async def add_user_to_workspace( role: str, # is_demo: bool = False, ) -> None: - async with engine.core_session() as session: + engine = get_transactions_engine() + + async with engine.session() as session: # fetch workspace by workspace_id (SQL) stmt = select(WorkspaceDB).filter_by(id=workspace_id) workspace = await session.execute(stmt) @@ -1748,7 +1809,9 @@ async def add_user_to_project( if not project: raise Exception(f"No project found with ID {project_id}") - async with engine.core_session() as session: + engine = get_transactions_engine() + + async with engine.session() as session: project_member = ProjectMemberDB( user_id=user_id, project_id=project_id, @@ -1788,7 +1851,9 @@ async def transfer_organization_ownership( Raises: ValueError: If new owner is not a member of the organization """ - async with engine.core_session() as session: + engine = get_transactions_engine() + + async with engine.session() as session: # Verify organization exists org_result = await session.execute( select(OrganizationDB).filter_by(id=uuid.UUID(organization_id)) @@ -1920,7 +1985,9 @@ async def transfer_organization_ownership( async def admin_delete_org_membership(membership_id: uuid.UUID) -> bool: """Delete an org membership by ID. Returns False if not found.""" - async with engine.core_session() as session: + engine = get_transactions_engine() + + async with engine.session() as session: result = await session.execute( select(OrganizationMemberDB).filter_by(id=membership_id) ) @@ -1934,7 +2001,9 @@ async def admin_delete_org_membership(membership_id: uuid.UUID) -> bool: async def admin_delete_workspace_membership(membership_id: uuid.UUID) -> bool: """Delete a workspace membership by ID. Returns False if not found.""" - async with engine.core_session() as session: + engine = get_transactions_engine() + + async with engine.session() as session: result = await session.execute( select(WorkspaceMemberDB).filter_by(id=membership_id) ) @@ -1948,7 +2017,9 @@ async def admin_delete_workspace_membership(membership_id: uuid.UUID) -> bool: async def admin_delete_project_membership(membership_id: uuid.UUID) -> bool: """Delete a project membership by ID. Returns False if not found.""" - async with engine.core_session() as session: + engine = get_transactions_engine() + + async with engine.session() as session: result = await session.execute( select(ProjectMemberDB).filter_by(id=membership_id) ) @@ -1965,7 +2036,9 @@ async def admin_get_member_org_ids( org_ids: List[uuid.UUID], ) -> Set[uuid.UUID]: """Return the subset of org_ids where the user has a membership row.""" - async with engine.core_session() as session: + engine = get_transactions_engine() + + async with engine.session() as session: rows = ( ( await session.execute( @@ -1992,7 +2065,9 @@ async def admin_swap_org_memberships( a membership row. For each qualifying org, target gets source's role and source gets target's prior role. """ - async with engine.core_session() as session: + engine = get_transactions_engine() + + async with engine.session() as session: source_rows = ( ( await session.execute( @@ -2054,7 +2129,9 @@ async def admin_swap_workspace_memberships( have a membership row. For each qualifying workspace, target gets source's role and source gets target's prior role. """ - async with engine.core_session() as session: + engine = get_transactions_engine() + + async with engine.session() as session: source_rows = ( ( await session.execute( @@ -2116,7 +2193,9 @@ async def admin_swap_project_memberships( have a membership row. For each qualifying project, target gets source's role and source gets target's prior role. """ - async with engine.core_session() as session: + engine = get_transactions_engine() + + async with engine.session() as session: source_rows = ( ( await session.execute( @@ -2172,7 +2251,9 @@ async def admin_delete_user_memberships(user_id: uuid.UUID) -> None: Called before hard-deleting a user so FK constraints are not violated. """ - async with engine.core_session() as session: + engine = get_transactions_engine() + + async with engine.session() as session: await session.execute( delete(OrganizationMemberDB).where(OrganizationMemberDB.user_id == user_id) ) diff --git a/api/ee/src/services/organization_service.py b/api/ee/src/services/organization_service.py index 5378008ae3..d11a6c8871 100644 --- a/api/ee/src/services/organization_service.py +++ b/api/ee/src/services/organization_service.py @@ -12,7 +12,7 @@ from oss.src.utils.env import env from oss.src.utils.logging import get_module_logger -from oss.src.dbs.postgres.shared.engine import engine +from oss.src.dbs.postgres.shared.engine import get_transactions_engine from oss.src.core.secrets.dtos import ( CreateSecretDTO, UpdateSecretDTO, @@ -280,7 +280,9 @@ async def create_domain( Token expires after 48 hours and can be refreshed. """ - async with engine.core_session() as session: + engine = get_transactions_engine() + + async with engine.session() as session: dao = OrganizationDomainsDAO(session) # Block if a verified domain already exists anywhere @@ -337,7 +339,9 @@ async def verify_domain( self, organization_id: str, domain_id: str, user_id: str ) -> OrganizationDomain: """Verify a domain via DNS check.""" - async with engine.core_session() as session: + engine = get_transactions_engine() + + async with engine.session() as session: dao = OrganizationDomainsDAO(session) domain = await dao.get_by_id( @@ -403,7 +407,9 @@ async def list_domains(self, organization_id: str) -> List[OrganizationDomain]: Tokens are returned for unverified domains (within expiry period). Verified domains have token=None (cleared after verification). """ - async with engine.core_session() as session: + engine = get_transactions_engine() + + async with engine.session() as session: dao = OrganizationDomainsDAO(session) domains = await dao.list_by_organization(organization_id=organization_id) @@ -430,7 +436,9 @@ async def refresh_token( Generates a new token and resets the 48-hour expiry window. For verified domains, this marks them as unverified for re-verification. """ - async with engine.core_session() as session: + engine = get_transactions_engine() + + async with engine.session() as session: dao = OrganizationDomainsDAO(session) domain = await dao.get_by_id( @@ -470,7 +478,9 @@ async def reset_domain( Generates a new token and marks the domain as unverified. """ - async with engine.core_session() as session: + engine = get_transactions_engine() + + async with engine.session() as session: dao = OrganizationDomainsDAO(session) domain = await dao.get_by_id( @@ -506,7 +516,9 @@ async def delete_domain( self, organization_id: str, domain_id: str, user_id: str ) -> bool: """Delete a domain.""" - async with engine.core_session() as session: + engine = get_transactions_engine() + + async with engine.session() as session: dao = OrganizationDomainsDAO(session) domain = await dao.get_by_id( @@ -573,7 +585,9 @@ async def create_provider( user_id: str, ) -> OrganizationProvider: """Create a new SSO provider.""" - async with engine.core_session() as session: + engine = get_transactions_engine() + + async with engine.session() as session: dao = OrganizationProvidersDAO(session) # Use the slug from payload (already validated to be lowercase letters and hyphens) @@ -648,7 +662,9 @@ async def update_provider( user_id: str, ) -> OrganizationProvider: """Update an SSO provider.""" - async with engine.core_session() as session: + engine = get_transactions_engine() + + async with engine.session() as session: dao = OrganizationProvidersDAO(session) provider = await dao.get_by_id( @@ -735,7 +751,9 @@ async def update_provider( async def list_providers(self, organization_id: str) -> List[OrganizationProvider]: """List all SSO providers for an organization.""" - async with engine.core_session() as session: + engine = get_transactions_engine() + + async with engine.session() as session: dao = OrganizationProvidersDAO(session) providers = await dao.list_by_organization(organization_id=organization_id) @@ -748,7 +766,9 @@ async def get_provider( self, organization_id: str, provider_id: str ) -> OrganizationProvider: """Get a single SSO provider by ID.""" - async with engine.core_session() as session: + engine = get_transactions_engine() + + async with engine.session() as session: dao = OrganizationProvidersDAO(session) provider = await dao.get_by_id( provider_id=provider_id, organization_id=organization_id @@ -761,7 +781,9 @@ async def test_provider( self, organization_id: str, provider_id: str, user_id: str ) -> OrganizationProvider: """Test SSO provider connection and mark as valid if successful.""" - async with engine.core_session() as session: + engine = get_transactions_engine() + + async with engine.session() as session: dao = OrganizationProvidersDAO(session) provider = await dao.get_by_id( @@ -806,7 +828,9 @@ async def delete_provider( self, organization_id: str, provider_id: str, user_id: str ) -> bool: """Delete an SSO provider.""" - async with engine.core_session() as session: + engine = get_transactions_engine() + + async with engine.session() as session: dao = OrganizationProvidersDAO(session) provider = await dao.get_by_id( diff --git a/api/ee/src/services/selectors.py b/api/ee/src/services/selectors.py index afc27b7e6d..b71ea56d6b 100644 --- a/api/ee/src/services/selectors.py +++ b/api/ee/src/services/selectors.py @@ -7,7 +7,7 @@ from oss.src.services import db_manager from oss.src.utils.logging import get_module_logger -from oss.src.dbs.postgres.shared.engine import engine +from oss.src.dbs.postgres.shared.engine import get_transactions_engine from ee.src.core.organizations.types import Organization from oss.src.models.db_models import ( @@ -39,7 +39,9 @@ async def get_user_org_and_workspace_id(user_uid) -> Dict[str, Union[str, List[s { "id": "123", "uid": "user123", "organization_ids": [], "workspace_ids": []} """ - async with engine.core_session() as session: + engine = get_transactions_engine() + + async with engine.session() as session: user = await db_manager.get_user_with_id(user_id=user_uid) if not user: raise NoResultFound(f"User with uid {user_uid} not found") @@ -94,7 +96,9 @@ async def get_org_default_workspace(organization: Organization) -> WorkspaceDB: WorkspaceDB: Instance of WorkspaceDB """ - async with engine.core_session() as session: + engine = get_transactions_engine() + + async with engine.session() as session: result = await session.execute( select(WorkspaceDB).filter_by( organization_id=organization.id, diff --git a/api/ee/tests/pytest/unit/services/test_db_manager_ee.py b/api/ee/tests/pytest/unit/services/test_db_manager_ee.py index c98c0078c3..24343b22b1 100644 --- a/api/ee/tests/pytest/unit/services/test_db_manager_ee.py +++ b/api/ee/tests/pytest/unit/services/test_db_manager_ee.py @@ -45,10 +45,14 @@ async def __aexit__(self, exc_type, exc, tb): def _patch_core_session(monkeypatch, memberships): + # db_manager_ee calls get_transactions_engine() — patch where it's called + mock_engine = type( + "MockEngine", (), {"session": lambda self: _SessionContext(memberships)} + )() monkeypatch.setattr( - db_manager_ee.engine, - "core_session", - lambda: _SessionContext(memberships), + db_manager_ee, + "get_transactions_engine", + lambda: mock_engine, ) diff --git a/api/entrypoints/routers.py b/api/entrypoints/routers.py index c827af1402..b7cd929973 100644 --- a/api/entrypoints/routers.py +++ b/api/entrypoints/routers.py @@ -1,5 +1,7 @@ from contextlib import asynccontextmanager +import time +import agenta as ag from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.gzip import GZipMiddleware @@ -12,6 +14,16 @@ from oss.src.utils.logging import get_module_logger from oss.src.utils.helpers import warn_deprecated_env_vars, validate_required_env_vars +# Engines +from oss.src.dbs.postgres.shared.engine import ( + get_transactions_engine, + get_analytics_engine, +) +from oss.src.dbs.redis.shared.engine import ( + get_cache_engine, + get_streams_engine, +) + from oss.databases.postgres.migrations.core.utils import ( check_for_new_migrations as check_for_new_core_migrations, ) @@ -141,24 +153,40 @@ from oss.src.utils.env import env from entrypoints.worker_evaluations import evaluations_worker -import oss.src.core.evaluations.tasks.live # noqa: F401 -import oss.src.core.evaluations.tasks.legacy # noqa: F401 -import oss.src.core.evaluations.tasks.batch # noqa: F401 +import oss.src.core.evaluations.tasks.query # noqa: F401 +import oss.src.core.evaluations.tasks.run # noqa: F401 +import oss.src.core.evaluations.tasks.source_slice # noqa: F401 -import agenta as ag +print("[STARTUP] About to import agenta SDK") +_t_ag_import = time.perf_counter() +# ag already imported at top +print(f"[STARTUP] agenta SDK imported (+{time.perf_counter() - _t_ag_import:.3f}s)") + +_startup_t0 = time.perf_counter() +print("[STARTUP] imports completed, beginning initialization") +print("[STARTUP] ag.init() starting") +_t_ag_init = time.perf_counter() ag.init( api_url=env.agenta.api_url, ) +print(f"[STARTUP] ag.init() completed (+{time.perf_counter() - _t_ag_init:.3f}s)") ee = None +_t_before_ee = time.perf_counter() if is_ee(): + print("[STARTUP] EE module import starting (Stripe init happens here)") import ee.src.main as ee # type: ignore + _ee_elapsed = time.perf_counter() - _t_before_ee + print(f"[STARTUP] EE module import completed (+{_ee_elapsed:.3f}s)") + log = get_module_logger(__name__) init_supertokens() +_st_elapsed = time.perf_counter() - _startup_t0 +print(f"[STARTUP] init_supertokens completed (+{_st_elapsed:.3f}s)") @asynccontextmanager @@ -182,6 +210,10 @@ async def lifespan(*args, **kwargs): for adapter in _composio_adapters.values(): await adapter.close() + await _transactions_engine.close() + await _analytics_engine.close() + await _streams_engine.close() + _OPENAPI_TAGS = [ { @@ -351,47 +383,65 @@ async def lifespan(*args, **kwargs): # DAOS ------------------------------------------------------------------------- -secrets_dao = SecretsDAO() -webhooks_dao = WebhooksDAO() +_t_daos = time.perf_counter() +print("[STARTUP] DAO initialization starting") -tracing_dao = TracingDAO() -events_dao = EventsDAO() +# Instantiate engines at startup (lazy — they don't connect until first use) +_transactions_engine = get_transactions_engine() +_analytics_engine = get_analytics_engine() +_streams_engine = get_streams_engine() +_cache_engine = get_cache_engine() + +secrets_dao = SecretsDAO(engine=_transactions_engine) +webhooks_dao = WebhooksDAO(engine=_transactions_engine) + +tracing_dao = TracingDAO(engine=_analytics_engine) +events_dao = EventsDAO(engine=_analytics_engine) testcases_dao = BlobsDAO( + engine=_transactions_engine, BlobDBE=TestcaseBlobDBE, ) testsets_dao = GitDAO( + engine=_transactions_engine, ArtifactDBE=TestsetArtifactDBE, VariantDBE=TestsetVariantDBE, RevisionDBE=TestsetRevisionDBE, ) queries_dao = GitDAO( + engine=_transactions_engine, ArtifactDBE=QueryArtifactDBE, VariantDBE=QueryVariantDBE, RevisionDBE=QueryRevisionDBE, ) workflows_dao = GitDAO( + engine=_transactions_engine, ArtifactDBE=WorkflowArtifactDBE, VariantDBE=WorkflowVariantDBE, RevisionDBE=WorkflowRevisionDBE, ) environments_dao = GitDAO( + engine=_transactions_engine, ArtifactDBE=EnvironmentArtifactDBE, VariantDBE=EnvironmentVariantDBE, RevisionDBE=EnvironmentRevisionDBE, ) -evaluations_dao = EvaluationsDAO() -folders_dao = FoldersDAO() +evaluations_dao = EvaluationsDAO(engine=_transactions_engine) +folders_dao = FoldersDAO(engine=_transactions_engine) -tools_dao = ToolsDAO() +tools_dao = ToolsDAO(engine=_transactions_engine) # SERVICES --------------------------------------------------------------------- +_t_daos_done = time.perf_counter() - _t_daos +print(f"[STARTUP] DAO initialization completed (+{_t_daos_done:.3f}s)") +_t_services = time.perf_counter() + vault_service = VaultService( secrets_dao=secrets_dao, ) @@ -533,6 +583,10 @@ async def lifespan(*args, **kwargs): adapter_registry=tools_adapter_registry, ) +_t_services_done = time.perf_counter() - _t_services +print(f"[STARTUP] Service initialization completed (+{_t_services_done:.3f}s)") +_t_routers = time.perf_counter() + # ROUTERS ---------------------------------------------------------------------- secrets = VaultRouter( @@ -685,6 +739,8 @@ async def lifespan(*args, **kwargs): # MOUNTING ROUTERS TO APP ROUTES ----------------------------------------------- +_t_mount_routers = time.perf_counter() + app.include_router( router=secrets.router, tags=["Secrets"], @@ -1113,5 +1169,14 @@ async def lifespan(*args, **kwargs): ) # ------------------------------------------------------------------------------ +_t_routers_done = time.perf_counter() - _t_routers +print(f"[STARTUP] Router initialization completed (+{_t_routers_done:.3f}s)") + +_t_mount_routers_done = time.perf_counter() - _t_mount_routers +print(f"[STARTUP] Router mounting completed (+{_t_mount_routers_done:.3f}s)") + if ee and is_ee(): app = ee.extend_app_schema(app) + +_total_startup = time.perf_counter() - _startup_t0 +print(f"[STARTUP] module initialization completed in {_total_startup:.3f}s") diff --git a/api/oss/databases/postgres/migrations/core/versions/a1b2c3d4e5f6_add_default_evaluation_queues.py b/api/oss/databases/postgres/migrations/core/versions/a1b2c3d4e5f6_add_default_evaluation_queues.py new file mode 100644 index 0000000000..a479416e2c --- /dev/null +++ b/api/oss/databases/postgres/migrations/core/versions/a1b2c3d4e5f6_add_default_evaluation_queues.py @@ -0,0 +1,27 @@ +"""add default evaluation queues + +Revision ID: a1b2c3d4e5f6 +Revises: e9f0a1b2c3d4 +Create Date: 2026-05-15 00:00:00 +""" + +from typing import Sequence, Union + +from alembic import op + +revision: str = "a1b2c3d4e5f6" +down_revision: Union[str, None] = "e9f0a1b2c3d4" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + op.execute(""" + CREATE UNIQUE INDEX ux_evaluation_queues_default_per_run + ON evaluation_queues (project_id, run_id) + WHERE (flags ->> 'is_default')::boolean = true + """) + + +def downgrade() -> None: + op.execute("DROP INDEX IF EXISTS ux_evaluation_queues_default_per_run") diff --git a/api/oss/databases/postgres/migrations/core/versions/a2b3c4d5e6f8_backfill_default_evaluation_queues.py b/api/oss/databases/postgres/migrations/core/versions/a2b3c4d5e6f8_backfill_default_evaluation_queues.py new file mode 100644 index 0000000000..0ba52faf28 --- /dev/null +++ b/api/oss/databases/postgres/migrations/core/versions/a2b3c4d5e6f8_backfill_default_evaluation_queues.py @@ -0,0 +1,99 @@ +"""backfill default evaluation queues + +Revision ID: a2b3c4d5e6f8 +Revises: a1b2c3d4e5f6 +Create Date: 2026-05-15 00:10:00 +""" + +from typing import Sequence, Union + +from alembic import op + +revision: str = "a2b3c4d5e6f8" +down_revision: Union[str, None] = "a1b2c3d4e5f6" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # Backfill the newly inferred direct-source flags for existing runs. + op.execute(""" + UPDATE evaluation_runs + SET flags = COALESCE(flags, '{}'::jsonb) + || jsonb_build_object( + 'has_traces', EXISTS ( + SELECT 1 + FROM jsonb_array_elements(COALESCE(data -> 'steps', '[]'::jsonb)) AS step + WHERE step ->> 'type' = 'input' + AND COALESCE(step -> 'references', '{}'::jsonb) = '{}'::jsonb + AND lower(COALESCE(step ->> 'key', '')) IN ('traces', 'query-direct') + ), + 'has_testcases', EXISTS ( + SELECT 1 + FROM jsonb_array_elements(COALESCE(data -> 'steps', '[]'::jsonb)) AS step + WHERE step ->> 'type' = 'input' + AND COALESCE(step -> 'references', '{}'::jsonb) = '{}'::jsonb + AND lower(COALESCE(step ->> 'key', '')) IN ('testcases', 'testset-direct') + ) + ) + """) + + # Create the canonical open/default view for every existing run. Existing + # default queues, active or archived, are preserved and block duplicates. + op.execute(""" + INSERT INTO evaluation_queues ( + project_id, + id, + created_at, + created_by_id, + flags, + data, + status, + run_id + ) + SELECT + r.project_id, + gen_random_uuid(), + CURRENT_TIMESTAMP, + r.created_by_id, + jsonb_build_object('is_default', true, 'is_sequential', false), + '{}'::jsonb, + 'running', + r.id + FROM evaluation_runs r + WHERE NOT EXISTS ( + SELECT 1 + FROM evaluation_queues q + WHERE q.project_id = r.project_id + AND q.run_id = r.id + AND (q.flags ->> 'is_default')::boolean = true + ) + """) + + # Recompute simple-queue eligibility under the new meaning. An already + # existing active default queue is as valid as one inserted above. + op.execute(""" + UPDATE evaluation_runs r + SET flags = COALESCE(r.flags, '{}'::jsonb) + || jsonb_build_object( + 'is_queue', + COALESCE((r.flags ->> 'has_human')::boolean, false) + AND EXISTS ( + SELECT 1 + FROM evaluation_queues q + WHERE q.project_id = r.project_id + AND q.run_id = r.id + AND (q.flags ->> 'is_default')::boolean = true + AND q.deleted_at IS NULL + ) + ) + """) + + +def downgrade() -> None: + # Keep generated queues/results intact on downgrade. Remove only the newly + # inferred flags; old is_queue semantics cannot be reconstructed safely. + op.execute(""" + UPDATE evaluation_runs + SET flags = COALESCE(flags, '{}'::jsonb) - 'has_traces' - 'has_testcases' + """) diff --git a/api/oss/src/apis/fastapi/evaluations/router.py b/api/oss/src/apis/fastapi/evaluations/router.py index 8afa388275..0ce6aef435 100644 --- a/api/oss/src/apis/fastapi/evaluations/router.py +++ b/api/oss/src/apis/fastapi/evaluations/router.py @@ -253,6 +253,16 @@ def __init__( operation_id="open_run", ) + # GET /api/evaluations/runs/{run_id}/default-queue + self.router.add_api_route( + path="/runs/{run_id}/default-queue", + methods=["GET"], + endpoint=self.fetch_default_queue, + response_model=EvaluationQueueResponse, + response_model_exclude_none=True, + operation_id="fetch_default_queue", + ) + # EVALUATION SCENARIOS ------------------------------------------------- # POST /api/evaluations/scenarios/ @@ -521,6 +531,26 @@ def __init__( operation_id="delete_queue", ) + # POST /api/evaluations/queues/{queue_id}/archive + self.router.add_api_route( + path="/queues/{queue_id}/archive", + methods=["POST"], + endpoint=self.archive_queue, + response_model=EvaluationQueueResponse, + response_model_exclude_none=True, + operation_id="archive_queue", + ) + + # POST /api/evaluations/queues/{queue_id}/unarchive + self.router.add_api_route( + path="/queues/{queue_id}/unarchive", + methods=["POST"], + endpoint=self.unarchive_queue, + response_model=EvaluationQueueResponse, + response_model_exclude_none=True, + operation_id="unarchive_queue", + ) + # POST /api/evaluations/queues/{queue_id}/scenarios/query self.router.add_api_route( path="/queues/{queue_id}/scenarios/query", @@ -780,6 +810,29 @@ async def fetch_run( return run_response + # GET /evaluations/runs/{run_id}/default-queue + @intercept_exceptions() + @suppress_exceptions(default=EvaluationQueueResponse(), exclude=[HTTPException]) + async def fetch_default_queue( + self, + request: Request, + *, + run_id: UUID, + ) -> EvaluationQueueResponse: + if is_ee(): + if not await check_action_access( # type: ignore + user_uid=request.state.user_id, + project_id=request.state.project_id, + permission=Permission.VIEW_EVALUATION_QUEUES, # type: ignore + ): + raise FORBIDDEN_EXCEPTION # type: ignore + + queue = await self.evaluations_service.fetch_default_queue( + project_id=UUID(request.state.project_id), + run_id=run_id, + ) + return EvaluationQueueResponse(count=1 if queue else 0, queue=queue) + # PATCH /evaluations/runs/{run_id} @intercept_exceptions() async def edit_run( @@ -1701,6 +1754,54 @@ async def edit_queue( return queue_response + # POST /evaluations/queues/{queue_id}/archive + @intercept_exceptions() + @handle_evaluation_closed_exception() + async def archive_queue( + self, + request: Request, + *, + queue_id: UUID, + ) -> EvaluationQueueResponse: + if is_ee(): + if not await check_action_access( # type: ignore + user_uid=request.state.user_id, + project_id=request.state.project_id, + permission=Permission.EDIT_EVALUATION_QUEUES, # type: ignore + ): + raise FORBIDDEN_EXCEPTION # type: ignore + + queue = await self.evaluations_service.archive_queue( + project_id=UUID(request.state.project_id), + user_id=UUID(request.state.user_id), + queue_id=queue_id, + ) + return EvaluationQueueResponse(count=1 if queue else 0, queue=queue) + + # POST /evaluations/queues/{queue_id}/unarchive + @intercept_exceptions() + @handle_evaluation_closed_exception() + async def unarchive_queue( + self, + request: Request, + *, + queue_id: UUID, + ) -> EvaluationQueueResponse: + if is_ee(): + if not await check_action_access( # type: ignore + user_uid=request.state.user_id, + project_id=request.state.project_id, + permission=Permission.EDIT_EVALUATION_QUEUES, # type: ignore + ): + raise FORBIDDEN_EXCEPTION # type: ignore + + queue = await self.evaluations_service.unarchive_queue( + project_id=UUID(request.state.project_id), + user_id=UUID(request.state.user_id), + queue_id=queue_id, + ) + return EvaluationQueueResponse(count=1 if queue else 0, queue=queue) + # DELETE /evaluations/queues/{queue_id} @intercept_exceptions() @handle_evaluation_closed_exception() diff --git a/api/oss/src/core/auth/helper.py b/api/oss/src/core/auth/helper.py index 503323370f..a57b0f6949 100644 --- a/api/oss/src/core/auth/helper.py +++ b/api/oss/src/core/auth/helper.py @@ -1,12 +1,11 @@ from dataclasses import dataclass from typing import Any, Optional, Set -import posthog - from oss.src.services.exceptions import UnauthorizedException from oss.src.utils.caching import get_cache, set_cache from oss.src.utils.common import is_ee from oss.src.utils.env import env +from oss.src.utils.lazy import _load_posthog from oss.src.utils.logging import get_module_logger @@ -61,10 +60,27 @@ async def _get_posthog_string_entries(feature_flag: str) -> Set[str]: if cached_entries is not None: return _normalize_string_set(cached_entries) - flag_entries = posthog.get_feature_flag_payload( - feature_flag, - "user distinct id", - ) + posthog = _load_posthog() + if posthog is None: + log.warning( + "[AUTH] PostHog feature flag lookup skipped", + feature_flag=feature_flag, + reason="unavailable", + ) + return set() + + try: + flag_entries = posthog.get_feature_flag_payload( + feature_flag, + "user distinct id", + ) + except Exception as exc: + log.warning( + "[AUTH] PostHog feature flag lookup skipped", + feature_flag=feature_flag, + reason=str(exc), + ) + return set() normalized_entries = _normalize_string_set(flag_entries) diff --git a/api/oss/src/core/auth/service.py b/api/oss/src/core/auth/service.py index 8fdf80409b..6ed4027b89 100644 --- a/api/oss/src/core/auth/service.py +++ b/api/oss/src/core/auth/service.py @@ -11,7 +11,7 @@ from oss.src.models.db_models import InvitationDB, ProjectDB, OrganizationDB from oss.src.services import db_manager -from oss.src.dbs.postgres.shared.engine import engine +from oss.src.dbs.postgres.shared.engine import get_transactions_engine from oss.src.dbs.postgres.users.dao import IdentitiesDAO if is_ee(): @@ -130,7 +130,9 @@ async def discover(self, email: str) -> Dict[str, Any]: # 2. Organizations with pending project invitations if email: try: - async with engine.core_session() as session: + engine = get_transactions_engine() + + async with engine.session() as session: # Query project_invitations for this email, join with projects to get organization_id stmt = ( select(ProjectDB.organization_id) @@ -480,7 +482,9 @@ async def enforce_domain_policies(self, email: str, user_id: UUID) -> None: "Auto-join requires organization, user, and at least one workspace" ) - async with engine.core_session() as session: + engine = get_transactions_engine() + + async with engine.session() as session: existing_org_member = await session.execute( select(OrganizationMemberDB).filter_by( user_id=user.id, organization_id=organization.id @@ -791,7 +795,9 @@ async def _get_organization_flags( if not is_ee(): return None - async with db_manager.engine.core_session() as session: + engine = get_transactions_engine() + + async with engine.session() as session: stmt = select(OrganizationDB.flags).where( OrganizationDB.id == organization_id ) @@ -803,7 +809,9 @@ async def _get_organization_slug(self, organization_id: UUID) -> Optional[str]: if not is_ee(): return None - async with db_manager.engine.core_session() as session: + engine = get_transactions_engine() + + async with engine.session() as session: stmt = select(OrganizationDB.slug).where( OrganizationDB.id == organization_id ) @@ -820,7 +828,9 @@ async def _is_organization_member( if not is_ee(): return False - async with db_manager.engine.core_session() as session: + engine = get_transactions_engine() + + async with engine.session() as session: stmt = select(OrganizationMemberDB).where( OrganizationMemberDB.user_id == user_id, OrganizationMemberDB.organization_id == organization_id, @@ -837,7 +847,9 @@ async def _is_organization_owner( if not is_ee(): return False - async with db_manager.engine.core_session() as session: + engine = get_transactions_engine() + + async with engine.session() as session: stmt = select(OrganizationMemberDB.role).where( OrganizationMemberDB.user_id == user_id, OrganizationMemberDB.organization_id == organization_id, diff --git a/api/oss/src/core/auth/supertokens/overrides.py b/api/oss/src/core/auth/supertokens/overrides.py index 0fc3ab1bda..38704b5f98 100644 --- a/api/oss/src/core/auth/supertokens/overrides.py +++ b/api/oss/src/core/auth/supertokens/overrides.py @@ -1,9 +1,8 @@ from typing import Dict, Any, List, Optional, Union from urllib.parse import urlparse -import posthog - from oss.src.utils.logging import get_module_logger +from oss.src.utils.lazy import _load_posthog from supertokens_python.recipe.thirdparty.provider import ( ProviderInput, @@ -252,6 +251,9 @@ async def _create_account(email: str, uid: str) -> bool: if env.posthog.enabled and env.posthog.api_key: try: + posthog = _load_posthog() + if posthog is None: + return True posthog.capture( distinct_id=auth_info.email, event="user_signed_up_v1", @@ -263,8 +265,11 @@ async def _create_account(email: str, uid: str) -> bool: "$set": {"email": auth_info.email}, }, ) - except Exception: - log.error("[AUTH] Failed to capture PostHog signup event", exc_info=True) + except Exception as exc: + log.warning( + "[AUTH] PostHog signup event capture skipped", + reason=str(exc), + ) log.info("[AUTH] _create_account done", email=auth_info.email, uid=uid) return True diff --git a/api/oss/src/core/evaluations/interfaces.py b/api/oss/src/core/evaluations/interfaces.py index 5858682f64..23a2ec7156 100644 --- a/api/oss/src/core/evaluations/interfaces.py +++ b/api/oss/src/core/evaluations/interfaces.py @@ -514,6 +514,26 @@ async def edit_queues( ) -> List[EvaluationQueue]: raise NotImplementedError + @abstractmethod + async def archive_queue( + self, + *, + project_id: UUID, + user_id: UUID, + queue_id: UUID, + ) -> Optional[EvaluationQueue]: + raise NotImplementedError + + @abstractmethod + async def unarchive_queue( + self, + *, + project_id: UUID, + user_id: UUID, + queue_id: UUID, + ) -> Optional[EvaluationQueue]: + raise NotImplementedError + @abstractmethod async def delete_queue( self, diff --git a/api/oss/src/core/evaluations/runtime/adapters.py b/api/oss/src/core/evaluations/runtime/adapters.py new file mode 100644 index 0000000000..85e12df4e0 --- /dev/null +++ b/api/oss/src/core/evaluations/runtime/adapters.py @@ -0,0 +1,436 @@ +from asyncio import Semaphore, gather +from typing import Any, Callable, Dict, List, Optional +from uuid import UUID + +from agenta.sdk.evaluations.runtime.models import ( + ResultLogRequest, + WorkflowExecutionRequest, + WorkflowExecutionResult, +) +from agenta.sdk.models.evaluations import EvaluationStatus as SdkEvaluationStatus + +from oss.src.core.evaluations.runtime.cache import RunnableCacheResolver +from oss.src.core.evaluations.types import ( + EvaluationMetricsRefresh, + EvaluationResultCreate, + EvaluationScenarioCreate, + EvaluationStatus, +) +from oss.src.core.evaluations.utils import fetch_trace +from oss.src.core.workflows.dtos import ( + WorkflowServiceRequest, + WorkflowServiceRequestData, +) + + +def _status(status: Any) -> EvaluationStatus: + value = getattr(status, "value", status) + return EvaluationStatus(value) + + +def _read_field(source: Any, field: str) -> Any: + if isinstance(source, dict): + return source.get(field) + return getattr(source, field, None) + + +def _dump_model(source: Any, **kwargs: Any) -> Any: + if hasattr(source, "model_dump"): + return source.model_dump(**kwargs) + return source + + +def _dump_json(source: Any) -> Any: + if hasattr(source, "model_dump"): + return source.model_dump(mode="json", exclude_none=True) + if isinstance(source, dict): + return {key: _dump_json(value) for key, value in source.items()} + if isinstance(source, list): + return [_dump_json(value) for value in source] + return source + + +class BackendWorkflowServiceRunner: + """API adapter from SDK runtime requests to the backend workflow service.""" + + def __init__( + self, + *, + workflows_service: Any, + request_builder: Optional[ + Callable[[WorkflowExecutionRequest], Dict[str, Any]] + ] = None, + ): + self.workflows_service = workflows_service + self.request_builder = request_builder + + async def execute( + self, + request: WorkflowExecutionRequest, + ) -> WorkflowExecutionResult: + kwargs = ( + self.request_builder(request) + if self.request_builder + else request.model_dump(mode="python", exclude_none=True) + ) + response = await self.workflows_service.invoke_workflow(**kwargs) + status = getattr(response, "status", None) + status_code = getattr(status, "code", None) + has_error = status_code != 200 + error = None + + if has_error: + error = ( + status.model_dump(mode="json", exclude_none=True) + if hasattr(status, "model_dump") + else {"code": status_code} + ) + + return WorkflowExecutionResult( + status=( + SdkEvaluationStatus.FAILURE + if has_error + else SdkEvaluationStatus.SUCCESS + ), + trace_id=getattr(response, "trace_id", None), + span_id=getattr(response, "span_id", None), + error=error, + outputs=getattr(response, "outputs", None), + ) + + +class BackendScenarioFactory: + def __init__( + self, + *, + project_id: UUID, + user_id: UUID, + timestamp: Any, + interval: Optional[int], + evaluations_service: Any, + ): + self.project_id = project_id + self.user_id = user_id + self.timestamp = timestamp + self.interval = interval + self.evaluations_service = evaluations_service + + async def __call__(self, run_id: UUID) -> Any: + scenarios = await self.evaluations_service.create_scenarios( + project_id=self.project_id, + user_id=self.user_id, + scenarios=[ + EvaluationScenarioCreate( + run_id=run_id, + timestamp=self.timestamp, + interval=self.interval, + status=EvaluationStatus.RUNNING, + ) + ], + ) + if not scenarios: + raise ValueError(f"Failed to create scenario for run {run_id}") + return scenarios[0] + + +class BackendResultLogger: + def __init__( + self, + *, + project_id: UUID, + user_id: UUID, + timestamp: Any, + interval: Optional[int], + evaluations_service: Any, + ): + self.project_id = project_id + self.user_id = user_id + self.timestamp = timestamp + self.interval = interval + self.evaluations_service = evaluations_service + + async def log(self, request: ResultLogRequest) -> Any: + cell = request.cell + results = await self.evaluations_service.create_results( + project_id=self.project_id, + user_id=self.user_id, + results=[ + EvaluationResultCreate( + run_id=cell.run_id, + scenario_id=cell.scenario_id, + step_key=cell.step_key, + repeat_idx=cell.repeat_idx, + status=_status(cell.status), + trace_id=( + request.trace_id + if request.trace_id is not None + else cell.trace_id + ), + testcase_id=( + request.testcase_id + if request.testcase_id is not None + else cell.testcase_id + ), + error=request.error if request.error is not None else cell.error, + timestamp=self.timestamp, + interval=self.interval, + ) + ], + ) + return results[0] if results else None + + +class BackendMetricsRefresher: + def __init__( + self, + *, + project_id: UUID, + user_id: UUID, + timestamp: Any, + interval: Optional[int], + evaluations_service: Any, + ): + self.project_id = project_id + self.user_id = user_id + self.timestamp = timestamp + self.interval = interval + self.evaluations_service = evaluations_service + + async def __call__( + self, + run_id: UUID, + scenario_id: Optional[UUID], + ) -> Any: + return await self.evaluations_service.refresh_metrics( + project_id=self.project_id, + user_id=self.user_id, + metrics=EvaluationMetricsRefresh( + run_id=run_id, + scenario_id=scenario_id, + timestamp=self.timestamp, + interval=self.interval, + ), + ) + + +class BackendTraceLoader: + def __init__( + self, + *, + project_id: UUID, + tracing_service: Any, + ): + self.project_id = project_id + self.tracing_service = tracing_service + + async def load(self, trace_id: str) -> Any: + return await fetch_trace( + tracing_service=self.tracing_service, + project_id=self.project_id, + trace_id=trace_id, + ) + + +class BackendWorkflowRunner: + def __init__( + self, + *, + project_id: UUID, + user_id: UUID, + workflows_service: Any, + ): + self.project_id = project_id + self.user_id = user_id + self.workflows_service = workflows_service + + async def execute( + self, + request: WorkflowExecutionRequest, + ) -> WorkflowExecutionResult: + return (await self.execute_batch([request]))[0] + + async def execute_batch( + self, + requests: List[WorkflowExecutionRequest], + semaphore: Optional[Semaphore] = None, + ) -> List[WorkflowExecutionResult]: + async def _guarded( + request: WorkflowExecutionRequest, + ) -> WorkflowExecutionResult: + if semaphore is not None: + async with semaphore: + return await self._execute_one(request) + return await self._execute_one(request) + + return list(await gather(*(_guarded(r) for r in requests))) + + async def _execute_one( + self, + request: WorkflowExecutionRequest, + ) -> WorkflowExecutionResult: + revision = request.revision + data = _read_field(revision, "data") + if isinstance(revision, dict): + revision_dump = revision + elif hasattr(revision, "model_dump"): + revision_dump = revision.model_dump(mode="json", exclude_none=True) + else: + revision_dump = revision + + interface = ( + { + "uri": _read_field(data, "uri"), + "url": _read_field(data, "url"), + "headers": _read_field(data, "headers"), + "schemas": _read_field(data, "schemas"), + } + if data + else {} + ) + configuration = ( + { + "script": _read_field(data, "script"), + "parameters": _read_field(data, "parameters"), + } + if data + else {} + ) + flags = _read_field(revision, "flags") + flags = ( + _dump_model( + flags, + mode="json", + exclude_none=True, + exclude_unset=True, + ) + if flags + else None + ) + response = await self.workflows_service.invoke_workflow( + project_id=self.project_id, + user_id=self.user_id, + request=WorkflowServiceRequest( + version="2025.07.14", + flags=flags, + interface=interface, + configuration=configuration, + data=WorkflowServiceRequestData( + revision=revision_dump, + parameters=configuration.get("parameters"), + testcase=( + request.source.testcase.model_dump( + mode="json", + exclude_none=True, + ) + if hasattr(request.source.testcase, "model_dump") + else request.source.testcase + ), + inputs=request.source.inputs, + trace=( + request.upstream_trace.model_dump( + mode="json", + exclude_none=True, + ) + if hasattr(request.upstream_trace, "model_dump") + else request.upstream_trace + ), + outputs=request.upstream_outputs or request.source.outputs, + ), + references=_dump_json(request.references), + links=request.links or {}, + ), + ) + status = getattr(response, "status", None) + status_code = getattr(status, "code", None) + has_error = status_code != 200 + return WorkflowExecutionResult( + status=( + SdkEvaluationStatus.FAILURE + if has_error + else SdkEvaluationStatus.SUCCESS + ), + trace_id=getattr(response, "trace_id", None), + span_id=getattr(response, "span_id", None), + error=( + status.model_dump(mode="json", exclude_none=True) + if has_error and hasattr(status, "model_dump") + else {"code": status_code} + if has_error + else None + ), + outputs=getattr(response, "outputs", None), + ) + + +class BackendEvaluatorRunner(BackendWorkflowRunner): + def __init__( + self, + *, + project_id: UUID, + user_id: UUID, + workflows_service: Any, + ): + super().__init__( + project_id=project_id, + user_id=user_id, + workflows_service=workflows_service, + ) + + +class BackendCachedRunner: + def __init__( + self, + *, + runner: Any, + tracing_service: Any, + project_id: UUID, + enabled: bool, + ): + self.runner = runner + self.tracing_service = tracing_service + self.project_id = project_id + self.enabled = enabled + self.cache_resolver = RunnableCacheResolver() + + async def execute( + self, + request: WorkflowExecutionRequest, + ) -> WorkflowExecutionResult: + return (await self.execute_batch([request]))[0] + + async def execute_batch( + self, + requests: List[WorkflowExecutionRequest], + semaphore: Optional[Semaphore] = None, + ) -> List[WorkflowExecutionResult]: + results: List[Optional[WorkflowExecutionResult]] = [None] * len(requests) + missing: List[WorkflowExecutionRequest] = [] + missing_positions: List[int] = [] + + for idx, request in enumerate(requests): + cache = await self.cache_resolver.resolve( + tracing_service=self.tracing_service, + project_id=self.project_id, + enabled=self.enabled and self.tracing_service is not None, + references=request.references, + links=request.links, + required_count=1, + ) + reusable = cache.reusable_traces[0] if cache.reusable_traces else None + if reusable and getattr(reusable, "trace_id", None): + results[idx] = WorkflowExecutionResult( + status=SdkEvaluationStatus.SUCCESS, + trace_id=str(reusable.trace_id), + trace=reusable, + ) + continue + + missing.append(request) + missing_positions.append(idx) + + if missing: + executed = await self.runner.execute_batch(missing, semaphore=semaphore) + for idx, execution in zip(missing_positions, executed): + results[idx] = execution + + return [result for result in results if result is not None] diff --git a/api/oss/src/core/evaluations/runtime/cache.py b/api/oss/src/core/evaluations/runtime/cache.py new file mode 100644 index 0000000000..805e4564f5 --- /dev/null +++ b/api/oss/src/core/evaluations/runtime/cache.py @@ -0,0 +1,60 @@ +from typing import Any, Dict, List, Optional +from uuid import UUID + +from pydantic import BaseModel, ConfigDict + +from oss.src.core.evaluations.utils import ( + fetch_traces_by_hash, + make_hash, + plan_missing_traces, + select_traces_for_reuse, +) + + +class CacheResolution(BaseModel): + model_config = ConfigDict(arbitrary_types_allowed=True) + + hash_id: Optional[str] + reusable_traces: List[Any] + missing_count: int + + +class RunnableCacheResolver: + async def resolve( + self, + *, + tracing_service: Any, + project_id: UUID, + enabled: bool, + references: Optional[Dict[str, Any]] = None, + links: Optional[Dict[str, Any]] = None, + required_count: int = 1, + ) -> CacheResolution: + hash_id = make_hash(references=references, links=links) + + if not enabled or not hash_id or required_count <= 0: + return CacheResolution( + hash_id=hash_id, + reusable_traces=[], + missing_count=max(0, required_count), + ) + + cached_traces = await fetch_traces_by_hash( + tracing_service, + project_id, + hash_id=hash_id, + limit=required_count, + ) + reusable_traces = select_traces_for_reuse( + traces=cached_traces, + required_count=required_count, + ) + + return CacheResolution( + hash_id=hash_id, + reusable_traces=reusable_traces, + missing_count=plan_missing_traces( + required_count=required_count, + reusable_count=len(reusable_traces), + ), + ) diff --git a/api/oss/src/core/evaluations/runtime/executor.py b/api/oss/src/core/evaluations/runtime/executor.py new file mode 100644 index 0000000000..8db00feb4d --- /dev/null +++ b/api/oss/src/core/evaluations/runtime/executor.py @@ -0,0 +1,62 @@ +from typing import Any, Dict, List, Optional + +from pydantic import BaseModel, ConfigDict + +from oss.src.core.evaluations.types import EvaluationStatus + + +class StepExecutionResult(BaseModel): + model_config = ConfigDict(arbitrary_types_allowed=True) + + status: EvaluationStatus + trace_id: Optional[str] = None + span_id: Optional[str] = None + hash_id: Optional[str] = None + error: Optional[Dict[str, Any]] = None + outputs: Optional[Any] = None + + +class RunnableStepExecutor: + """Backend compatibility shell for runnable execution adapters. + + This public worker-facing class is kept for now. New orchestration should + target the SDK runtime WorkflowRunner protocol and keep backend workflow + service details in this module. + """ + + async def execute(self, **kwargs: Any) -> StepExecutionResult: + raise NotImplementedError + + +class WorkflowRunnableStepExecutor(RunnableStepExecutor): + def __init__(self, *, workflows_service: Any): + self.workflows_service = workflows_service + + async def execute(self, **kwargs: Any) -> StepExecutionResult: + response = await self.workflows_service.invoke_workflow(**kwargs) + status = getattr(response, "status", None) + status_code = getattr(status, "code", None) + has_error = status_code != 200 + error = None + + if has_error: + error = ( + status.model_dump(mode="json", exclude_none=True) + if hasattr(status, "model_dump") + else {"code": status_code} + ) + + return StepExecutionResult( + status=EvaluationStatus.FAILURE if has_error else EvaluationStatus.SUCCESS, + trace_id=getattr(response, "trace_id", None), + error=error, + outputs=getattr(response, "outputs", None), + ) + + +class ApplicationBatchRunnableStepExecutor(RunnableStepExecutor): + def __init__(self, *, batch_invoke: Any): + self.batch_invoke = batch_invoke + + async def execute_batch(self, **kwargs: Any) -> List[Any]: + return await self.batch_invoke(**kwargs) diff --git a/api/oss/src/core/evaluations/runtime/locks.py b/api/oss/src/core/evaluations/runtime/locks.py index 33cff2b4f0..f5e85b9e7e 100644 --- a/api/oss/src/core/evaluations/runtime/locks.py +++ b/api/oss/src/core/evaluations/runtime/locks.py @@ -121,7 +121,7 @@ async def _write_meta( payload: LockPayload, ttl: int, ) -> None: - await caching.r_lock.set( + await caching._cache_engine.get_r_lock().set( _actual_meta_name(lock_key), orjson.dumps(payload.model_dump(mode="json")), ex=ttl, @@ -134,7 +134,8 @@ async def _touch_meta( ttl: int, ) -> None: meta_key = _actual_meta_name(lock_key) - raw = await caching.r_lock.get(meta_key) + r_lock = caching._cache_engine.get_r_lock() + raw = await r_lock.get(meta_key) if not raw: return @@ -145,7 +146,7 @@ async def _touch_meta( return payload.updated_at = _now_iso() - await caching.r_lock.set( + await r_lock.set( meta_key, orjson.dumps(payload.model_dump(mode="json")), ex=ttl, @@ -157,11 +158,12 @@ async def _read_meta_if_lock_exists( lock_key: str, ) -> Optional[LockPayload]: actual_lock_key = _actual_lock_name(lock_key) - if not await caching.r_lock.exists(actual_lock_key): - await caching.r_lock.delete(_actual_meta_name(lock_key)) + r_lock = caching._cache_engine.get_r_lock() + if not await r_lock.exists(actual_lock_key): + await r_lock.delete(_actual_meta_name(lock_key)) return None - raw = await caching.r_lock.get(_actual_meta_name(lock_key)) + raw = await r_lock.get(_actual_meta_name(lock_key)) if not raw: return None @@ -263,7 +265,7 @@ async def _release_lock( return False try: - await caching.r_lock.delete(_actual_meta_name(lock_key)) + await caching._cache_engine.get_r_lock().delete(_actual_meta_name(lock_key)) except Exception: log.warning( "[LOCK] Released lock but failed to delete metadata", @@ -367,7 +369,8 @@ async def list_active_job_locks( Wildcard discovery must use SCAN, never KEYS. """ payloads: list[LockPayload] = [] - async for raw_lock_key in caching.r_lock.scan_iter( + r_lock = caching._cache_engine.get_r_lock() + async for raw_lock_key in r_lock.scan_iter( match=_actual_lock_name(job_lock_pattern(run_id)) ): meta_key = ( @@ -375,7 +378,7 @@ async def list_active_job_locks( if isinstance(raw_lock_key, bytes) else f"{raw_lock_key}:meta" ) - raw_payload = await caching.r_lock.get(meta_key) + raw_payload = await r_lock.get(meta_key) if not raw_payload: continue @@ -403,9 +406,8 @@ async def is_run_executing( *, run_id: str, ) -> bool: - async for _ in caching.r_lock.scan_iter( - match=_actual_lock_name(job_lock_pattern(run_id)) - ): + r_lock = caching._cache_engine.get_r_lock() + async for _ in r_lock.scan_iter(match=_actual_lock_name(job_lock_pattern(run_id))): return True return False @@ -414,7 +416,8 @@ async def has_mutation_lock( *, run_id: str, ) -> bool: - return bool(await caching.r_lock.exists(_actual_lock_name(run_lock_key(run_id)))) + r_lock = caching._cache_engine.get_r_lock() + return bool(await r_lock.exists(_actual_lock_name(run_lock_key(run_id)))) async def refresh_worker_heartbeat( @@ -424,7 +427,8 @@ async def refresh_worker_heartbeat( ) -> WorkerHeartbeatPayload: now = _now_iso() hb_key = _actual_lock_name(worker_heartbeat_key(worker_id)) - raw = await caching.r_lock.get(hb_key) + r_lock = caching._cache_engine.get_r_lock() + raw = await r_lock.get(hb_key) created_at = now if raw: @@ -442,7 +446,7 @@ async def refresh_worker_heartbeat( created_at=created_at, updated_at=now, ) - await caching.r_lock.set( + await r_lock.set( hb_key, orjson.dumps(payload.model_dump(mode="json")), ex=ttl, diff --git a/api/oss/src/core/evaluations/runtime/models.py b/api/oss/src/core/evaluations/runtime/models.py new file mode 100644 index 0000000000..fd6e8520f8 --- /dev/null +++ b/api/oss/src/core/evaluations/runtime/models.py @@ -0,0 +1,125 @@ +from typing import Any, Dict, List, Literal, Optional +from uuid import UUID + +from pydantic import BaseModel, ConfigDict, Field + +from oss.src.core.evaluations.types import EvaluationStatus, Origin, Type + +InputSourceKind = Literal["query", "testset", "trace", "testcase", "direct"] +SourceBatchKind = Literal["traces", "testcases"] +TopologyStatus = Literal["supported", "potential", "not_planned", "unsupported"] +DispatchKind = Literal[ + "batch_query", + "batch_testset", + "batch_invocation", + "queue_traces", + "queue_testcases", + "live_query", +] + + +class RuntimeModel(BaseModel): + model_config = ConfigDict(arbitrary_types_allowed=True) + + +class InputSourceSpec(RuntimeModel): + kind: InputSourceKind + step_key: str + references: Dict[str, Any] = Field(default_factory=dict) + + +class ResolvedSourceItem(RuntimeModel): + kind: InputSourceKind + step_key: str + references: Dict[str, Any] = Field(default_factory=dict) + trace_id: Optional[str] = None + span_id: Optional[str] = None + testcase_id: Optional[UUID] = None + testcase: Optional[Any] = None + trace: Optional[Any] = None + inputs: Optional[Any] = None + outputs: Optional[Any] = None + + +class ResolvedSourceBatch(RuntimeModel): + kind: SourceBatchKind + step_key: str + trace_ids: List[str] = Field(default_factory=list) + testcase_ids: List[UUID] = Field(default_factory=list) + + +class ResolvedTestsetInputSpec(RuntimeModel): + step_key: str + testset: Any + testset_revision: Any + testcases: List[Any] = Field(default_factory=list) + testcases_data: List[Dict[str, Any]] = Field(default_factory=list) + + +class ScenarioBinding(RuntimeModel): + scenario_id: UUID + source: ResolvedSourceItem + interval: Optional[int] = None + timestamp: Optional[Any] = None + + +class EvaluationStep(RuntimeModel): + key: str + type: Type + origin: Origin + references: Dict[str, Any] = Field(default_factory=dict) + inputs: List[str] = Field(default_factory=list) + + +class TensorSlice(RuntimeModel): + run_id: UUID + scenario_ids: Optional[List[UUID]] = None + step_keys: Optional[List[str]] = None + repeat_idxs: Optional[List[int]] = None + + +class TensorProbeSummary(RuntimeModel): + existing_count: int = 0 + missing_count: int = 0 + success_count: int = 0 + failure_count: int = 0 + pending_count: int = 0 + any_count: int = 0 + + +class PlannedCell(RuntimeModel): + run_id: UUID + scenario_id: UUID + step_key: str + step_type: Type + origin: Origin + repeat_idx: int + status: EvaluationStatus + should_execute: bool = False + trace_id: Optional[str] = None + span_id: Optional[str] = None + testcase_id: Optional[UUID] = None + error: Optional[Dict[str, Any]] = None + + +class ExecutionPlan(RuntimeModel): + run_id: UUID + cells: List[PlannedCell] + + @property + def executable_cells(self) -> List[PlannedCell]: + return [cell for cell in self.cells if cell.should_execute] + + +class ProcessSummary(RuntimeModel): + created: int = 0 + reused: int = 0 + pending: int = 0 + failed: int = 0 + + +class TopologyDecision(RuntimeModel): + status: TopologyStatus + label: str + reason: str + dispatch: Optional[DispatchKind] = None diff --git a/api/oss/src/core/evaluations/runtime/planner.py b/api/oss/src/core/evaluations/runtime/planner.py new file mode 100644 index 0000000000..46de7bd0ca --- /dev/null +++ b/api/oss/src/core/evaluations/runtime/planner.py @@ -0,0 +1,146 @@ +from typing import Dict, Iterable, List, Optional +from uuid import UUID + +from oss.src.core.evaluations.runtime.models import ( + EvaluationStep, + ExecutionPlan, + PlannedCell, + ResolvedSourceItem, + ScenarioBinding, +) +from oss.src.core.evaluations.types import ( + EvaluationResultCreate, + EvaluationRun, + EvaluationRunDataStep, + EvaluationStatus, +) +from agenta.sdk.evaluations.runtime.planner import ( + EvaluationPlanner as SdkEvaluationPlanner, +) + + +def _step_inputs(step: EvaluationRunDataStep) -> List[str]: + return [step_input.key for step_input in (step.inputs or []) if step_input.key] + + +def normalize_steps( + steps: Optional[Iterable[EvaluationRunDataStep]], +) -> List[EvaluationStep]: + return [ + EvaluationStep( + key=step.key, + type=step.type, + origin=step.origin, + references=step.references or {}, + inputs=_step_inputs(step), + ) + for step in (steps or []) + ] + + +def make_scenario_bindings( + *, + scenario_ids: List[UUID], + source_items: List[ResolvedSourceItem], +) -> List[ScenarioBinding]: + if len(scenario_ids) != len(source_items): + raise ValueError("scenario_ids and source_items must have the same length") + + return [ + ScenarioBinding(scenario_id=scenario_id, source=source_item) + for scenario_id, source_item in zip(scenario_ids, source_items) + ] + + +class EvaluationPlanner: + """Backend DTO adapter around the SDK-owned runtime planner.""" + + def plan( + self, + *, + run: EvaluationRun, + bindings: List[ScenarioBinding], + ) -> ExecutionPlan: + if not run.id: + raise ValueError("run.id is required") + + steps = normalize_steps(run.data.steps if run.data else None) + flags = run.flags + + sdk_plan = SdkEvaluationPlanner().plan_bindings( + run_id=run.id, + bindings=bindings, # type: ignore[arg-type] + steps=steps, # type: ignore[arg-type] + repeats=run.data.repeats if run.data else None, + is_split=bool(flags and flags.is_split), + is_live=bool(flags and flags.is_live), + has_traces=bool(flags and flags.has_traces), + has_testcases=bool(flags and flags.has_testcases), + ) + + return ExecutionPlan( + run_id=sdk_plan.run_id, + cells=[ + PlannedCell( + run_id=cell.run_id, + scenario_id=cell.scenario_id, + step_key=cell.step_key, + step_type=cell.step_type, + origin=cell.origin, + repeat_idx=cell.repeat_idx, + status=EvaluationStatus(cell.status.value), + should_execute=cell.should_execute, + trace_id=cell.trace_id, + span_id=cell.span_id, + testcase_id=cell.testcase_id, + error=cell.error, + ) + for cell in sdk_plan.cells + ], + ) + + +def index_cells_by_slot( + plan: ExecutionPlan, +) -> Dict[tuple[UUID, str, int], PlannedCell]: + return { + (cell.scenario_id, cell.step_key, cell.repeat_idx): cell for cell in plan.cells + } + + +def planned_cells_to_result_creates( + cells: Iterable[PlannedCell], +) -> List[EvaluationResultCreate]: + return [ + EvaluationResultCreate( + run_id=cell.run_id, + scenario_id=cell.scenario_id, + step_key=cell.step_key, + repeat_idx=cell.repeat_idx, + status=cell.status, + trace_id=cell.trace_id, + testcase_id=cell.testcase_id, + error=cell.error, + ) + for cell in cells + ] + + +def plan_source_input_result_creates( + *, + run: EvaluationRun, + scenario_id: UUID, + source_item: ResolvedSourceItem, +) -> List[EvaluationResultCreate]: + plan = EvaluationPlanner().plan( + run=run, + bindings=make_scenario_bindings( + scenario_ids=[scenario_id], + source_items=[source_item], + ), + ) + return planned_cells_to_result_creates( + cell + for cell in plan.cells + if cell.step_type == "input" and cell.step_key == source_item.step_key + ) diff --git a/api/oss/src/core/evaluations/runtime/sources.py b/api/oss/src/core/evaluations/runtime/sources.py new file mode 100644 index 0000000000..e1247945d2 --- /dev/null +++ b/api/oss/src/core/evaluations/runtime/sources.py @@ -0,0 +1,448 @@ +from datetime import datetime +from typing import Any, Dict, List, Optional +from uuid import UUID + +from oss.src.core.evaluations.runtime.models import ( + ResolvedSourceBatch, + ResolvedSourceItem, + ResolvedTestsetInputSpec, +) +from oss.src.core.evaluations.types import EvaluationRun, EvaluationRunDataStep +from oss.src.core.evaluations.utils import fetch_trace +from oss.src.core.shared.dtos import Reference +from oss.src.core.tracing.dtos import ( + Filtering, + Windowing, + Formatting, + Format, + Focus, + TracingQuery, + LogicalOperator, +) + + +def _extract_root_span(trace: Any) -> Optional[Any]: + spans = ( + trace.get("spans") if isinstance(trace, dict) else getattr(trace, "spans", None) + ) + if not isinstance(spans, dict) or not spans: + return None + + for span in spans.values(): + if isinstance(span, list): + continue + if _extract_span_id(span): + return span + + return None + + +def _extract_span_id(span: Any) -> Optional[str]: + span_id = ( + span.get("span_id") + if isinstance(span, dict) + else getattr(span, "span_id", None) + ) + return str(span_id) if span_id else None + + +def _extract_ag_data(trace: Any) -> Dict[str, Any]: + root_span = _extract_root_span(trace) + if root_span is None: + return {} + + attributes = ( + root_span.get("attributes", {}) + if isinstance(root_span, dict) + else getattr(root_span, "attributes", {}) + ) + if hasattr(attributes, "model_dump"): + attributes = attributes.model_dump(mode="json", exclude_none=True) + if not isinstance(attributes, dict): + return {} + + ag = attributes.get("ag") or {} + data = ag.get("data") if isinstance(ag, dict) else {} + return data if isinstance(data, dict) else {} + + +class SourceResolver: + async def resolve( + self, + *, + project_id: UUID, + step: EvaluationRunDataStep, + ) -> Optional[ResolvedSourceBatch]: + raise NotImplementedError + + +class QueryRevisionTraceResolver(SourceResolver): + def __init__(self, *, queries_service: Any): + self.queries_service = queries_service + + async def resolve( + self, + *, + project_id: UUID, + step: EvaluationRunDataStep, + ) -> Optional[ResolvedSourceBatch]: + refs = step.references or {} + query_revision_ref = refs.get("query_revision") + + if not step.key or not query_revision_ref or not query_revision_ref.id: + return None + + query_revision = await self.queries_service.fetch_query_revision( + project_id=project_id, + query_revision_ref=query_revision_ref, + include_trace_ids=True, + ) + trace_ids = ( + query_revision.data.trace_ids + if query_revision and query_revision.data and query_revision.data.trace_ids + else [] + ) + + if not trace_ids: + return None + + return ResolvedSourceBatch( + kind="traces", + step_key=step.key, + trace_ids=trace_ids, + ) + + +class TestsetRevisionTestcaseResolver(SourceResolver): + def __init__(self, *, testsets_service: Any): + self.testsets_service = testsets_service + + async def resolve( + self, + *, + project_id: UUID, + step: EvaluationRunDataStep, + ) -> Optional[ResolvedSourceBatch]: + refs = step.references or {} + testset_revision_ref = refs.get("testset_revision") + + if not step.key or not testset_revision_ref or not testset_revision_ref.id: + return None + + testset_revision = await self.testsets_service.fetch_testset_revision( + project_id=project_id, + testset_revision_ref=testset_revision_ref, + include_testcase_ids=True, + ) + testcase_ids = ( + testset_revision.data.testcase_ids + if testset_revision + and testset_revision.data + and testset_revision.data.testcase_ids + else [] + ) + + if not testcase_ids: + return None + + return ResolvedSourceBatch( + kind="testcases", + step_key=step.key, + testcase_ids=testcase_ids, + ) + + +class TestsetRevisionPayloadResolver: + def __init__(self, *, testsets_service: Any): + self.testsets_service = testsets_service + + async def resolve( + self, + *, + project_id: UUID, + step: EvaluationRunDataStep, + ) -> ResolvedTestsetInputSpec: + refs = step.references or {} + testset_revision_ref = refs.get("testset_revision") + + if not testset_revision_ref or not isinstance(testset_revision_ref.id, UUID): + raise ValueError( + f"Evaluation input step {step.key} missing testset_revision reference." + ) + + testset_revision = await self.testsets_service.fetch_testset_revision( + project_id=project_id, + testset_revision_ref=testset_revision_ref, + ) + if not testset_revision: + raise ValueError( + f"Testset revision with id {testset_revision_ref.id} not found!" + ) + if not testset_revision.data or not testset_revision.data.testcases: + raise ValueError( + f"Testset revision with id {testset_revision_ref.id} has no testcases!" + ) + + testset_variant = await self.testsets_service.fetch_testset_variant( + project_id=project_id, + testset_variant_ref=Reference(id=testset_revision.variant_id), + ) + if not testset_variant: + raise ValueError( + f"Testset variant with id {testset_revision.variant_id} not found!" + ) + + testset = await self.testsets_service.fetch_testset( + project_id=project_id, + testset_ref=Reference(id=testset_variant.testset_id), + ) + if not testset: + raise ValueError(f"Testset with id {testset_variant.testset_id} not found!") + + testcases = testset_revision.data.testcases + return ResolvedTestsetInputSpec( + step_key=step.key, + testset=testset, + testset_revision=testset_revision, + testcases=testcases, + testcases_data=[ + {**testcase.data, "testcase_id": str(testcase.id)} + for testcase in testcases + ], + ) + + +async def resolve_queue_source_batches( + *, + project_id: UUID, + run: EvaluationRun, + queries_service: Any, + testsets_service: Any, +) -> List[ResolvedSourceBatch]: + if not run.data or not run.data.steps: + return [] + + resolvers: List[SourceResolver] = [ + QueryRevisionTraceResolver(queries_service=queries_service), + TestsetRevisionTestcaseResolver(testsets_service=testsets_service), + ] + batches: List[ResolvedSourceBatch] = [] + + for step in run.data.steps: + if step.type != "input" or not step.key: + continue + + for resolver in resolvers: + batch = await resolver.resolve( + project_id=project_id, + step=step, + ) + if batch: + batches.append(batch) + break + + return batches + + +async def resolve_testset_input_specs( + *, + project_id: UUID, + input_steps: List[EvaluationRunDataStep], + testsets_service: Any, +) -> List[ResolvedTestsetInputSpec]: + resolver = TestsetRevisionPayloadResolver(testsets_service=testsets_service) + return [ + await resolver.resolve( + project_id=project_id, + step=input_step, + ) + for input_step in input_steps + ] + + +async def resolve_direct_source_items( + *, + project_id: UUID, + trace_ids: Optional[List[str]] = None, + testcase_ids: Optional[List[UUID]] = None, + testcases_service: Any = None, + tracing_service: Any = None, +) -> List[ResolvedSourceItem]: + source_items: List[ResolvedSourceItem] = [] + testcase_ids = testcase_ids or [] + trace_ids = trace_ids or [] + + testcases = ( + await testcases_service.fetch_testcases( + project_id=project_id, + testcase_ids=testcase_ids, + ) + if testcase_ids and testcases_service is not None + else [] + ) + testcases_by_id = { + testcase.id: testcase for testcase in testcases if getattr(testcase, "id", None) + } + traces_by_id: Dict[str, Any] = {} + + if trace_ids and tracing_service is not None: + for trace_id in trace_ids: + trace = await fetch_trace( + tracing_service=tracing_service, + project_id=project_id, + trace_id=trace_id, + max_retries=1, + delay=0, + ) + if trace is not None: + traces_by_id[trace_id] = trace + + source_items.extend( + ResolvedSourceItem( + kind="testcase", + step_key="", + testcase=testcases_by_id.get(testcase_id), + testcase_id=testcase_id, + ) + for testcase_id in testcase_ids + ) + for trace_id in trace_ids: + trace = traces_by_id.get(trace_id) + ag_data = _extract_ag_data(trace) if trace is not None else {} + root_span = _extract_root_span(trace) if trace is not None else None + source_items.append( + ResolvedSourceItem( + kind="trace", + step_key="", + trace_id=trace_id, + span_id=_extract_span_id(root_span), + trace=trace, + inputs=ag_data.get("inputs"), + outputs=ag_data.get("outputs"), + ) + ) + + return source_items + + +async def resolve_live_query_traces( + *, + project_id: UUID, + query_revisions: Dict[str, Any], + tracing_service: Any, + newest: Optional[datetime] = None, + oldest: Optional[datetime] = None, + use_windowing: bool = False, +) -> Dict[str, List[Any]]: + query_traces: Dict[str, List[Any]] = {} + + for query_step_key, query_revision in query_revisions.items(): + formatting = Formatting( + focus=Focus.TRACE, + format=Format.AGENTA, + ) + filtering = Filtering( + operator=LogicalOperator.AND, + conditions=[], + ) + windowing = Windowing( + oldest=oldest, + newest=newest, + next=None, + limit=None, + order="ascending", + interval=None, + rate=None, + ) + + query_revision_data = getattr(query_revision, "data", None) + if query_revision_data: + query_filtering = getattr(query_revision_data, "filtering", None) + query_windowing = getattr(query_revision_data, "windowing", None) + + if query_filtering: + filtering = query_filtering + + if query_windowing and use_windowing: + windowing = Windowing( + oldest=query_windowing.oldest, + newest=query_windowing.newest, + limit=query_windowing.limit, + order=query_windowing.order, + rate=query_windowing.rate, + ) + elif query_windowing: + windowing.rate = query_windowing.rate + + query_traces[query_step_key] = ( + await tracing_service.query_traces( + project_id=project_id, + query=TracingQuery( + formatting=formatting, + filtering=filtering, + windowing=windowing, + ), + ) + or [] + ) + + return query_traces + + +async def resolve_query_source_items( + *, + project_id: UUID, + run: EvaluationRun, + queries_service: Any, + tracing_service: Any, + newest: Optional[datetime] = None, + oldest: Optional[datetime] = None, + use_windowing: bool = False, +) -> Dict[str, List[ResolvedSourceItem]]: + if not run.data or not run.data.steps: + return {} + + query_revisions: Dict[str, Any] = {} + for step in run.data.steps: + if step.type != "input" or not step.key: + continue + + query_revision_ref = (step.references or {}).get("query_revision") + if not query_revision_ref: + continue + + query_revision = await queries_service.fetch_query_revision( + project_id=project_id, + query_revision_ref=query_revision_ref, + ) + if ( + not query_revision + or not getattr(query_revision, "id", None) + or not getattr(query_revision, "slug", None) + ): + continue + + query_revisions[step.key] = query_revision + + query_traces = await resolve_live_query_traces( + project_id=project_id, + query_revisions=query_revisions, + tracing_service=tracing_service, + newest=newest, + oldest=oldest, + use_windowing=use_windowing, + ) + + return { + query_step_key: [ + ResolvedSourceItem( + kind="trace", + step_key=query_step_key, + trace_id=trace.trace_id, + trace=trace, + ) + for trace in traces + if trace and trace.trace_id + ] + for query_step_key, traces in query_traces.items() + } diff --git a/api/oss/src/core/evaluations/runtime/task_runner.py b/api/oss/src/core/evaluations/runtime/task_runner.py new file mode 100644 index 0000000000..baab75b7e9 --- /dev/null +++ b/api/oss/src/core/evaluations/runtime/task_runner.py @@ -0,0 +1,59 @@ +from datetime import datetime +from typing import Any, List, Optional +from uuid import UUID + +from agenta.sdk.evaluations.runtime.execution import EvaluationTaskRunner + + +class TaskiqEvaluationTaskRunner(EvaluationTaskRunner): + """API adapter from generic evaluation dispatch to Taskiq tasks.""" + + def __init__(self, *, worker: Any): + self.worker = worker + + async def process_run( + self, + *, + project_id: UUID, + user_id: UUID, + run_id: UUID, + newest: Optional[datetime] = None, + oldest: Optional[datetime] = None, + ) -> Any: + kwargs = dict( + project_id=project_id, + user_id=user_id, + run_id=run_id, + ) + if newest is not None: + kwargs["newest"] = newest + if oldest is not None: + kwargs["oldest"] = oldest + + return await self.worker.process_run.kiq(**kwargs) + + async def process_slice( + self, + *, + project_id: UUID, + user_id: UUID, + run_id: UUID, + source_kind: str, + trace_ids: Optional[List[str]] = None, + testcase_ids: Optional[List[UUID]] = None, + input_step_key: Optional[str] = None, + ) -> Any: + kwargs = dict( + project_id=project_id, + user_id=user_id, + run_id=run_id, + source_kind=source_kind, + ) + if trace_ids is not None: + kwargs["trace_ids"] = trace_ids + if testcase_ids is not None: + kwargs["testcase_ids"] = testcase_ids + if input_step_key is not None: + kwargs["input_step_key"] = input_step_key + + return await self.worker.process_slice.kiq(**kwargs) diff --git a/api/oss/src/core/evaluations/runtime/tensor.py b/api/oss/src/core/evaluations/runtime/tensor.py new file mode 100644 index 0000000000..0e3e35a23e --- /dev/null +++ b/api/oss/src/core/evaluations/runtime/tensor.py @@ -0,0 +1,192 @@ +from typing import List, Optional +from uuid import UUID + +from oss.src.core.evaluations.runtime.models import ( + ProcessSummary, + TensorProbeSummary, + TensorSlice, +) +from oss.src.core.evaluations.types import ( + EvaluationMetricsRefresh, + EvaluationResult, + EvaluationResultCreate, + EvaluationResultQuery, + EvaluationStatus, +) + + +def _empty_dimension(values: Optional[List[object]]) -> bool: + return values == [] + + +def _slice_is_empty(tensor_slice: TensorSlice) -> bool: + return any( + _empty_dimension(values) + for values in ( + tensor_slice.scenario_ids, + tensor_slice.step_keys, + tensor_slice.repeat_idxs, + ) + ) + + +def _query_from_slice(tensor_slice: TensorSlice) -> EvaluationResultQuery: + return EvaluationResultQuery( + run_id=tensor_slice.run_id, + scenario_ids=tensor_slice.scenario_ids, + step_keys=tensor_slice.step_keys, + repeat_idxs=tensor_slice.repeat_idxs, + ) + + +class TensorSliceOperations: + def __init__(self, *, evaluations_service): + self.evaluations_service = evaluations_service + + async def probe( + self, + *, + project_id: UUID, + tensor_slice: TensorSlice, + ) -> List[EvaluationResult]: + if _slice_is_empty(tensor_slice): + return [] + + return await self.evaluations_service.query_results( + project_id=project_id, + result=_query_from_slice(tensor_slice), + ) + + async def populate( + self, + *, + project_id: UUID, + user_id: UUID, + results: List[EvaluationResultCreate], + refresh_metrics: bool = True, + ) -> List[EvaluationResult]: + if not results: + return [] + + created = await self.evaluations_service.create_results( + project_id=project_id, + user_id=user_id, + results=results, + ) + + if refresh_metrics: + await self._refresh_results_metrics( + project_id=project_id, + user_id=user_id, + results=created, + ) + + return created + + async def probe_summary( + self, + *, + project_id: UUID, + tensor_slice: TensorSlice, + expected_count: Optional[int] = None, + ) -> TensorProbeSummary: + results = await self.probe( + project_id=project_id, + tensor_slice=tensor_slice, + ) + existing_count = len(results) + expected = expected_count if expected_count is not None else existing_count + + return TensorProbeSummary( + existing_count=existing_count, + missing_count=max(0, expected - existing_count), + success_count=sum( + 1 for result in results if result.status == EvaluationStatus.SUCCESS + ), + failure_count=sum( + 1 + for result in results + if result.status + in { + EvaluationStatus.FAILURE, + EvaluationStatus.ERRORS, + } + ), + pending_count=sum( + 1 for result in results if result.status == EvaluationStatus.PENDING + ), + any_count=existing_count, + ) + + async def prune( + self, + *, + project_id: UUID, + user_id: UUID, + tensor_slice: TensorSlice, + refresh_metrics: bool = True, + ) -> List[UUID]: + results = await self.probe( + project_id=project_id, + tensor_slice=tensor_slice, + ) + result_ids = [result.id for result in results if result.id] + if not result_ids: + return [] + + deleted = await self.evaluations_service.delete_results( + project_id=project_id, + result_ids=result_ids, + ) + + if refresh_metrics: + await self._refresh_results_metrics( + project_id=project_id, + user_id=user_id, + results=results, + ) + + return deleted + + async def process( + self, + *, + project_id: UUID, + user_id: UUID, + tensor_slice: TensorSlice, + ) -> ProcessSummary: + if _slice_is_empty(tensor_slice): + return ProcessSummary() + + await self.evaluations_service.refresh_metrics( + project_id=project_id, + user_id=user_id, + metrics=EvaluationMetricsRefresh( + run_id=tensor_slice.run_id, + scenario_ids=tensor_slice.scenario_ids, + ), + ) + return ProcessSummary() + + async def _refresh_results_metrics( + self, + *, + project_id: UUID, + user_id: UUID, + results: List[EvaluationResult], + ) -> None: + scenario_ids = sorted( + {result.scenario_id for result in results if result.scenario_id}, + key=str, + ) + if not results: + return + + await self.evaluations_service.refresh_metrics( + project_id=project_id, + user_id=user_id, + metrics=EvaluationMetricsRefresh( + run_id=results[0].run_id, + scenario_ids=scenario_ids or None, + ), + ) diff --git a/api/oss/src/core/evaluations/runtime/topology.py b/api/oss/src/core/evaluations/runtime/topology.py new file mode 100644 index 0000000000..f7950b6e0d --- /dev/null +++ b/api/oss/src/core/evaluations/runtime/topology.py @@ -0,0 +1,34 @@ +from agenta.sdk.evaluations.runtime.topology import classify_steps_topology + +from oss.src.core.evaluations.runtime.planner import normalize_steps +from oss.src.core.evaluations.runtime.models import TopologyDecision +from oss.src.core.evaluations.types import EvaluationRun + + +def classify_run_topology(run: EvaluationRun) -> TopologyDecision: + """Classify the current evaluation graph for worker dispatch. + + This is intentionally conservative. It mirrors the currently supported + worker-dispatched topologies while naming future-interest and not-planned + shapes explicitly. + """ + + steps = run.data.steps if run.data and run.data.steps else [] + flags = run.flags + + decision = classify_steps_topology( + steps=normalize_steps(steps), + is_live=bool(flags and flags.is_live), + has_queries=bool(flags and flags.has_queries), + has_testsets=bool(flags and flags.has_testsets), + has_traces=bool(flags and flags.has_traces), + has_testcases=bool(flags and flags.has_testcases), + has_evaluators=bool(flags and flags.has_evaluators), + ) + + return TopologyDecision( + status=decision.status, + label=decision.label, + reason=decision.reason, + dispatch=decision.dispatch, + ) diff --git a/api/oss/src/core/evaluations/service.py b/api/oss/src/core/evaluations/service.py index 6869e78e52..7a8ce80ddc 100644 --- a/api/oss/src/core/evaluations/service.py +++ b/api/oss/src/core/evaluations/service.py @@ -20,6 +20,7 @@ EvaluationRunDataMapping, EvaluationRunDataStepInput, EvaluationRunDataStep, + EvaluationRunDataConcurrency, EvaluationRunData, EvaluationRun, EvaluationRunCreate, @@ -100,10 +101,18 @@ ) from oss.src.core.evaluations.utils import get_metrics_keys_from_schema +from oss.src.core.evaluations.runtime.topology import classify_run_topology +from oss.src.core.evaluations.runtime.sources import resolve_queue_source_batches +from oss.src.core.evaluations.runtime.task_runner import TaskiqEvaluationTaskRunner log = get_module_logger(__name__) +# Product policy toggle: when True, every evaluation run keeps a default queue +# even when it has no human evaluators. Keep this as a global until the product +# decision is finalized. +EVALUATIONS_DEFAULT_QUEUES_FOR_ALL_RUNS = False + if TYPE_CHECKING: from oss.src.tasks.taskiq.evaluations.worker import EvaluationsWorker @@ -199,6 +208,11 @@ def __init__( self.testsets_service = testsets_service self.evaluators_service = evaluators_service self.evaluations_worker = evaluations_worker + self.evaluations_task_runner = ( + TaskiqEvaluationTaskRunner(worker=evaluations_worker) + if evaluations_worker is not None + else None + ) ### CRUD @@ -225,7 +239,7 @@ async def refresh_runs( log.error(e, exc_info=True) return False - if self.evaluations_worker is None: + if self.evaluations_task_runner is None: log.warning( "[LIVE] Taskiq client is not configured; skipping live run dispatch" ) @@ -266,12 +280,10 @@ async def refresh_runs( run=run, ) - await self.evaluations_worker.evaluate_live_query.kiq( + await self.evaluations_task_runner.process_run( project_id=project_id, user_id=user_id, - # run_id=run.id, - # newest=newest, oldest=oldest, ) @@ -359,34 +371,96 @@ async def _ensure_human_annotation_queue( user_id: UUID, run: EvaluationRun, ) -> None: - """Create an EvaluationQueue for human annotation steps if none exists for this run.""" - if not run.id or not run.data or not run.data.steps: - return + await self._reconcile_default_queue( + project_id=project_id, + user_id=user_id, + run=run, + ) - human_step_keys = [ - step.key - for step in run.data.steps - if step.type == "annotation" and step.origin == "human" and step.key - ] + async def fetch_default_queue( + self, + *, + project_id: UUID, + run_id: UUID, + include_archived: bool = False, + ) -> Optional[EvaluationQueue]: + queues = await self.query_queues( + project_id=project_id, + queue=EvaluationQueueQuery( + run_id=run_id, + flags=EvaluationQueueQueryFlags(is_default=True), + include_archived=include_archived, + ), + ) + return queues[0] if queues else None - if not human_step_keys: - return + async def _reconcile_default_queue( + self, + *, + project_id: UUID, + user_id: UUID, + run: EvaluationRun, + ) -> EvaluationRun: + if not run.id: + return run - existing_queues = await self.query_queues( + has_human = bool(run.flags and run.flags.has_human) + should_exist = EVALUATIONS_DEFAULT_QUEUES_FOR_ALL_RUNS or has_human + default_queue = await self.fetch_default_queue( project_id=project_id, - queue=EvaluationQueueQuery(run_id=run.id), + run_id=run.id, + include_archived=True, ) - if any(q.run_id == run.id for q in existing_queues): - return - await self.create_queue( - project_id=project_id, - user_id=user_id, - queue=EvaluationQueueCreate( - run_id=run.id, - status=EvaluationStatus.RUNNING, - data=EvaluationQueueData(step_keys=human_step_keys), - ), + if should_exist: + if default_queue is None: + default_queue = await self.create_queue( + project_id=project_id, + user_id=user_id, + queue=EvaluationQueueCreate( + run_id=run.id, + status=EvaluationStatus.RUNNING, + flags=EvaluationQueueFlags(is_default=True), + data=EvaluationQueueData(), + ), + ) + elif default_queue.deleted_at is not None: + default_queue = await self.unarchive_queue( + project_id=project_id, + user_id=user_id, + queue_id=default_queue.id, + ) + elif default_queue is not None and default_queue.deleted_at is None: + default_queue = await self.archive_queue( + project_id=project_id, + user_id=user_id, + queue_id=default_queue.id, + ) + + is_queue = bool( + has_human and default_queue is not None and default_queue.deleted_at is None + ) + if run.flags and run.flags.is_queue == is_queue: + return run + + flags = run.flags.model_copy() if run.flags else EvaluationRunFlags() + flags.is_queue = is_queue + return ( + await self.evaluations_dao.edit_run( + project_id=project_id, + user_id=user_id, + run=EvaluationRunEdit( + id=run.id, + name=run.name, + description=run.description, + flags=flags, + tags=run.tags, + meta=run.meta, + status=run.status, + data=run.data, + ), + ) + or run ) async def fetch_live_runs( @@ -444,12 +518,19 @@ async def create_run( ) -> Optional[EvaluationRun]: run.version = CURRENT_VERSION - return await self.evaluations_dao.create_run( + created_run = await self.evaluations_dao.create_run( project_id=project_id, user_id=user_id, # run=run, ) + if created_run: + created_run = await self._reconcile_default_queue( + project_id=project_id, + user_id=user_id, + run=created_run, + ) + return created_run async def create_runs( self, @@ -462,12 +543,20 @@ async def create_runs( for run in runs: run.version = CURRENT_VERSION - return await self.evaluations_dao.create_runs( + created_runs = await self.evaluations_dao.create_runs( project_id=project_id, user_id=user_id, # runs=runs, ) + return [ + await self._reconcile_default_queue( + project_id=project_id, + user_id=user_id, + run=created_run, + ) + for created_run in created_runs + ] async def fetch_run( self, @@ -505,12 +594,19 @@ async def edit_run( ) -> Optional[EvaluationRun]: run.version = CURRENT_VERSION - return await self.evaluations_dao.edit_run( + edited_run = await self.evaluations_dao.edit_run( project_id=project_id, user_id=user_id, # run=run, ) + if edited_run: + edited_run = await self._reconcile_default_queue( + project_id=project_id, + user_id=user_id, + run=edited_run, + ) + return edited_run async def edit_runs( self, @@ -523,12 +619,20 @@ async def edit_runs( for run in runs: run.version = CURRENT_VERSION - return await self.evaluations_dao.edit_runs( + edited_runs = await self.evaluations_dao.edit_runs( project_id=project_id, user_id=user_id, # runs=runs, ) + return [ + await self._reconcile_default_queue( + project_id=project_id, + user_id=user_id, + run=edited_run, + ) + for edited_run in edited_runs + ] async def delete_run( self, @@ -1538,6 +1642,26 @@ def mapping_key( # - EVALUATION QUEUE ------------------------------------------------------- + @staticmethod + def _validate_default_queue_data( + *, flags: Optional[EvaluationQueueFlags], data: Optional[EvaluationQueueData] + ) -> None: + if not flags or not flags.is_default or not data: + return + if any( + value is not None + for value in ( + data.user_ids, + data.scenario_ids, + data.step_keys, + data.batch_size, + data.batch_offset, + ) + ): + raise ValueError( + "default queues cannot filter scenarios, steps, assignments, or batches" + ) + async def create_queue( self, *, @@ -1547,13 +1671,21 @@ async def create_queue( queue: EvaluationQueueCreate, ) -> Optional[EvaluationQueue]: queue.version = CURRENT_VERSION + self._validate_default_queue_data(flags=queue.flags, data=queue.data) - return await self.evaluations_dao.create_queue( + created_queue = await self.evaluations_dao.create_queue( project_id=project_id, user_id=user_id, # queue=queue, ) + if created_queue: + await self._sync_run_queue_flag_for_default_queue( + project_id=project_id, + user_id=user_id, + queue=created_queue, + ) + return created_queue async def create_queues( self, @@ -1565,13 +1697,21 @@ async def create_queues( ) -> List[EvaluationQueue]: for queue in queues: queue.version = CURRENT_VERSION + self._validate_default_queue_data(flags=queue.flags, data=queue.data) - return await self.evaluations_dao.create_queues( + created_queues = await self.evaluations_dao.create_queues( project_id=project_id, user_id=user_id, # queues=queues, ) + for created_queue in created_queues: + await self._sync_run_queue_flag_for_default_queue( + project_id=project_id, + user_id=user_id, + queue=created_queue, + ) + return created_queues async def fetch_queue( self, @@ -1608,13 +1748,29 @@ async def edit_queue( queue: EvaluationQueueEdit, ) -> Optional[EvaluationQueue]: queue.version = CURRENT_VERSION + existing = await self.fetch_queue(project_id=project_id, queue_id=queue.id) + if existing and existing.flags and existing.flags.is_default: + if queue.flags and not queue.flags.is_default: + raise ValueError("default queues cannot be demoted") + effective_flags = existing.flags + else: + effective_flags = queue.flags or (existing.flags if existing else None) + effective_data = queue.data or (existing.data if existing else None) + self._validate_default_queue_data(flags=effective_flags, data=effective_data) - return await self.evaluations_dao.edit_queue( + edited_queue = await self.evaluations_dao.edit_queue( project_id=project_id, user_id=user_id, # queue=queue, ) + if edited_queue: + await self._sync_run_queue_flag_for_default_queue( + project_id=project_id, + user_id=user_id, + queue=edited_queue, + ) + return edited_queue async def edit_queues( self, @@ -1627,12 +1783,110 @@ async def edit_queues( for queue in queues: queue.version = CURRENT_VERSION - return await self.evaluations_dao.edit_queues( + existing_queues = await self.fetch_queues( + project_id=project_id, + queue_ids=[queue.id for queue in queues], + ) + existing_by_id = {queue.id: queue for queue in existing_queues} + for queue in queues: + existing = existing_by_id.get(queue.id) + if existing and existing.flags and existing.flags.is_default: + if queue.flags and not queue.flags.is_default: + raise ValueError("default queues cannot be demoted") + effective_flags = existing.flags + else: + effective_flags = queue.flags or (existing.flags if existing else None) + effective_data = queue.data or (existing.data if existing else None) + self._validate_default_queue_data( + flags=effective_flags, data=effective_data + ) + + edited_queues = await self.evaluations_dao.edit_queues( project_id=project_id, user_id=user_id, # queues=queues, ) + for edited_queue in edited_queues: + await self._sync_run_queue_flag_for_default_queue( + project_id=project_id, + user_id=user_id, + queue=edited_queue, + ) + return edited_queues + + async def _sync_run_queue_flag_for_default_queue( + self, + *, + project_id: UUID, + user_id: UUID, + queue: EvaluationQueue, + ) -> None: + if not queue.flags or not queue.flags.is_default: + return + run = await self.fetch_run(project_id=project_id, run_id=queue.run_id) + if not run: + return + has_human = bool(run.flags and run.flags.has_human) + is_queue = bool(has_human and queue.deleted_at is None) + if run.flags and run.flags.is_queue == is_queue: + return + flags = run.flags.model_copy() if run.flags else EvaluationRunFlags() + flags.is_queue = is_queue + await self.evaluations_dao.edit_run( + project_id=project_id, + user_id=user_id, + run=EvaluationRunEdit( + id=run.id, + name=run.name, + description=run.description, + flags=flags, + tags=run.tags, + meta=run.meta, + status=run.status, + data=run.data, + ), + ) + + async def archive_queue( + self, + *, + project_id: UUID, + user_id: UUID, + queue_id: UUID, + ) -> Optional[EvaluationQueue]: + queue = await self.evaluations_dao.archive_queue( + project_id=project_id, + user_id=user_id, + queue_id=queue_id, + ) + if queue: + await self._sync_run_queue_flag_for_default_queue( + project_id=project_id, + user_id=user_id, + queue=queue, + ) + return queue + + async def unarchive_queue( + self, + *, + project_id: UUID, + user_id: UUID, + queue_id: UUID, + ) -> Optional[EvaluationQueue]: + queue = await self.evaluations_dao.unarchive_queue( + project_id=project_id, + user_id=user_id, + queue_id=queue_id, + ) + if queue: + await self._sync_run_queue_flag_for_default_queue( + project_id=project_id, + user_id=user_id, + queue=queue, + ) + return queue async def delete_queue( self, @@ -1641,6 +1895,9 @@ async def delete_queue( # queue_id: UUID, ) -> Optional[UUID]: + existing = await self.fetch_queue(project_id=project_id, queue_id=queue_id) + if existing and existing.flags and existing.flags.is_default: + raise ValueError("default queues must be archived, not hard deleted") return await self.evaluations_dao.delete_queue( project_id=project_id, # @@ -1654,6 +1911,12 @@ async def delete_queues( # queue_ids: List[UUID], ) -> List[UUID]: + existing_queues = await self.fetch_queues( + project_id=project_id, + queue_ids=queue_ids, + ) + if any(queue.flags and queue.flags.is_default for queue in existing_queues): + raise ValueError("default queues must be archived, not hard deleted") return await self.evaluations_dao.delete_queues( project_id=project_id, # @@ -1800,6 +2063,11 @@ def __init__( self.evaluators_service = evaluators_service self.evaluations_service = evaluations_service self.evaluations_worker = evaluations_worker + self.evaluations_task_runner = ( + TaskiqEvaluationTaskRunner(worker=evaluations_worker) + if evaluations_worker is not None + else None + ) async def create( self, @@ -1856,6 +2124,7 @@ async def create( evaluator_steps=evaluation.data.evaluator_steps, # repeats=evaluation.data.repeats, + concurrency=evaluation.data.concurrency, # is_live=evaluation.flags.is_live, ) @@ -2232,50 +2501,26 @@ async def start( _evaluation = await self._parse_evaluation_run(run=run) return _evaluation - if self.evaluations_worker is None: + if self.evaluations_task_runner is None: log.warning( "[EVAL] Taskiq client missing; cannot dispatch evaluation run", ) return _evaluation - has_query_steps = bool(_evaluation.data.query_steps) - has_testset_steps = bool(_evaluation.data.testset_steps) - has_application_steps = bool(_evaluation.data.application_steps) - has_evaluator_steps = bool(_evaluation.data.evaluator_steps) - - if has_query_steps and has_evaluator_steps: - await self._ensure_human_annotation_queue( - project_id=project_id, - user_id=user_id, - run=run, - ) - await self.evaluations_worker.evaluate_batch_query.kiq( - project_id=project_id, - user_id=user_id, - # - run_id=run.id, - ) - - elif ( - has_testset_steps and has_application_steps and has_evaluator_steps - ): - await self.evaluations_worker.evaluate_batch_testset.kiq( - project_id=project_id, - user_id=user_id, - # - run_id=run.id, - ) + # Worker task names are API-internal, so dispatch through the + # unified run processor rather than topology-specific handlers. + topology = classify_run_topology(run) - elif ( - has_testset_steps - and has_application_steps - and not has_evaluator_steps - and not has_query_steps - ): - await self.evaluations_worker.evaluate_batch_invocation.kiq( + if topology.dispatch: + if topology.dispatch == "batch_query": + await self._ensure_human_annotation_queue( + project_id=project_id, + user_id=user_id, + run=run, + ) + await self.evaluations_task_runner.process_run( project_id=project_id, user_id=user_id, - # run_id=run.id, ) @@ -2283,10 +2528,9 @@ async def start( log.warning( "[EVAL] [start] [skip] unsupported non-live run topology", run_id=run.id, - has_query_steps=has_query_steps, - has_testset_steps=has_testset_steps, - has_application_steps=has_application_steps, - has_evaluator_steps=has_evaluator_steps, + topology=topology.label, + topology_status=topology.status, + reason=topology.reason, ) return _evaluation @@ -2335,7 +2579,7 @@ async def _ensure_human_annotation_queue( run=run, ) - async def evaluate_batch_traces( + async def dispatch_trace_slice( self, *, project_id: UUID, @@ -2347,7 +2591,7 @@ async def evaluate_batch_traces( ) -> bool: if not trace_ids: return False - if self.evaluations_worker is None: + if self.evaluations_task_runner is None: log.warning( "[EVAL] Taskiq client missing; cannot dispatch trace batch", run_id=run_id, @@ -2358,9 +2602,13 @@ async def evaluate_batch_traces( project_id=project_id, run_id=run_id, ) - if not run or not run.flags or not run.flags.is_queue: + if ( + not run + or not run.flags + or not (run.flags.has_traces or run.flags.has_queries) + ): log.warning( - "[EVAL] trace batch dispatch requires a queue evaluation run", + "[EVAL] trace batch dispatch requires a trace-capable evaluation run", run_id=run_id, ) return False @@ -2371,17 +2619,17 @@ async def evaluate_batch_traces( run=run, ) - await self.evaluations_worker.evaluate_batch_traces.kiq( + await self.evaluations_task_runner.process_slice( project_id=project_id, user_id=user_id, - # run_id=run_id, + source_kind="traces", trace_ids=trace_ids, input_step_key=input_step_key, ) return True - async def evaluate_batch_testcases( + async def dispatch_testcase_slice( self, *, project_id: UUID, @@ -2393,7 +2641,7 @@ async def evaluate_batch_testcases( ) -> bool: if not testcase_ids: return False - if self.evaluations_worker is None: + if self.evaluations_task_runner is None: log.warning( "[EVAL] Taskiq client missing; cannot dispatch testcase batch", run_id=run_id, @@ -2404,9 +2652,13 @@ async def evaluate_batch_testcases( project_id=project_id, run_id=run_id, ) - if not run or not run.flags or not run.flags.is_queue: + if ( + not run + or not run.flags + or not (run.flags.has_testcases or run.flags.has_testsets) + ): log.warning( - "[EVAL] testcase batch dispatch requires a queue evaluation run", + "[EVAL] testcase batch dispatch requires a testcase-capable evaluation run", run_id=run_id, ) return False @@ -2417,11 +2669,11 @@ async def evaluate_batch_testcases( run=run, ) - await self.evaluations_worker.evaluate_batch_testcases.kiq( + await self.evaluations_task_runner.process_slice( project_id=project_id, user_id=user_id, - # run_id=run_id, + source_kind="testcases", testcase_ids=testcase_ids, input_step_key=input_step_key, ) @@ -2485,6 +2737,7 @@ async def _make_evaluation_run_data( evaluator_steps: Optional[Target] = None, # repeats: Optional[int] = None, + concurrency: Optional[EvaluationRunDataConcurrency] = None, # is_live: Optional[bool] = None, ) -> Optional[EvaluationRunData]: @@ -3044,6 +3297,7 @@ async def _make_evaluation_run_data( steps=steps, mappings=mappings, repeats=repeats or 1, + concurrency=concurrency, ) except Exception: # pylint: disable=broad-exception-caught @@ -3409,7 +3663,7 @@ async def create( is_live=False, is_active=True, is_closed=False, - is_queue=True, + is_queue=False, ), tags=queue.tags, meta=queue.meta, @@ -3534,40 +3788,39 @@ async def query( run_ids_filter = list(dict.fromkeys(requested_run_ids)) + eligible_runs = await self.evaluations_service.query_runs( + project_id=project_id, + run=EvaluationRunQuery( + flags=EvaluationRunQueryFlags(is_queue=True), + ), + ) + eligible_run_ids = [run.id for run in eligible_runs if run and run.id] if query and query.kind is not None: - run_query = EvaluationRunQuery( - flags=EvaluationRunQueryFlags( - is_queue=True, - has_queries=query.kind == SimpleQueueKind.TRACES, - has_testsets=query.kind == SimpleQueueKind.TESTCASES, - ), - ) - runs = await self.evaluations_service.query_runs( - project_id=project_id, - run=run_query, - ) + eligible_run_ids = [ + run.id + for run in eligible_runs + if run and run.id and self._get_kind(run) == query.kind + ] + if not eligible_run_ids: + return [] - kind_run_ids = [run.id for run in runs if run and run.id] - if not kind_run_ids: + eligible_run_ids_set = set(eligible_run_ids) + if run_ids_filter is None: + run_ids_filter = eligible_run_ids + else: + run_ids_filter = [ + run_id for run_id in run_ids_filter if run_id in eligible_run_ids_set + ] + if not run_ids_filter: return [] - kind_run_ids_set = set(kind_run_ids) - if run_ids_filter is None: - run_ids_filter = kind_run_ids - else: - run_ids_filter = [ - run_id for run_id in run_ids_filter if run_id in kind_run_ids_set - ] - if not run_ids_filter: - return [] - queues = await self.evaluations_service.query_queues( project_id=project_id, queue=EvaluationQueueQuery( name=query.name if query else None, description=query.description if query else None, # - flags=EvaluationQueueQueryFlags(), + flags=EvaluationQueueQueryFlags(is_default=True), tags=query.tags if query else None, meta=query.meta if query else None, # @@ -3631,7 +3884,7 @@ async def add_traces( if self._get_kind(run) != SimpleQueueKind.TRACES: return None - ok = await self.simple_evaluations_service.evaluate_batch_traces( + ok = await self.simple_evaluations_service.dispatch_trace_slice( project_id=project_id, user_id=user_id, # @@ -3671,7 +3924,7 @@ async def add_testcases( if self._get_kind(run) != SimpleQueueKind.TESTCASES: return None - ok = await self.simple_evaluations_service.evaluate_batch_testcases( + ok = await self.simple_evaluations_service.dispatch_testcase_slice( project_id=project_id, user_id=user_id, # @@ -3799,63 +4052,33 @@ async def _dispatch_source_batches( if not run.id or not run.data or not run.data.steps: return False - dispatched = False - for step in run.data.steps: - if step.type != "input" or not step.key: - continue - - refs = step.references or {} - query_revision_ref = refs.get("query_revision") - testset_revision_ref = refs.get("testset_revision") - - if query_revision_ref and query_revision_ref.id: - query_revision = await self.simple_evaluations_service.queries_service.fetch_query_revision( - project_id=project_id, - query_revision_ref=query_revision_ref, - include_trace_ids=True, - ) - trace_ids = ( - query_revision.data.trace_ids - if query_revision - and query_revision.data - and query_revision.data.trace_ids - else [] - ) - if not trace_ids: - continue + batches = await resolve_queue_source_batches( + project_id=project_id, + run=run, + queries_service=self.simple_evaluations_service.queries_service, + testsets_service=self.simple_evaluations_service.testsets_service, + ) - ok = await self.simple_evaluations_service.evaluate_batch_traces( + dispatched = False + for batch in batches: + if batch.kind == "traces" and batch.trace_ids: + ok = await self.simple_evaluations_service.dispatch_trace_slice( project_id=project_id, user_id=user_id, run_id=run.id, - trace_ids=trace_ids, - input_step_key=step.key, + trace_ids=batch.trace_ids, + input_step_key=batch.step_key, ) dispatched = dispatched or ok continue - if testset_revision_ref and testset_revision_ref.id: - testset_revision = await self.simple_evaluations_service.testsets_service.fetch_testset_revision( - project_id=project_id, - testset_revision_ref=testset_revision_ref, - include_testcase_ids=True, - ) - testcase_ids = ( - testset_revision.data.testcase_ids - if testset_revision - and testset_revision.data - and testset_revision.data.testcase_ids - else [] - ) - if not testcase_ids: - continue - - ok = await self.simple_evaluations_service.evaluate_batch_testcases( + if batch.kind == "testcases" and batch.testcase_ids: + ok = await self.simple_evaluations_service.dispatch_testcase_slice( project_id=project_id, user_id=user_id, run_id=run.id, - testcase_ids=testcase_ids, - input_step_key=step.key, + testcase_ids=batch.testcase_ids, + input_step_key=batch.step_key, ) dispatched = dispatched or ok @@ -3882,9 +4105,7 @@ async def _make_run_data( annotation_mappings: List[EvaluationRunDataMapping] = [] annotation_step_keys: List[str] = [] - source_step_key = ( - "query-direct" if kind == SimpleQueueKind.TRACES else "testset-direct" - ) + source_step_key = "traces" if kind == SimpleQueueKind.TRACES else "testcases" source_step = EvaluationRunDataStep( key=source_step_key, type="input", @@ -4003,21 +4224,22 @@ def _get_kind(self, run: EvaluationRun) -> Optional[SimpleQueueKind]: if not run.flags or not run.flags.is_queue: return None - if run.flags.has_queries and not run.flags.has_testsets: - return SimpleQueueKind.TRACES - - if run.flags.has_testsets and not run.flags.has_queries: - return SimpleQueueKind.TESTCASES - - return None + families = [ + (run.flags.has_queries, SimpleQueueKind.QUERIES), + (run.flags.has_testsets, SimpleQueueKind.TESTSETS), + (run.flags.has_traces, SimpleQueueKind.TRACES), + (run.flags.has_testcases, SimpleQueueKind.TESTCASES), + ] + enabled = [kind for enabled, kind in families if enabled] + return enabled[0] if len(enabled) == 1 else None @staticmethod def _get_source_kind(*, queue_data: SimpleQueueData) -> Optional[SimpleQueueKind]: if queue_data.queries: - return SimpleQueueKind.TRACES + return SimpleQueueKind.QUERIES if queue_data.testsets: - return SimpleQueueKind.TESTCASES + return SimpleQueueKind.TESTSETS return None diff --git a/api/oss/src/core/evaluations/tasks/batch.py b/api/oss/src/core/evaluations/tasks/batch.py deleted file mode 100644 index 1416010344..0000000000 --- a/api/oss/src/core/evaluations/tasks/batch.py +++ /dev/null @@ -1,148 +0,0 @@ -from uuid import UUID - -from oss.src.utils.logging import get_module_logger -from oss.src.utils.common import is_ee - -if is_ee(): - pass - -from oss.src.dbs.postgres.queries.dbes import ( - QueryArtifactDBE, - QueryVariantDBE, - QueryRevisionDBE, -) -from oss.src.dbs.postgres.testcases.dbes import ( - TestcaseBlobDBE, -) -from oss.src.dbs.postgres.testsets.dbes import ( - TestsetArtifactDBE, - TestsetVariantDBE, - TestsetRevisionDBE, -) -from oss.src.dbs.postgres.workflows.dbes import ( - WorkflowArtifactDBE, - WorkflowVariantDBE, - WorkflowRevisionDBE, -) - -from oss.src.dbs.postgres.tracing.dao import TracingDAO -from oss.src.dbs.postgres.blobs.dao import BlobsDAO -from oss.src.dbs.postgres.git.dao import GitDAO -from oss.src.dbs.postgres.evaluations.dao import EvaluationsDAO - -from oss.src.core.tracing.service import TracingService -from oss.src.core.queries.service import QueriesService -from oss.src.core.testcases.service import TestcasesService -from oss.src.core.testsets.service import TestsetsService -from oss.src.core.testsets.service import SimpleTestsetsService -from oss.src.core.workflows.service import WorkflowsService -from oss.src.core.evaluators.service import EvaluatorsService -from oss.src.core.evaluators.service import SimpleEvaluatorsService -from oss.src.core.evaluations.service import EvaluationsService -from oss.src.core.annotations.service import AnnotationsService - - -log = get_module_logger(__name__) - - -# DBS -------------------------------------------------------------------------- - -tracing_dao = TracingDAO() - -testcases_dao = BlobsDAO( - BlobDBE=TestcaseBlobDBE, -) - -queries_dao = GitDAO( - ArtifactDBE=QueryArtifactDBE, - VariantDBE=QueryVariantDBE, - RevisionDBE=QueryRevisionDBE, -) - -testsets_dao = GitDAO( - ArtifactDBE=TestsetArtifactDBE, - VariantDBE=TestsetVariantDBE, - RevisionDBE=TestsetRevisionDBE, -) - -workflows_dao = GitDAO( - ArtifactDBE=WorkflowArtifactDBE, - VariantDBE=WorkflowVariantDBE, - RevisionDBE=WorkflowRevisionDBE, -) - -evaluations_dao = EvaluationsDAO() - -# CORE ------------------------------------------------------------------------- - -tracing_service = TracingService( - tracing_dao=tracing_dao, -) - -queries_service = QueriesService( - queries_dao=queries_dao, -) - -testcases_service = TestcasesService( - testcases_dao=testcases_dao, -) - -testsets_service = TestsetsService( - testsets_dao=testsets_dao, - testcases_service=testcases_service, -) - -simple_testsets_service = SimpleTestsetsService( - testsets_service=testsets_service, -) - -workflows_service = WorkflowsService( - workflows_dao=workflows_dao, -) - -evaluators_service = EvaluatorsService( - workflows_service=workflows_service, -) - -simple_evaluators_service = SimpleEvaluatorsService( - evaluators_service=evaluators_service, -) - -evaluations_service = EvaluationsService( - evaluations_dao=evaluations_dao, - tracing_service=tracing_service, - queries_service=queries_service, - testsets_service=testsets_service, - evaluators_service=evaluators_service, - # -) - -annotations_service = AnnotationsService( - tracing_service=tracing_service, - evaluators_service=evaluators_service, - simple_evaluators_service=simple_evaluators_service, -) - -# ------------------------------------------------------------------------------ - - -def evaluate_testsets( - self, - *, - project_id: UUID, - user_id: UUID, - # - run_id: UUID, -): - pass - - -def evaluate_queries( - self, - *, - project_id: UUID, - user_id: UUID, - # - run_id: UUID, -): - pass diff --git a/api/oss/src/core/evaluations/tasks/legacy.py b/api/oss/src/core/evaluations/tasks/legacy.py deleted file mode 100644 index 548a5b9dba..0000000000 --- a/api/oss/src/core/evaluations/tasks/legacy.py +++ /dev/null @@ -1,2260 +0,0 @@ -from typing import Dict, List, Optional, Any - -from uuid import UUID -from json import dumps - -from fastapi import Request - -from oss.src.utils.logging import get_module_logger -from oss.src.utils.common import is_ee -from oss.src.services import llm_apps_service -from oss.src.models.shared_models import InvokationResult -from oss.src.services.db_manager import get_project_by_id - -if is_ee(): - from ee.src.utils.entitlements import check_entitlements, Counter - - -from oss.src.core.queries.service import QueriesService -from oss.src.core.testcases.service import TestcasesService -from oss.src.core.testsets.service import TestsetsService -from oss.src.core.applications.service import ApplicationsService -from oss.src.core.workflows.service import WorkflowsService -from oss.src.core.evaluators.service import SimpleEvaluatorsService -from oss.src.core.evaluations.service import EvaluationsService - -from oss.src.core.tracing.service import TracingService - - -from oss.src.core.evaluations.types import ( - EvaluationStatus, - EvaluationRun, - EvaluationRunEdit, - EvaluationScenarioCreate, - EvaluationScenarioEdit, - EvaluationResultCreate, - EvaluationMetricsRefresh, -) - -from oss.src.core.shared.dtos import Reference -from oss.src.core.workflows.dtos import ( - WorkflowServiceRequestData, - WorkflowServiceRequest, -) - - -from oss.src.core.evaluations.utils import ( - build_repeat_indices, - effective_is_split, - fetch_traces_by_hash, - fetch_trace, - make_hash, - plan_missing_traces, - required_traces_for_step, - select_traces_for_reuse, -) - - -log = get_module_logger(__name__) - - -def _resolve_runtime_uri( - *, - revision_data: Optional[Any], -) -> Optional[str]: - if revision_data is None: - return None - - return WorkflowsService._get_service_url(revision_data=revision_data) - - -def _extract_root_span(trace: Optional[Any]) -> Optional[Any]: - if not trace: - # log.debug("[TRACE] [ROOT]", reason="missing-trace") - return None - - spans = getattr(trace, "spans", None) - - if not isinstance(spans, dict): - # log.debug( - # "[TRACE] [ROOT]", - # trace_id=str(getattr(trace, "trace_id", None)) - # if getattr(trace, "trace_id", None) - # else None, - # reason="spans-not-dict", - # spans_type=type(spans).__name__ if spans is not None else None, - # ) - return None - - if not spans: - # log.debug( - # "[TRACE] [ROOT]", - # trace_id=str(getattr(trace, "trace_id", None)) - # if getattr(trace, "trace_id", None) - # else None, - # reason="spans-empty", - # ) - return None - - root_span = list(spans.values())[0] - if isinstance(root_span, list): - # log.debug( - # "[TRACE] [ROOT]", - # trace_id=str(getattr(trace, "trace_id", None)) - # if getattr(trace, "trace_id", None) - # else None, - # reason="first-span-is-list", - # span_keys=list(spans.keys()), - # first_list_len=len(root_span), - # ) - return None - - # log.debug( - # "[TRACE] [ROOT]", - # trace_id=str(getattr(trace, "trace_id", None)) - # if getattr(trace, "trace_id", None) - # else None, - # reason="resolved", - # span_keys=list(spans.keys()), - # root_span_id=str(getattr(root_span, "span_id", None)) - # if getattr(root_span, "span_id", None) - # else None, - # ) - return root_span - - -def _build_trace_context( - *, - trace: Optional[Any], - error: Optional[Dict[str, Any]] = None, -) -> Optional[Dict[str, Any]]: - root_span = _extract_root_span(trace) - trace_id = getattr(trace, "trace_id", None) if trace else None - - if not root_span or not trace_id: - # log.debug( - # "[TRACE] [CONTEXT]", - # trace_id=str(trace_id) if trace_id else None, - # has_root_span=bool(root_span), - # has_error=bool(error), - # error=error, - # ) - return None - - # log.debug( - # "[TRACE] [CONTEXT]", - # trace_id=str(trace_id), - # span_id=str(getattr(root_span, "span_id", None)) - # if getattr(root_span, "span_id", None) - # else None, - # has_error=bool(error), - # ) - return { - "trace": trace, - "trace_id": str(trace_id), - "span_id": getattr(root_span, "span_id", None), - "root_span": root_span, - "error": error, - } - - -async def _resolve_testset_input_specs( - *, - project_id: UUID, - input_steps: List[Any], - testsets_service: TestsetsService, -) -> List[Dict[str, Any]]: - input_specs: List[Dict[str, Any]] = [] - - for input_step in input_steps: - input_refs = input_step.references or {} - testset_revision_ref = input_refs.get("testset_revision") - - if not testset_revision_ref or not isinstance(testset_revision_ref.id, UUID): - raise ValueError( - f"Evaluation input step {input_step.key} missing testset_revision reference." - ) - - testset_revision = await testsets_service.fetch_testset_revision( - project_id=project_id, - testset_revision_ref=testset_revision_ref, - ) - if not testset_revision: - raise ValueError( - f"Testset revision with id {testset_revision_ref.id} not found!" - ) - if not testset_revision.data or not testset_revision.data.testcases: - raise ValueError( - f"Testset revision with id {testset_revision_ref.id} has no testcases!" - ) - - testset_variant = await testsets_service.fetch_testset_variant( - project_id=project_id, - testset_variant_ref=Reference(id=testset_revision.variant_id), - ) - if not testset_variant: - raise ValueError( - f"Testset variant with id {testset_revision.variant_id} not found!" - ) - - testset = await testsets_service.fetch_testset( - project_id=project_id, - testset_ref=Reference(id=testset_variant.testset_id), - ) - if not testset: - raise ValueError(f"Testset with id {testset_variant.testset_id} not found!") - - testcases = testset_revision.data.testcases - input_specs.append( - { - "step_key": input_step.key, - "testset": testset, - "testset_revision": testset_revision, - "testcases": testcases, - "testcases_data": [ - {**testcase.data, "id": str(testcase.id)} for testcase in testcases - ], - } - ) - - return input_specs - - -async def evaluate_batch_testset( - *, - project_id: UUID, - user_id: UUID, - # - run_id: UUID, - # - tracing_service: TracingService, - testsets_service: TestsetsService, - queries_service: QueriesService, - workflows_service: WorkflowsService, - applications_service: ApplicationsService, - evaluations_service: EvaluationsService, - # - simple_evaluators_service: SimpleEvaluatorsService, -): - """ - Annotates an application revision applied to a testset using auto evaluator(s). - - All testset, application, and evaluator information is extracted from the - evaluation run's data.steps references. - - Args: - project_id (UUID): The ID of the project. - user_id (UUID): The ID of the user. - run_id (UUID): The ID of the evaluation run. - - Returns: - None - """ - request = Request( - scope={ - "type": "http", - "http_version": "1.1", - "scheme": "http", - } - ) - request.state.project_id = str(project_id) - request.state.user_id = str(user_id) - - project = None - run = None - - try: - # ---------------------------------------------------------------------- - log.info( - "[SCOPE] ", run_id=run_id, project_id=project_id, user_id=user_id - ) - # ---------------------------------------------------------------------- - - # fetch project -------------------------------------------------------- - project = await get_project_by_id( - project_id=str(project_id), - ) - # ---------------------------------------------------------------------- - - # fetch run ------------------------------------------------------------ - run = await evaluations_service.fetch_run( - project_id=project_id, - # - run_id=run_id, - ) - - if not run: - raise ValueError(f"Evaluation run with id {run_id} not found!") - - if not run.data: - raise ValueError(f"Evaluation run with id {run_id} has no data!") - - if not run.data.steps: - raise ValueError(f"Evaluation run with id {run_id} has no steps!") - - steps = run.data.steps - repeats = run.data.repeats or 1 - repeat_indices = build_repeat_indices(repeats) - is_cached = bool(run.flags.is_cached) if run.flags else False - - input_steps = [step for step in steps if step.type == "input"] - invocation_steps = [step for step in steps if step.type == "invocation"] - annotation_steps = [step for step in steps if step.type == "annotation"] - - log.info( - "[STEPS] ", - run_id=run_id, - count=len(steps), - input_keys=[step.key for step in input_steps], - invocation_keys=[step.key for step in invocation_steps], - annotation_keys=[step.key for step in annotation_steps], - step_types=[getattr(step, "type", None) for step in steps], - ) - - if not input_steps or len(invocation_steps) != 1: - raise ValueError( - f"Evaluation run with id {run_id} must have at least one input and exactly one invocation step." - ) - - invocation_step = invocation_steps[0] - invocation_step_key = invocation_step.key - is_split = effective_is_split( - is_split=bool(run.flags.is_split) if run.flags else False, - has_application_steps=True, - has_evaluator_steps=bool(annotation_steps), - ) - application_required_count = required_traces_for_step( - repeats=repeats, - is_split=is_split, - step_kind="application", - has_evaluator_steps=bool(annotation_steps), - ) - evaluator_required_count = required_traces_for_step( - repeats=repeats, - is_split=is_split, - step_kind="evaluator", - has_evaluator_steps=bool(annotation_steps), - ) - - application_revision_ref = invocation_step.references.get( - "application_revision" - ) - if not application_revision_ref or not isinstance( - application_revision_ref.id, UUID - ): - raise ValueError( - f"Evaluation run with id {run_id} missing invocation.application_revision reference." - ) - - run_config = { - "batch_size": 10, - "max_retries": 3, - "retry_delay": 3, - "delay_between_batches": 5, - } - - input_specs = await _resolve_testset_input_specs( - project_id=project_id, - input_steps=input_steps, - testsets_service=testsets_service, - ) - testset_revision_ids = [ - str(input_spec["testset_revision"].id) for input_spec in input_specs - ] - - log.info("[TESTSET] ", run_id=run_id, ids=testset_revision_ids) - log.info( - "[APPLICATION] ", - run_id=run_id, - ids=[str(application_revision_ref.id)], - ) - # ---------------------------------------------------------------------- - - # flatten scenario sources --------------------------------------------- - scenario_specs = [ - { - "input_step_key": input_spec["step_key"], - "testset": input_spec["testset"], - "testset_revision": input_spec["testset_revision"], - "testcase": testcase, - "testcase_data": testcase_data, - } - for input_spec in input_specs - for testcase, testcase_data in zip( - input_spec["testcases"], - input_spec["testcases_data"], - ) - ] - nof_scenarios = len(scenario_specs) - # ---------------------------------------------------------------------- - - # fetch application ---------------------------------------------------- - application_revision = await applications_service.fetch_application_revision( - project_id=project_id, - application_revision_ref=application_revision_ref, - ) - - if application_revision is None: - raise ValueError( - f"App revision with id {application_revision_ref.id} not found!" - ) - - application_variant = await applications_service.fetch_application_variant( - project_id=project_id, - application_variant_ref=Reference( - id=application_revision.application_variant_id - ), - ) - - if application_variant is None: - raise ValueError( - f"Application variant with id {application_revision.application_variant_id} not found!" - ) - - application = await applications_service.fetch_application( - project_id=project_id, - application_ref=Reference(id=application_variant.application_id), - ) - - if application is None: - raise ValueError( - f"Application with id {application_variant.application_id} not found!" - ) - - uri = _resolve_runtime_uri(revision_data=application_revision.data) - - if not uri: - raise ValueError( - f"No deployment URI found for revision {application_revision_ref.id}!" - ) - - # fetch evaluators ----------------------------------------------------- - evaluator_references = {step.key: step.references for step in annotation_steps} - # log.debug( - # "[EVALUATORS] ", - # run_id=run_id, - # count=len(annotation_steps), - # refs={ - # step_key: ( - # { - # key: str(reference.id) - # if getattr(reference, "id", None) - # else None - # for key, reference in (references or {}).items() - # } - # ) - # for step_key, references in evaluator_references.items() - # }, - # ) - - evaluators = {} - for evaluator_key, evaluator_refs in evaluator_references.items(): - evaluators[evaluator_key] = await workflows_service.fetch_workflow_revision( - project_id=project_id, - # - workflow_revision_ref=evaluator_refs.get("evaluator_revision"), - ) - # log.debug( - # "[EVALUATORS] [FETCH]", - # run_id=run_id, - # resolved={ - # evaluator_key: ( - # str(evaluator_revision.id) - # if evaluator_revision and evaluator_revision.id - # else None - # ) - # for evaluator_key, evaluator_revision in evaluators.items() - # }, - # ) - # ---------------------------------------------------------------------- - - # create scenarios ----------------------------------------------------- - scenarios_create = [ - EvaluationScenarioCreate( - run_id=run_id, - # - status=EvaluationStatus.RUNNING, - ) - for _ in range(nof_scenarios) - ] - - scenarios = await evaluations_service.create_scenarios( - project_id=project_id, - user_id=user_id, - # - scenarios=scenarios_create, - ) - - if len(scenarios) != nof_scenarios: - raise ValueError(f"Failed to create evaluation scenarios for run {run_id}!") - # ---------------------------------------------------------------------- - - # create input steps --------------------------------------------------- - results_create = [ - EvaluationResultCreate( - run_id=run_id, - scenario_id=scenario.id, - step_key=scenario_specs[idx]["input_step_key"], - repeat_idx=repeat_idx, - # - status=EvaluationStatus.SUCCESS, - # - testcase_id=scenario_specs[idx]["testcase"].id, - ) - for idx, scenario in enumerate(scenarios) - for repeat_idx in repeat_indices - ] - - steps = await evaluations_service.create_results( - project_id=project_id, - user_id=user_id, - # - results=results_create, - ) - - if len(steps) != nof_scenarios * len(repeat_indices): - raise ValueError(f"Failed to create evaluation steps for run {run_id}!") - # ---------------------------------------------------------------------- - - # flatten testcases ---------------------------------------------------- - _testcases = [ - scenario_spec["testcase"].model_dump(mode="json") - for scenario_spec in scenario_specs - ] - - log.info( - "[BATCH] ", - run_id=run_id, - ids=testset_revision_ids, - count=len(_testcases), - size=len(dumps(_testcases).encode("utf-8")), - ) - # ---------------------------------------------------------------------- - - run_has_errors = 0 - run_has_pending = False - run_status = EvaluationStatus.SUCCESS - - # run invocations / evaluators ----------------------------------------- - for idx in range(nof_scenarios): - scenario = scenarios[idx] - scenario_spec = scenario_specs[idx] - testcase = scenario_spec["testcase"] - testcase_data = scenario_spec["testcase_data"] - testset = scenario_spec["testset"] - testset_revision = scenario_spec["testset_revision"] - - scenario_has_errors = 0 - scenario_has_pending = False - scenario_status = EvaluationStatus.SUCCESS - application_references = { - "testcase": {"id": str(testcase.id)}, - "testset": {"id": str(testset.id)}, - "testset_variant": {"id": str(testset_revision.variant_id)}, - "testset_revision": {"id": str(testset_revision.id)}, - "application": {"id": str(application.id)}, - "application_variant": {"id": str(application_variant.id)}, - "application_revision": {"id": str(application_revision.id)}, - } - - application_hash_id = make_hash( - references=application_references, - links=None, - ) - cached_application_traces = [] - if is_cached and application_hash_id: - cached_application_traces = await fetch_traces_by_hash( - tracing_service, - project_id, - hash_id=application_hash_id, - limit=application_required_count, - ) - - cached_application_contexts = [] - for reusable_trace in select_traces_for_reuse( - traces=cached_application_traces, - required_count=application_required_count, - ): - reusable_context = _build_trace_context(trace=reusable_trace) - if reusable_context: - cached_application_contexts.append(reusable_context) - - missing_application_count = plan_missing_traces( - required_count=application_required_count, - reusable_count=len(cached_application_contexts), - ) - - invoked_application_contexts = [] - if missing_application_count > 0: - invocations: List[ - InvokationResult - ] = await llm_apps_service.batch_invoke( - project_id=str(project_id), - user_id=str(user_id), - testset_data=[ - testcase_data for _ in range(missing_application_count) - ], # type: ignore[arg-type] - revision=application_revision, - uri=uri, - rate_limit_config=run_config, - application_id=str(application.id), - references=application_references, - scenarios=[ - scenario.model_dump( - mode="json", - exclude_none=True, - ) - for _ in range(missing_application_count) - ], - ) - - if len(invocations) != missing_application_count: - raise ValueError( - f"Unexpected batch invocation count for scenario {scenario.id}!" - ) - - for invocation in invocations: - invocation_error = ( - invocation.result.error.model_dump(mode="json") - if invocation.result and invocation.result.error - else None - ) - invoked_trace = None - if not invocation_error and invocation.trace_id: - invoked_trace = await fetch_trace( - tracing_service=tracing_service, - project_id=project_id, - trace_id=invocation.trace_id, - ) - - invocation_context = ( - _build_trace_context( - trace=invoked_trace, - error=invocation_error, - ) - if invoked_trace - else None - ) - if invocation_context: - invoked_application_contexts.append(invocation_context) - else: - invoked_application_contexts.append( - { - "trace": invoked_trace, - "trace_id": invocation.trace_id, - "span_id": invocation.span_id, - "root_span": None, - "error": invocation_error - or { - "message": "Invocation trace missing or malformed." - }, - } - ) - - application_contexts = ( - cached_application_contexts + invoked_application_contexts - ) - application_context_by_repeat: Dict[int, Dict[str, Any]] = {} - if is_split: - for repeat_idx, context in zip(repeat_indices, application_contexts): - application_context_by_repeat[repeat_idx] = context - else: - shared_context = ( - application_contexts[0] if application_contexts else None - ) - if shared_context: - for repeat_idx in repeat_indices: - application_context_by_repeat[repeat_idx] = shared_context - - invocation_results_create = [] - scenario_invocation_failed = False - for repeat_idx in repeat_indices: - application_context = application_context_by_repeat.get(repeat_idx) - application_error = ( - application_context.get("error") - if application_context - else {"message": "Invocation trace missing."} - ) - has_invocation_error = not ( - application_context - and application_context.get("trace_id") - and application_context.get("root_span") - and not application_error - ) - if has_invocation_error: - scenario_invocation_failed = True - - invocation_results_create.append( - EvaluationResultCreate( - run_id=run_id, - scenario_id=scenario.id, - step_key=invocation_step_key, - repeat_idx=repeat_idx, - status=( - EvaluationStatus.FAILURE - if has_invocation_error - else EvaluationStatus.SUCCESS - ), - trace_id=( - application_context.get("trace_id") - if application_context - else None - ), - error=application_error if has_invocation_error else None, - ) - ) - - created_invocation_results = await evaluations_service.create_results( - project_id=project_id, - user_id=user_id, - results=invocation_results_create, - ) - if len(created_invocation_results) != len(repeat_indices): - raise ValueError( - f"Failed to create invocation results for scenario {scenario.id}!" - ) - - if scenario_invocation_failed: - scenario_has_errors += 1 - - for annotation_step in annotation_steps: - annotation_step_key = annotation_step.key - - if annotation_step.origin in {"human", "custom"}: - # log.debug( - # "[EVALUATOR] [SKIP]", - # run_id=run_id, - # scenario_id=scenario.id, - # step_key=annotation_step_key, - # origin=annotation_step.origin, - # reason="non-auto-origin", - # ) - scenario_has_pending = True - run_has_pending = True - continue - - evaluator_revision = evaluators.get(annotation_step_key) - if not evaluator_revision: - # log.warning( - # "[EVALUATOR] [MISSING]", - # run_id=run_id, - # scenario_id=scenario.id, - # step_key=annotation_step_key, - # references={ - # key: str(reference.id) - # if getattr(reference, "id", None) - # else None - # for key, reference in ( - # evaluator_references.get(annotation_step_key, {}) or {} - # ).items() - # }, - # ) - log.error( - f"Evaluator revision for {annotation_step_key} not found!" - ) - scenario_has_errors += 1 - scenario_status = EvaluationStatus.ERRORS - continue - - _revision = evaluator_revision.model_dump( - mode="json", - exclude_none=True, - ) - interface = ( - dict( - uri=evaluator_revision.data.uri, - url=evaluator_revision.data.url, - headers=evaluator_revision.data.headers, - schemas=evaluator_revision.data.schemas, - ) - if evaluator_revision.data - else dict() - ) - configuration = ( - dict( - script=evaluator_revision.data.script, - parameters=evaluator_revision.data.parameters, - ) - if evaluator_revision.data - else dict() - ) - parameters = configuration.get("parameters") - flags = ( - evaluator_revision.flags.model_dump( - mode="json", - exclude_none=True, - exclude_unset=True, - ) - if evaluator_revision.flags - else None - ) - - base_references: Dict[str, Any] = { - **evaluator_references[annotation_step_key], - "testcase": {"id": str(testcase.id)}, - "testset": {"id": str(testset.id)}, - "testset_variant": {"id": str(testset_revision.variant_id)}, - "testset_revision": {"id": str(testset_revision.id)}, - } - - evaluator_results_create = [] - if not is_split: - shared_application_context = application_context_by_repeat.get( - repeat_indices[0] - ) - # log.debug( - # "[EVALUATOR] [PLAN]", - # run_id=run_id, - # scenario_id=scenario.id, - # step_key=annotation_step_key, - # repeats=repeat_indices, - # is_split=is_split, - # has_shared_application_context=bool(shared_application_context), - # has_shared_root_span=bool( - # shared_application_context - # and shared_application_context.get("root_span") - # ), - # ) - if ( - not shared_application_context - or not shared_application_context.get("root_span") - ): - scenario_has_errors += 1 - scenario_status = EvaluationStatus.ERRORS - evaluator_results_create = [ - EvaluationResultCreate( - run_id=run_id, - scenario_id=scenario.id, - step_key=annotation_step_key, - repeat_idx=repeat_idx, - status=EvaluationStatus.FAILURE, - error={ - "message": "Evaluator skipped because invocation trace is missing." - }, - ) - for repeat_idx in repeat_indices - ] - else: - shared_trace = shared_application_context["trace"] - shared_root_span = shared_application_context["root_span"] - shared_links = { - invocation_step_key: { - "trace_id": shared_application_context["trace_id"], - "span_id": shared_application_context["span_id"], - } - } - workflow_service_request_data = WorkflowServiceRequestData( - revision=_revision, - parameters=parameters, - testcase=testcase.model_dump(mode="json"), - inputs=testcase.data, - trace=shared_trace.model_dump( - mode="json", - exclude_none=True, - ) - if shared_trace - else None, - outputs=( - ( - shared_root_span.model_dump( - mode="json", - exclude_none=True, - ) - .get("attributes", {}) - .get("ag", {}) - .get("data", {}) - ).get("outputs") - if shared_root_span - else None - ), - ) - workflow_service_request = WorkflowServiceRequest( - version="2025.07.14", - flags=flags, - interface=interface, - configuration=configuration, - data=workflow_service_request_data, - references=base_references, - links=shared_links, - ) - evaluator_hash_id = make_hash( - references=base_references, - links=shared_links, - ) - cached_evaluator_traces = [] - if is_cached and evaluator_hash_id: - cached_evaluator_traces = await fetch_traces_by_hash( - tracing_service, - project_id, - hash_id=evaluator_hash_id, - limit=evaluator_required_count, - ) - - reusable_evaluator_traces = select_traces_for_reuse( - traces=cached_evaluator_traces, - required_count=evaluator_required_count, - ) - evaluator_results_create.extend( - EvaluationResultCreate( - run_id=run_id, - scenario_id=scenario.id, - step_key=annotation_step_key, - repeat_idx=repeat_idx, - status=EvaluationStatus.SUCCESS, - trace_id=str(reusable_trace.trace_id), - ) - for repeat_idx, reusable_trace in zip( - repeat_indices, - reusable_evaluator_traces, - ) - if reusable_trace and reusable_trace.trace_id - ) - - for repeat_idx in repeat_indices[ - len(reusable_evaluator_traces) : - ]: - # log.debug( - # "[EVALUATOR] [INVOKE]", - # run_id=run_id, - # scenario_id=scenario.id, - # step_key=annotation_step_key, - # repeat_idx=repeat_idx, - # cached_reuse_count=len(reusable_evaluator_traces), - # trace_links=shared_links, - # ) - workflows_service_response = ( - await workflows_service.invoke_workflow( - project_id=project_id, - user_id=user_id, - request=workflow_service_request, - annotate=True, - ) - ) - has_error = workflows_service_response.status.code != 200 - result_trace_id = workflows_service_response.trace_id - result_error = None - result_status = EvaluationStatus.SUCCESS - - if has_error: - result_status = EvaluationStatus.FAILURE - result_error = ( - workflows_service_response.status.model_dump( - mode="json", - exclude_none=True, - ) - ) - scenario_has_errors += 1 - scenario_status = EvaluationStatus.ERRORS - elif result_trace_id: - fetched_evaluator_trace = await fetch_trace( - tracing_service=tracing_service, - project_id=project_id, - trace_id=result_trace_id, - ) - if not fetched_evaluator_trace: - result_status = EvaluationStatus.FAILURE - result_error = { - "message": "Evaluator trace missing after invocation." - } - scenario_has_errors += 1 - scenario_status = EvaluationStatus.ERRORS - else: - result_status = EvaluationStatus.FAILURE - result_error = { - "message": "Evaluator trace_id is missing." - } - scenario_has_errors += 1 - scenario_status = EvaluationStatus.ERRORS - - evaluator_results_create.append( - EvaluationResultCreate( - run_id=run_id, - scenario_id=scenario.id, - step_key=annotation_step_key, - repeat_idx=repeat_idx, - status=result_status, - trace_id=result_trace_id, - error=result_error, - ) - ) - else: - for repeat_idx in repeat_indices: - application_context = application_context_by_repeat.get( - repeat_idx - ) - # log.debug( - # "[EVALUATOR] [PLAN]", - # run_id=run_id, - # scenario_id=scenario.id, - # step_key=annotation_step_key, - # repeat_idx=repeat_idx, - # is_split=is_split, - # has_application_context=bool(application_context), - # has_root_span=bool( - # application_context - # and application_context.get("root_span") - # ), - # ) - if not application_context or not application_context.get( - "root_span" - ): - scenario_has_errors += 1 - scenario_status = EvaluationStatus.ERRORS - evaluator_results_create.append( - EvaluationResultCreate( - run_id=run_id, - scenario_id=scenario.id, - step_key=annotation_step_key, - repeat_idx=repeat_idx, - status=EvaluationStatus.FAILURE, - error={ - "message": "Evaluator skipped because invocation trace is missing." - }, - ) - ) - continue - - application_trace = application_context["trace"] - application_root_span = application_context["root_span"] - application_root_span_data = ( - application_root_span.model_dump( - mode="json", - exclude_none=True, - ) - .get("attributes", {}) - .get("ag", {}) - .get("data", {}) - ) - links = { - invocation_step_key: { - "trace_id": application_context["trace_id"], - "span_id": application_context["span_id"], - } - } - workflow_service_request = WorkflowServiceRequest( - version="2025.07.14", - flags=flags, - interface=interface, - configuration=configuration, - data=WorkflowServiceRequestData( - revision=_revision, - parameters=parameters, - testcase=testcase.model_dump(mode="json"), - inputs=testcase.data, - trace=application_trace.model_dump( - mode="json", - exclude_none=True, - ) - if application_trace - else None, - outputs=application_root_span_data.get("outputs"), - ), - references=base_references, - links=links, - ) - evaluator_hash_id = make_hash( - references=base_references, - links=links, - ) - cached_evaluator_trace = None - if is_cached and evaluator_hash_id: - cached_matches = await fetch_traces_by_hash( - tracing_service, - project_id, - hash_id=evaluator_hash_id, - limit=1, - ) - reusable_match = select_traces_for_reuse( - traces=cached_matches, - required_count=1, - ) - if reusable_match: - cached_evaluator_trace = reusable_match[0] - - if cached_evaluator_trace and cached_evaluator_trace.trace_id: - evaluator_results_create.append( - EvaluationResultCreate( - run_id=run_id, - scenario_id=scenario.id, - step_key=annotation_step_key, - repeat_idx=repeat_idx, - status=EvaluationStatus.SUCCESS, - trace_id=str(cached_evaluator_trace.trace_id), - ) - ) - continue - - # log.debug( - # "[EVALUATOR] [INVOKE]", - # run_id=run_id, - # scenario_id=scenario.id, - # step_key=annotation_step_key, - # repeat_idx=repeat_idx, - # trace_links=links, - # ) - workflows_service_response = ( - await workflows_service.invoke_workflow( - project_id=project_id, - user_id=user_id, - request=workflow_service_request, - annotate=True, - ) - ) - - result_trace_id = workflows_service_response.trace_id - result_error = None - result_status = EvaluationStatus.SUCCESS - has_error = workflows_service_response.status.code != 200 - if has_error: - result_status = EvaluationStatus.FAILURE - result_error = workflows_service_response.status.model_dump( - mode="json", - exclude_none=True, - ) - scenario_has_errors += 1 - scenario_status = EvaluationStatus.ERRORS - elif result_trace_id: - fetched_evaluator_trace = await fetch_trace( - tracing_service=tracing_service, - project_id=project_id, - trace_id=result_trace_id, - ) - if not fetched_evaluator_trace: - result_status = EvaluationStatus.FAILURE - result_error = { - "message": "Evaluator trace missing after invocation." - } - scenario_has_errors += 1 - scenario_status = EvaluationStatus.ERRORS - else: - result_status = EvaluationStatus.FAILURE - result_error = {"message": "Evaluator trace_id is missing."} - scenario_has_errors += 1 - scenario_status = EvaluationStatus.ERRORS - - evaluator_results_create.append( - EvaluationResultCreate( - run_id=run_id, - scenario_id=scenario.id, - step_key=annotation_step_key, - repeat_idx=repeat_idx, - status=result_status, - trace_id=result_trace_id, - error=result_error, - ) - ) - - created_annotation_results = await evaluations_service.create_results( - project_id=project_id, - user_id=user_id, - results=evaluator_results_create, - ) - # log.debug( - # "[EVALUATOR] [RESULTS]", - # run_id=run_id, - # scenario_id=scenario.id, - # step_key=annotation_step_key, - # created=len(created_annotation_results), - # expected=len(repeat_indices), - # ) - - if len(created_annotation_results) != len(repeat_indices): - raise ValueError( - f"Failed to create evaluation results for scenario with id {scenario.id}!" - ) - - final_scenario_status = ( - EvaluationStatus.PENDING - if scenario_status == EvaluationStatus.SUCCESS and scenario_has_pending - else scenario_status - ) - - if final_scenario_status == EvaluationStatus.ERRORS: - run_has_errors += 1 - - scenario_edit = EvaluationScenarioEdit( - id=scenario.id, - tags=scenario.tags, - meta=scenario.meta, - status=final_scenario_status, - ) - - scenario = await evaluations_service.edit_scenario( - project_id=project_id, - user_id=user_id, - # - scenario=scenario_edit, - ) - - if not scenario: - raise ValueError( - f"Failed to edit evaluation scenario with id {scenario.id}!" - ) - - if scenario_status != EvaluationStatus.FAILURE: - try: - metrics = await evaluations_service.refresh_metrics( - project_id=project_id, - user_id=user_id, - # - metrics=EvaluationMetricsRefresh( - run_id=run_id, - scenario_id=scenario.id, - ), - ) - - if not metrics: - log.warning( - f"Refreshing metrics failed for {run_id} | {scenario.id}" - ) - - except Exception: - log.warning( - f"Refreshing metrics failed for {run_id} | {scenario.id}", - exc_info=True, - ) - # ---------------------------------------------------------------------- - - if run_status != EvaluationStatus.FAILURE: - if run_has_errors: - run_status = EvaluationStatus.ERRORS - elif run_has_pending: - run_status = EvaluationStatus.RUNNING - else: - run_status = EvaluationStatus.SUCCESS - - except Exception as e: # pylint: disable=broad-exception-caught - log.error( - f"An error occurred during evaluation: {e}", - exc_info=True, - ) - - run_status = EvaluationStatus.FAILURE - - if not run: - log.info("[FAIL] ", run_id=run_id, project_id=project_id, user_id=user_id) - return - - if run_status != EvaluationStatus.FAILURE: - try: - metrics = await evaluations_service.refresh_metrics( - project_id=project_id, - user_id=user_id, - # - metrics=EvaluationMetricsRefresh( - run_id=run_id, - ), - ) - - if not metrics: - log.warning(f"Refreshing metrics failed for {run_id}") - - run_status = EvaluationStatus.FAILURE - - except Exception: # pylint: disable=broad-exception-caught - log.warning(f"Refreshing metrics failed for {run_id}", exc_info=True) - - run_status = EvaluationStatus.FAILURE - - # edit evaluation run status ----------------------------------------------- - run_edit = EvaluationRunEdit( - id=run_id, - # - name=run.name, - description=run.description, - # - tags=run.tags, - meta=run.meta, - # - status=run_status, - flags=run.flags, - # - data=run.data, - ) - - await evaluations_service.edit_run( - project_id=project_id, - user_id=user_id, - # - run=run_edit, - ) - - # edit meters to avoid counting failed evaluations -------------------------- - if run_status == EvaluationStatus.FAILURE and project is not None: - if is_ee(): - await check_entitlements( - organization_id=project.organization_id, - key=Counter.EVALUATIONS, - delta=-1, - ) - - log.info("[DONE] ", run_id=run_id, project_id=project_id, user_id=user_id) - - return - - -async def evaluate_batch_invocation( - *, - project_id: UUID, - user_id: UUID, - # - run_id: UUID, - # - tracing_service: TracingService, - testsets_service: TestsetsService, - applications_service: ApplicationsService, - evaluations_service: EvaluationsService, -): - """ - Run batch invocation over a testset without evaluator steps. - - This loop creates scenarios and input/invocation results, but does not - invoke evaluator workflows and does not refresh evaluation metrics. - """ - run = None - run_status = EvaluationStatus.SUCCESS - - try: - # ---------------------------------------------------------------------- - log.info( - "[SCOPE] ", run_id=run_id, project_id=project_id, user_id=user_id - ) - # ---------------------------------------------------------------------- - - # fetch project -------------------------------------------------------- - project = await get_project_by_id( - project_id=str(project_id), - ) - # ---------------------------------------------------------------------- - - # fetch run ------------------------------------------------------------ - run = await evaluations_service.fetch_run( - project_id=project_id, - run_id=run_id, - ) - - if not run: - raise ValueError(f"Evaluation run with id {run_id} not found!") - if not run.data or not run.data.steps: - raise ValueError(f"Evaluation run with id {run_id} has no steps!") - repeats = run.data.repeats or 1 - repeat_indices = build_repeat_indices(repeats) - is_cached = bool(run.flags.is_cached) if run.flags else False - application_required_count = required_traces_for_step( - repeats=repeats, - is_split=False, - step_kind="application", - has_evaluator_steps=False, - ) - - steps = run.data.steps - input_steps = [step for step in steps if step.type == "input"] - invocation_steps = [step for step in steps if step.type == "invocation"] - annotation_steps = [step for step in steps if step.type == "annotation"] - - if annotation_steps: - raise ValueError( - f"Evaluation run with id {run_id} contains annotation steps; " - "use evaluate_batch_testset instead." - ) - if not input_steps or len(invocation_steps) != 1: - raise ValueError( - f"Evaluation run with id {run_id} must have at least one input and exactly one invocation step." - ) - - invocation_step_key = invocation_steps[0].key - invocation_refs = invocation_steps[0].references or {} - - application_revision_ref = invocation_refs.get("application_revision") - if not application_revision_ref or not isinstance( - application_revision_ref.id, UUID - ): - raise ValueError( - f"Evaluation run with id {run_id} missing invocation.application_revision reference." - ) - # ---------------------------------------------------------------------- - - input_specs = await _resolve_testset_input_specs( - project_id=project_id, - input_steps=input_steps, - testsets_service=testsets_service, - ) - scenario_specs = [ - { - "input_step_key": input_spec["step_key"], - "testset": input_spec["testset"], - "testset_revision": input_spec["testset_revision"], - "testcase": testcase, - "testcase_data": testcase_data, - } - for input_spec in input_specs - for testcase, testcase_data in zip( - input_spec["testcases"], - input_spec["testcases_data"], - ) - ] - nof_scenarios = len(scenario_specs) - # ---------------------------------------------------------------------- - - # fetch application ---------------------------------------------------- - application_revision = await applications_service.fetch_application_revision( - project_id=project_id, - application_revision_ref=application_revision_ref, - ) - if not application_revision: - raise ValueError( - f"Application revision with id {application_revision_ref.id} not found!" - ) - - application_variant = await applications_service.fetch_application_variant( - project_id=project_id, - application_variant_ref=Reference( - id=application_revision.application_variant_id - ), - ) - if not application_variant: - raise ValueError( - f"Application variant with id {application_revision.application_variant_id} not found!" - ) - - application = await applications_service.fetch_application( - project_id=project_id, - application_ref=Reference(id=application_variant.application_id), - ) - if not application: - raise ValueError( - f"Application with id {application_variant.application_id} not found!" - ) - - uri = _resolve_runtime_uri(revision_data=application_revision.data) - if not uri: - raise ValueError( - f"No deployment URI found for revision {application_revision_ref.id}!" - ) - - # create scenarios ----------------------------------------------------- - scenarios = await evaluations_service.create_scenarios( - project_id=project_id, - user_id=user_id, - scenarios=[ - EvaluationScenarioCreate( - run_id=run_id, - status=EvaluationStatus.RUNNING, - ) - for _ in range(nof_scenarios) - ], - ) - if len(scenarios) != nof_scenarios: - raise ValueError(f"Failed to create evaluation scenarios for run {run_id}!") - # ---------------------------------------------------------------------- - - # create input results ------------------------------------------------- - input_results = await evaluations_service.create_results( - project_id=project_id, - user_id=user_id, - results=[ - EvaluationResultCreate( - run_id=run_id, - scenario_id=scenario.id, - step_key=scenario_specs[idx]["input_step_key"], - repeat_idx=repeat_idx, - status=EvaluationStatus.SUCCESS, - testcase_id=scenario_specs[idx]["testcase"].id, - ) - for idx, scenario in enumerate(scenarios) - for repeat_idx in repeat_indices - ], - ) - if len(input_results) != nof_scenarios * len(repeat_indices): - raise ValueError(f"Failed to create input results for run {run_id}!") - # ---------------------------------------------------------------------- - - # resolve cache / invoke application ----------------------------------- - run_config = { - "batch_size": 10, - "max_retries": 3, - "retry_delay": 3, - "delay_between_batches": 5, - } - scenario_invocations: Dict[tuple[int, int], Dict[str, Any]] = {} - for idx, scenario in enumerate(scenarios): - scenario_spec = scenario_specs[idx] - testcase = scenario_spec["testcase"] - testcase_data = scenario_spec["testcase_data"] - testset = scenario_spec["testset"] - testset_revision = scenario_spec["testset_revision"] - references = { - "testcase": {"id": str(testcase.id)}, - "testset": {"id": str(testset.id)}, - "testset_variant": {"id": str(testset_revision.variant_id)}, - "testset_revision": {"id": str(testset_revision.id)}, - "application": {"id": str(application.id)}, - "application_variant": {"id": str(application_variant.id)}, - "application_revision": {"id": str(application_revision.id)}, - } - hash_id = make_hash(references=references, links=None) - cached_traces = [] - if is_cached and hash_id: - cached_traces = await fetch_traces_by_hash( - tracing_service, - project_id, - hash_id=hash_id, - limit=application_required_count, - ) - reusable_traces = select_traces_for_reuse( - traces=cached_traces, - required_count=application_required_count, - ) - for repeat_idx, reusable_trace in zip(repeat_indices, reusable_traces): - scenario_invocations[(idx, repeat_idx)] = { - "status": EvaluationStatus.SUCCESS, - "trace_id": ( - str(reusable_trace.trace_id) - if reusable_trace and reusable_trace.trace_id - else None - ), - "error": None, - } - - missing_repeat_indices = repeat_indices[len(reusable_traces) :] - if missing_repeat_indices: - invocations = await llm_apps_service.batch_invoke( - project_id=str(project_id), - user_id=str(user_id), - testset_data=[ - testcase_data for _ in range(len(missing_repeat_indices)) - ], # type: ignore[arg-type] - revision=application_revision, - uri=uri, - rate_limit_config=run_config, - application_id=str(application.id), - references=references, - scenarios=[ - scenario.model_dump( - mode="json", - exclude_none=True, - ) - for _ in range(len(missing_repeat_indices)) - ], - ) - if len(invocations) != len(missing_repeat_indices): - raise ValueError( - f"Unexpected batch invocation count for scenario {scenario.id}!" - ) - for repeat_idx, invocation in zip(missing_repeat_indices, invocations): - invocation_error = ( - invocation.result.error.model_dump(mode="json") - if invocation.result and invocation.result.error - else None - ) - scenario_invocations[(idx, repeat_idx)] = { - "status": ( - EvaluationStatus.FAILURE - if invocation_error - else EvaluationStatus.SUCCESS - ), - "trace_id": invocation.trace_id, - "error": invocation_error, - } - # ---------------------------------------------------------------------- - - # create invocation results + finalize scenarios ------------------------ - run_has_errors = 0 - invocation_results = await evaluations_service.create_results( - project_id=project_id, - user_id=user_id, - results=[ - EvaluationResultCreate( - run_id=run_id, - scenario_id=scenario.id, - step_key=invocation_step_key, - repeat_idx=repeat_idx, - status=( - scenario_invocations.get((idx, repeat_idx), {}).get("status") - or EvaluationStatus.FAILURE - ), - trace_id=scenario_invocations.get((idx, repeat_idx), {}).get( - "trace_id" - ), - error=scenario_invocations.get((idx, repeat_idx), {}).get("error"), - ) - for idx, scenario in enumerate(scenarios) - for repeat_idx in repeat_indices - ], - ) - if len(invocation_results) != nof_scenarios * len(repeat_indices): - raise ValueError(f"Failed to create invocation results for run {run_id}!") - - for idx, scenario in enumerate(scenarios): - scenario_status = ( - EvaluationStatus.SUCCESS - if all( - scenario_invocations.get((idx, repeat_idx), {}).get("status") - == EvaluationStatus.SUCCESS - for repeat_idx in repeat_indices - ) - else EvaluationStatus.ERRORS - ) - if not all( - scenario_invocations.get((idx, repeat_idx), {}).get("status") - == EvaluationStatus.SUCCESS - for repeat_idx in repeat_indices - ): - run_has_errors += 1 - - edited_scenario = await evaluations_service.edit_scenario( - project_id=project_id, - user_id=user_id, - scenario=EvaluationScenarioEdit( - id=scenario.id, - tags=scenario.tags, - meta=scenario.meta, - status=scenario_status, - ), - ) - if not edited_scenario: - raise ValueError( - f"Failed to edit evaluation scenario with id {scenario.id}!" - ) - - if run_has_errors: - run_status = EvaluationStatus.ERRORS - # ---------------------------------------------------------------------- - - except Exception as e: # pylint: disable=broad-exception-caught - log.error( - f"An error occurred during batch invocation: {e}", - exc_info=True, - ) - run_status = EvaluationStatus.FAILURE - - if not run: - log.info("[FAIL] ", run_id=run_id, project_id=project_id, user_id=user_id) - return - - await evaluations_service.edit_run( - project_id=project_id, - user_id=user_id, - run=EvaluationRunEdit( - id=run_id, - name=run.name, - description=run.description, - tags=run.tags, - meta=run.meta, - status=run_status, - flags=run.flags, - data=run.data, - ), - ) - - if run_status == EvaluationStatus.FAILURE and is_ee(): - await check_entitlements( - organization_id=project.organization_id, # type: ignore[attr-defined] - key=Counter.EVALUATIONS, - delta=-1, - ) - - log.info("[DONE] ", run_id=run_id, project_id=project_id, user_id=user_id) - return - - -async def _evaluate_batch_items( - *, - project_id: UUID, - user_id: UUID, - # - run_id: UUID, - # - testcase_ids: Optional[List[UUID]] = None, - trace_ids: Optional[List[str]] = None, - input_step_key: Optional[str] = None, - # - tracing_service: Optional[TracingService] = None, - testcases_service: Optional[TestcasesService] = None, - workflows_service: WorkflowsService, - evaluations_service: EvaluationsService, -): - request = Request( - scope={ - "type": "http", - "http_version": "1.1", - "scheme": "http", - } - ) - request.state.project_id = str(project_id) - request.state.user_id = str(user_id) - - run: Optional[EvaluationRun] = None - scenarios = [] - run_status = EvaluationStatus.SUCCESS - - try: - run = await evaluations_service.fetch_run( - project_id=project_id, - run_id=run_id, - ) - if not run: - raise ValueError(f"Evaluation run with id {run_id} not found!") - if not run.flags or not run.flags.is_queue: - raise ValueError( - f"Evaluation run with id {run_id} is not configured for ad-hoc batching!" - ) - if not run.data or not run.data.steps: - raise ValueError(f"Evaluation run with id {run_id} has no data steps!") - repeats = run.data.repeats or 1 - repeat_indices = build_repeat_indices(repeats) - is_cached = bool(run.flags.is_cached) - - testcase_ids = testcase_ids or [] - trace_ids = trace_ids or [] - if not testcase_ids and not trace_ids: - raise ValueError( - f"Evaluation run with id {run_id} has no testcase_ids or trace_ids!" - ) - if trace_ids and tracing_service is None: - raise ValueError("tracing_service is required for trace batches") - if testcase_ids and testcases_service is None: - raise ValueError("testcases_service is required for testcase batches") - - steps = run.data.steps - input_steps = [step for step in steps if step.type == "input"] - invocation_steps = [step for step in steps if step.type == "invocation"] - annotation_steps = [step for step in steps if step.type == "annotation"] - - if input_step_key is not None: - matching_input_step = next( - (step for step in input_steps if step.key == input_step_key), - None, - ) - if matching_input_step is None: - raise ValueError( - f"Evaluation run with id {run_id} has no input step '{input_step_key}'!" - ) - else: - input_step_key = input_steps[0].key if input_steps else None - invocation_step_key = invocation_steps[0].key if invocation_steps else None - evaluator_references = { - step.key: step.references or {} for step in annotation_steps - } - evaluator_revisions: Dict[str, Any] = {} - for annotation_step_key, annotation_refs in evaluator_references.items(): - evaluator_revision_ref = annotation_refs.get("evaluator_revision") - evaluator_revisions[annotation_step_key] = ( - await workflows_service.fetch_workflow_revision( - project_id=project_id, - workflow_revision_ref=evaluator_revision_ref, - ) - if evaluator_revision_ref - else None - ) - - testcases = ( - await testcases_service.fetch_testcases( - project_id=project_id, - testcase_ids=testcase_ids, - ) - if testcase_ids - else [] - ) - testcases_by_id = { - testcase.id: testcase for testcase in testcases if testcase.id - } - - scenario_items = [] - for testcase_id in testcase_ids: - testcase = testcases_by_id.get(testcase_id) - scenario_items.append( - dict( - kind="testcase", - testcase=testcase, - testcase_id=testcase_id, - trace_id=None, - ) - ) - for trace_id in trace_ids: - scenario_items.append( - dict( - kind="trace", - testcase=None, - testcase_id=None, - trace_id=trace_id, - ) - ) - - scenarios = await evaluations_service.create_scenarios( - project_id=project_id, - user_id=user_id, - scenarios=[ - EvaluationScenarioCreate( - run_id=run_id, - status=EvaluationStatus.RUNNING, - ) - for _ in scenario_items - ], - ) - if len(scenarios) != len(scenario_items): - raise ValueError(f"Failed to create scenarios for run {run_id}") - - run_has_errors = False - run_has_pending = False - - for idx, scenario in enumerate(scenarios): - scenario_status = EvaluationStatus.SUCCESS - scenario_has_pending = False - scenario_item = scenario_items[idx] - - source_testcase = scenario_item["testcase"] - source_testcase_id = scenario_item["testcase_id"] - source_trace_id = scenario_item["trace_id"] - - _trace = None - inputs = None - outputs = None - query_span_id = None - - if source_testcase_id and source_testcase is None: - run_has_errors = True - scenario_status = EvaluationStatus.ERRORS - await evaluations_service.create_results( - project_id=project_id, - user_id=user_id, - results=[ - EvaluationResultCreate( - run_id=run_id, - scenario_id=scenario.id, - step_key=step.key, - repeat_idx=repeat_idx, - status=EvaluationStatus.ERRORS, - testcase_id=source_testcase_id, - error={ - "message": f"Testcase {source_testcase_id} not found." - }, - ) - for step in annotation_steps - for repeat_idx in repeat_indices - ], - ) - - if source_testcase_id and source_testcase and input_step_key: - input_results = await evaluations_service.create_results( - project_id=project_id, - user_id=user_id, - results=[ - EvaluationResultCreate( - run_id=run_id, - scenario_id=scenario.id, - step_key=input_step_key, - repeat_idx=repeat_idx, - status=EvaluationStatus.SUCCESS, - testcase_id=source_testcase_id, - ) - for repeat_idx in repeat_indices - ], - ) - if len(input_results) != len(repeat_indices): - raise ValueError( - f"Failed to create input result for scenario {scenario.id}" - ) - - if source_testcase and source_testcase.data: - inputs = source_testcase.data - - if source_trace_id: - trace = await fetch_trace( - project_id=project_id, - trace_id=source_trace_id, - tracing_service=tracing_service, - ) - if not trace or not isinstance(trace.spans, dict): - scenario_status = EvaluationStatus.ERRORS - run_has_errors = True - else: - root_span = list(trace.spans.values())[0] - if isinstance(root_span, list): - scenario_status = EvaluationStatus.ERRORS - run_has_errors = True - else: - query_span_id = root_span.span_id - _trace = trace.model_dump(mode="json", exclude_none=True) - _root_span = root_span.model_dump( - mode="json", exclude_none=True - ) - - root_span_attributes: dict = _root_span.get("attributes") or {} - root_span_ag: dict = root_span_attributes.get("ag") or {} - root_span_ag_data: dict = root_span_ag.get("data") or {} - outputs = root_span_ag_data.get("outputs") - if not inputs: - inputs = root_span_ag_data.get("inputs") - - if ( - source_trace_id - and input_step_key - and scenario_status == EvaluationStatus.SUCCESS - ): - input_results = await evaluations_service.create_results( - project_id=project_id, - user_id=user_id, - results=[ - EvaluationResultCreate( - run_id=run_id, - scenario_id=scenario.id, - step_key=input_step_key, - repeat_idx=repeat_idx, - status=EvaluationStatus.SUCCESS, - trace_id=source_trace_id, - ) - for repeat_idx in repeat_indices - ], - ) - if len(input_results) != len(repeat_indices): - raise ValueError( - f"Failed to create trace input result for scenario {scenario.id}" - ) - - if ( - source_trace_id - and invocation_step_key - and scenario_status == EvaluationStatus.SUCCESS - ): - invocation_results = await evaluations_service.create_results( - project_id=project_id, - user_id=user_id, - results=[ - EvaluationResultCreate( - run_id=run_id, - scenario_id=scenario.id, - step_key=invocation_step_key, - repeat_idx=repeat_idx, - status=EvaluationStatus.SUCCESS, - trace_id=source_trace_id, - ) - for repeat_idx in repeat_indices - ], - ) - if len(invocation_results) != len(repeat_indices): - raise ValueError( - f"Failed to create invocation result for scenario {scenario.id}" - ) - - if scenario_status == EvaluationStatus.SUCCESS: - for annotation_step in annotation_steps: - annotation_step_key = annotation_step.key - if annotation_step.origin in {"human", "custom"}: - scenario_has_pending = True - run_has_pending = True - # Human/custom steps are not auto-invoked here. - # Results are created later by the annotator via the annotation submission flow. - continue - - evaluator_revision = evaluator_revisions.get(annotation_step_key) - if not evaluator_revision: - run_has_errors = True - scenario_status = EvaluationStatus.ERRORS - continue - - _revision = evaluator_revision.model_dump( - mode="json", - exclude_none=True, - ) - interface = ( - dict( - uri=evaluator_revision.data.uri, - url=evaluator_revision.data.url, - headers=evaluator_revision.data.headers, - schemas=evaluator_revision.data.schemas, - ) - if evaluator_revision.data - else dict() - ) - configuration = ( - dict( - script=evaluator_revision.data.script, - parameters=evaluator_revision.data.parameters, - ) - if evaluator_revision.data - else dict() - ) - parameters = configuration.get("parameters") - flags = ( - evaluator_revision.flags.model_dump( - mode="json", - exclude_none=True, - exclude_unset=True, - ) - if evaluator_revision.flags - else None - ) - - links: Dict[str, Any] = {} - source_step_key = invocation_step_key or input_step_key - if source_step_key and source_trace_id and query_span_id: - links[source_step_key] = dict( - trace_id=source_trace_id, - span_id=query_span_id, - ) - - workflow_service_request = WorkflowServiceRequest( - version="2025.07.14", - flags=flags, - interface=interface, - configuration=configuration, - data=WorkflowServiceRequestData( - revision=_revision, - parameters=parameters, - testcase=( - source_testcase.model_dump( - mode="json", exclude_none=True - ) - if source_testcase - else None - ), - inputs=inputs, - trace=_trace, - outputs=outputs, - ), - references=evaluator_references.get(annotation_step_key, {}), - links=links, - ) - hash_references: Dict[str, Any] = { - **(evaluator_references.get(annotation_step_key, {}) or {}) - } - if source_testcase_id: - hash_references["testcase"] = {"id": str(source_testcase_id)} - - hash_id = make_hash( - references=hash_references, - links=links, - ) - cached_traces = [] - if is_cached and hash_id and tracing_service is not None: - cached_traces = await fetch_traces_by_hash( - tracing_service, - project_id, - hash_id=hash_id, - limit=len(repeat_indices), - ) - - reusable_traces = select_traces_for_reuse( - traces=cached_traces, - required_count=len(repeat_indices), - ) - _ = plan_missing_traces( - required_count=len(repeat_indices), - reusable_count=len(reusable_traces), - ) - - results_payload = [ - EvaluationResultCreate( - run_id=run_id, - scenario_id=scenario.id, - step_key=annotation_step_key, - repeat_idx=repeat_idx, - status=EvaluationStatus.SUCCESS, - testcase_id=source_testcase_id, - trace_id=str(reusable_trace.trace_id), - ) - for repeat_idx, reusable_trace in zip( - repeat_indices, - reusable_traces, - ) - if reusable_trace and reusable_trace.trace_id - ] - - for repeat_idx in repeat_indices[len(reusable_traces) :]: - workflows_service_response = ( - await workflows_service.invoke_workflow( - project_id=project_id, - user_id=user_id, - request=workflow_service_request, - annotate=True, - ) - ) - - has_error = workflows_service_response.status.code != 200 - result_trace_id = workflows_service_response.trace_id - result_error = None - result_status = EvaluationStatus.SUCCESS - if has_error: - result_status = EvaluationStatus.FAILURE - result_error = workflows_service_response.status.model_dump( - mode="json", - exclude_none=True, - ) - scenario_status = EvaluationStatus.ERRORS - run_has_errors = True - - results_payload.append( - EvaluationResultCreate( - run_id=run_id, - scenario_id=scenario.id, - step_key=annotation_step_key, - repeat_idx=repeat_idx, - status=result_status, - testcase_id=source_testcase_id, - trace_id=result_trace_id, - error=result_error, - ) - ) - - step_results = await evaluations_service.create_results( - project_id=project_id, - user_id=user_id, - results=results_payload, - ) - if len(step_results) != len(repeat_indices): - raise ValueError( - f"Failed to create annotation results for scenario {scenario.id}" - ) - - final_scenario_status = ( - EvaluationStatus.PENDING - if scenario_status == EvaluationStatus.SUCCESS and scenario_has_pending - else scenario_status - ) - await evaluations_service.edit_scenario( - project_id=project_id, - user_id=user_id, - scenario=EvaluationScenarioEdit( - id=scenario.id, - tags=scenario.tags, - meta=scenario.meta, - status=final_scenario_status, - ), - ) - - try: - await evaluations_service.refresh_metrics( - project_id=project_id, - user_id=user_id, - metrics=EvaluationMetricsRefresh( - run_id=run_id, - scenario_id=scenario.id, - ), - ) - except Exception: # pylint: disable=broad-exception-caught - log.warning( - f"Refreshing metrics failed for {run_id} | {scenario.id}", - exc_info=True, - ) - - if run_has_errors: - run_status = EvaluationStatus.ERRORS - elif run_has_pending: - run_status = EvaluationStatus.RUNNING - else: - run_status = EvaluationStatus.SUCCESS - - except Exception as e: # pylint: disable=broad-exception-caught - log.error( - f"An error occurred during batch items evaluation: {e}", - exc_info=True, - ) - run_status = EvaluationStatus.FAILURE - - if not run: - return - - # For ad-hoc/queue runs (multiple independent batches writing to the same run), - # re-fetch the current stored status and never downgrade it to a less severe state. - # This prevents a later successful batch from overwriting ERRORS from an earlier one. - if run.flags and run.flags.is_queue and run_status != EvaluationStatus.FAILURE: - _severity = { - EvaluationStatus.FAILURE: 4, - EvaluationStatus.ERRORS: 3, - EvaluationStatus.RUNNING: 2, - EvaluationStatus.SUCCESS: 1, - EvaluationStatus.PENDING: 0, - } - current_run = await evaluations_service.fetch_run( - project_id=project_id, - run_id=run_id, - ) - if current_run and current_run.status: - stored_severity = _severity.get(current_run.status, 0) - if stored_severity > _severity.get(run_status, 0): - run_status = current_run.status - - try: - if run_status != EvaluationStatus.FAILURE: - await evaluations_service.refresh_metrics( - project_id=project_id, - user_id=user_id, - metrics=EvaluationMetricsRefresh(run_id=run_id), - ) - except Exception: # pylint: disable=broad-exception-caught - log.warning(f"Refreshing metrics failed for {run_id}", exc_info=True) - run_status = EvaluationStatus.FAILURE - - await evaluations_service.edit_run( - project_id=project_id, - user_id=user_id, - run=EvaluationRunEdit( - id=run_id, - name=run.name, - description=run.description, - tags=run.tags, - meta=run.meta, - status=run_status, - flags=run.flags, - data=run.data, - ), - ) - - log.info("[DONE] ", run_id=run_id, project_id=project_id, user_id=user_id) - - return - - -async def evaluate_batch_traces( - *, - project_id: UUID, - user_id: UUID, - # - run_id: UUID, - trace_ids: List[str], - input_step_key: Optional[str] = None, - # - tracing_service: TracingService, - workflows_service: WorkflowsService, - evaluations_service: EvaluationsService, -): - return await _evaluate_batch_items( - project_id=project_id, - user_id=user_id, - run_id=run_id, - # - trace_ids=trace_ids, - input_step_key=input_step_key, - tracing_service=tracing_service, - workflows_service=workflows_service, - evaluations_service=evaluations_service, - ) - - -async def evaluate_batch_testcases( - *, - project_id: UUID, - user_id: UUID, - # - run_id: UUID, - testcase_ids: List[UUID], - input_step_key: Optional[str] = None, - # - tracing_service: TracingService, - testcases_service: TestcasesService, - workflows_service: WorkflowsService, - evaluations_service: EvaluationsService, -): - return await _evaluate_batch_items( - project_id=project_id, - user_id=user_id, - run_id=run_id, - # - testcase_ids=testcase_ids, - input_step_key=input_step_key, - tracing_service=tracing_service, - testcases_service=testcases_service, - workflows_service=workflows_service, - evaluations_service=evaluations_service, - ) diff --git a/api/oss/src/core/evaluations/tasks/live.py b/api/oss/src/core/evaluations/tasks/live.py deleted file mode 100644 index 7aa19e1f4e..0000000000 --- a/api/oss/src/core/evaluations/tasks/live.py +++ /dev/null @@ -1,859 +0,0 @@ -from typing import Dict, Any, Optional -from uuid import UUID -from datetime import datetime, timezone - -from oss.src.utils.logging import get_module_logger - -from oss.src.dbs.postgres.queries.dbes import ( - QueryArtifactDBE, - QueryVariantDBE, - QueryRevisionDBE, -) -from oss.src.dbs.postgres.testcases.dbes import ( - TestcaseBlobDBE, -) -from oss.src.dbs.postgres.testsets.dbes import ( - TestsetArtifactDBE, - TestsetVariantDBE, - TestsetRevisionDBE, -) -from oss.src.dbs.postgres.workflows.dbes import ( - WorkflowArtifactDBE, - WorkflowVariantDBE, - WorkflowRevisionDBE, -) - -from oss.src.dbs.postgres.tracing.dao import TracingDAO -from oss.src.dbs.postgres.blobs.dao import BlobsDAO -from oss.src.dbs.postgres.git.dao import GitDAO -from oss.src.dbs.postgres.evaluations.dao import EvaluationsDAO - -from oss.src.core.tracing.service import TracingService -from oss.src.core.queries.service import QueriesService -from oss.src.core.testcases.service import TestcasesService -from oss.src.core.testsets.service import TestsetsService -from oss.src.core.testsets.service import SimpleTestsetsService -from oss.src.core.workflows.service import WorkflowsService -from oss.src.core.evaluators.service import EvaluatorsService -from oss.src.core.evaluators.service import SimpleEvaluatorsService -from oss.src.core.evaluations.service import EvaluationsService -from oss.src.core.annotations.service import AnnotationsService - - -from oss.src.core.evaluations.types import ( - EvaluationMetricsRefresh, - EvaluationStatus, - EvaluationScenarioCreate, - EvaluationScenarioEdit, - EvaluationResultCreate, -) -from oss.src.core.shared.dtos import ( - Reference, - Traces, -) -from oss.src.core.tracing.dtos import ( - Filtering, - Windowing, - Formatting, - Format, - Focus, - TracingQuery, - LogicalOperator, -) -from oss.src.core.workflows.dtos import ( - WorkflowServiceRequestData, - WorkflowServiceRequest, -) -from oss.src.core.queries.dtos import ( - QueryRevisionData, - QueryRevision, -) -from oss.src.core.evaluators.dtos import ( - EvaluatorRevisionData, - EvaluatorRevision, -) - -from oss.src.core.evaluations.utils import ( - build_repeat_indices, - fetch_trace, - fetch_traces_by_hash, - make_hash, - select_traces_for_reuse, -) - - -log = get_module_logger(__name__) - - -# DBS -------------------------------------------------------------------------- - -tracing_dao = TracingDAO() - -testcases_dao = BlobsDAO( - BlobDBE=TestcaseBlobDBE, -) - -queries_dao = GitDAO( - ArtifactDBE=QueryArtifactDBE, - VariantDBE=QueryVariantDBE, - RevisionDBE=QueryRevisionDBE, -) - -testsets_dao = GitDAO( - ArtifactDBE=TestsetArtifactDBE, - VariantDBE=TestsetVariantDBE, - RevisionDBE=TestsetRevisionDBE, -) - -workflows_dao = GitDAO( - ArtifactDBE=WorkflowArtifactDBE, - VariantDBE=WorkflowVariantDBE, - RevisionDBE=WorkflowRevisionDBE, -) - -evaluations_dao = EvaluationsDAO() - -# CORE ------------------------------------------------------------------------- - -tracing_service = TracingService( - tracing_dao=tracing_dao, -) - -queries_service = QueriesService( - queries_dao=queries_dao, -) - -testcases_service = TestcasesService( - testcases_dao=testcases_dao, -) - -testsets_service = TestsetsService( - testsets_dao=testsets_dao, - testcases_service=testcases_service, -) - -simple_testsets_service = SimpleTestsetsService( - testsets_service=testsets_service, -) - -workflows_service = WorkflowsService( - workflows_dao=workflows_dao, -) - -evaluators_service = EvaluatorsService( - workflows_service=workflows_service, -) - -simple_evaluators_service = SimpleEvaluatorsService( - evaluators_service=evaluators_service, -) - -evaluations_service = EvaluationsService( - evaluations_dao=evaluations_dao, - tracing_service=tracing_service, - queries_service=queries_service, - testsets_service=testsets_service, - evaluators_service=evaluators_service, - # -) - -# APIS ------------------------------------------------------------------------- - -annotations_service = AnnotationsService( - tracing_service=tracing_service, - evaluators_service=evaluators_service, - simple_evaluators_service=simple_evaluators_service, -) - -# ------------------------------------------------------------------------------ - - -async def evaluate_live_query( - project_id: UUID, - user_id: UUID, - # - run_id: UUID, - # - newest: Optional[datetime] = None, - oldest: Optional[datetime] = None, - # - use_windowing: bool = False, -): - # count in minutes - timestamp = oldest or datetime.now(timezone.utc) - interval: Optional[int] = None - if newest and oldest: - interval = int((newest - oldest).total_seconds() / 60) - - try: - # ---------------------------------------------------------------------- - log.info( - "[SCOPE] ", - run_id=run_id, - project_id=project_id, - user_id=user_id, - ) - - log.info( - "[RANGE] ", - run_id=run_id, - timestamp=timestamp, - interval=interval, - newest=newest, - oldest=oldest, - use_windowing=use_windowing, - ) - # ---------------------------------------------------------------------- - - # fetch evaluation run ------------------------------------------------- - run = await evaluations_service.fetch_run( - project_id=project_id, - run_id=run_id, - ) - - if not run: - raise ValueError(f"Evaluation run with id {run_id} not found!") - - if not run.data: - raise ValueError(f"Evaluation run with id {run_id} has no data!") - - if not run.data.steps: - raise ValueError(f"Evaluation run with id {run_id} has no steps!") - - steps = run.data.steps - repeats = run.data.repeats or 1 - repeat_indices = build_repeat_indices(repeats) - is_cached = bool(getattr(run.flags, "is_cached", False)) - - input_steps = { - step.key: step - for step in steps - if step.type == "input" # -------- - } - invocation_steps = { - step.key: step for step in steps if step.type == "invocation" - } - annotation_steps = { - step.key: step for step in steps if step.type == "annotation" - } - - input_steps_keys = list(input_steps.keys()) - invocation_steps_keys = list(invocation_steps.keys()) # noqa: F841 - annotation_steps_keys = list(annotation_steps.keys()) - - nof_annotations = len(annotation_steps_keys) - # ---------------------------------------------------------------------- - - # initialize query variables ------------------------------------------- - query_revision_refs: Dict[str, Reference] = dict() - # - query_revisions: Dict[str, QueryRevision] = dict() - query_references: Dict[str, Dict[str, Reference]] = dict() - # - query_traces: Dict[str, Traces] = dict() - # ---------------------------------------------------------------------- - - # initialize evaluator variables --------------------------------------- - evaluator_revision_refs: Dict[str, Reference] = dict() - # - evaluator_revisions: Dict[str, EvaluatorRevision] = dict() - evaluator_references: Dict[str, Dict[str, Reference]] = dict() - # ---------------------------------------------------------------------- - - # get query steps references ------------------------------------------- - for input_step_key in input_steps_keys: - query_refs = input_steps[input_step_key].references - query_revision_ref = query_refs.get("query_revision") - - if query_revision_ref: - query_revision_refs[input_step_key] = query_revision_ref - - # ---------------------------------------------------------------------- - - # get evaluator steps references --------------------------------------- - for annotation_step_key in annotation_steps_keys: - evaluator_refs = annotation_steps[annotation_step_key].references - evaluator_revision_ref = evaluator_refs.get("evaluator_revision") - - if evaluator_revision_ref: - evaluator_revision_refs[annotation_step_key] = evaluator_revision_ref - # ---------------------------------------------------------------------- - - # fetch query revisions ------------------------------------------------ - for ( - query_step_key, - query_revision_ref, - ) in query_revision_refs.items(): - query_revision = await queries_service.fetch_query_revision( - project_id=project_id, - # - query_revision_ref=query_revision_ref, - ) - - if query_revision and not query_revision.data: - query_revision.data = QueryRevisionData() - - if ( - not query_revision - or not query_revision.id - or not query_revision.slug - or not query_revision.data - ): - log.warn( - f"Query revision with ref {query_revision_ref.model_dump(mode='json')} not found!" - ) - continue - - query_step = input_steps[query_step_key] - - query_revisions[query_step_key] = query_revision - query_references[query_step_key] = query_step.references - # ---------------------------------------------------------------------- - - # fetch evaluator revisions -------------------------------------------- - for ( - evaluator_step_key, - evaluator_revision_ref, - ) in evaluator_revision_refs.items(): - evaluator_revision = await evaluators_service.fetch_evaluator_revision( - project_id=project_id, - # - evaluator_revision_ref=evaluator_revision_ref, - ) - - if evaluator_revision and not evaluator_revision.data: - evaluator_revision.data = EvaluatorRevisionData() - - if ( - not evaluator_revision - or not evaluator_revision.id - or not evaluator_revision.slug - or not evaluator_revision.data - ): - log.warn( - f"Evaluator revision with ref {evaluator_revision_ref.model_dump(mode='json')} not found!" - ) - continue - - evaluator_step = annotation_steps[evaluator_step_key] - - evaluator_revisions[evaluator_step_key] = evaluator_revision - evaluator_references[evaluator_step_key] = evaluator_step.references - # ---------------------------------------------------------------------- - - # run query revisions -------------------------------------------------- - for query_step_key, query_revision in query_revisions.items(): - formatting = Formatting( - focus=Focus.TRACE, - format=Format.AGENTA, - ) - filtering = Filtering( - operator=LogicalOperator.AND, - conditions=list(), - ) - windowing = Windowing( - oldest=oldest, - newest=newest, - next=None, - limit=None, - order="ascending", - interval=None, - rate=None, - ) - - if query_revision.data: - if query_revision.data.filtering: - filtering = query_revision.data.filtering - - if query_revision.data.windowing: - query_windowing = query_revision.data.windowing - - if use_windowing: - windowing = Windowing( - oldest=query_windowing.oldest, - newest=query_windowing.newest, - limit=query_windowing.limit, - order=query_windowing.order, - rate=query_windowing.rate, - # next= - # interval= - ) - else: - windowing.rate = query_windowing.rate - - query = TracingQuery( - formatting=formatting, - filtering=filtering, - windowing=windowing, - ) - - query_traces_result = await tracing_service.query_traces( - project_id=project_id, - query=TracingQuery( - formatting=query.formatting, - filtering=query.filtering, - windowing=query.windowing, - ), - ) - - nof_traces = len(query_traces_result) - - log.info( - "[TRACES] ", - run_id=run_id, - count=nof_traces, - ) - - query_traces[query_step_key] = query_traces_result or [] - # ---------------------------------------------------------------------- - - total_traces = sum(len(traces) for traces in query_traces.values()) - if total_traces == 0: - return - - # run online evaluation ------------------------------------------------ - any_results_created = False - for query_step_key in query_traces.keys(): - query_step_traces = [ - trace - for trace in query_traces[query_step_key] - if trace and trace.trace_id - ] - if not query_step_traces: - continue - - # create scenarios ------------------------------------------------- - - nof_traces = len(query_step_traces) - - scenarios_create = [ - EvaluationScenarioCreate( - run_id=run_id, - timestamp=timestamp, - interval=interval, - # - status=EvaluationStatus.RUNNING, - ) - for _ in range(nof_traces) - ] - - scenarios = await evaluations_service.create_scenarios( - project_id=project_id, - user_id=user_id, - # - scenarios=scenarios_create, - ) - - if len(scenarios) != nof_traces: - log.error( - "[LIVE] Could not create evaluation scenarios", - run_id=run_id, - ) - continue - # ------------------------------------------------------------------ - - # create query steps ----------------------------------------------- - query_trace_ids = [ - trace.trace_id for trace in query_step_traces if trace.trace_id - ] - scenario_ids = [scenario.id for scenario in scenarios if scenario.id] - - results_create = [ - EvaluationResultCreate( - run_id=run_id, - scenario_id=scenario_id, - step_key=query_step_key, - repeat_idx=repeat_idx, - timestamp=timestamp, - interval=interval, - # - status=EvaluationStatus.SUCCESS, - # - trace_id=query_trace_id, - ) - for scenario_id, query_trace_id in zip(scenario_ids, query_trace_ids) - for repeat_idx in repeat_indices - ] - - results = await evaluations_service.create_results( - project_id=project_id, - user_id=user_id, - # - results=results_create, - ) - - if len(results) != nof_traces * len(repeat_indices): - raise ValueError( - f"Failed to create evaluation results for run {run_id}!" - ) - # ------------------------------------------------------------------ - - scenario_has_errors: Dict[int, int] = dict() - scenario_status: Dict[int, EvaluationStatus] = dict() - scenario_has_pending: Dict[int, bool] = dict() - - # iterate over query traces ---------------------------------------- - for idx, trace in enumerate(query_step_traces): - scenario_results_created = False - scenario_has_errors[idx] = 0 - scenario_status[idx] = EvaluationStatus.SUCCESS - scenario_has_pending[idx] = False - - scenario = scenarios[idx] - scenario_id = scenario_ids[idx] - query_trace_id = query_trace_ids[idx] - - if not isinstance(trace.spans, dict): - log.warn( - f"Trace with id {query_trace_id} has no root spans", - run_id=run_id, - ) - scenario_has_errors[idx] += 1 - scenario_status[idx] = EvaluationStatus.ERRORS - continue - - root_span = list(trace.spans.values())[0] - - if isinstance(root_span, list): - log.warn( - f"More than one root span for trace with id {query_trace_id}", - run_id=run_id, - ) - scenario_has_errors[idx] += 1 - scenario_status[idx] = EvaluationStatus.ERRORS - continue - - query_span_id = root_span.span_id - - log.info( - "[TRACE] ", - run_id=run_id, - trace_id=query_trace_id, - ) - - # run evaluator revisions -------------------------------------- - for jdx in range(nof_annotations): - annotation_step_key = annotation_steps_keys[jdx] - annotation_step = annotation_steps[annotation_step_key] - - if annotation_step.origin in {"human", "custom"}: - scenario_has_pending[idx] = True - continue - - step_status = EvaluationStatus.SUCCESS - - references: Dict[str, Any] = { - **evaluator_references[annotation_step_key], - } - links: Dict[str, Any] = { - query_step_key: dict( - trace_id=query_trace_id, - span_id=query_span_id, - ) - } - - # invoke annotation workflow ------------------------------- - evaluator_revision = evaluator_revisions[annotation_step_key] - - if not evaluator_revision: - log.error( - f"Evaluator revision for {annotation_step_key} not found!" - ) - scenario_has_errors[idx] += 1 - # run_has_errors += 1 - step_status = EvaluationStatus.FAILURE - scenario_status[idx] = EvaluationStatus.ERRORS - # run_status = EvaluationStatus.ERRORS - continue - - _revision = evaluator_revision.model_dump( - mode="json", - exclude_none=True, - ) - interface = ( - dict( - uri=evaluator_revision.data.uri, - url=evaluator_revision.data.url, - headers=evaluator_revision.data.headers, - schemas=evaluator_revision.data.schemas, - ) - if evaluator_revision.data - else dict() - ) - configuration = ( - dict( - script=evaluator_revision.data.script, - parameters=evaluator_revision.data.parameters, - ) - if evaluator_revision.data - else dict() - ) - parameters = configuration.get("parameters") - - _testcase = None - inputs = None - - _trace: Optional[dict] = ( - trace.model_dump( - mode="json", - exclude_none=True, - ) - if trace - else None - ) - - _root_span = root_span.model_dump(mode="json", exclude_none=True) - testcase_data = None - - root_span_attributes: dict = _root_span.get("attributes") or {} - root_span_attributes_ag: dict = root_span_attributes.get("ag") or {} - root_span_attributes_ag_data: dict = ( - root_span_attributes_ag.get("data") or {} - ) - root_span_attributes_ag_data_outputs = ( - root_span_attributes_ag_data.get("outputs") - ) - root_span_attributes_ag_data_inputs = ( - root_span_attributes_ag_data.get("inputs") - ) - - outputs = root_span_attributes_ag_data_outputs - inputs = testcase_data or root_span_attributes_ag_data_inputs - - workflow_service_request_data = WorkflowServiceRequestData( - revision=_revision, - parameters=parameters, - # - testcase=_testcase, - inputs=inputs, - # - trace=_trace, - outputs=outputs, - ) - - flags = ( - evaluator_revision.flags.model_dump( - mode="json", - exclude_none=True, - exclude_unset=True, - ) - if evaluator_revision.flags - else None - ) - - workflow_service_request = WorkflowServiceRequest( - version="2025.07.14", - # - flags=flags, - # - interface=interface, - configuration=configuration, - # - data=workflow_service_request_data, - # - references=references, - links=links, - ) - hash_id = make_hash(references=references, links=links) - cached_traces = [] - if is_cached and hash_id: - cached_traces = await fetch_traces_by_hash( - tracing_service, - project_id, - hash_id=hash_id, - limit=len(repeat_indices), - ) - - reusable_traces = select_traces_for_reuse( - traces=cached_traces, - required_count=len(repeat_indices), - ) - results_create = [ - EvaluationResultCreate( - run_id=run_id, - scenario_id=scenario_id, - step_key=annotation_step_key, - repeat_idx=repeat_idx, - # - timestamp=timestamp, - interval=interval, - # - status=EvaluationStatus.SUCCESS, - # - trace_id=str(reusable_trace.trace_id), - ) - for repeat_idx, reusable_trace in zip( - repeat_indices, - reusable_traces, - ) - if reusable_trace and reusable_trace.trace_id - ] - - for repeat_idx in repeat_indices[len(reusable_traces) :]: - log.info( - "Invoking evaluator... ", - scenario_id=scenario.id, - repeat_idx=repeat_idx, - trace_id=query_trace_id, - uri=interface.get("uri"), - ) - workflows_service_response = ( - await workflows_service.invoke_workflow( - project_id=project_id, - user_id=user_id, - # - request=workflow_service_request, - # - annotate=True, - ) - ) - log.info( - "Invoked evaluator ", - scenario_id=scenario.id, - repeat_idx=repeat_idx, - trace_id=workflows_service_response.trace_id, - ) - - trace_id = workflows_service_response.trace_id - error = None - step_status = EvaluationStatus.SUCCESS - has_error = workflows_service_response.status.code != 200 - - if has_error: - log.warn( - "There is an error in evaluator %s for query %s.", - annotation_step_key, - query_trace_id, - ) - step_status = EvaluationStatus.FAILURE - scenario_has_errors[idx] += 1 - scenario_status[idx] = EvaluationStatus.ERRORS - error = workflows_service_response.status.model_dump( - mode="json", - exclude_none=True, - ) - else: - annotation = workflows_service_response - trace_id = annotation.trace_id - - if not annotation.trace_id: - log.warn("annotation trace_id is missing.") - scenario_has_errors[idx] += 1 - scenario_status[idx] = EvaluationStatus.ERRORS - continue - - fetched_trace = await fetch_trace( - tracing_service=tracing_service, - project_id=project_id, - trace_id=annotation.trace_id, - ) - - if fetched_trace: - log.info( - "Trace found ", - scenario_id=scenario.id, - step_key=annotation_step_key, - trace_id=annotation.trace_id, - ) - else: - log.warn( - "Trace missing", - scenario_id=scenario.id, - step_key=annotation_step_key, - trace_id=annotation.trace_id, - ) - scenario_has_errors[idx] += 1 - scenario_status[idx] = EvaluationStatus.ERRORS - continue - - results_create.append( - EvaluationResultCreate( - run_id=run_id, - scenario_id=scenario_id, - step_key=annotation_step_key, - repeat_idx=repeat_idx, - # - timestamp=timestamp, - interval=interval, - # - status=step_status, - # - trace_id=trace_id, - error=error, - ) - ) - - results = await evaluations_service.create_results( - project_id=project_id, - user_id=user_id, - # - results=results_create, - ) - - if len(results) != len(repeat_indices): - raise ValueError( - f"Failed to create evaluation results for scenario with id {scenario.id}!" - ) - scenario_results_created = True - any_results_created = True - # -------------------------------------------------------------- - - scenario_edit = EvaluationScenarioEdit( - id=scenario.id, - tags=scenario.tags, - meta=scenario.meta, - status=( - EvaluationStatus.PENDING - if ( - scenario_status[idx] == EvaluationStatus.SUCCESS - and scenario_has_pending[idx] - ) - else scenario_status[idx] - ), - ) - - scenario = await evaluations_service.edit_scenario( - project_id=project_id, - user_id=user_id, - # - scenario=scenario_edit, - ) - - if not scenario or not scenario.id: - log.error( - f"Failed to update evaluation scenario with id {scenario_id}!", - run_id=run_id, - ) - - if scenario_results_created: - await evaluations_service.refresh_metrics( - project_id=project_id, - user_id=user_id, - # - metrics=EvaluationMetricsRefresh( - run_id=run_id, - scenario_id=scenario_id, - ), - ) - # ------------------------------------------------------------------ - - if any_results_created: - await evaluations_service.refresh_metrics( - project_id=project_id, - user_id=user_id, - # - metrics=EvaluationMetricsRefresh( - run_id=run_id, - timestamp=timestamp, - interval=interval, - ), - ) - except Exception as e: # pylint: disable=broad-exception-caught - log.error(e, exc_info=True) - - log.info( - "[DONE] ", - run_id=run_id, - ) - - return diff --git a/api/oss/src/core/evaluations/tasks/query.py b/api/oss/src/core/evaluations/tasks/query.py new file mode 100644 index 0000000000..6ca7cac0ca --- /dev/null +++ b/api/oss/src/core/evaluations/tasks/query.py @@ -0,0 +1,238 @@ +from typing import Optional +from uuid import UUID +from datetime import datetime, timezone + +from oss.src.utils.logging import get_module_logger + +from oss.src.dbs.postgres.queries.dbes import ( + QueryArtifactDBE, + QueryVariantDBE, + QueryRevisionDBE, +) +from oss.src.dbs.postgres.testcases.dbes import ( + TestcaseBlobDBE, +) +from oss.src.dbs.postgres.testsets.dbes import ( + TestsetArtifactDBE, + TestsetVariantDBE, + TestsetRevisionDBE, +) +from oss.src.dbs.postgres.workflows.dbes import ( + WorkflowArtifactDBE, + WorkflowVariantDBE, + WorkflowRevisionDBE, +) + +from oss.src.dbs.postgres.tracing.dao import TracingDAO +from oss.src.dbs.postgres.blobs.dao import BlobsDAO +from oss.src.dbs.postgres.git.dao import GitDAO +from oss.src.dbs.postgres.evaluations.dao import EvaluationsDAO + +from oss.src.core.tracing.service import TracingService +from oss.src.core.queries.service import QueriesService +from oss.src.core.testcases.service import TestcasesService +from oss.src.core.testsets.service import TestsetsService +from oss.src.core.testsets.service import SimpleTestsetsService +from oss.src.core.workflows.service import WorkflowsService +from oss.src.core.evaluators.service import EvaluatorsService +from oss.src.core.evaluators.service import SimpleEvaluatorsService +from oss.src.core.evaluations.service import EvaluationsService +from oss.src.core.annotations.service import AnnotationsService + + +from oss.src.core.evaluations.runtime.sources import resolve_query_source_items +from oss.src.core.evaluations.tasks.source_slice import process_evaluation_source_slice + + +log = get_module_logger(__name__) + + +# DBS -------------------------------------------------------------------------- + +tracing_dao = TracingDAO() + +testcases_dao = BlobsDAO( + BlobDBE=TestcaseBlobDBE, +) + +queries_dao = GitDAO( + ArtifactDBE=QueryArtifactDBE, + VariantDBE=QueryVariantDBE, + RevisionDBE=QueryRevisionDBE, +) + +testsets_dao = GitDAO( + ArtifactDBE=TestsetArtifactDBE, + VariantDBE=TestsetVariantDBE, + RevisionDBE=TestsetRevisionDBE, +) + +workflows_dao = GitDAO( + ArtifactDBE=WorkflowArtifactDBE, + VariantDBE=WorkflowVariantDBE, + RevisionDBE=WorkflowRevisionDBE, +) + +evaluations_dao = EvaluationsDAO() + +# CORE ------------------------------------------------------------------------- + +tracing_service = TracingService( + tracing_dao=tracing_dao, +) + +queries_service = QueriesService( + queries_dao=queries_dao, +) + +testcases_service = TestcasesService( + testcases_dao=testcases_dao, +) + +testsets_service = TestsetsService( + testsets_dao=testsets_dao, + testcases_service=testcases_service, +) + +simple_testsets_service = SimpleTestsetsService( + testsets_service=testsets_service, +) + +workflows_service = WorkflowsService( + workflows_dao=workflows_dao, +) + +evaluators_service = EvaluatorsService( + workflows_service=workflows_service, +) + +simple_evaluators_service = SimpleEvaluatorsService( + evaluators_service=evaluators_service, +) + +evaluations_service = EvaluationsService( + evaluations_dao=evaluations_dao, + tracing_service=tracing_service, + queries_service=queries_service, + testsets_service=testsets_service, + evaluators_service=evaluators_service, + # +) + +# APIS ------------------------------------------------------------------------- + +annotations_service = AnnotationsService( + tracing_service=tracing_service, + evaluators_service=evaluators_service, + simple_evaluators_service=simple_evaluators_service, +) + +# ------------------------------------------------------------------------------ + + +async def process_query_source_run( + project_id: UUID, + user_id: UUID, + # + run_id: UUID, + # + newest: Optional[datetime] = None, + oldest: Optional[datetime] = None, + # + use_windowing: bool = False, +): + # Backward-compatible live-query worker shell. Scheduling/windowing stays + # here for now; query source resolution and evaluator execution are routed + # through the unified runtime. + # count in minutes + timestamp = oldest or datetime.now(timezone.utc) + interval: Optional[int] = None + if newest and oldest: + interval = int((newest - oldest).total_seconds() / 60) + + try: + # ---------------------------------------------------------------------- + log.info( + "[SCOPE] ", + run_id=run_id, + project_id=project_id, + user_id=user_id, + ) + + log.info( + "[RANGE] ", + run_id=run_id, + timestamp=timestamp, + interval=interval, + newest=newest, + oldest=oldest, + use_windowing=use_windowing, + ) + # ---------------------------------------------------------------------- + + # fetch evaluation run ------------------------------------------------- + run = await evaluations_service.fetch_run( + project_id=project_id, + run_id=run_id, + ) + + if not run: + raise ValueError(f"Evaluation run with id {run_id} not found!") + + if not run.data: + raise ValueError(f"Evaluation run with id {run_id} has no data!") + + if not run.data.steps: + raise ValueError(f"Evaluation run with id {run_id} has no steps!") + + source_items_by_step = await resolve_query_source_items( + project_id=project_id, + run=run, + queries_service=queries_service, + tracing_service=tracing_service, + newest=newest, + oldest=oldest, + use_windowing=use_windowing, + ) + for query_step_key, source_items in source_items_by_step.items(): + log.info( + "[TRACES] ", + run_id=run_id, + count=len(source_items), + ) + # ---------------------------------------------------------------------- + + total_traces = sum( + len(source_items) for source_items in source_items_by_step.values() + ) + if total_traces == 0: + return + + for query_step_key, source_items in source_items_by_step.items(): + if not source_items: + continue + + await process_evaluation_source_slice( + project_id=project_id, + user_id=user_id, + run_id=run_id, + source_items=source_items, + input_step_key=query_step_key, + timestamp=timestamp, + interval=interval, + require_queue=False, + update_run_status=False, + refresh_metrics_without_auto_results=False, + tracing_service=tracing_service, + workflows_service=workflows_service, + evaluations_service=evaluations_service, + ) + except Exception as e: # pylint: disable=broad-exception-caught + log.error(e, exc_info=True) + + log.info( + "[DONE] ", + run_id=run_id, + ) + + return diff --git a/api/oss/src/core/evaluations/tasks/run.py b/api/oss/src/core/evaluations/tasks/run.py new file mode 100644 index 0000000000..8b81103d17 --- /dev/null +++ b/api/oss/src/core/evaluations/tasks/run.py @@ -0,0 +1,155 @@ +from datetime import datetime +from typing import Literal, Optional +from uuid import UUID + +from oss.src.core.applications.service import ApplicationsService +from oss.src.core.evaluations.runtime.topology import classify_run_topology +from oss.src.core.evaluations.service import EvaluationsService +from oss.src.core.evaluations.tasks.source_slice import ( + process_evaluation_source_slice, + process_testset_source_run, +) +from oss.src.core.evaluations.tasks.query import process_query_source_run +from oss.src.core.evaluators.service import SimpleEvaluatorsService +from oss.src.core.queries.service import QueriesService +from oss.src.core.testcases.service import TestcasesService +from oss.src.core.testsets.service import TestsetsService +from oss.src.core.tracing.service import TracingService +from oss.src.core.workflows.service import WorkflowsService +from oss.src.utils.logging import get_module_logger + +log = get_module_logger(__name__) + +EvaluationSliceSource = Literal["traces", "testcases"] + + +async def process_evaluation_run( + *, + project_id: UUID, + user_id: UUID, + run_id: UUID, + newest: Optional[datetime] = None, + oldest: Optional[datetime] = None, + tracing_service: TracingService, + testsets_service: TestsetsService, + queries_service: QueriesService, + workflows_service: WorkflowsService, + applications_service: ApplicationsService, + evaluations_service: EvaluationsService, + simple_evaluators_service: SimpleEvaluatorsService, +) -> bool: + run = await evaluations_service.fetch_run( + project_id=project_id, + run_id=run_id, + ) + if not run: + log.warning("[EVAL] [process-run] run not found", run_id=run_id) + return False + + topology = classify_run_topology(run) + + if topology.dispatch == "live_query": + await process_query_source_run( + project_id=project_id, + user_id=user_id, + run_id=run_id, + newest=newest, + oldest=oldest, + use_windowing=False, + ) + return True + + if topology.dispatch == "batch_query": + await process_query_source_run( + project_id=project_id, + user_id=user_id, + run_id=run_id, + newest=None, + oldest=None, + use_windowing=True, + ) + return True + + if topology.dispatch == "batch_testset": + await process_testset_source_run( + project_id=project_id, + user_id=user_id, + run_id=run_id, + tracing_service=tracing_service, + testsets_service=testsets_service, + workflows_service=workflows_service, + applications_service=applications_service, + evaluations_service=evaluations_service, + ) + return True + + if topology.dispatch == "batch_invocation": + await process_testset_source_run( + project_id=project_id, + user_id=user_id, + run_id=run_id, + tracing_service=tracing_service, + testsets_service=testsets_service, + workflows_service=workflows_service, + applications_service=applications_service, + evaluations_service=evaluations_service, + ) + return True + + log.warning( + "[EVAL] [process-run] unsupported run topology", + run_id=run_id, + topology=topology.label, + topology_status=topology.status, + reason=topology.reason, + ) + return False + + +async def process_evaluation_slice( + *, + project_id: UUID, + user_id: UUID, + run_id: UUID, + source_kind: EvaluationSliceSource, + trace_ids: Optional[list[str]] = None, + testcase_ids: Optional[list[UUID]] = None, + input_step_key: Optional[str] = None, + tracing_service: TracingService, + testcases_service: TestcasesService, + workflows_service: WorkflowsService, + evaluations_service: EvaluationsService, +) -> bool: + if source_kind == "traces": + await process_evaluation_source_slice( + project_id=project_id, + user_id=user_id, + run_id=run_id, + trace_ids=trace_ids or [], + input_step_key=input_step_key, + tracing_service=tracing_service, + workflows_service=workflows_service, + evaluations_service=evaluations_service, + ) + return True + + if source_kind == "testcases": + await process_evaluation_source_slice( + project_id=project_id, + user_id=user_id, + run_id=run_id, + testcase_ids=testcase_ids or [], + input_step_key=input_step_key, + tracing_service=tracing_service, + testcases_service=testcases_service, + workflows_service=workflows_service, + evaluations_service=evaluations_service, + ) + return True + + log.warning( + "[EVAL] [process-slice] unsupported source kind", + run_id=run_id, + source_kind=source_kind, + ) + return False diff --git a/api/oss/src/core/evaluations/tasks/source_slice.py b/api/oss/src/core/evaluations/tasks/source_slice.py new file mode 100644 index 0000000000..17af6aab9f --- /dev/null +++ b/api/oss/src/core/evaluations/tasks/source_slice.py @@ -0,0 +1,566 @@ +from typing import Dict, List, Optional, Any + +from uuid import UUID + +from agenta.sdk.evaluations.runtime.models import ( + EvaluationStep as SdkEvaluationStep, + ResolvedSourceItem as SdkResolvedSourceItem, +) +from agenta.sdk.evaluations.runtime.source_slice import ( + process_evaluation_source_slice as sdk_process_evaluation_source_slice, +) + +from oss.src.utils.logging import get_module_logger + +from oss.src.core.testcases.service import TestcasesService +from oss.src.core.testsets.service import TestsetsService +from oss.src.core.applications.service import ApplicationsService +from oss.src.core.workflows.service import WorkflowsService +from oss.src.core.evaluations.service import EvaluationsService + +from oss.src.core.tracing.service import TracingService + + +from oss.src.core.evaluations.types import ( + EvaluationStatus, + EvaluationRun, + EvaluationRunEdit, + EvaluationScenarioEdit, +) + +from oss.src.core.evaluations.utils import ( + effective_is_split, +) +from oss.src.core.evaluations.runtime.adapters import ( + BackendCachedRunner, + BackendMetricsRefresher, + BackendResultLogger, + BackendScenarioFactory, + BackendTraceLoader, + BackendWorkflowRunner, +) +from oss.src.core.evaluations.runtime.models import ResolvedSourceItem +from oss.src.core.evaluations.runtime.sources import ( + resolve_direct_source_items, + resolve_testset_input_specs, +) + + +log = get_module_logger(__name__) + + +async def _resolve_testset_input_specs( + *, + project_id: UUID, + input_steps: List[Any], + testsets_service: TestsetsService, +) -> List[Dict[str, Any]]: + return [ + { + "step_key": spec.step_key, + "testset": spec.testset, + "testset_revision": spec.testset_revision, + "testcases": spec.testcases, + "testcases_data": spec.testcases_data, + } + for spec in await resolve_testset_input_specs( + project_id=project_id, + input_steps=input_steps, + testsets_service=testsets_service, + ) + ] + + +async def process_testset_source_run( + *, + project_id: UUID, + user_id: UUID, + # + run_id: UUID, + # + tracing_service: TracingService, + testsets_service: TestsetsService, + workflows_service: WorkflowsService, + applications_service: ApplicationsService, + evaluations_service: EvaluationsService, +): + """Resolve testset rows, then process them through the unified source loop.""" + log.info( + "[WORKER] process_testset_source_run: start", + run_id=str(run_id), + project_id=str(project_id), + ) + + run = await evaluations_service.fetch_run( + project_id=project_id, + run_id=run_id, + ) + if not run: + raise ValueError(f"Evaluation run with id {run_id} not found!") + if not run.data or not run.data.steps: + raise ValueError(f"Evaluation run with id {run_id} has no data steps!") + + log.info( + "[WORKER] process_testset_source_run: run fetched", + run_id=str(run_id), + run_name=run.name, + run_status=str(run.status), + steps=[ + {"key": s.key, "type": s.type, "origin": s.origin} for s in run.data.steps + ], + repeats=run.data.repeats, + concurrency=run.data.concurrency.model_dump() if run.data.concurrency else None, + ) + + input_steps = [step for step in run.data.steps if step.type == "input"] + input_specs = await _resolve_testset_input_specs( + project_id=project_id, + input_steps=input_steps, + testsets_service=testsets_service, + ) + + log.info( + "[WORKER] process_testset_source_run: input specs resolved", + run_id=str(run_id), + input_specs=[ + { + "step_key": spec["step_key"], + "testset_id": str(spec["testset"].id), + "testset_revision_id": str(spec["testset_revision"].id), + "testcase_count": len(spec["testcases"]), + } + for spec in input_specs + ], + ) + + source_items = [ + ResolvedSourceItem( + kind="testcase", + step_key=input_spec["step_key"], + references={ + "testcase": {"id": str(testcase.id)}, + "testset": {"id": str(input_spec["testset"].id)}, + "testset_variant": { + "id": str(input_spec["testset_revision"].variant_id) + }, + "testset_revision": {"id": str(input_spec["testset_revision"].id)}, + }, + testcase_id=testcase.id, + testcase=testcase, + inputs=testcase_data, + ) + for input_spec in input_specs + for testcase, testcase_data in zip( + input_spec["testcases"], + input_spec["testcases_data"], + ) + ] + + return await process_evaluation_source_slice( + project_id=project_id, + user_id=user_id, + run_id=run_id, + source_items=source_items, + require_queue=False, + update_run_status=True, + refresh_metrics_without_auto_results=True, + tracing_service=tracing_service, + workflows_service=workflows_service, + applications_service=applications_service, + evaluations_service=evaluations_service, + ) + + +async def process_evaluation_source_slice( + *, + project_id: UUID, + user_id: UUID, + run_id: UUID, + testcase_ids: Optional[List[UUID]] = None, + trace_ids: Optional[List[str]] = None, + source_items: Optional[List[ResolvedSourceItem]] = None, + input_step_key: Optional[str] = None, + timestamp: Optional[Any] = None, + interval: Optional[int] = None, + require_queue: bool = True, + update_run_status: bool = True, + refresh_metrics_without_auto_results: bool = True, + tracing_service: Optional[TracingService] = None, + testcases_service: Optional[TestcasesService] = None, + workflows_service: Optional[WorkflowsService] = None, + applications_service: Optional[ApplicationsService] = None, + evaluations_service: EvaluationsService, +): + """Resolve backend adapters, then delegate execution to the SDK runtime.""" + log.info( + "[WORKER] process_evaluation_source_slice: start", + run_id=str(run_id), + project_id=str(project_id), + source_items_count=len(source_items) if source_items else 0, + testcase_ids_count=len(testcase_ids) if testcase_ids else 0, + trace_ids_count=len(trace_ids) if trace_ids else 0, + require_queue=require_queue, + update_run_status=update_run_status, + ) + + run: Optional[EvaluationRun] = None + run_status = EvaluationStatus.SUCCESS + + try: + run = await evaluations_service.fetch_run( + project_id=project_id, + run_id=run_id, + ) + if not run: + raise ValueError(f"Evaluation run with id {run_id} not found!") + if require_queue and ( + not run.flags or not (run.flags.has_traces or run.flags.has_testcases) + ): + raise ValueError( + f"Evaluation run with id {run_id} is not configured for ad-hoc batching!" + ) + if not run.data or not run.data.steps: + raise ValueError(f"Evaluation run with id {run_id} has no data steps!") + + steps = run.data.steps + input_steps = [step for step in steps if step.type == "input"] + invocation_steps = [step for step in steps if step.type == "invocation"] + annotation_steps = [step for step in steps if step.type == "annotation"] + + log.info( + "[WORKER] process_evaluation_source_slice: run fetched", + run_id=str(run_id), + run_name=run.name, + run_status=str(run.status), + total_steps=len(steps), + input_steps=[ + { + "key": s.key, + "references": { + k: (v.id if hasattr(v, "id") else v) + for k, v in (s.references or {}).items() + }, + } + for s in input_steps + ], + invocation_steps=[ + { + "key": s.key, + "references": { + k: str(v.id) if hasattr(v, "id") else v + for k, v in (s.references or {}).items() + }, + } + for s in invocation_steps + ], + annotation_steps=[ + { + "key": s.key, + "origin": s.origin, + "references": { + k: str(v.id) if hasattr(v, "id") else v + for k, v in (s.references or {}).items() + }, + } + for s in annotation_steps + ], + repeats=run.data.repeats, + concurrency=run.data.concurrency.model_dump() + if run.data.concurrency + else None, + flags=run.flags.model_dump() if run.flags else None, + ) + + if len(invocation_steps) > 1: + raise ValueError( + f"Evaluation run with id {run_id} has more than one invocation step." + ) + + if input_step_key is not None and not any( + step.key == input_step_key for step in input_steps + ): + raise ValueError( + f"Evaluation run with id {run_id} has no input step '{input_step_key}'!" + ) + + testcase_ids = testcase_ids or [] + trace_ids = trace_ids or [] + source_items = source_items or [] + if not source_items and not testcase_ids and not trace_ids: + raise ValueError( + f"Evaluation run with id {run_id} has no source items, testcase_ids, or trace_ids!" + ) + if trace_ids and tracing_service is None: + raise ValueError("tracing_service is required for trace batches") + if testcase_ids and testcases_service is None: + raise ValueError("testcases_service is required for testcase batches") + + if not source_items: + source_items = await resolve_direct_source_items( + project_id=project_id, + testcase_ids=testcase_ids, + trace_ids=trace_ids, + testcases_service=testcases_service, + tracing_service=tracing_service, + ) + effective_input_step_key = ( + input_step_key + or ( + source_items[0].step_key + if source_items and source_items[0].step_key + else None + ) + or (input_steps[0].key if input_steps else "") + ) + sdk_source_items = [ + SdkResolvedSourceItem( + kind=source_item.kind, + step_key=source_item.step_key or effective_input_step_key, + references=source_item.references or {}, + trace_id=source_item.trace_id, + span_id=source_item.span_id, + testcase_id=source_item.testcase_id, + testcase=source_item.testcase, + trace=source_item.trace, + inputs=source_item.inputs + or getattr(source_item.testcase, "data", None), + outputs=source_item.outputs, + ) + for source_item in source_items + ] + + sdk_steps = [ + SdkEvaluationStep( + key=step.key, + type=step.type, + origin=step.origin, + references=step.references or {}, + inputs=[step_input.key for step_input in (step.inputs or [])], + ) + for step in steps + ] + + runners: Dict[str, Any] = {} + revisions: Dict[str, Any] = {} + + if invocation_steps: + if applications_service is None: + raise ValueError( + "applications_service is required for invocation steps" + ) + if workflows_service is None: + raise ValueError("workflows_service is required for invocation steps") + invocation_step = invocation_steps[0] + application_revision_ref = invocation_step.references.get( + "application_revision" + ) + if not application_revision_ref or not isinstance( + application_revision_ref.id, UUID + ): + raise ValueError( + f"Evaluation run with id {run_id} missing invocation.application_revision reference." + ) + application_revision = ( + await applications_service.fetch_application_revision( + project_id=project_id, + application_revision_ref=application_revision_ref, + ) + ) + if application_revision is None: + raise ValueError( + f"App revision with id {application_revision_ref.id} not found!" + ) + runners[invocation_step.key] = BackendCachedRunner( + runner=BackendWorkflowRunner( + project_id=project_id, + user_id=user_id, + workflows_service=workflows_service, + ), + tracing_service=tracing_service, + project_id=project_id, + enabled=bool(run.flags and run.flags.is_cached), + ) + revisions[invocation_step.key] = application_revision + + auto_annotation_steps = [ + step for step in annotation_steps if step.origin not in {"human", "custom"} + ] + if auto_annotation_steps and workflows_service is None: + raise ValueError("workflows_service is required for auto annotation steps") + for annotation_step in auto_annotation_steps: + evaluator_revision_ref = (annotation_step.references or {}).get( + "evaluator_revision" + ) + evaluator_revision = ( + await workflows_service.fetch_workflow_revision( # type: ignore[union-attr] + project_id=project_id, + workflow_revision_ref=evaluator_revision_ref, + ) + if evaluator_revision_ref + else None + ) + if evaluator_revision is None: + continue + runners[annotation_step.key] = BackendCachedRunner( + runner=BackendWorkflowRunner( + project_id=project_id, + user_id=user_id, + workflows_service=workflows_service, + ), + tracing_service=tracing_service, + project_id=project_id, + enabled=bool(run.flags and run.flags.is_cached), + ) + revisions[annotation_step.key] = evaluator_revision + + log.info( + "[WORKER] process_evaluation_source_slice: runners/revisions resolved", + run_id=str(run_id), + runner_keys=list(runners.keys()), + revision_keys=list(revisions.keys()), + sdk_source_items_count=len(sdk_source_items), + sdk_steps=[{"key": s.key, "type": s.type} for s in sdk_steps], + ) + + processed = await sdk_process_evaluation_source_slice( + run_id=run_id, + source_items=sdk_source_items, + steps=sdk_steps, + repeats=run.data.repeats, + create_scenario=BackendScenarioFactory( + project_id=project_id, + user_id=user_id, + timestamp=timestamp, + interval=interval, + evaluations_service=evaluations_service, + ), + result_logger=BackendResultLogger( + project_id=project_id, + user_id=user_id, + timestamp=timestamp, + interval=interval, + evaluations_service=evaluations_service, + ), + refresh_metrics=BackendMetricsRefresher( + project_id=project_id, + user_id=user_id, + timestamp=timestamp, + interval=interval, + evaluations_service=evaluations_service, + ), + runners=runners, + revisions=revisions, + trace_loader=( + BackendTraceLoader( + project_id=project_id, + tracing_service=tracing_service, + ) + if tracing_service is not None + else None + ), + is_split=effective_is_split( + is_split=bool(run.flags and run.flags.is_split), + has_application_steps=bool(invocation_steps), + has_evaluator_steps=bool(annotation_steps), + ), + log_pending=False, + refresh_metrics_without_auto_results=refresh_metrics_without_auto_results, + batch_size=run.data.concurrency.batch_size + if run.data.concurrency + else None, + max_retries=run.data.concurrency.max_retries + if run.data.concurrency + else None, + retry_delay=run.data.concurrency.retry_delay + if run.data.concurrency + else None, + ) + + log.info( + "[WORKER] process_evaluation_source_slice: SDK complete", + run_id=str(run_id), + processed_count=len(processed), + scenarios_with_errors=sum(1 for i in processed if i.has_errors), + scenarios_with_pending=sum(1 for i in processed if i.has_pending), + scenarios_with_auto_results=sum( + 1 for i in processed if i.auto_results_created + ), + result_step_keys=[list(i.results.keys()) for i in processed], + ) + + for item in processed: + scenario_status = ( + EvaluationStatus.ERRORS + if item.has_errors + else EvaluationStatus.PENDING + if item.has_pending + else EvaluationStatus.SUCCESS + ) + await evaluations_service.edit_scenario( + project_id=project_id, + user_id=user_id, + scenario=EvaluationScenarioEdit( + id=item.scenario.id, + tags=getattr(item.scenario, "tags", None), + meta=getattr(item.scenario, "meta", None), + status=scenario_status, + ), + ) + + if any(item.has_errors for item in processed): + run_status = EvaluationStatus.ERRORS + elif any(item.has_pending for item in processed): + run_status = EvaluationStatus.RUNNING + else: + run_status = EvaluationStatus.SUCCESS + + except Exception as e: # pylint: disable=broad-exception-caught + log.error( + f"An error occurred during source slice evaluation: {e}", + exc_info=True, + ) + run_status = EvaluationStatus.FAILURE + + if not run: + return + + if ( + update_run_status + and run.flags + and (run.flags.has_traces or run.flags.has_testcases) + and run_status != EvaluationStatus.FAILURE + ): + severity = { + EvaluationStatus.FAILURE: 4, + EvaluationStatus.ERRORS: 3, + EvaluationStatus.RUNNING: 2, + EvaluationStatus.SUCCESS: 1, + EvaluationStatus.PENDING: 0, + } + current_run = await evaluations_service.fetch_run( + project_id=project_id, + run_id=run_id, + ) + if current_run and current_run.status: + stored_severity = severity.get(current_run.status, 0) + if stored_severity > severity.get(run_status, 0): + run_status = current_run.status + + if update_run_status: + await evaluations_service.edit_run( + project_id=project_id, + user_id=user_id, + run=EvaluationRunEdit( + id=run_id, + name=run.name, + description=run.description, + tags=run.tags, + meta=run.meta, + status=run_status, + flags=run.flags, + data=run.data, + ), + ) + + log.info("[DONE] ", run_id=run_id, project_id=project_id, user_id=user_id) + return diff --git a/api/oss/src/core/evaluations/types.py b/api/oss/src/core/evaluations/types.py index b9167f13a3..f56cfb488b 100644 --- a/api/oss/src/core/evaluations/types.py +++ b/api/oss/src/core/evaluations/types.py @@ -82,12 +82,14 @@ class EvaluationRunFlags(BaseModel): is_live: bool = False # Indicates if the run has live queries is_active: bool = False # Indicates if the run is currently active is_closed: bool = False # Indicates if the run is modifiable - is_queue: bool = False # Indicates this run belongs to a simple annotation queue + is_queue: bool = False # Indicates active default queue + active human work is_cached: bool = False # Indicates the run should reuse traces by hash is_split: bool = False # Indicates repeats fan out at the application step # - has_queries: bool = False # Indicates if the run has queries - has_testsets: bool = False # Indicates if the run has testsets + has_queries: bool = False # Indicates if the run has query-backed inputs + has_testsets: bool = False # Indicates if the run has testset-backed inputs + has_traces: bool = False # Indicates if the run has direct trace inputs + has_testcases: bool = False # Indicates if the run has direct testcase inputs has_evaluators: bool = False # Indicates if the run has evaluators # has_custom: bool = False # Indicates if the run has custom evaluators @@ -100,13 +102,19 @@ class EvaluationRunQueryFlags(BaseModel): is_active: Optional[bool] = None # Indicates if the run is currently active is_closed: Optional[bool] = None # Indicates if the run is modifiable is_queue: Optional[bool] = ( - None # Indicates this run belongs to a simple annotation queue + None # Indicates active default queue + active human work ) is_cached: Optional[bool] = None # Indicates the run should reuse traces by hash is_split: Optional[bool] = None # Indicates repeats fan out at the application step # - has_queries: Optional[bool] = None # Indicates if the run has queries - has_testsets: Optional[bool] = None # Indicates if the run has testsets + has_queries: Optional[bool] = None # Indicates if the run has query-backed inputs + has_testsets: Optional[bool] = ( + None # Indicates if the run has testset-backed inputs + ) + has_traces: Optional[bool] = None # Indicates if the run has direct trace inputs + has_testcases: Optional[bool] = ( + None # Indicates if the run has direct testcase inputs + ) has_evaluators: Optional[bool] = None # Indicates if the run has evaluators # has_custom: Optional[bool] = None # Indicates if the run has custom evaluators @@ -141,9 +149,16 @@ class EvaluationRunDataMapping(BaseModel): step: EvaluationRunDataMappingStep +class EvaluationRunDataConcurrency(BaseModel): + batch_size: Optional[int] = None + max_retries: Optional[int] = None + retry_delay: Optional[float] = None + + class EvaluationRunData(BaseModel): steps: Optional[List[EvaluationRunDataStep]] = None repeats: Optional[int] = 1 + concurrency: Optional[EvaluationRunDataConcurrency] = None mappings: Optional[List[EvaluationRunDataMapping]] = None @field_validator("repeats") @@ -387,10 +402,12 @@ class EvaluationMetricsSpecsRefresh(BaseModel): class EvaluationQueueFlags(BaseModel): is_sequential: bool = False + is_default: bool = False class EvaluationQueueQueryFlags(BaseModel): is_sequential: Optional[bool] = None + is_default: Optional[bool] = None class EvaluationQueueData(BaseModel): @@ -466,6 +483,7 @@ class EvaluationQueueEdit(Identifier, Header, Metadata): class EvaluationQueueQuery(Header, Metadata): flags: Optional[EvaluationQueueQueryFlags] = None # type: ignore + include_archived: Optional[bool] = None user_id: Optional[UUID] = None user_ids: Optional[List[UUID]] = None @@ -500,6 +518,7 @@ class SimpleEvaluationData(BaseModel): evaluator_steps: Optional[Target] = None repeats: Optional[int] = None + concurrency: Optional[EvaluationRunDataConcurrency] = None class SimpleEvaluation(Version, Identifier, Lifecycle, Header, Metadata): @@ -534,6 +553,8 @@ class SimpleEvaluationQuery(Header, Metadata): class SimpleQueueKind(str, Enum): + QUERIES = "queries" + TESTSETS = "testsets" TRACES = "traces" TESTCASES = "testcases" @@ -632,6 +653,15 @@ def validate_sources(self): if not has_kind and not has_queries and not has_testsets: raise ValueError("simple queue requires kind, queries, or testsets") + if has_queries and self.kind not in (None, SimpleQueueKind.QUERIES): + raise ValueError("query-backed queues must use kind='queries'") + if has_testsets and self.kind not in (None, SimpleQueueKind.TESTSETS): + raise ValueError("testset-backed queues must use kind='testsets'") + if self.kind == SimpleQueueKind.QUERIES and not has_queries: + raise ValueError("kind='queries' requires query sources") + if self.kind == SimpleQueueKind.TESTSETS and not has_testsets: + raise ValueError("kind='testsets' requires testset sources") + return self diff --git a/api/oss/src/core/evaluations/utils.py b/api/oss/src/core/evaluations/utils.py index 8cdc68905b..1679e15494 100644 --- a/api/oss/src/core/evaluations/utils.py +++ b/api/oss/src/core/evaluations/utils.py @@ -370,11 +370,12 @@ def effective_is_split( *, is_split: bool, is_live: bool = False, - is_queue: bool = False, + has_traces: bool = False, + has_testcases: bool = False, has_application_steps: bool = False, has_evaluator_steps: bool = False, ) -> bool: - if is_live or is_queue: + if is_live or has_traces or has_testcases: return False if not has_application_steps or not has_evaluator_steps: return False diff --git a/api/oss/src/core/events/streaming.py b/api/oss/src/core/events/streaming.py index c8d6f2b0ba..90e19726b4 100644 --- a/api/oss/src/core/events/streaming.py +++ b/api/oss/src/core/events/streaming.py @@ -4,7 +4,6 @@ from orjson import dumps, loads from pydantic import BaseModel -from redis.asyncio import Redis try: from asyncpg.pgproto.pgproto import UUID as AsyncpgUUID @@ -12,7 +11,7 @@ AsyncpgUUID = None from oss.src.core.events.dtos import Event -from oss.src.utils.env import env +from oss.src.dbs.redis.shared.engine import get_streams_engine from oss.src.utils.logging import get_module_logger log = get_module_logger(__name__) @@ -24,16 +23,9 @@ def _orjson_default(obj): raise TypeError(f"Type is not JSON serializable: {type(obj)}") -_redis: Optional[Redis] = None - - -def _get_redis() -> Optional[Redis]: - global _redis - - if _redis is None and env.redis.uri_durable: - _redis = Redis.from_url(env.redis.uri_durable, decode_responses=False) - - return _redis +def _get_redis(): + engine = get_streams_engine() + return engine.get_redis() if engine else None class EventMessage(BaseModel): diff --git a/api/oss/src/core/tracing/streaming.py b/api/oss/src/core/tracing/streaming.py index eb41d27aa7..81a145aa77 100644 --- a/api/oss/src/core/tracing/streaming.py +++ b/api/oss/src/core/tracing/streaming.py @@ -1,29 +1,20 @@ import zlib -from typing import List, Optional +from typing import List from uuid import UUID from orjson import dumps, loads from pydantic import BaseModel -from redis.asyncio import Redis -from oss.src.utils.env import env +from oss.src.dbs.redis.shared.engine import get_streams_engine from oss.src.utils.logging import get_module_logger from oss.src.core.tracing.dtos import OTelFlatSpan log = get_module_logger(__name__) -_redis: Optional[Redis] = None -def _get_redis() -> Redis: - global _redis - - if _redis is None: - if not env.redis.uri_durable: - raise RuntimeError("REDIS_URI_DURABLE is required for tracing streams.") - _redis = Redis.from_url(env.redis.uri_durable, decode_responses=False) - - return _redis +def _get_redis(): + return get_streams_engine().get_redis() class SpanMessage(BaseModel): diff --git a/api/oss/src/dbs/postgres/blobs/dao.py b/api/oss/src/dbs/postgres/blobs/dao.py index a14bc1bc5b..527e5e9e8a 100644 --- a/api/oss/src/dbs/postgres/blobs/dao.py +++ b/api/oss/src/dbs/postgres/blobs/dao.py @@ -15,7 +15,10 @@ from oss.src.dbs.postgres.shared.utils import apply_windowing from oss.src.dbs.postgres.shared.exceptions import check_entity_creation_conflict -from oss.src.dbs.postgres.shared.engine import engine +from oss.src.dbs.postgres.shared.engine import ( + TransactionsEngine, + get_transactions_engine, +) from oss.src.dbs.postgres.blobs.mappings import map_dbe_to_dto, map_dto_to_dbe @@ -30,8 +33,12 @@ def __init__( self, *, BlobDBE: Type[T], + engine: TransactionsEngine = None, ): self.BlobDBE = BlobDBE # pylint: disable=invalid-name + if engine is None: + engine = get_transactions_engine() + self.engine = engine # ─ blobs ────────────────────────────────────────────────────────────────── @@ -63,7 +70,7 @@ async def add_blob( ) try: - async with engine.core_session() as session: + async with self.engine.session() as session: stmt = select(self.BlobDBE).filter( self.BlobDBE.project_id == project_id, # type: ignore ) @@ -106,7 +113,7 @@ async def fetch_blob( # blob_id: UUID, ) -> Optional[Blob]: - async with engine.core_session() as session: + async with self.engine.session() as session: stmt = select(self.BlobDBE).filter( self.BlobDBE.project_id == project_id, # type: ignore ) @@ -138,7 +145,7 @@ async def edit_blob( # blob_edit: BlobEdit, ) -> Optional[Blob]: - async with engine.core_session() as session: + async with self.engine.session() as session: stmt = select(self.BlobDBE).filter( self.BlobDBE.project_id == project_id, # type: ignore ) @@ -179,7 +186,7 @@ async def remove_blob( # blob_id: UUID, ) -> Optional[Blob]: - async with engine.core_session() as session: + async with self.engine.session() as session: stmt = select(self.BlobDBE).filter( self.BlobDBE.project_id == project_id, # type: ignore ) @@ -239,7 +246,7 @@ async def add_blobs( blob_ids = [blob.id for blob in blobs] try: - async with engine.core_session() as session: + async with self.engine.session() as session: stmt = select(self.BlobDBE).filter( self.BlobDBE.project_id == project_id, # type: ignore ) @@ -295,7 +302,7 @@ async def fetch_blobs( # blob_ids: List[UUID], ) -> List[Blob]: - async with engine.core_session() as session: + async with self.engine.session() as session: stmt = select(self.BlobDBE).filter( self.BlobDBE.project_id == project_id, # type: ignore ) @@ -331,7 +338,7 @@ async def edit_blobs( # blob_edits: List[BlobEdit], ) -> List[Blob]: - async with engine.core_session() as session: + async with self.engine.session() as session: stmt = select(self.BlobDBE).filter( self.BlobDBE.project_id == project_id, # type: ignore ) @@ -381,7 +388,7 @@ async def remove_blobs( # blob_ids: List[UUID], ) -> List[Blob]: - async with engine.core_session() as session: + async with self.engine.session() as session: stmt = select(self.BlobDBE).filter( self.BlobDBE.project_id == project_id, # type: ignore ) @@ -422,7 +429,7 @@ async def query_blobs( # windowing: Optional[Windowing] = None, ) -> List[Blob]: - async with engine.core_session() as session: + async with self.engine.session() as session: stmt = select(self.BlobDBE).filter( self.BlobDBE.project_id == project_id, # type: ignore ) diff --git a/api/oss/src/dbs/postgres/evaluations/dao.py b/api/oss/src/dbs/postgres/evaluations/dao.py index 7fd522e215..37cd6f20e7 100644 --- a/api/oss/src/dbs/postgres/evaluations/dao.py +++ b/api/oss/src/dbs/postgres/evaluations/dao.py @@ -47,7 +47,10 @@ from oss.src.dbs.postgres.shared.utils import apply_windowing from oss.src.dbs.postgres.shared.exceptions import check_entity_creation_conflict -from oss.src.dbs.postgres.shared.engine import engine +from oss.src.dbs.postgres.shared.engine import ( + TransactionsEngine, + get_transactions_engine, +) from oss.src.dbs.postgres.evaluations.utils import ( create_run_references, edit_run_references, @@ -74,8 +77,10 @@ class EvaluationsDAO(EvaluationsDAOInterface): - def __init__(self): - pass + def __init__(self, engine: TransactionsEngine = None): + if engine is None: + engine = get_transactions_engine() + self.engine = engine # - EVALUATION RUN --------------------------------------------------------- @@ -113,7 +118,7 @@ async def create_run( run_dbe.data = _run.data.model_dump(mode="json") # type: ignore try: - async with engine.core_session() as session: + async with self.engine.session() as session: session.add(run_dbe) await session.commit() @@ -172,7 +177,7 @@ async def create_runs( run_dbes.append(run_dbe) try: - async with engine.core_session() as session: + async with self.engine.session() as session: session.add_all(run_dbes) await session.commit() @@ -200,7 +205,7 @@ async def fetch_run( # run_id: UUID, ) -> Optional[EvaluationRun]: - async with engine.core_session() as session: + async with self.engine.session() as session: stmt = select(EvaluationRunDBE).filter( EvaluationRunDBE.project_id == project_id, ) @@ -233,7 +238,7 @@ async def fetch_runs( # run_ids: List[UUID], ) -> List[EvaluationRun]: - async with engine.core_session() as session: + async with self.engine.session() as session: stmt = select(EvaluationRunDBE).filter( EvaluationRunDBE.project_id == project_id, ) @@ -267,7 +272,7 @@ async def edit_run( # run: EvaluationRunEdit, ) -> Optional[EvaluationRun]: - async with engine.core_session() as session: + async with self.engine.session() as session: stmt = select(EvaluationRunDBE).filter( EvaluationRunDBE.project_id == project_id, ) @@ -332,7 +337,7 @@ async def edit_runs( ) -> List[EvaluationRun]: run_ids = [run.id for run in runs] - async with engine.core_session() as session: + async with self.engine.session() as session: stmt = select(EvaluationRunDBE).filter( EvaluationRunDBE.project_id == project_id, ) @@ -406,7 +411,7 @@ async def delete_run( # run_id: UUID, ) -> Optional[UUID]: - async with engine.core_session() as session: + async with self.engine.session() as session: stmt = select(EvaluationRunDBE).filter( EvaluationRunDBE.project_id == project_id, ) @@ -438,7 +443,7 @@ async def delete_runs( # run_ids: List[UUID], ) -> List[UUID]: - async with engine.core_session() as session: + async with self.engine.session() as session: stmt = select(EvaluationRunDBE).filter( EvaluationRunDBE.project_id == project_id, ) @@ -478,7 +483,7 @@ async def close_run( # status: Optional[EvaluationStatus] = None, ) -> Optional[EvaluationRun]: - async with engine.core_session() as session: + async with self.engine.session() as session: stmt = select(EvaluationRunDBE).filter( EvaluationRunDBE.project_id == project_id, EvaluationRunDBE.id == run_id, @@ -526,7 +531,7 @@ async def close_runs( # run_ids: List[UUID], ) -> List[EvaluationRun]: - async with engine.core_session() as session: + async with self.engine.session() as session: stmt = select(EvaluationRunDBE).filter( EvaluationRunDBE.project_id == project_id, EvaluationRunDBE.id.in_(run_ids), @@ -576,7 +581,7 @@ async def open_run( # run_id: UUID, ) -> Optional[EvaluationRun]: - async with engine.core_session() as session: + async with self.engine.session() as session: stmt = select(EvaluationRunDBE).filter( EvaluationRunDBE.project_id == project_id, EvaluationRunDBE.id == run_id, @@ -620,7 +625,7 @@ async def open_runs( # run_ids: List[UUID], ) -> List[EvaluationRun]: - async with engine.core_session() as session: + async with self.engine.session() as session: stmt = select(EvaluationRunDBE).filter( EvaluationRunDBE.project_id == project_id, EvaluationRunDBE.id.in_(run_ids), @@ -671,7 +676,7 @@ async def query_runs( # windowing: Optional[Windowing] = None, ) -> List[EvaluationRun]: - async with engine.core_session() as session: + async with self.engine.session() as session: stmt = select(EvaluationRunDBE).filter( EvaluationRunDBE.project_id == project_id, ) @@ -759,7 +764,7 @@ async def fetch_live_runs( *, windowing: Optional[Windowing] = None, ) -> List[Tuple[UUID, EvaluationRun]]: - async with engine.core_session() as session: + async with self.engine.session() as session: stmt = select(EvaluationRunDBE) stmt = stmt.filter( @@ -841,7 +846,7 @@ async def create_scenario( ) try: - async with engine.core_session() as session: + async with self.engine.session() as session: session.add(scenario_dbe) await session.commit() @@ -900,7 +905,7 @@ async def create_scenarios( ] try: - async with engine.core_session() as session: + async with self.engine.session() as session: session.add_all(scenario_dbes) await session.commit() @@ -928,7 +933,7 @@ async def fetch_scenario( # scenario_id: UUID, ) -> Optional[EvaluationScenario]: - async with engine.core_session() as session: + async with self.engine.session() as session: stmt = select(EvaluationScenarioDBE).filter( EvaluationScenarioDBE.project_id == project_id, ) @@ -961,7 +966,7 @@ async def fetch_scenarios( # scenario_ids: List[UUID], ) -> List[EvaluationScenario]: - async with engine.core_session() as session: + async with self.engine.session() as session: stmt = select(EvaluationScenarioDBE).filter( EvaluationScenarioDBE.project_id == project_id, ) @@ -995,7 +1000,7 @@ async def edit_scenario( # scenario: EvaluationScenarioEdit, ) -> Optional[EvaluationScenario]: - async with engine.core_session() as session: + async with self.engine.session() as session: stmt = select(EvaluationScenarioDBE).filter( EvaluationScenarioDBE.project_id == project_id, ) @@ -1052,7 +1057,7 @@ async def edit_scenarios( ) -> List[EvaluationScenario]: scenario_ids = [scenario.id for scenario in scenarios] - async with engine.core_session() as session: + async with self.engine.session() as session: stmt = select(EvaluationScenarioDBE).filter( EvaluationScenarioDBE.project_id == project_id, ) @@ -1116,7 +1121,7 @@ async def delete_scenario( # scenario_id: UUID, ) -> Optional[UUID]: - async with engine.core_session() as session: + async with self.engine.session() as session: stmt = select(EvaluationScenarioDBE).filter( EvaluationScenarioDBE.project_id == project_id, ) @@ -1160,7 +1165,7 @@ async def delete_scenarios( # scenario_ids: List[UUID], ) -> List[UUID]: - async with engine.core_session() as session: + async with self.engine.session() as session: stmt = select(EvaluationScenarioDBE).filter( EvaluationScenarioDBE.project_id == project_id, ) @@ -1207,7 +1212,7 @@ async def query_scenario_ids( # scenario: Optional[EvaluationScenarioQuery] = None, ) -> List[UUID]: - async with engine.core_session() as session: + async with self.engine.session() as session: stmt = select(EvaluationScenarioDBE.id).filter( EvaluationScenarioDBE.project_id == project_id, ) @@ -1254,7 +1259,7 @@ async def query_scenarios( # windowing: Optional[Windowing] = None, ) -> List[EvaluationScenario]: - async with engine.core_session() as session: + async with self.engine.session() as session: stmt = select(EvaluationScenarioDBE).filter( EvaluationScenarioDBE.project_id == project_id, ) @@ -1381,7 +1386,7 @@ async def create_result( ) try: - async with engine.core_session() as session: + async with self.engine.session() as session: session.add(result_dbe) await session.commit() @@ -1418,7 +1423,7 @@ async def create_results( run_id=result.run_id, ) - async with engine.core_session() as session: + async with self.engine.session() as session: _results = [ EvaluationResult( **result.model_dump( @@ -1462,7 +1467,7 @@ async def fetch_result( # result_id: UUID, ) -> Optional[EvaluationResult]: - async with engine.core_session() as session: + async with self.engine.session() as session: stmt = select(EvaluationResultDBE).filter( EvaluationResultDBE.project_id == project_id, ) @@ -1495,7 +1500,7 @@ async def fetch_results( # result_ids: List[UUID], ) -> List[EvaluationResult]: - async with engine.core_session() as session: + async with self.engine.session() as session: stmt = select(EvaluationResultDBE).filter( EvaluationResultDBE.project_id == project_id, ) @@ -1529,7 +1534,7 @@ async def edit_result( # result: EvaluationResultEdit, ) -> Optional[EvaluationResult]: - async with engine.core_session() as session: + async with self.engine.session() as session: stmt = select(EvaluationResultDBE).filter( EvaluationResultDBE.project_id == project_id, ) @@ -1587,7 +1592,7 @@ async def edit_results( ) -> List[EvaluationResult]: result_ids = [result.id for result in results] - async with engine.core_session() as session: + async with self.engine.session() as session: stmt = select(EvaluationResultDBE).filter( EvaluationResultDBE.project_id == project_id, ) @@ -1652,7 +1657,7 @@ async def delete_result( # result_id: UUID, ) -> Optional[UUID]: - async with engine.core_session() as session: + async with self.engine.session() as session: stmt = select(EvaluationResultDBE).filter( EvaluationResultDBE.project_id == project_id, ) @@ -1697,7 +1702,7 @@ async def delete_results( # result_ids: List[UUID], ) -> List[UUID]: - async with engine.core_session() as session: + async with self.engine.session() as session: stmt = select(EvaluationResultDBE).filter( EvaluationResultDBE.project_id == project_id, ) @@ -1745,7 +1750,7 @@ async def query_results( # windowing: Optional[Windowing] = None, ) -> List[EvaluationResult]: - async with engine.core_session() as session: + async with self.engine.session() as session: stmt = select(EvaluationResultDBE).filter( EvaluationResultDBE.project_id == project_id, ) @@ -1926,7 +1931,7 @@ async def create_metrics( ] # Classify metrics into 3 groups based on NULL pattern, then batch upsert - async with engine.core_session() as session: + async with self.engine.session() as session: returned_metric_dbes = [] # Convert DBE instances to dicts using SQLAlchemy's inspection mapper = inspect(EvaluationMetricsDBE) @@ -2059,7 +2064,7 @@ async def fetch_metrics( # metrics_ids: List[UUID], ) -> List[EvaluationMetrics]: - async with engine.core_session() as session: + async with self.engine.session() as session: stmt = select(EvaluationMetricsDBE).filter( EvaluationMetricsDBE.project_id == project_id, ) @@ -2095,7 +2100,7 @@ async def edit_metrics( ) -> List[EvaluationMetrics]: metrics_ids = [metric.id for metric in metrics] - async with engine.core_session() as session: + async with self.engine.session() as session: stmt = select(EvaluationMetricsDBE).filter( EvaluationMetricsDBE.project_id == project_id, ) @@ -2160,7 +2165,7 @@ async def delete_metrics( # metrics_ids: Optional[List[UUID]] = None, ) -> List[UUID]: - async with engine.core_session() as session: + async with self.engine.session() as session: if metrics_ids is None: return [] @@ -2211,7 +2216,7 @@ async def query_metrics( # windowing: Optional[Windowing] = None, ) -> List[EvaluationMetrics]: - async with engine.core_session() as session: + async with self.engine.session() as session: stmt = select(EvaluationMetricsDBE).filter( EvaluationMetricsDBE.project_id == project_id, ) @@ -2366,7 +2371,7 @@ async def create_queue( queue_dbe.user_ids = _flatten_queue_user_ids(queue.data) try: - async with engine.core_session() as session: + async with self.engine.session() as session: session.add(queue_dbe) await session.commit() @@ -2422,7 +2427,7 @@ async def create_queues( queue_dbe.user_ids = _flatten_queue_user_ids(queue.data) try: - async with engine.core_session() as session: + async with self.engine.session() as session: session.add_all(queue_dbes) await session.commit() @@ -2450,7 +2455,7 @@ async def fetch_queue( # queue_id: UUID, ) -> Optional[EvaluationQueue]: - async with engine.core_session() as session: + async with self.engine.session() as session: stmt = select(EvaluationQueueDBE).filter( EvaluationQueueDBE.project_id == project_id, ) @@ -2483,7 +2488,7 @@ async def fetch_queues( # queue_ids: List[UUID], ) -> List[EvaluationQueue]: - async with engine.core_session() as session: + async with self.engine.session() as session: stmt = select(EvaluationQueueDBE).filter( EvaluationQueueDBE.project_id == project_id, ) @@ -2517,7 +2522,7 @@ async def edit_queue( # queue: EvaluationQueueEdit, ) -> Optional[EvaluationQueue]: - async with engine.core_session() as session: + async with self.engine.session() as session: stmt = select(EvaluationQueueDBE).filter( EvaluationQueueDBE.project_id == project_id, ) @@ -2579,7 +2584,7 @@ async def edit_queues( ) -> List[EvaluationQueue]: queue_ids = [queue.id for queue in queues] - async with engine.core_session() as session: + async with self.engine.session() as session: stmt = select(EvaluationQueueDBE).filter( EvaluationQueueDBE.project_id == project_id, ) @@ -2641,6 +2646,82 @@ async def edit_queues( return _queues + @suppress_exceptions() + async def archive_queue( + self, + *, + project_id: UUID, + user_id: UUID, + queue_id: UUID, + ) -> Optional[EvaluationQueue]: + async with self.engine.session() as session: + stmt = ( + select(EvaluationQueueDBE) + .filter( + EvaluationQueueDBE.project_id == project_id, + EvaluationQueueDBE.id == queue_id, + ) + .limit(1) + ) + queue_dbe = (await session.execute(stmt)).scalars().first() + if queue_dbe is None: + return None + + run_flags = await _get_run_flags( + session=session, + project_id=project_id, + run_id=queue_dbe.run_id, # type: ignore + ) + if run_flags.get("is_closed", False): + raise EvaluationClosedConflict( + run_id=queue_dbe.run_id, + queue_id=queue_dbe.id, # type: ignore + ) + + queue_dbe.deleted_at = datetime.now(timezone.utc) + queue_dbe.deleted_by_id = user_id + await session.commit() + return create_dto_from_dbe(DTO=EvaluationQueue, dbe=queue_dbe) + + @suppress_exceptions() + async def unarchive_queue( + self, + *, + project_id: UUID, + user_id: UUID, + queue_id: UUID, + ) -> Optional[EvaluationQueue]: + async with self.engine.session() as session: + stmt = ( + select(EvaluationQueueDBE) + .filter( + EvaluationQueueDBE.project_id == project_id, + EvaluationQueueDBE.id == queue_id, + ) + .limit(1) + ) + queue_dbe = (await session.execute(stmt)).scalars().first() + if queue_dbe is None: + return None + + run_flags = await _get_run_flags( + session=session, + project_id=project_id, + run_id=queue_dbe.run_id, # type: ignore + ) + if run_flags.get("is_closed", False): + raise EvaluationClosedConflict( + run_id=queue_dbe.run_id, + queue_id=queue_dbe.id, # type: ignore + ) + + queue_dbe.deleted_at = None + queue_dbe.deleted_by_id = None + queue_dbe.updated_at = datetime.now(timezone.utc) + queue_dbe.updated_by_id = user_id + await session.commit() + return create_dto_from_dbe(DTO=EvaluationQueue, dbe=queue_dbe) + @suppress_exceptions() async def delete_queue( self, @@ -2649,7 +2730,7 @@ async def delete_queue( # queue_id: UUID, ) -> Optional[UUID]: - async with engine.core_session() as session: + async with self.engine.session() as session: stmt = select(EvaluationQueueDBE).filter( EvaluationQueueDBE.project_id == project_id, ) @@ -2681,7 +2762,7 @@ async def delete_queues( # queue_ids: List[UUID], ) -> List[UUID]: - async with engine.core_session() as session: + async with self.engine.session() as session: stmt = select(EvaluationQueueDBE).filter( EvaluationQueueDBE.project_id == project_id, ) @@ -2716,11 +2797,14 @@ async def query_queues( # windowing: Optional[Windowing] = None, ) -> List[EvaluationQueue]: - async with engine.core_session() as session: + async with self.engine.session() as session: stmt = select(EvaluationQueueDBE).filter( EvaluationQueueDBE.project_id == project_id, ) + if not queue or not queue.include_archived: + stmt = stmt.filter(EvaluationQueueDBE.deleted_at.is_(None)) + if queue is not None: if queue.ids is not None: stmt = stmt.filter( @@ -2812,7 +2896,8 @@ async def _get_run_flags( session: Optional[AsyncSession] = None, ) -> dict: if session is None: - async with engine.core_session() as session: + engine = get_transactions_engine() + async with engine.session() as session: return await _get_run_flags( project_id=project_id, run_id=run_id, diff --git a/api/oss/src/dbs/postgres/evaluations/dbes.py b/api/oss/src/dbs/postgres/evaluations/dbes.py index e3ad7e8490..abacd74240 100644 --- a/api/oss/src/dbs/postgres/evaluations/dbes.py +++ b/api/oss/src/dbs/postgres/evaluations/dbes.py @@ -3,6 +3,7 @@ ForeignKeyConstraint, UniqueConstraint, Index, + text, ) from oss.src.dbs.postgres.shared.base import Base @@ -286,6 +287,13 @@ class EvaluationQueueDBE( "ix_evaluation_queues_run_id", "run_id", ), # for filtering + Index( + "ux_evaluation_queues_default_per_run", + "project_id", + "run_id", + unique=True, + postgresql_where=text("(flags ->> 'is_default')::boolean = true"), + ), # one canonical default queue per run, including archived rows Index( "ix_evaluation_queues_user_ids", "user_ids", diff --git a/api/oss/src/dbs/postgres/evaluations/utils.py b/api/oss/src/dbs/postgres/evaluations/utils.py index f1d60fcfa7..686246c61b 100644 --- a/api/oss/src/dbs/postgres/evaluations/utils.py +++ b/api/oss/src/dbs/postgres/evaluations/utils.py @@ -93,6 +93,8 @@ def _make_run_flags( flags.has_queries = False flags.has_testsets = False + flags.has_traces = False + flags.has_testcases = False flags.has_evaluators = False # flags.has_custom = False @@ -103,13 +105,15 @@ def _make_run_flags( if _step.type == "input": _references = _step.references or dict() - if flags.is_queue and not _references: + if not _references: step_key = (_step.key or "").lower() - if "query" in step_key: - flags.has_queries = True - if "testset" in step_key: - flags.has_testsets = True + # Direct source inputs are explicit source families. Legacy + # direct keys remain recognized for old rows. + if step_key in {"traces", "query-direct"}: + flags.has_traces = True + if step_key in {"testcases", "testset-direct"}: + flags.has_testcases = True for _key in _references.keys(): step_key = str(_key).lower() diff --git a/api/oss/src/dbs/postgres/events/dao.py b/api/oss/src/dbs/postgres/events/dao.py index cb35cdc193..0b6f01c5c4 100644 --- a/api/oss/src/dbs/postgres/events/dao.py +++ b/api/oss/src/dbs/postgres/events/dao.py @@ -13,12 +13,14 @@ map_event_dto_to_dbe, map_event_dbe_to_dto, ) -from oss.src.dbs.postgres.shared.engine import engine +from oss.src.dbs.postgres.shared.engine import AnalyticsEngine, get_analytics_engine class EventsDAO(EventsDAOInterface): - def __init__(self): - pass + def __init__(self, engine: AnalyticsEngine = None): + if engine is None: + engine = get_analytics_engine() + self.engine = engine ### EVENTS @@ -32,7 +34,7 @@ async def ingest( if not events: return 0 - async with engine.tracing_session() as session: + async with self.engine.session() as session: total_ingested = 0 for event in events: @@ -87,7 +89,7 @@ async def query( # windowing: Optional[Windowing] = None, ) -> List[Event]: - async with engine.tracing_session() as session: + async with self.engine.session() as session: # BASE stmt = select(EventDBE) diff --git a/api/oss/src/dbs/postgres/folders/dao.py b/api/oss/src/dbs/postgres/folders/dao.py index 2a03df3a93..c81b062f8d 100644 --- a/api/oss/src/dbs/postgres/folders/dao.py +++ b/api/oss/src/dbs/postgres/folders/dao.py @@ -20,7 +20,10 @@ FolderPathDepthExceeded, FolderPathLengthExceeded, ) -from oss.src.dbs.postgres.shared.engine import engine +from oss.src.dbs.postgres.shared.engine import ( + TransactionsEngine, + get_transactions_engine, +) from oss.src.dbs.postgres.folders.dbes import FolderDBE from oss.src.dbs.postgres.folders.mappings import ( create_dbe_from_dto, @@ -111,8 +114,10 @@ async def _delete_folder_tree( class FoldersDAO(FoldersDAOInterface): - def __init__(self): - pass + def __init__(self, engine: TransactionsEngine = None): + if engine is None: + engine = get_transactions_engine() + self.engine = engine @suppress_exceptions( exclude=[ @@ -133,7 +138,7 @@ async def create( parent_path = None if folder_create.parent_id: - async with engine.core_session() as session: + async with self.engine.session() as session: parent = await _get_folder_row( session=session, folder_id=folder_create.parent_id, @@ -153,7 +158,7 @@ async def create( ) folder_dbe.created_by_id = user_id - async with engine.core_session() as session: + async with self.engine.session() as session: try: session.add(folder_dbe) await session.commit() @@ -175,7 +180,7 @@ async def fetch( project_id: UUID, folder_id: UUID, ) -> Optional[Folder]: - async with engine.core_session() as session: + async with self.engine.session() as session: folder = await _get_folder_row( session=session, folder_id=folder_id, @@ -207,7 +212,7 @@ async def edit( ) -> Optional[Folder]: kind = folder_edit.kind or FolderKind.APPLICATIONS - async with engine.core_session() as session: + async with self.engine.session() as session: folder = await _get_folder_row( session=session, folder_id=folder_edit.id, @@ -285,7 +290,7 @@ async def delete( user_id: UUID, folder_id: UUID, ) -> Optional[UUID]: - async with engine.core_session() as session: + async with self.engine.session() as session: folder = await _get_folder_row( session=session, folder_id=folder_id, @@ -311,7 +316,7 @@ async def query( project_id: UUID, folder_query: FolderQuery, ) -> List[Folder]: - async with engine.core_session() as session: + async with self.engine.session() as session: stmt = select(FolderDBE).filter(FolderDBE.project_id == project_id) if folder_query.id is not None: diff --git a/api/oss/src/dbs/postgres/git/dao.py b/api/oss/src/dbs/postgres/git/dao.py index e7d81cd038..22654bac70 100644 --- a/api/oss/src/dbs/postgres/git/dao.py +++ b/api/oss/src/dbs/postgres/git/dao.py @@ -33,7 +33,10 @@ from oss.src.dbs.postgres.shared.utils import apply_windowing from oss.src.dbs.postgres.shared.exceptions import check_entity_creation_conflict from oss.src.utils.exceptions import suppress_exceptions -from oss.src.dbs.postgres.shared.engine import engine +from oss.src.dbs.postgres.shared.engine import ( + TransactionsEngine, + get_transactions_engine, +) from oss.src.dbs.postgres.git.mappings import ( map_dbe_to_dto, map_dto_to_dbe, @@ -53,10 +56,14 @@ def __init__( ArtifactDBE: Type[T], VariantDBE: Type[T], RevisionDBE: Type[T], + engine: TransactionsEngine = None, ): self.ArtifactDBE = ArtifactDBE # pylint: disable=invalid-name self.VariantDBE = VariantDBE # pylint: disable=invalid-name self.RevisionDBE = RevisionDBE # pylint: disable=invalid-name + if engine is None: + engine = get_transactions_engine() + self.engine = engine # ─ artifacts ────────────────────────────────────────────────────────────── @@ -95,7 +102,7 @@ async def create_artifact( ) try: - async with engine.core_session() as session: + async with self.engine.session() as session: session.add(artifact_dbe) await session.commit() @@ -128,7 +135,7 @@ async def fetch_artifact( if not artifact_ref: return None - async with engine.core_session() as session: + async with self.engine.session() as session: stmt = select(self.ArtifactDBE).filter( self.ArtifactDBE.project_id == project_id, # type: ignore ) @@ -166,7 +173,7 @@ async def edit_artifact( # artifact_edit: ArtifactEdit, ) -> Optional[Artifact]: - async with engine.core_session() as session: + async with self.engine.session() as session: stmt = select(self.ArtifactDBE).filter( self.ArtifactDBE.project_id == project_id, # type: ignore ) @@ -223,7 +230,7 @@ async def archive_artifact( # artifact_id: UUID, ) -> Optional[Artifact]: - async with engine.core_session() as session: + async with self.engine.session() as session: stmt = select(self.ArtifactDBE).filter( self.ArtifactDBE.project_id == project_id, # type: ignore ) @@ -265,7 +272,7 @@ async def unarchive_artifact( # artifact_id: UUID, ) -> Optional[Artifact]: - async with engine.core_session() as session: + async with self.engine.session() as session: stmt = select(self.ArtifactDBE).filter( self.ArtifactDBE.project_id == project_id, # type: ignore ) @@ -312,7 +319,7 @@ async def query_artifacts( # windowing: Optional[Windowing] = None, ) -> List[Artifact]: - async with engine.core_session() as session: + async with self.engine.session() as session: stmt = select(self.ArtifactDBE).filter( self.ArtifactDBE.project_id == project_id, # type: ignore ) @@ -451,7 +458,7 @@ async def create_variant( ) try: - async with engine.core_session() as session: + async with self.engine.session() as session: session.add(variant_dbe) await session.commit() @@ -485,7 +492,7 @@ async def fetch_variant( if not artifact_ref and not variant_ref: return None - async with engine.core_session() as session: + async with self.engine.session() as session: stmt = select(self.VariantDBE).filter( self.VariantDBE.project_id == project_id, # type: ignore ) @@ -527,7 +534,7 @@ async def edit_variant( # variant_edit: VariantEdit, ) -> Optional[Variant]: - async with engine.core_session() as session: + async with self.engine.session() as session: stmt = select(self.VariantDBE).filter( self.VariantDBE.project_id == project_id, # type: ignore ) @@ -574,7 +581,7 @@ async def archive_variant( # variant_id: UUID, ) -> Optional[Variant]: - async with engine.core_session() as session: + async with self.engine.session() as session: stmt = select(self.VariantDBE).filter( self.VariantDBE.project_id == project_id, # type: ignore ) @@ -633,7 +640,7 @@ async def unarchive_variant( # variant_id: UUID, ) -> Optional[Variant]: - async with engine.core_session() as session: + async with self.engine.session() as session: stmt = select(self.VariantDBE).filter( self.VariantDBE.project_id == project_id, # type: ignore ) @@ -697,7 +704,7 @@ async def query_variants( # windowing: Optional[Windowing] = None, ) -> List[Variant]: - async with engine.core_session() as session: + async with self.engine.session() as session: stmt = select(self.VariantDBE).filter( self.VariantDBE.project_id == project_id, # type: ignore ) @@ -983,7 +990,7 @@ async def create_revision( ) try: - async with engine.core_session() as session: + async with self.engine.session() as session: session.add(revision_dbe) await session.commit() @@ -1029,7 +1036,7 @@ async def fetch_revision( if not variant_ref and not revision_ref: return None - async with engine.core_session() as session: + async with self.engine.session() as session: stmt = select(self.RevisionDBE).filter( self.RevisionDBE.project_id == project_id, # type: ignore ) @@ -1084,7 +1091,7 @@ async def edit_revision( # revision_edit: RevisionEdit, ) -> Optional[Revision]: - async with engine.core_session() as session: + async with self.engine.session() as session: stmt = select(self.RevisionDBE).filter( self.RevisionDBE.project_id == project_id, # type: ignore ) @@ -1131,7 +1138,7 @@ async def archive_revision( # revision_id: UUID, ) -> Optional[Revision]: - async with engine.core_session() as session: + async with self.engine.session() as session: stmt = select(self.RevisionDBE).filter( self.RevisionDBE.project_id == project_id, # type: ignore ) @@ -1173,7 +1180,7 @@ async def unarchive_revision( # revision_id: UUID, ) -> Optional[Revision]: - async with engine.core_session() as session: + async with self.engine.session() as session: stmt = select(self.RevisionDBE).filter( self.RevisionDBE.project_id == project_id, # type: ignore ) @@ -1224,7 +1231,7 @@ async def query_revisions( # windowing: Optional[Windowing] = None, ) -> List[Revision]: - async with engine.core_session() as session: + async with self.engine.session() as session: stmt = select(self.RevisionDBE).filter( self.RevisionDBE.project_id == project_id, # type: ignore ) @@ -1497,7 +1504,7 @@ async def commit_revision( ) try: - async with engine.core_session() as session: + async with self.engine.session() as session: session.add(revision_dbe) await session.commit() @@ -1599,7 +1606,7 @@ async def log_revisions( # `include_archived`. ROW_NUMBER() over that set gives us each row's # 1-indexed position; we then keep rows up to the target's position # and limit to `depth` from the tail. - async with engine.core_session() as session: + async with self.engine.session() as session: visibility_filter = ( (self.RevisionDBE.deleted_at.is_(None),) # type: ignore if not include_archived @@ -1665,7 +1672,7 @@ async def _get_version( variant_id: UUID, revision_id: UUID, ) -> str: - async with engine.core_session() as session: + async with self.engine.session() as session: stmt = ( select(func.count()) # pylint: disable=not-callable .select_from(self.RevisionDBE) # type: ignore @@ -1689,7 +1696,7 @@ async def _set_version( revision_id: UUID, version: str, ) -> None: - async with engine.core_session() as session: + async with self.engine.session() as session: stmt = update(self.RevisionDBE).filter( self.RevisionDBE.project_id == project_id, # type: ignore ) @@ -1708,7 +1715,7 @@ async def _null_revision_fields( project_id: UUID, revision_id: UUID, ) -> None: - async with engine.core_session() as session: + async with self.engine.session() as session: stmt = ( update(self.RevisionDBE) .filter( diff --git a/api/oss/src/dbs/postgres/secrets/dao.py b/api/oss/src/dbs/postgres/secrets/dao.py index e356e84059..edcaa2a35d 100644 --- a/api/oss/src/dbs/postgres/secrets/dao.py +++ b/api/oss/src/dbs/postgres/secrets/dao.py @@ -3,8 +3,10 @@ from oss.src.dbs.postgres.secrets.dbes import SecretsDBE from oss.src.core.secrets.interfaces import SecretsDAOInterface - -from oss.src.dbs.postgres.shared.engine import engine +from oss.src.dbs.postgres.shared.engine import ( + TransactionsEngine, + get_transactions_engine, +) from oss.src.core.secrets.dtos import CreateSecretDTO, UpdateSecretDTO from oss.src.dbs.postgres.secrets.mappings import ( @@ -17,8 +19,10 @@ class SecretsDAO(SecretsDAOInterface): - def __init__(self): - pass + def __init__(self, engine: TransactionsEngine = None): + if engine is None: + engine = get_transactions_engine() + self.engine = engine @staticmethod def _validate_scope(project_id: UUID | None, organization_id: UUID | None) -> None: @@ -48,7 +52,7 @@ async def create( organization_id=organization_id, secret_dto=create_secret_dto, ) - async with engine.core_session() as session: + async with self.engine.session() as session: session.add(secrets_dbe) await session.commit() @@ -61,7 +65,7 @@ async def get( project_id: UUID | None, organization_id: UUID | None, ): - async with engine.core_session() as session: + async with self.engine.session() as session: scope_filter = self._scope_filter(project_id, organization_id) stmt = select(SecretsDBE).filter_by( id=secret_id, @@ -77,7 +81,7 @@ async def get( return secrets_dto async def list(self, project_id: UUID | None, organization_id: UUID | None): - async with engine.core_session() as session: + async with self.engine.session() as session: scope_filter = self._scope_filter(project_id, organization_id) stmt = select(SecretsDBE).filter_by(**scope_filter) @@ -96,7 +100,7 @@ async def update( project_id: UUID | None, organization_id: UUID | None, ): - async with engine.core_session() as session: + async with self.engine.session() as session: scope_filter = self._scope_filter(project_id, organization_id) stmt = select(SecretsDBE).filter_by( id=secret_id, @@ -124,7 +128,7 @@ async def delete( project_id: UUID | None, organization_id: UUID | None, ): - async with engine.core_session() as session: + async with self.engine.session() as session: scope_filter = self._scope_filter(project_id, organization_id) stmt = select(SecretsDBE).filter_by( id=secret_id, diff --git a/api/oss/src/dbs/postgres/shared/engine.py b/api/oss/src/dbs/postgres/shared/engine.py index 67253fbed1..b37985ccb9 100644 --- a/api/oss/src/dbs/postgres/shared/engine.py +++ b/api/oss/src/dbs/postgres/shared/engine.py @@ -1,5 +1,5 @@ from asyncio import current_task -from typing import AsyncGenerator +from typing import AsyncGenerator, Optional from contextlib import asynccontextmanager from math import floor @@ -11,17 +11,14 @@ async_scoped_session, ) -from oss.src.dbs.postgres.shared.config import ( - POSTGRES_URI_CORE, - POSTGRES_URI_TRACING, -) +from oss.src.utils.env import env DATABASE_MEMORY = 32 * 1024 * 1024 * 1024 # 32 GB DATABASE_FACTOR = 8 * 1024 * 1024 * 1.15 # 8 MB + 15% overhead -DATABASE_MAX_CONNECTIONS = 5000 # 5000 connections +DATABASE_MAX_CONNECTIONS = 5000 MAX_CONNECTIONS = min(DATABASE_MEMORY / DATABASE_FACTOR, DATABASE_MAX_CONNECTIONS) -APP_CONNECTIONS = MAX_CONNECTIONS * 0.9 # reserve 10% for non-app connections +APP_CONNECTIONS = MAX_CONNECTIONS * 0.9 NOF_CONSUMERS = 2 * 4 # 2 engines x 4 containers NOF_CONNECTIONS = floor(APP_CONNECTIONS / NOF_CONSUMERS) POOL_SIZE = floor(NOF_CONNECTIONS * 0.25) @@ -29,98 +26,99 @@ POOL_RECYCLE = 30 * 60 # 30 minutes -class Engine: - def __init__(self) -> None: - self.postgres_uri_core = POSTGRES_URI_CORE +class TransactionsEngine: + """Postgres core DB — application data.""" - self.async_core_engine: AsyncEngine = create_async_engine( - url=self.postgres_uri_core, + def __init__(self) -> None: + self._engine: AsyncEngine = create_async_engine( + url=env.postgres.uri_core, pool_pre_ping=True, pool_recycle=POOL_RECYCLE, pool_size=POOL_SIZE, max_overflow=MAX_OVERFLOW, ) - self.async_core_session_maker = async_sessionmaker( + _session_maker = async_sessionmaker( autocommit=False, autoflush=False, class_=AsyncSession, expire_on_commit=False, - bind=self.async_core_engine, + bind=self._engine, ) - self.async_core_session = async_scoped_session( - session_factory=self.async_core_session_maker, + self._session = async_scoped_session( + session_factory=_session_maker, scopefunc=current_task, ) - self.postgres_uri_tracing = POSTGRES_URI_TRACING + async def close(self) -> None: + if self._engine is not None: + await self._engine.dispose() - self.async_tracing_engine: AsyncEngine = create_async_engine( - url=self.postgres_uri_tracing, + @asynccontextmanager + async def session(self) -> AsyncGenerator[AsyncSession, None]: + session: AsyncSession = self._session() + try: + yield session + await session.commit() + except Exception as e: + await session.rollback() + raise e + finally: + await session.close() + + +class AnalyticsEngine: + """Postgres tracing DB — observability data.""" + + def __init__(self) -> None: + self._engine: AsyncEngine = create_async_engine( + url=env.postgres.uri_tracing, pool_pre_ping=True, pool_recycle=POOL_RECYCLE, pool_size=POOL_SIZE, max_overflow=MAX_OVERFLOW, ) - self.async_tracing_session_maker = async_sessionmaker( + _session_maker = async_sessionmaker( autocommit=False, autoflush=False, class_=AsyncSession, expire_on_commit=False, - bind=self.async_tracing_engine, + bind=self._engine, ) - - self.async_tracing_session = async_scoped_session( - session_factory=self.async_tracing_session_maker, + self._session = async_scoped_session( + session_factory=_session_maker, scopefunc=current_task, ) - async def open(self): - raise NotImplementedError() - - async def close(self): - if self.async_core_engine is not None: - await self.async_core_engine.dispose() - - if self.async_tracing_engine is not None: - await self.async_tracing_engine.dispose() + async def close(self) -> None: + if self._engine is not None: + await self._engine.dispose() @asynccontextmanager - async def core_session(self) -> AsyncGenerator[AsyncSession, None]: - session: AsyncSession = self.async_core_session() - + async def session(self) -> AsyncGenerator[AsyncSession, None]: + session: AsyncSession = self._session() try: yield session await session.commit() - except Exception as e: await session.rollback() raise e - finally: await session.close() - @asynccontextmanager - async def tracing_session(self) -> AsyncGenerator[AsyncSession, None]: - session: AsyncSession = self.async_tracing_session() - - try: - yield session - await session.commit() - - except Exception as e: - await session.rollback() - raise e - - finally: - await session.close() - ### LEGACY CODE ### +_transactions_engine: Optional[TransactionsEngine] = None +_analytics_engine: Optional[AnalyticsEngine] = None - async def init_db(self): - self.open() - async def close_db(self): - self.close() +def get_transactions_engine() -> TransactionsEngine: + global _transactions_engine + if _transactions_engine is None: + _transactions_engine = TransactionsEngine() + return _transactions_engine -engine = Engine() +def get_analytics_engine() -> AnalyticsEngine: + global _analytics_engine + if _analytics_engine is None: + _analytics_engine = AnalyticsEngine() + return _analytics_engine diff --git a/api/oss/src/dbs/postgres/tools/dao.py b/api/oss/src/dbs/postgres/tools/dao.py index f94e87f273..c3cefe279c 100644 --- a/api/oss/src/dbs/postgres/tools/dao.py +++ b/api/oss/src/dbs/postgres/tools/dao.py @@ -16,7 +16,10 @@ ToolConnectionCreate, ) -from oss.src.dbs.postgres.shared.engine import engine +from oss.src.dbs.postgres.shared.engine import ( + TransactionsEngine, + get_transactions_engine, +) from oss.src.dbs.postgres.tools.dbes import ToolConnectionDBE from oss.src.dbs.postgres.tools.mappings import ( map_connection_create_to_dbe, @@ -28,8 +31,16 @@ class ToolsDAO(ToolsDAOInterface): - def __init__(self, *, ToolConnectionDBE: type = ToolConnectionDBE): + def __init__( + self, + *, + ToolConnectionDBE: type = ToolConnectionDBE, + engine: TransactionsEngine = None, + ): self.ToolConnectionDBE = ToolConnectionDBE + if engine is None: + engine = get_transactions_engine() + self.engine = engine @suppress_exceptions(exclude=[EntityCreationConflict]) async def create_connection( @@ -49,7 +60,7 @@ async def create_connection( ) try: - async with engine.core_session() as session: + async with self.engine.session() as session: session.add(dbe) await session.commit() await session.refresh(dbe) @@ -80,7 +91,7 @@ async def get_connection( connection_id: UUID, ) -> Optional[ToolConnection]: """Fetch a connection by ID scoped to project_id. Returns None if not found.""" - async with engine.core_session() as session: + async with self.engine.session() as session: stmt = ( select(self.ToolConnectionDBE) .filter(self.ToolConnectionDBE.project_id == project_id) @@ -109,7 +120,7 @@ async def update_connection( data_update: Optional[dict] = None, ) -> Optional[ToolConnection]: """Partially update flags and/or data for a connection. Returns updated DTO or None.""" - async with engine.core_session() as session: + async with self.engine.session() as session: stmt = ( select(self.ToolConnectionDBE) .filter(self.ToolConnectionDBE.project_id == project_id) @@ -158,7 +169,7 @@ async def delete_connection( connection_id: UUID, ) -> bool: """Hard-delete a connection row. Returns True if a row was deleted.""" - async with engine.core_session() as session: + async with self.engine.session() as session: stmt = ( delete(self.ToolConnectionDBE) .where(self.ToolConnectionDBE.project_id == project_id) @@ -181,7 +192,7 @@ async def query_connections( is_active: Optional[bool] = True, ) -> List[ToolConnection]: """List connections with optional filters. Defaults to active-only (is_active=True).""" - async with engine.core_session() as session: + async with self.engine.session() as session: stmt = select(self.ToolConnectionDBE).filter( self.ToolConnectionDBE.project_id == project_id, ) @@ -215,7 +226,7 @@ async def activate_connection_by_provider_id( project_id: Optional[UUID] = None, ) -> Optional[ToolConnection]: """Set is_valid=True and is_active=True for the connection matching the provider ID.""" - async with engine.core_session() as session: + async with self.engine.session() as session: stmt = select(self.ToolConnectionDBE).filter( self.ToolConnectionDBE.data["connected_account_id"].astext == provider_connection_id @@ -252,7 +263,7 @@ async def find_connection_by_provider_id( provider_connection_id: str, ) -> Optional[ToolConnection]: """Lookup any connection by provider-side connected_account_id (no project scope).""" - async with engine.core_session() as session: + async with self.engine.session() as session: stmt = ( select(self.ToolConnectionDBE) .filter( diff --git a/api/oss/src/dbs/postgres/tracing/dao.py b/api/oss/src/dbs/postgres/tracing/dao.py index e2295cac8a..e10ce0f31d 100644 --- a/api/oss/src/dbs/postgres/tracing/dao.py +++ b/api/oss/src/dbs/postgres/tracing/dao.py @@ -30,7 +30,7 @@ ) from oss.src.dbs.postgres.shared.utils import apply_windowing -from oss.src.dbs.postgres.shared.engine import engine +from oss.src.dbs.postgres.shared.engine import AnalyticsEngine, get_analytics_engine from oss.src.dbs.postgres.tracing.dbes import SpanDBE from oss.src.dbs.postgres.tracing.mappings import ( map_span_dbe_to_link_dto, @@ -67,8 +67,10 @@ class TracingDAO(TracingDAOInterface): - def __init__(self): - pass + def __init__(self, engine: AnalyticsEngine = None): + if engine is None: + engine = get_analytics_engine() + self.engine = engine ### SPANS @@ -85,7 +87,7 @@ async def ingest( if not span_dtos: return [] - async with engine.tracing_session() as session: + async with self.engine.session() as session: link_dtos: List[OTelLink] = [] for span_dto in span_dtos: @@ -158,7 +160,7 @@ async def query( # --------- try: - async with engine.tracing_session() as session: + async with self.engine.session() as session: # TIMEOUT await session.execute(TIMEOUT_STMT) # ------- @@ -413,7 +415,7 @@ async def analytics( # log.trace(str(statistics_stmt.compile(**DEBUG_ARGS)).replace("\n", " ")) # --------- - async with engine.tracing_session() as session: + async with self.engine.session() as session: await session.execute(TIMEOUT_STMT) rows = (await session.execute(select(statistics_stmt))).mappings().all() @@ -544,7 +546,7 @@ async def legacy_analytics( ) = parse_windowing(query.windowing) try: - async with engine.tracing_session() as session: + async with self.engine.session() as session: await session.execute(TIMEOUT_STMT) # BASE QUERY HELPERS @@ -747,7 +749,7 @@ async def fetch( if not trace_ids and not span_ids: return [] - async with engine.tracing_session() as session: + async with self.engine.session() as session: stmt = select(SpanDBE).filter(SpanDBE.project_id == project_id) if trace_ids: @@ -785,7 +787,7 @@ async def delete( if not trace_ids: return [] - async with engine.tracing_session() as session: + async with self.engine.session() as session: stmt = select(SpanDBE).filter( SpanDBE.project_id == project_id, SpanDBE.trace_id.in_(trace_ids), @@ -874,7 +876,7 @@ async def _query_by_group( realtime: If True, use last_active (mutable, shows recent activity but unstable cursors). If False/None, use first_active (immutable, stable cursors but doesn't reflect new activity). """ - async with engine.tracing_session() as session: + async with self.engine.session() as session: # TIMEOUT await session.execute(TIMEOUT_STMT) diff --git a/api/oss/src/dbs/postgres/users/dao.py b/api/oss/src/dbs/postgres/users/dao.py index 2fbe1c71b9..5da796a824 100644 --- a/api/oss/src/dbs/postgres/users/dao.py +++ b/api/oss/src/dbs/postgres/users/dao.py @@ -3,7 +3,10 @@ from sqlalchemy import select from sqlalchemy.exc import IntegrityError -from oss.src.dbs.postgres.shared.engine import engine +from oss.src.dbs.postgres.shared.engine import ( + TransactionsEngine, + get_transactions_engine, +) from oss.src.dbs.postgres.users.dbes import UserIdentityDBE from oss.src.dbs.postgres.users.mappings import ( map_identity_dbe_to_dto, @@ -13,10 +16,15 @@ class IdentitiesDAO: + def __init__(self, engine: Optional[TransactionsEngine] = None): + if engine is None: + engine = get_transactions_engine() + self.engine = engine + async def create(self, dto: UserIdentityCreate) -> UserIdentity: identity_dbe = map_create_dto_to_dbe(dto) - async with engine.core_session() as session: + async with self.engine.session() as session: try: session.add(identity_dbe) await session.commit() @@ -37,7 +45,7 @@ async def create(self, dto: UserIdentityCreate) -> UserIdentity: async def get_by_method_subject( self, method: str, subject: str ) -> Optional[UserIdentity]: - async with engine.core_session() as session: + async with self.engine.session() as session: stmt = select(UserIdentityDBE).filter_by( method=method, subject=subject, @@ -51,7 +59,7 @@ async def get_by_method_subject( return map_identity_dbe_to_dto(identity_dbe) async def list_by_user(self, user_id: UUID) -> List[UserIdentity]: - async with engine.core_session() as session: + async with self.engine.session() as session: stmt = select(UserIdentityDBE).filter_by(user_id=user_id) result = await session.execute(stmt) identity_dbes = result.scalars().all() @@ -59,7 +67,7 @@ async def list_by_user(self, user_id: UUID) -> List[UserIdentity]: return [map_identity_dbe_to_dto(dbe) for dbe in identity_dbes] async def list_by_domain(self, domain: str) -> List[UserIdentity]: - async with engine.core_session() as session: + async with self.engine.session() as session: stmt = select(UserIdentityDBE).filter_by(domain=domain) result = await session.execute(stmt) identity_dbes = result.scalars().all() diff --git a/api/oss/src/dbs/postgres/webhooks/dao.py b/api/oss/src/dbs/postgres/webhooks/dao.py index 47a31154af..6df48b8e4e 100644 --- a/api/oss/src/dbs/postgres/webhooks/dao.py +++ b/api/oss/src/dbs/postgres/webhooks/dao.py @@ -17,7 +17,10 @@ ) from oss.src.core.webhooks.interfaces import WebhooksDAOInterface -from oss.src.dbs.postgres.shared.engine import engine +from oss.src.dbs.postgres.shared.engine import ( + TransactionsEngine, + get_transactions_engine, +) from oss.src.dbs.postgres.shared.utils import apply_windowing from oss.src.dbs.postgres.webhooks.dbes import ( WebhookSubscriptionDBE, @@ -33,8 +36,10 @@ class WebhooksDAO(WebhooksDAOInterface): - def __init__(self): - pass + def __init__(self, engine: TransactionsEngine = None): + if engine is None: + engine = get_transactions_engine() + self.engine = engine # --- SUBSCRIPTIONS ------------------------------------------------------ # @@ -57,7 +62,7 @@ async def create_subscription( secret_id=secret_id, ) - async with engine.core_session() as session: + async with self.engine.session() as session: session.add(subscription_dbe) await session.commit() @@ -75,7 +80,7 @@ async def fetch_subscription( # subscription_id: UUID, ) -> Optional[WebhookSubscription]: - async with engine.core_session() as session: + async with self.engine.session() as session: stmt = select(WebhookSubscriptionDBE).where( WebhookSubscriptionDBE.project_id == project_id, WebhookSubscriptionDBE.id == subscription_id, @@ -102,7 +107,7 @@ async def edit_subscription( # secret_id: UUID | None = None, ) -> Optional[WebhookSubscription]: - async with engine.core_session() as session: + async with self.engine.session() as session: stmt = select(WebhookSubscriptionDBE).where( WebhookSubscriptionDBE.id == subscription.id, WebhookSubscriptionDBE.project_id == project_id, @@ -140,7 +145,7 @@ async def delete_subscription( # subscription_id: UUID, ) -> bool: - async with engine.core_session() as session: + async with self.engine.session() as session: stmt = select(WebhookSubscriptionDBE).where( WebhookSubscriptionDBE.project_id == project_id, WebhookSubscriptionDBE.id == subscription_id, @@ -168,7 +173,7 @@ async def query_subscriptions( # windowing: Optional[Windowing] = None, ) -> List[WebhookSubscription]: - async with engine.core_session() as session: + async with self.engine.session() as session: stmt = select(WebhookSubscriptionDBE).filter( WebhookSubscriptionDBE.project_id == project_id, ) @@ -234,7 +239,7 @@ async def create_delivery( delivery=delivery, ) - async with engine.core_session() as session: + async with self.engine.session() as session: values = { c.name: getattr(delivery_dbe, c.name) for c in WebhookDeliveryDBE.__table__.columns @@ -275,7 +280,7 @@ async def fetch_delivery( # delivery_id: UUID, ) -> Optional[WebhookDelivery]: - async with engine.core_session() as session: + async with self.engine.session() as session: stmt = select(WebhookDeliveryDBE).where( WebhookDeliveryDBE.project_id == project_id, WebhookDeliveryDBE.id == delivery_id, @@ -301,7 +306,7 @@ async def query_deliveries( # windowing: Optional[Windowing] = None, ) -> List[WebhookDelivery]: - async with engine.core_session() as session: + async with self.engine.session() as session: stmt = select(WebhookDeliveryDBE).filter( WebhookDeliveryDBE.project_id == project_id, ) diff --git a/api/oss/src/dbs/redis/shared/engine.py b/api/oss/src/dbs/redis/shared/engine.py new file mode 100644 index 0000000000..02153a9772 --- /dev/null +++ b/api/oss/src/dbs/redis/shared/engine.py @@ -0,0 +1,89 @@ +from typing import TYPE_CHECKING, Optional + +from oss.src.utils.env import env + +if TYPE_CHECKING: + from redis.asyncio import Redis + + +class CacheEngine: + """Redis volatile — caching and distributed locks.""" + + def __init__(self) -> None: + from redis.asyncio import Redis + + self._r: Optional[Redis] = None + self._r_lock: Optional[Redis] = None + + def get_r(self) -> "Redis": + if self._r is None: + from redis.asyncio import Redis + + self._r = Redis.from_url( + url=env.redis.uri_volatile, + decode_responses=False, + socket_timeout=0.5, + ) + return self._r + + def get_r_lock(self) -> "Redis": + if self._r_lock is None: + from redis.asyncio import Redis + + AGENTA_LOCK_SOCKET_TIMEOUT = 30 + + self._r_lock = Redis.from_url( + url=env.redis.uri_volatile, + decode_responses=False, + socket_timeout=AGENTA_LOCK_SOCKET_TIMEOUT, + ) + return self._r_lock + + async def close(self) -> None: + if self._r is not None: + await self._r.close() + self._r = None + if self._r_lock is not None: + await self._r_lock.close() + self._r_lock = None + + +class StreamsEngine: + """Redis durable — persistent streams for tracing/events.""" + + def __init__(self) -> None: + from redis.asyncio import Redis + + self._redis: Optional[Redis] = None + + def get_redis(self) -> "Redis": + if self._redis is None: + from redis.asyncio import Redis + + if not env.redis.uri_durable: + raise RuntimeError("REDIS_URI_DURABLE is required for streams.") + self._redis = Redis.from_url(env.redis.uri_durable, decode_responses=False) + return self._redis + + async def close(self) -> None: + if self._redis is not None: + await self._redis.close() + self._redis = None + + +_cache_engine: Optional[CacheEngine] = None +_streams_engine: Optional[StreamsEngine] = None + + +def get_cache_engine() -> CacheEngine: + global _cache_engine + if _cache_engine is None: + _cache_engine = CacheEngine() + return _cache_engine + + +def get_streams_engine() -> StreamsEngine: + global _streams_engine + if _streams_engine is None: + _streams_engine = StreamsEngine() + return _streams_engine diff --git a/api/oss/src/services/admin_manager.py b/api/oss/src/services/admin_manager.py index 222bd063df..747349936b 100644 --- a/api/oss/src/services/admin_manager.py +++ b/api/oss/src/services/admin_manager.py @@ -10,7 +10,7 @@ from oss.src.utils.logging import get_module_logger -from oss.src.dbs.postgres.shared.engine import engine +from oss.src.dbs.postgres.shared.engine import get_transactions_engine from oss.src.services import db_manager from oss.src.models.db_models import UserDB @@ -159,7 +159,9 @@ async def legacy_create_organization( return_org_wrk: bool = False, return_org_wrk_prj: bool = False, ) -> Union[OrganizationDB, WorkspaceDB]: - async with engine.core_session() as session: + engine = get_transactions_engine() + + async with engine.session() as session: create_org_data = payload.model_dump(exclude_unset=True) create_org_data.pop("is_demo", None) @@ -262,7 +264,9 @@ async def user_exists(user_email: str) -> bool: async def check_user( request: UserRequest, ) -> Optional[UserRequest]: - async with engine.core_session() as session: + engine = get_transactions_engine() + + async with engine.session() as session: result = await session.execute( select(UserDB).filter_by( email=request.email, @@ -279,7 +283,9 @@ async def check_user( async def create_user( request: UserRequest, ) -> Reference: - async with engine.core_session() as session: + engine = get_transactions_engine() + + async with engine.session() as session: user_db = UserDB( # id=uuid7() # use default # @@ -306,7 +312,9 @@ async def create_organization( request: OrganizationRequest, created_by_id: uuid.UUID, ) -> Reference: - async with engine.core_session() as session: + engine = get_transactions_engine() + + async with engine.session() as session: organization_db = OrganizationDB( name=request.name, description=request.description, @@ -332,7 +340,9 @@ async def create_organization( async def create_workspace( request: WorkspaceRequest, ) -> Reference: - async with engine.core_session() as session: + engine = get_transactions_engine() + + async with engine.session() as session: workspace_db = WorkspaceDB( # id=uuid7() # use default # @@ -361,7 +371,9 @@ async def create_workspace( async def create_project( request: ProjectRequest, ) -> Reference: - async with engine.core_session() as session: + engine = get_transactions_engine() + + async with engine.session() as session: project_db = ProjectDB( # id=uuid7() # use default # diff --git a/api/oss/src/services/analytics_service.py b/api/oss/src/services/analytics_service.py index 77b08fe8c7..a462ffdb84 100644 --- a/api/oss/src/services/analytics_service.py +++ b/api/oss/src/services/analytics_service.py @@ -3,12 +3,12 @@ from datetime import datetime from typing import Callable, Optional -import posthog from fastapi import Request from oss.src.utils.caching import get_cache, set_cache from oss.src.utils.common import is_oss from oss.src.utils.env import env from oss.src.utils.logging import get_module_logger +from oss.src.utils.lazy import _load_posthog log = get_module_logger(__name__) @@ -44,15 +44,6 @@ } -# Initialize PostHog only if enabled -if env.posthog.enabled: - posthog.api_key = env.posthog.api_key - posthog.host = env.posthog.api_url - log.info("✓ PostHog enabled") -else: - log.warn("✗ PostHog disabled") - - async def _set_activation_property( distinct_id: str, property_name: str, @@ -66,6 +57,10 @@ async def _set_activation_property( if not distinct_id or not env.posthog.enabled: return + posthog = _load_posthog() + if posthog is None: + return + project_id = getattr(request.state, "project_id", None) user_id = getattr(request.state, "user_id", None) @@ -121,6 +116,10 @@ def capture_oss_deployment_created(user_email: str, organization_id: str): """ if is_oss() and env.posthog.enabled: + posthog = _load_posthog() + if posthog is None: + return + try: posthog.capture( distinct_id=user_email, @@ -247,28 +246,32 @@ async def analytics_middleware(request: Request, call_next: Callable): pass if distinct_id and env.posthog.api_key: - properties["$set"] = {"email": distinct_id} - - posthog.capture( - distinct_id=distinct_id, - event=event_name, - properties=properties or {}, - ) - - # Check if this is an activation event - if event_name in ACTIVATION_EVENTS: - property_name, allowed_auth_methods = ACTIVATION_EVENTS[event_name] - - # Check if auth method is allowed for this activation - if ( - allowed_auth_methods is None - or auth_method in allowed_auth_methods - ): - await _set_activation_property( - distinct_id=distinct_id, - property_name=property_name, - request=request, - ) + posthog = _load_posthog() + if posthog is not None: + properties["$set"] = {"email": distinct_id} + + posthog.capture( + distinct_id=distinct_id, + event=event_name, + properties=properties or {}, + ) + + # Check if this is an activation event + if event_name in ACTIVATION_EVENTS: + property_name, allowed_auth_methods = ACTIVATION_EVENTS[ + event_name + ] + + # Check if auth method is allowed for this activation + if ( + allowed_auth_methods is None + or auth_method in allowed_auth_methods + ): + await _set_activation_property( + distinct_id=distinct_id, + property_name=property_name, + request=request, + ) except Exception as e: log.error(f"❌ Error capturing event in PostHog: {e}") diff --git a/api/oss/src/services/api_key_service.py b/api/oss/src/services/api_key_service.py index f8d78af6c7..46be3b8904 100644 --- a/api/oss/src/services/api_key_service.py +++ b/api/oss/src/services/api_key_service.py @@ -13,7 +13,7 @@ from oss.src.utils.logging import get_module_logger from oss.src.models.db_models import APIKeyDB, UserDB -from oss.src.dbs.postgres.shared.engine import engine +from oss.src.dbs.postgres.shared.engine import get_transactions_engine # from oss.src.utils.redis_utils import redis_connection @@ -38,7 +38,9 @@ async def _generate_unique_prefix(): # Define the characters to use for the prefix alphabet = string.ascii_letters + string.digits - async with engine.core_session() as session: + engine = get_transactions_engine() + + async with engine.session() as session: while True: # Generate a random 8-character prefix prefix = "".join(secrets.choice(alphabet) for _ in range(8)) @@ -82,7 +84,9 @@ async def create_api_key( # get rate limit from env rate_limit = 0 - async with engine.core_session() as session: + engine = get_transactions_engine() + + async with engine.session() as session: # Create an APIKeyDB instance with the prefix, hashed API key, and user_id api_key = APIKeyDB( prefix=prefix, @@ -115,7 +119,9 @@ async def is_valid_api_key(key: str): - The API Key object if the API key is valid, False otherwise. """ - async with engine.core_session() as session: + engine = get_transactions_engine() + + async with engine.session() as session: # Check if the API key is valid (not blacklisted and not expired) result = await session.execute( select(APIKeyDB) @@ -234,7 +240,9 @@ async def list_api_keys(user_id: str, project_id: str) -> List[APIKeyDB]: List[APIKeyDB]: A list of APIKeyDB objects associated with the user, sorted by most recently created first. """ - async with engine.core_session() as session: + engine = get_transactions_engine() + + async with engine.session() as session: result = await session.execute( select(APIKeyDB) .filter_by( @@ -260,7 +268,9 @@ async def delete_api_key(user_id: str, key_prefix: str): KeyError: If the API key does not exist or does not belong to the user. """ - async with engine.core_session() as session: + engine = get_transactions_engine() + + async with engine.session() as session: result = await session.execute( select(APIKeyDB).filter_by( created_by_id=uuid.UUID(user_id), prefix=key_prefix diff --git a/api/oss/src/services/db_manager.py b/api/oss/src/services/db_manager.py index 40c1e57bd9..54a91f6b12 100644 --- a/api/oss/src/services/db_manager.py +++ b/api/oss/src/services/db_manager.py @@ -20,7 +20,9 @@ from oss.src.services import user_service, analytics_service from oss.src.utils.common import is_ee from oss.src.utils.env import env -from oss.src.dbs.postgres.shared.engine import engine +from oss.src.dbs.postgres.shared.engine import ( + get_transactions_engine, +) from oss.src.utils.helpers import get_slug_from_name_and_id from oss.src.dbs.postgres.blobs.dao import BlobsDAO @@ -64,7 +66,8 @@ async def fetch_project_by_id( project_id: str, ) -> Optional[ProjectDB]: - async with engine.core_session() as session: + engine = get_transactions_engine() + async with engine.session() as session: project = ( ( await session.execute( @@ -91,7 +94,8 @@ async def fetch_projects_by_workspace( List[ProjectDB]: Projects scoped to the workspace. """ - async with engine.core_session() as session: + engine = get_transactions_engine() + async with engine.session() as session: result = await session.execute( select(ProjectDB) .filter(ProjectDB.workspace_id == uuid.UUID(workspace_id)) @@ -103,7 +107,8 @@ async def fetch_projects_by_workspace( async def fetch_workspace_by_id( workspace_id: str, ) -> Optional[WorkspaceDB]: - async with engine.core_session() as session: + engine = get_transactions_engine() + async with engine.session() as session: workspace = ( ( await session.execute( @@ -122,7 +127,9 @@ async def fetch_workspace_by_id( async def fetch_organization_by_id( organization_id: str, ) -> Optional[OrganizationDB]: - async with engine.core_session() as session: + engine = get_transactions_engine() + + async with engine.session() as session: organization = ( ( await session.execute( @@ -312,7 +319,9 @@ async def get_user(user_uid: str) -> UserDB: UserDB: instance of user """ - async with engine.core_session() as session: + engine = get_transactions_engine() + + async with engine.session() as session: # NOTE: Backward Compatibility # --------------------------- # Previously, the user_id field in the api_keys collection in MongoDB used the @@ -332,7 +341,9 @@ async def get_user(user_uid: str) -> UserDB: async def is_first_user_signup() -> bool: """Check if this is the first user signing up (no users exist yet).""" - async with engine.core_session() as session: + engine = get_transactions_engine() + + async with engine.session() as session: total_users = ( await session.scalar(select(func.count()).select_from(UserDB)) or 0 ) @@ -402,7 +413,9 @@ async def setup_oss_organization_for_first_user( # org with SELECT ... FOR UPDATE inside a single transaction — the # second caller blocks until the first commits, then sees the # workspace and skips the insert. No schema change required. - async with engine.core_session() as session: + engine = get_transactions_engine() + + async with engine.session() as session: await session.execute( select(OrganizationDB.id).filter_by(id=organization_db.id).with_for_update() ) @@ -471,7 +484,9 @@ async def check_if_user_invitation_exists(email: str, organization_id: str): "Default project not found for user invitation in organization." ) - async with engine.core_session() as session: + engine = get_transactions_engine() + + async with engine.session() as session: result = await session.execute( select(InvitationDB).filter_by( email=email, @@ -598,7 +613,9 @@ async def get_default_workspace_id_oss() -> str: orgs) cannot shadow the real singleton workspace and steer auth scope resolution to the wrong tenant. """ - async with engine.core_session() as session: + engine = get_transactions_engine() + + async with engine.session() as session: result = await session.execute( select(WorkspaceDB) .join(OrganizationDB, WorkspaceDB.organization_id == OrganizationDB.id) @@ -640,7 +657,9 @@ async def create_organization( EE keeps the previous behavior (one org per signup, slug left NULL). """ - async with engine.core_session() as session: + engine = get_transactions_engine() + + async with engine.session() as session: # For bootstrap scenario, use a placeholder UUID if not provided _owner_id = owner_id or uuid.uuid4() _created_by_id = created_by_id or _owner_id @@ -710,7 +729,9 @@ async def create_workspace(name: str, organization_id: str): WorkspaceDB: instance of workspace """ - async with engine.core_session() as session: + engine = get_transactions_engine() + + async with engine.session() as session: workspace_db = WorkspaceDB( name=name, organization_id=uuid.UUID(organization_id), @@ -740,7 +761,9 @@ async def update_organization(organization_id: str, values_to_update: Dict[str, values_to_update (Dict[str, Any]): The values to update in the organization """ - async with engine.core_session() as session: + engine = get_transactions_engine() + + async with engine.session() as session: result = await session.execute( select(OrganizationDB).filter_by(id=uuid.UUID(organization_id)) ) @@ -778,7 +801,9 @@ async def create_or_update_default_project(values_to_update: Dict[str, Any]): "create_or_update_default_project requires 'organization_id' in values_to_update" ) - async with engine.core_session() as session: + engine = get_transactions_engine() + + async with engine.session() as session: result = await session.execute( select(ProjectDB).filter_by( organization_id=organization_id, @@ -808,7 +833,9 @@ async def get_organizations() -> List[OrganizationDB]: List: A list of organizations. """ - async with engine.core_session() as session: + engine = get_transactions_engine() + + async with engine.session() as session: result = await session.execute(select(OrganizationDB)) organizations = result.scalars().all() return organizations @@ -825,7 +852,9 @@ async def get_organization_by_id(organization_id: str) -> OrganizationDB: OrganizationDB: The organization object if found, None otherwise. """ - async with engine.core_session() as session: + engine = get_transactions_engine() + + async with engine.session() as session: result = await session.execute( select(OrganizationDB).filter_by(id=uuid.UUID(organization_id)) ) @@ -844,7 +873,9 @@ async def get_organization_by_slug(organization_slug: str) -> OrganizationDB: OrganizationDB: The organization object if found, None otherwise. """ - async with engine.core_session() as session: + engine = get_transactions_engine() + + async with engine.session() as session: result = await session.execute( select(OrganizationDB).filter_by(slug=organization_slug) ) @@ -863,7 +894,9 @@ async def get_organization_owner(organization_id: str): UserDB: The owner of the organization if found, None otherwise. """ - async with engine.core_session() as session: + engine = get_transactions_engine() + + async with engine.session() as session: result = await session.execute( select(OrganizationDB).filter_by(id=uuid.UUID(organization_id)) ) @@ -888,7 +921,9 @@ async def get_user_organizations(user_id: str) -> List[OrganizationDB]: if is_ee(): from ee.src.models.db_models import OrganizationMemberDB - async with engine.core_session() as session: + engine = get_transactions_engine() + + async with engine.session() as session: # Query organizations through organization_members table result = await session.execute( select(OrganizationDB) @@ -917,7 +952,9 @@ async def get_workspace(workspace_id: str) -> WorkspaceDB: Workspace: The retrieved workspace. """ - async with engine.core_session() as session: + engine = get_transactions_engine() + + async with engine.session() as session: query = select(WorkspaceDB).filter_by(id=uuid.UUID(workspace_id)) result = await session.execute(query) @@ -933,7 +970,9 @@ async def get_workspaces() -> List[WorkspaceDB]: List: A list of workspaces. """ - async with engine.core_session() as session: + engine = get_transactions_engine() + + async with engine.session() as session: result = await session.execute(select(WorkspaceDB)) workspaces = result.scalars().all() return workspaces @@ -959,7 +998,9 @@ async def remove_user_from_workspace(project_id: str, email: str): if not project: raise NoResultFound(f"Project with ID {project_id} not found") - async with engine.core_session() as session: + engine = get_transactions_engine() + + async with engine.session() as session: if user: await session.delete(user) @@ -1005,7 +1046,9 @@ async def get_user_with_id(user_id: str) -> UserDB: Exception: If an error occurs while getting the user from the database. """ - async with engine.core_session() as session: + engine = get_transactions_engine() + + async with engine.session() as session: result = await session.execute(select(UserDB).filter_by(id=uuid.UUID(user_id))) user = result.scalars().first() if user is None: @@ -1017,7 +1060,9 @@ async def get_user_with_id(user_id: str) -> UserDB: async def update_user_username(user_id: str, username: str) -> UserDB: """Update a user's username.""" - async with engine.core_session() as session: + engine = get_transactions_engine() + + async with engine.session() as session: result = await session.execute(select(UserDB).filter_by(id=uuid.UUID(user_id))) user = result.scalars().first() if user is None: @@ -1052,7 +1097,9 @@ async def get_user_with_email(email: str): if "@" not in email: raise Exception("Please provide a valid email address") - async with engine.core_session() as session: + engine = get_transactions_engine() + + async with engine.session() as session: result = await session.execute(select(UserDB).filter_by(email=email)) user = result.scalars().first() return user @@ -1090,7 +1137,9 @@ async def create_user_invitation_to_organization( if not project: raise NoResultFound(f"Project with ID {project_id} not found") - async with engine.core_session() as session: + engine = get_transactions_engine() + + async with engine.session() as session: invitation = InvitationDB( token=token, email=email, @@ -1126,7 +1175,9 @@ async def get_project_by_id(project_id: str) -> ProjectDB: str: The retrieve project or None """ - async with engine.core_session() as session: + engine = get_transactions_engine() + + async with engine.session() as session: project_query = await session.execute( select(ProjectDB) .options(joinedload(ProjectDB.organization).load_only(OrganizationDB.name)) @@ -1150,7 +1201,9 @@ async def get_default_project_id_from_workspace( Union[str, Exception]: The default project ID or an exception error message. """ - async with engine.core_session() as session: + engine = get_transactions_engine() + + async with engine.session() as session: project_query = await session.execute( select(ProjectDB) .where( @@ -1248,7 +1301,8 @@ async def _create( if session is not None: return await _create(session) - async with engine.core_session() as new_session: + engine = get_transactions_engine() + async with engine.session() as new_session: return await _create(new_session) @@ -1260,7 +1314,9 @@ async def delete_project(project_id: str) -> None: project_id (str): Identifier of project to delete. """ - async with engine.core_session() as session: + engine = get_transactions_engine() + + async with engine.session() as session: project = await session.get(ProjectDB, uuid.UUID(project_id)) if project is None: raise NoResultFound(f"Project with ID {project_id} not found") @@ -1291,7 +1347,9 @@ async def set_default_project(project_id: str) -> ProjectDB: ProjectDB: Updated project. """ - async with engine.core_session() as session: + engine = get_transactions_engine() + + async with engine.session() as session: project = await session.get(ProjectDB, uuid.UUID(project_id)) if project is None: raise NoResultFound(f"Project with ID {project_id} not found") @@ -1334,7 +1392,9 @@ async def update_project_name(project_id: str, *, project_name: str) -> ProjectD if not project_name.strip(): raise ValueError("Project name cannot be empty") - async with engine.core_session() as session: + engine = get_transactions_engine() + + async with engine.session() as session: project = await session.get(ProjectDB, uuid.UUID(project_id)) if project is None: raise NoResultFound(f"Project with ID {project_id} not found") @@ -1365,7 +1425,9 @@ async def get_project_invitation_by_email(project_id: str, email: str) -> Invita InvitationDB: invitation object """ - async with engine.core_session() as session: + engine = get_transactions_engine() + + async with engine.session() as session: result = await session.execute( select(InvitationDB).filter_by( project_id=uuid.UUID(project_id), email=email @@ -1385,7 +1447,9 @@ async def get_project_invitations(project_id: str) -> InvitationDB: List[InvitationDB]: invitation objects """ - async with engine.core_session() as session: + engine = get_transactions_engine() + + async with engine.session() as session: result = await session.execute( select(InvitationDB).filter_by(project_id=uuid.UUID(project_id)) ) @@ -1405,7 +1469,9 @@ async def update_invitation(invitation_id: str, values_to_update: dict) -> bool: bool: True if the invitation was successfully updated, False otherwise. """ - async with engine.core_session() as session: + engine = get_transactions_engine() + + async with engine.session() as session: result = await session.execute( select(InvitationDB).filter_by(id=uuid.UUID(invitation_id)) ) @@ -1446,7 +1512,9 @@ async def delete_invitation(invitation_id: str) -> bool: bool: True if the invitation was successfully deleted, False otherwise. """ - async with engine.core_session() as session: + engine = get_transactions_engine() + + async with engine.session() as session: result = await session.execute( select(InvitationDB).filter_by(id=uuid.UUID(invitation_id)) ) @@ -1498,7 +1566,9 @@ async def get_project_by_organization_id(organization_id: str): ProjectDB: project object """ - async with engine.core_session() as session: + engine = get_transactions_engine() + + async with engine.session() as session: result = await session.execute( select(ProjectDB).filter_by(organization_id=uuid.UUID(organization_id)) ) @@ -1513,8 +1583,9 @@ async def get_default_project_by_organization_id(organization_id: str): so callers that depend on the OSS singleton invariant don't accidentally pick up an ephemeral per-account project. """ + engine = get_transactions_engine() - async with engine.core_session() as session: + async with engine.session() as session: result = await session.execute( select(ProjectDB).filter_by( organization_id=uuid.UUID(organization_id), @@ -1538,7 +1609,9 @@ async def get_project_invitation_by_token_and_email( InvitationDB: invitation object """ - async with engine.core_session() as session: + engine = get_transactions_engine() + + async with engine.session() as session: result = await session.execute( select(InvitationDB).filter_by( project_id=uuid.UUID(project_id), token=token, email=email @@ -1562,7 +1635,9 @@ async def get_user_api_key_by_prefix( The user api key by prefix. """ - async with engine.core_session() as session: + engine = get_transactions_engine() + + async with engine.session() as session: result = await session.execute( select(APIKeyDB).filter_by( prefix=api_key_prefix, created_by_id=uuid.UUID(user_id) @@ -1578,56 +1653,74 @@ async def get_user_api_key_by_prefix( async def admin_get_user_by_id(user_id: uuid.UUID) -> Optional[UserDB]: - async with engine.core_session() as session: + engine = get_transactions_engine() + + async with engine.session() as session: result = await session.execute(select(UserDB).filter_by(id=user_id)) return result.scalars().first() async def admin_get_user_by_email(email: str) -> Optional[UserDB]: - async with engine.core_session() as session: + engine = get_transactions_engine() + + async with engine.session() as session: result = await session.execute(select(UserDB).filter_by(email=email)) return result.scalars().first() async def admin_get_org_by_id(org_id: uuid.UUID) -> Optional[OrganizationDB]: - async with engine.core_session() as session: + engine = get_transactions_engine() + + async with engine.session() as session: result = await session.execute(select(OrganizationDB).filter_by(id=org_id)) return result.scalars().first() async def admin_get_org_by_slug(slug: str) -> Optional[OrganizationDB]: - async with engine.core_session() as session: + engine = get_transactions_engine() + + async with engine.session() as session: result = await session.execute(select(OrganizationDB).filter_by(slug=slug)) return result.scalars().first() async def admin_get_workspace_by_id(ws_id: uuid.UUID) -> Optional[WorkspaceDB]: - async with engine.core_session() as session: + engine = get_transactions_engine() + + async with engine.session() as session: result = await session.execute(select(WorkspaceDB).filter_by(id=ws_id)) return result.scalars().first() async def admin_get_project_by_id(proj_id: uuid.UUID) -> Optional[ProjectDB]: - async with engine.core_session() as session: + engine = get_transactions_engine() + + async with engine.session() as session: result = await session.execute(select(ProjectDB).filter_by(id=proj_id)) return result.scalars().first() async def admin_get_api_key_by_id(key_id: uuid.UUID) -> Optional[APIKeyDB]: - async with engine.core_session() as session: + engine = get_transactions_engine() + + async with engine.session() as session: result = await session.execute(select(APIKeyDB).filter_by(id=key_id)) return result.scalars().first() async def admin_get_api_key_by_prefix(prefix: str) -> Optional[APIKeyDB]: - async with engine.core_session() as session: + engine = get_transactions_engine() + + async with engine.session() as session: result = await session.execute(select(APIKeyDB).filter_by(prefix=prefix)) return result.scalars().first() async def admin_get_orgs_owned_by_user(user_id: uuid.UUID) -> List[OrganizationDB]: """Return orgs where user is owner OR creator (both carry RESTRICT FK).""" - async with engine.core_session() as session: + engine = get_transactions_engine() + + async with engine.session() as session: result = await session.execute( select(OrganizationDB).where( or_( @@ -1642,7 +1735,9 @@ async def admin_get_orgs_owned_by_user(user_id: uuid.UUID) -> List[OrganizationD async def admin_get_workspace_ids_for_orgs( org_ids: List[uuid.UUID], ) -> List[uuid.UUID]: - async with engine.core_session() as session: + engine = get_transactions_engine() + + async with engine.session() as session: result = await session.execute( select(WorkspaceDB.id).where(WorkspaceDB.organization_id.in_(org_ids)) ) @@ -1652,7 +1747,9 @@ async def admin_get_workspace_ids_for_orgs( async def admin_get_project_ids_for_orgs( org_ids: List[uuid.UUID], ) -> List[uuid.UUID]: - async with engine.core_session() as session: + engine = get_transactions_engine() + + async with engine.session() as session: result = await session.execute( select(ProjectDB.id).where(ProjectDB.organization_id.in_(org_ids)) ) @@ -1666,7 +1763,9 @@ async def admin_get_or_create_user( existing = await admin_get_user_by_email(email) if existing: return existing - async with engine.core_session() as session: + engine = get_transactions_engine() + + async with engine.session() as session: user_db = UserDB( uid=str(uuid.uuid4()), username=username or email.split("@")[0], @@ -1698,7 +1797,9 @@ async def admin_create_organization( On EE behavior is unchanged: a new row is inserted with the supplied ``name``/``slug``. """ - async with engine.core_session() as session: + engine = get_transactions_engine() + + async with engine.session() as session: if not is_ee(): stmt = ( pg_insert(OrganizationDB) @@ -1754,7 +1855,9 @@ async def admin_create_workspace( On EE behavior is unchanged: a fresh workspace row is always inserted. """ - async with engine.core_session() as session: + engine = get_transactions_engine() + + async with engine.session() as session: if not is_ee(): await session.execute( select(OrganizationDB.id).filter_by(id=org_id).with_for_update() @@ -1800,7 +1903,9 @@ async def admin_create_project( *, is_default: bool = False, ) -> ProjectDB: - async with engine.core_session() as session: + engine = get_transactions_engine() + + async with engine.session() as session: proj_db = ProjectDB( project_name=name, is_default=is_default, @@ -1815,31 +1920,41 @@ async def admin_create_project( async def admin_delete_user(user_id: uuid.UUID) -> None: - async with engine.core_session() as session: + engine = get_transactions_engine() + + async with engine.session() as session: await session.execute(delete(UserDB).where(UserDB.id == user_id)) await session.commit() async def admin_delete_organization(org_id: uuid.UUID) -> None: - async with engine.core_session() as session: + engine = get_transactions_engine() + + async with engine.session() as session: await session.execute(delete(OrganizationDB).where(OrganizationDB.id == org_id)) await session.commit() async def admin_delete_workspace(ws_id: uuid.UUID) -> None: - async with engine.core_session() as session: + engine = get_transactions_engine() + + async with engine.session() as session: await session.execute(delete(WorkspaceDB).where(WorkspaceDB.id == ws_id)) await session.commit() async def admin_delete_project(proj_id: uuid.UUID) -> None: - async with engine.core_session() as session: + engine = get_transactions_engine() + + async with engine.session() as session: await session.execute(delete(ProjectDB).where(ProjectDB.id == proj_id)) await session.commit() async def admin_delete_api_key(key_id: uuid.UUID) -> None: - async with engine.core_session() as session: + engine = get_transactions_engine() + + async with engine.session() as session: await session.execute(delete(APIKeyDB).where(APIKeyDB.id == key_id)) await session.commit() @@ -1852,7 +1967,9 @@ async def admin_delete_accounts_batch( user_ids: List[uuid.UUID], ) -> None: """Delete a batch of entities atomically, in dependency order.""" - async with engine.core_session() as session: + engine = get_transactions_engine() + + async with engine.session() as session: for proj_id in project_ids: await session.execute(delete(ProjectDB).where(ProjectDB.id == proj_id)) for ws_id in workspace_ids: @@ -1894,7 +2011,9 @@ async def admin_transfer_org_ownership_batch( of that user does not destroy orgs now owned by the target. """ now = datetime.now(timezone.utc) - async with engine.core_session() as session: + engine = get_transactions_engine() + + async with engine.session() as session: for org_id in org_ids: await session.execute( update(OrganizationDB) diff --git a/api/oss/src/services/user_service.py b/api/oss/src/services/user_service.py index d254510e72..d6803fe50a 100644 --- a/api/oss/src/services/user_service.py +++ b/api/oss/src/services/user_service.py @@ -5,7 +5,7 @@ from oss.src.utils.env import env from oss.src.models.db_models import UserDB from oss.src.utils.logging import get_module_logger -from oss.src.dbs.postgres.shared.engine import engine +from oss.src.dbs.postgres.shared.engine import get_transactions_engine from oss.src.models.api.user_models import UserUpdate from oss.src.services import db_manager, email_service @@ -36,7 +36,9 @@ async def delete_user(user_id: str) -> None: Raises: NoResultFound: If user with the given ID is not found. """ - async with engine.core_session() as session: + engine = get_transactions_engine() + + async with engine.session() as session: result = await session.execute(select(UserDB).filter_by(id=user_id)) user = result.scalars().first() @@ -68,7 +70,9 @@ async def create_new_user(payload: dict) -> UserDB: # Attempt to create new user try: - async with engine.core_session() as session: + engine = get_transactions_engine() + + async with engine.session() as session: user = UserDB(**payload) session.add(user) @@ -110,7 +114,9 @@ async def update_user(user_uid: str, payload: UserUpdate) -> UserDB: NoResultFound: User with session id xxxx not found. """ - async with engine.core_session() as session: + engine = get_transactions_engine() + + async with engine.session() as session: result = await session.execute(select(UserDB).filter_by(uid=user_uid)) user = result.scalars().first() diff --git a/api/oss/src/tasks/taskiq/evaluations/worker.py b/api/oss/src/tasks/taskiq/evaluations/worker.py index 4c70fa8d25..085f9a9d9f 100644 --- a/api/oss/src/tasks/taskiq/evaluations/worker.py +++ b/api/oss/src/tasks/taskiq/evaluations/worker.py @@ -15,14 +15,10 @@ from oss.src.core.workflows.service import WorkflowsService from oss.src.core.evaluations.service import EvaluationsService -from oss.src.core.evaluations.tasks.legacy import ( - evaluate_batch_testset as evaluate_batch_testset_impl, - evaluate_batch_invocation as evaluate_batch_invocation_impl, - evaluate_batch_testcases as evaluate_batch_testcases_impl, - evaluate_batch_traces as evaluate_batch_traces_impl, -) -from oss.src.core.evaluations.tasks.live import ( - evaluate_live_query as evaluate_live_query_impl, +from oss.src.core.evaluations.tasks.run import ( + EvaluationSliceSource, + process_evaluation_run, + process_evaluation_slice, ) from oss.src.core.evaluations.runtime.locks import ( acquire_job_lock, @@ -225,55 +221,12 @@ def _register_tasks(self): """Register all evaluation tasks with the broker.""" @self.broker.task( - task_name="evaluations.legacy.annotate", + task_name="evaluations.run.process", retry_on_error=False, max_retries=0, # Never retry - handle errors in application logic ) - async def evaluate_batch_testset( + async def process_run( *, - project_id: UUID, - user_id: UUID, - # - run_id: UUID, - context: Context = TaskiqDepends(), - ) -> Any: - """Legacy annotation task - wraps the existing annotate function.""" - log.info( - "[TASK] Starting evaluate_batch_testset", - project_id=str(project_id), - user_id=str(user_id), - ) - - result = await self._with_job_lock( - run_id, - job_id=context.message.task_id or str(uuid4()), - job_type="api", - allow_concurrency=False, - runner=lambda: evaluate_batch_testset_impl( - project_id=project_id, - user_id=user_id, - # - run_id=run_id, - # - tracing_service=self.tracing_service, - testsets_service=self.testsets_service, - queries_service=self.queries_service, - workflows_service=self.workflows_service, - applications_service=self.applications_service, - evaluations_service=self.evaluations_service, - # - simple_evaluators_service=self.simple_evaluators_service, - ), - ) - log.info("[TASK] Completed evaluate_batch_testset") - return result - - @self.broker.task( - task_name="evaluations.live.evaluate", - retry_on_error=False, - max_retries=0, # Never retry - handle errors in application logic - ) - async def evaluate_live_query( project_id: UUID, user_id: UUID, # @@ -283,8 +236,8 @@ async def evaluate_live_query( oldest: Optional[datetime] = None, context: Context = TaskiqDepends(), ) -> Any: - """Live evaluation task - evaluates traces against evaluators.""" - log.info("[TASK] Starting evaluate_live_query") + """Process one evaluation run using the unified topology dispatcher.""" + log.info("[TASK] Starting process_run") if newest is None: newest = datetime.now(timezone.utc) @@ -296,168 +249,64 @@ async def evaluate_live_query( job_id=context.message.task_id or str(uuid4()), job_type="api", allow_concurrency=False, - runner=lambda: evaluate_live_query_impl( + runner=lambda: process_evaluation_run( project_id=project_id, user_id=user_id, - # run_id=run_id, - # newest=newest, oldest=oldest, - ), - ) - log.info("[TASK] Completed evaluate_live_query") - return result - - @self.broker.task( - task_name="evaluations.queries.batch", - retry_on_error=False, - max_retries=0, - ) - async def evaluate_batch_query( - *, - project_id: UUID, - user_id: UUID, - # - run_id: UUID, - context: Context = TaskiqDepends(), - ) -> Any: - """One-shot query evaluation task for non-live runs.""" - log.info("[TASK] Starting evaluate_batch_query") - - result = await self._with_job_lock( - run_id, - job_id=context.message.task_id or str(uuid4()), - job_type="api", - allow_concurrency=False, - runner=lambda: evaluate_live_query_impl( - project_id=project_id, - user_id=user_id, - # - run_id=run_id, - # - newest=None, - oldest=None, - # - use_windowing=True, - ), - ) - log.info("[TASK] Completed evaluate_batch_query") - return result - - @self.broker.task( - task_name="evaluations.invocations.batch", - retry_on_error=False, - max_retries=0, - ) - async def evaluate_batch_invocation( - *, - project_id: UUID, - user_id: UUID, - # - run_id: UUID, - context: Context = TaskiqDepends(), - ) -> Any: - log.info("[TASK] Starting evaluate_batch_invocation") - result = await self._with_job_lock( - run_id, - job_id=context.message.task_id or str(uuid4()), - job_type="api", - allow_concurrency=False, - runner=lambda: evaluate_batch_invocation_impl( - project_id=project_id, - user_id=user_id, - # - run_id=run_id, - # tracing_service=self.tracing_service, testsets_service=self.testsets_service, + queries_service=self.queries_service, + workflows_service=self.workflows_service, applications_service=self.applications_service, evaluations_service=self.evaluations_service, + simple_evaluators_service=self.simple_evaluators_service, ), ) - log.info("[TASK] Completed evaluate_batch_invocation") + log.info("[TASK] Completed process_run") return result @self.broker.task( - task_name="evaluations.traces.batch", + task_name="evaluations.slice.process", retry_on_error=False, max_retries=0, ) - async def evaluate_batch_traces( + async def process_slice( *, project_id: UUID, user_id: UUID, # run_id: UUID, - trace_ids: list[str], + source_kind: EvaluationSliceSource, + trace_ids: Optional[list[str]] = None, + testcase_ids: Optional[list[UUID]] = None, input_step_key: Optional[str] = None, context: Context = TaskiqDepends(), ) -> Any: - log.info("[TASK] Starting evaluate_batch_traces") + log.info("[TASK] Starting process_slice", source_kind=source_kind) result = await self._with_job_lock( run_id, job_id=context.message.task_id or str(uuid4()), job_type="api", allow_concurrency=True, - runner=lambda: evaluate_batch_traces_impl( + runner=lambda: process_evaluation_slice( project_id=project_id, user_id=user_id, - # run_id=run_id, + source_kind=source_kind, trace_ids=trace_ids, - input_step_key=input_step_key, - # - tracing_service=self.tracing_service, - workflows_service=self.workflows_service, - evaluations_service=self.evaluations_service, - ), - ) - log.info("[TASK] Completed evaluate_batch_traces") - return result - - @self.broker.task( - task_name="evaluations.testcases.batch", - retry_on_error=False, - max_retries=0, - ) - async def evaluate_batch_testcases( - *, - project_id: UUID, - user_id: UUID, - # - run_id: UUID, - testcase_ids: list[UUID], - input_step_key: Optional[str] = None, - context: Context = TaskiqDepends(), - ) -> Any: - log.info("[TASK] Starting evaluate_batch_testcases") - result = await self._with_job_lock( - run_id, - job_id=context.message.task_id or str(uuid4()), - job_type="api", - allow_concurrency=True, - runner=lambda: evaluate_batch_testcases_impl( - project_id=project_id, - user_id=user_id, - # - run_id=run_id, testcase_ids=testcase_ids, input_step_key=input_step_key, - # tracing_service=self.tracing_service, testcases_service=self.testcases_service, workflows_service=self.workflows_service, evaluations_service=self.evaluations_service, ), ) - log.info("[TASK] Completed evaluate_batch_testcases") + log.info("[TASK] Completed process_slice", source_kind=source_kind) return result # Store task references for external access - self.evaluate_batch_testset = evaluate_batch_testset - self.evaluate_live_query = evaluate_live_query - self.evaluate_batch_query = evaluate_batch_query - self.evaluate_batch_invocation = evaluate_batch_invocation - self.evaluate_batch_traces = evaluate_batch_traces - self.evaluate_batch_testcases = evaluate_batch_testcases + self.process_run = process_run + self.process_slice = process_slice diff --git a/api/oss/src/utils/caching.py b/api/oss/src/utils/caching.py index 5fb83e1a65..96be2adaab 100644 --- a/api/oss/src/utils/caching.py +++ b/api/oss/src/utils/caching.py @@ -5,12 +5,11 @@ import orjson -# from cachetools import TTLCache -from redis.asyncio import Redis from pydantic import BaseModel from oss.src.utils.logging import get_module_logger from oss.src.utils.env import env +from oss.src.dbs.redis.shared.engine import get_cache_engine log = get_module_logger(__name__) @@ -38,20 +37,8 @@ # Original L1 cache: # local_cache: TTLCache = TTLCache(maxsize=4096, ttl=AGENTA_CACHE_LOCAL_TTL) -# Use volatile Redis instance for caching (prefix-based separation) -# decode_responses=False: orjson operates on bytes for 3x performance vs json -r = Redis.from_url( - url=env.redis.uri_volatile, - decode_responses=False, - socket_timeout=0.5, # read/write timeout -) - -# Dedicated Redis client for distributed locks with a longer timeout. -r_lock = Redis.from_url( - url=env.redis.uri_volatile, - decode_responses=False, - socket_timeout=AGENTA_LOCK_SOCKET_TIMEOUT, -) +_cache_engine = get_cache_engine() + # Ownership-safe lock scripts. Owner token must match to renew/release. _LOCK_RENEW_IF_OWNER_SCRIPT = """ @@ -129,7 +116,7 @@ async def _scan(pattern: str) -> list[str]: keys: list[str] = [] while True: # TODO: Really ? - cursor, batch = await r.scan( + cursor, batch = await _cache_engine.get_r().scan( cursor=cursor, match=pattern, count=AGENTA_CACHE_SCAN_BATCH_SIZE, @@ -219,7 +206,7 @@ async def _try_get_and_maybe_renew( # return _deserialize(raw, model=model, is_list=is_list) # Layer 2: Check Redis (distributed, 5min TTL, ~1ms latency) - raw = await r.get(cache_name) + raw = await _cache_engine.get_r().get(cache_name) if raw is not None: if CACHE_DEBUG: @@ -240,7 +227,7 @@ async def _try_get_and_maybe_renew( name=cache_name, ) - await r.expire(cache_name, ttl) + await _cache_engine.get_r().expire(cache_name, ttl) else: if CACHE_DEBUG: log.debug( @@ -306,7 +293,7 @@ async def _maybe_retry_get( lock_name = f"lock::{cache_name}" lock_ex = int(lock_ttl * 1000) # convert seconds to milliseconds - got_lock = await r.set(lock_name, "1", nx=True, ex=lock_ex) + got_lock = await _cache_engine.get_r().set(lock_name, "1", nx=True, ex=lock_ex) if got_lock: if CACHE_DEBUG: @@ -379,7 +366,7 @@ async def set_cache( # # Original L1 write path: # local_cache[cache_name] = cache_value - await r.set(cache_name, cache_value, px=cache_px) + await _cache_engine.get_r().set(cache_name, cache_value, px=cache_px) if CACHE_DEBUG: log.debug( @@ -392,7 +379,7 @@ async def set_cache( lock_name = f"lock::{cache_name}" - check = await r.delete(lock_name) + check = await _cache_engine.get_r().delete(lock_name) if check: if CACHE_DEBUG: @@ -513,7 +500,7 @@ async def invalidate_cache( # # Original L1 invalidation path: # local_cache.pop(cache_name, None) - await r.delete(cache_name) + await _cache_engine.get_r().delete(cache_name) else: cache_name = _pack( @@ -560,7 +547,7 @@ async def invalidate_cache( redis_keys_deleted = 0 for i in range(0, len(keys), AGENTA_CACHE_DELETE_BATCH_SIZE): batch = keys[i : i + AGENTA_CACHE_DELETE_BATCH_SIZE] - deleted_count = await r.delete(*batch) + deleted_count = await _cache_engine.get_r().delete(*batch) redis_keys_deleted += deleted_count if CACHE_DEBUG: @@ -642,7 +629,9 @@ async def acquire_lock( lock_owner = uuid4().hex # Atomic SET NX: Returns True if lock acquired, False if already held - acquired = await r_lock.set(lock_key, lock_owner, nx=True, ex=ttl) + acquired = await _cache_engine.get_r_lock().set( + lock_key, lock_owner, nx=True, ex=ttl + ) if acquired: if CACHE_DEBUG: @@ -704,7 +693,7 @@ async def renew_lock( ) if owner: - renewed = await r_lock.eval( + renewed = await _cache_engine.get_r_lock().eval( _LOCK_RENEW_IF_OWNER_SCRIPT, 1, lock_key, @@ -712,7 +701,7 @@ async def renew_lock( str(ttl), ) else: - renewed = await r_lock.expire(lock_key, ttl) + renewed = await _cache_engine.get_r_lock().expire(lock_key, ttl) if renewed: if CACHE_DEBUG: @@ -774,14 +763,14 @@ async def release_lock( ) if owner: - deleted = await r_lock.eval( + deleted = await _cache_engine.get_r_lock().eval( _LOCK_RELEASE_IF_OWNER_SCRIPT, 1, lock_key, owner, ) else: - deleted = await r_lock.delete(lock_key) + deleted = await _cache_engine.get_r_lock().delete(lock_key) if deleted: if CACHE_DEBUG: diff --git a/api/oss/src/utils/lazy.py b/api/oss/src/utils/lazy.py new file mode 100644 index 0000000000..9f72abfc18 --- /dev/null +++ b/api/oss/src/utils/lazy.py @@ -0,0 +1,65 @@ +from typing import Optional, TYPE_CHECKING + +if TYPE_CHECKING: + import stripe + import posthog + + +_stripe_module: Optional["stripe"] = None +_stripe_checked = False + +_posthog_module: Optional["posthog"] = None +_posthog_checked = False + + +def _load_stripe() -> Optional["stripe"]: + global _stripe_module, _stripe_checked + + if _stripe_checked: + return _stripe_module + + _stripe_checked = True + try: + import stripe as _stripe + from oss.src.utils.env import env + from oss.src.utils.logging import get_module_logger + + log = get_module_logger(__name__) + + if env.stripe.enabled: + _stripe.api_key = env.stripe.api_key + log.info("✓ Stripe enabled:", target=env.stripe.webhook_target) + + _stripe_module = _stripe + except Exception: + _stripe_module = None + + return _stripe_module + + +def _load_posthog() -> Optional["posthog"]: + global _posthog_module, _posthog_checked + + if _posthog_checked: + return _posthog_module + + _posthog_checked = True + try: + import posthog as _posthog + from oss.src.utils.env import env + from oss.src.utils.logging import get_module_logger + + log = get_module_logger(__name__) + + if env.posthog.enabled: + _posthog.api_key = env.posthog.api_key + _posthog.host = env.posthog.api_url + log.info("✓ PostHog enabled") + else: + log.warn("✗ PostHog disabled") + + _posthog_module = _posthog + except Exception: + _posthog_module = None + + return _posthog_module diff --git a/api/oss/tests/pytest/acceptance/accounts/test_actions.py b/api/oss/tests/pytest/acceptance/accounts/test_actions.py index 3a2d6acc79..97e8328380 100644 --- a/api/oss/tests/pytest/acceptance/accounts/test_actions.py +++ b/api/oss/tests/pytest/acceptance/accounts/test_actions.py @@ -8,6 +8,8 @@ """ from uuid import uuid4 +import os +import pytest # --------------------------------------------------------------------------- @@ -54,6 +56,10 @@ def _delete_account_by_email(admin_api, *, email): class TestResetPassword: + @pytest.mark.skipif( + not os.getenv("POSTHOG_API_KEY"), + reason="PostHog API key not configured", + ) def test_reset_password_for_existing_identity(self, admin_api): uid = uuid4().hex[:12] email = f"reset-{uid}@test.agenta.ai" diff --git a/api/oss/tests/pytest/acceptance/evaluations/test_simple_evaluations_workflows.py b/api/oss/tests/pytest/acceptance/evaluations/test_simple_evaluations_workflows.py index 8f994024b1..a932a26c21 100644 --- a/api/oss/tests/pytest/acceptance/evaluations/test_simple_evaluations_workflows.py +++ b/api/oss/tests/pytest/acceptance/evaluations/test_simple_evaluations_workflows.py @@ -63,6 +63,44 @@ def _create_simple_evaluator(authed_api) -> dict: return response.json()["evaluator"] +def _create_simple_testset(authed_api) -> dict: + slug = uuid4().hex + response = authed_api( + "POST", + "/simple/testsets/", + json={ + "testset": { + "slug": f"testset-{slug}", + "name": f"Testset {slug}", + "data": { + "testcases": [ + {"data": {"input": "hello", "expected": "world"}}, + {"data": {"input": "hola", "expected": "mundo"}}, + ] + }, + } + }, + ) + assert response.status_code == 200 + return response.json()["testset"] + + +def _create_simple_application(authed_api) -> dict: + slug = uuid4().hex + response = authed_api( + "POST", + "/simple/applications/", + json={ + "application": { + "slug": f"application-{slug}", + "name": f"Application {slug}", + } + }, + ) + assert response.status_code == 200 + return response.json()["application"] + + class TestSimpleEvaluationsWorkflowReferences: def test_create_live_simple_evaluation_accepts_query_and_evaluator_revision_ids( self, authed_api @@ -124,3 +162,146 @@ def test_create_live_simple_evaluation_rejects_non_invocation_query_revision( body = response.json() assert body["count"] == 0 assert body.get("evaluation") is None + + def test_create_batch_inference_evaluation_preserves_flags_repeats_and_refs( + self, authed_api + ): + testset = _create_simple_testset(authed_api) + application = _create_simple_application(authed_api) + + response = authed_api( + "POST", + "/simple/evaluations/", + json={ + "evaluation": { + "name": "batch-inference-setup", + "flags": { + "is_cached": True, + "is_split": True, + }, + "data": { + "testset_steps": [testset["revision_id"]], + "application_steps": [application["revision_id"]], + "repeats": 3, + }, + } + }, + ) + + assert response.status_code == 200 + body = response.json() + assert body["count"] == 1 + + evaluation = body["evaluation"] + assert evaluation["flags"]["is_cached"] is True + assert evaluation["flags"]["is_split"] is True + assert evaluation["data"]["repeats"] == 3 + assert set(evaluation["data"]["testset_steps"].keys()) == { + testset["revision_id"] + } + assert set(evaluation["data"]["application_steps"].keys()) == { + application["revision_id"] + } + + def test_create_batch_evaluation_preserves_application_evaluator_matrix( + self, authed_api + ): + testset = _create_simple_testset(authed_api) + application = _create_simple_application(authed_api) + evaluator = _create_simple_evaluator(authed_api) + + response = authed_api( + "POST", + "/simple/evaluations/", + json={ + "evaluation": { + "name": "batch-application-evaluator-setup", + "flags": { + "is_cached": False, + "is_split": False, + }, + "data": { + "testset_steps": {testset["revision_id"]: "custom"}, + "application_steps": {application["revision_id"]: "custom"}, + "evaluator_steps": {evaluator["revision_id"]: "auto"}, + "repeats": 2, + }, + } + }, + ) + + assert response.status_code == 200 + body = response.json() + assert body["count"] == 1 + + evaluation = body["evaluation"] + assert evaluation["flags"]["is_cached"] is False + assert evaluation["flags"]["is_split"] is False + assert evaluation["data"]["repeats"] == 2 + assert evaluation["data"]["testset_steps"] == {testset["revision_id"]: "custom"} + assert evaluation["data"]["application_steps"] == { + application["revision_id"]: "custom" + } + assert evaluation["data"]["evaluator_steps"] == { + evaluator["revision_id"]: "auto" + } + + def test_create_testset_to_evaluator_evaluation_does_not_fail_setup( + self, authed_api + ): + testset = _create_simple_testset(authed_api) + evaluator = _create_simple_evaluator(authed_api) + + response = authed_api( + "POST", + "/simple/evaluations/", + json={ + "evaluation": { + "name": "testset-to-evaluator-setup", + "data": { + "testset_steps": [testset["revision_id"]], + "evaluator_steps": [evaluator["revision_id"]], + }, + } + }, + ) + + assert response.status_code == 200 + body = response.json() + assert body["count"] == 1 + evaluation = body["evaluation"] + assert set(evaluation["data"]["testset_steps"].keys()) == { + testset["revision_id"] + } + assert set(evaluation["data"]["evaluator_steps"].keys()) == { + evaluator["revision_id"] + } + + def test_create_query_to_application_evaluation_does_not_fail_setup( + self, authed_api + ): + query = _create_simple_query(authed_api) + application = _create_simple_application(authed_api) + + response = authed_api( + "POST", + "/simple/evaluations/", + json={ + "evaluation": { + "name": "query-to-application-setup", + "data": { + "query_steps": [query["revision_id"]], + "application_steps": [application["revision_id"]], + }, + } + }, + ) + + assert response.status_code == 200 + body = response.json() + assert body["count"] == 1 + evaluation = body["evaluation"] + assert set(evaluation["data"]["query_steps"].keys()) == {query["revision_id"]} + assert set(evaluation["data"]["application_steps"].keys()) == { + application["revision_id"] + } diff --git a/api/oss/tests/pytest/acceptance/evaluations/test_simple_queues_basics.py b/api/oss/tests/pytest/acceptance/evaluations/test_simple_queues_basics.py index d8fdabef7a..2bc3ff9b53 100644 --- a/api/oss/tests/pytest/acceptance/evaluations/test_simple_queues_basics.py +++ b/api/oss/tests/pytest/acceptance/evaluations/test_simple_queues_basics.py @@ -198,6 +198,39 @@ def test_create_simple_queue_from_testsets(self, authed_api): assert queue["data"]["kind"] == "testcases" assert queue["data"]["testsets"] == [testset_revision_id] + def test_create_source_backed_queue_preserves_repeats_and_assignments( + self, authed_api + ): + evaluator_revision_id = _create_evaluator(authed_api) + testset_revision_id = _create_testset(authed_api) + user_id_1 = str(uuid4()) + user_id_2 = str(uuid4()) + + response = authed_api( + "POST", + "/simple/queues/", + json={ + "queue": { + "name": "testset-backed-queue-with-repeats", + "data": { + "testsets": [testset_revision_id], + "evaluators": {evaluator_revision_id: "human"}, + "repeats": 2, + "assignments": [[user_id_1], [user_id_2]], + }, + } + }, + ) + + assert response.status_code == 200 + data = response.json() + assert data["count"] == 1 + queue = data["queue"] + assert queue["data"]["kind"] == "testcases" + assert queue["data"]["testsets"] == [testset_revision_id] + assert queue["data"]["repeats"] == 2 + assert queue["data"]["assignments"] == [[user_id_1], [user_id_2]] + def test_create_simple_queue_without_evaluators_returns_empty(self, authed_api): # ACT ------------------------------------------------------------------ response = authed_api( @@ -308,6 +341,30 @@ def test_create_simple_queue_rejects_kind_with_queries(self, authed_api): in response.text ) + def test_create_simple_queue_rejects_mixed_query_and_testset_sources( + self, authed_api + ): + evaluator_revision_id = _create_evaluator(authed_api) + query_revision_id = _create_query(authed_api) + testset_revision_id = _create_testset(authed_api) + + response = authed_api( + "POST", + "/simple/queues/", + json={ + "queue": { + "data": { + "queries": [query_revision_id], + "testsets": [testset_revision_id], + "evaluators": [evaluator_revision_id], + }, + } + }, + ) + + assert response.status_code == 422 + assert "simple queue source must be either queries or testsets" in response.text + def test_create_simple_queue_with_assignments(self, authed_api): # ARRANGE -------------------------------------------------------------- evaluator_revision_id = _create_evaluator(authed_api) diff --git a/api/oss/tests/pytest/unit/auth/test_helper.py b/api/oss/tests/pytest/unit/auth/test_helper.py index ce4fa55a34..795af7c18f 100644 --- a/api/oss/tests/pytest/unit/auth/test_helper.py +++ b/api/oss/tests/pytest/unit/auth/test_helper.py @@ -72,7 +72,7 @@ async def test_get_blocked_domains_accepts_string_posthog_payload(monkeypatch): blocked_emails=set(), allowed_domains=set(), ), - posthog=SimpleNamespace(enabled=True), + posthog=SimpleNamespace(enabled=True, api_key="posthog-key"), ), ) monkeypatch.setattr( @@ -107,7 +107,7 @@ async def test_get_blocked_domains_splits_comma_separated_posthog_payload(monkey blocked_emails=set(), allowed_domains=set(), ), - posthog=SimpleNamespace(enabled=True), + posthog=SimpleNamespace(enabled=True, api_key="posthog-key"), ), ) monkeypatch.setattr( @@ -131,6 +131,36 @@ async def test_get_blocked_domains_splits_comma_separated_posthog_payload(monkey assert blocked_domains == {"spica.asia", "agenta.ai", "yopmail.com"} +@pytest.mark.asyncio +async def test_get_blocked_emails_treats_posthog_errors_as_empty(monkeypatch): + monkeypatch.setattr( + auth_helper, + "env", + SimpleNamespace( + agenta=SimpleNamespace( + blocked_domains=set(), + blocked_emails=set(), + allowed_domains=set(), + ), + posthog=SimpleNamespace(enabled=True), + ), + ) + monkeypatch.setattr( + auth_helper, + "get_cache", + AsyncMock(return_value=None), + ) + monkeypatch.setattr( + auth_helper.posthog, + "get_feature_flag_payload", + lambda feature_flag, distinct_id: (_ for _ in ()).throw( + ValueError("API key is required") + ), + ) + + assert await auth_helper.get_blocked_emails() == set() + + @pytest.mark.asyncio async def test_thirdparty_sign_in_up_checks_blocking_before_auth(monkeypatch): called = False diff --git a/api/oss/tests/pytest/unit/evaluations/test_cache_split_utils.py b/api/oss/tests/pytest/unit/evaluations/test_cache_split_utils.py index 7c0cef9d24..06ae1026e9 100644 --- a/api/oss/tests/pytest/unit/evaluations/test_cache_split_utils.py +++ b/api/oss/tests/pytest/unit/evaluations/test_cache_split_utils.py @@ -110,7 +110,7 @@ def test_repeat_and_fanout_planning_helpers_follow_split_rules(): assert ( effective_is_split( is_split=True, - is_queue=True, + has_traces=True, has_application_steps=True, has_evaluator_steps=True, ) diff --git a/api/oss/tests/pytest/unit/evaluations/test_query_eval_loops.py b/api/oss/tests/pytest/unit/evaluations/test_query_eval_loops.py index 4f88be8851..fe9f8eda1c 100644 --- a/api/oss/tests/pytest/unit/evaluations/test_query_eval_loops.py +++ b/api/oss/tests/pytest/unit/evaluations/test_query_eval_loops.py @@ -24,7 +24,7 @@ SimpleQueueData, ) from oss.src.core.evaluations.service import SimpleQueuesService -from oss.src.core.evaluations.tasks import live as live_module +from oss.src.core.evaluations.tasks import query as query_module @pytest.mark.asyncio @@ -92,8 +92,8 @@ async def test_simple_queue_create_dispatches_each_query_source_with_step_key(): ) ), testsets_service=SimpleNamespace(fetch_testset_revision=AsyncMock()), - evaluate_batch_traces=AsyncMock(return_value=True), - evaluate_batch_testcases=AsyncMock(return_value=True), + dispatch_trace_slice=AsyncMock(return_value=True), + dispatch_testcase_slice=AsyncMock(return_value=True), ) service = SimpleQueuesService( @@ -118,7 +118,7 @@ async def test_simple_queue_create_dispatches_each_query_source_with_step_key(): assert created_queue.data is not None assert created_queue.data.kind == "traces" assert created_queue.data.queries == [query_revision_id_1, query_revision_id_2] - assert simple_evaluations_service.evaluate_batch_traces.await_args_list == [ + assert simple_evaluations_service.dispatch_trace_slice.await_args_list == [ call( project_id=project_id, user_id=user_id, @@ -137,7 +137,7 @@ async def test_simple_queue_create_dispatches_each_query_source_with_step_key(): @pytest.mark.asyncio -async def test_evaluate_live_query_marks_human_steps_pending(monkeypatch): +async def test_process_query_source_run_marks_human_steps_pending(monkeypatch): project_id = uuid4() user_id = uuid4() run_id = uuid4() @@ -200,7 +200,7 @@ async def test_evaluate_live_query_marks_human_steps_pending(monkeypatch): refresh_metrics = AsyncMock() monkeypatch.setattr( - live_module, + query_module, "evaluations_service", SimpleNamespace( fetch_run=fetch_run, @@ -211,7 +211,7 @@ async def test_evaluate_live_query_marks_human_steps_pending(monkeypatch): ), ) monkeypatch.setattr( - live_module, + query_module, "queries_service", SimpleNamespace( fetch_query_revision=AsyncMock( @@ -224,7 +224,7 @@ async def test_evaluate_live_query_marks_human_steps_pending(monkeypatch): ), ) monkeypatch.setattr( - live_module, + query_module, "evaluators_service", SimpleNamespace( fetch_evaluator_revision=AsyncMock( @@ -237,22 +237,16 @@ async def test_evaluate_live_query_marks_human_steps_pending(monkeypatch): ), ) monkeypatch.setattr( - live_module, + query_module, "tracing_service", SimpleNamespace(query_traces=AsyncMock(return_value=[trace])), ) monkeypatch.setattr( - live_module, + query_module, "workflows_service", SimpleNamespace(invoke_workflow=AsyncMock()), ) - monkeypatch.setattr( - live_module, - "fetch_traces_by_hash", - AsyncMock(return_value=[]), - ) - - await live_module.evaluate_live_query( + await query_module.process_query_source_run( project_id=project_id, user_id=user_id, run_id=run_id, @@ -268,3 +262,59 @@ async def test_evaluate_live_query_marks_human_steps_pending(monkeypatch): assert isinstance(scenario_edit, EvaluationScenarioEdit) assert scenario_edit.status == EvaluationStatus.PENDING refresh_metrics.assert_not_awaited() + + +@pytest.mark.asyncio +async def test_process_query_source_run_skips_empty_query_results(monkeypatch): + project_id = uuid4() + user_id = uuid4() + run_id = uuid4() + query_revision_id = uuid4() + evaluator_revision_id = uuid4() + run = EvaluationRun( + id=run_id, + flags=EvaluationRunFlags(has_queries=True, has_evaluators=True), + status=EvaluationStatus.RUNNING, + data=EvaluationRunData( + steps=[ + EvaluationRunDataStep( + key="query-main", + type="input", + origin="custom", + references={"query_revision": Reference(id=query_revision_id)}, + ), + EvaluationRunDataStep( + key="evaluator-auto", + type="annotation", + origin="auto", + references={ + "evaluator_revision": Reference(id=evaluator_revision_id) + }, + ), + ] + ), + ) + process_source_slice = AsyncMock() + monkeypatch.setattr( + query_module, + "evaluations_service", + SimpleNamespace(fetch_run=AsyncMock(return_value=run)), + ) + monkeypatch.setattr( + query_module, + "resolve_query_source_items", + AsyncMock(return_value={"query-main": []}), + ) + monkeypatch.setattr( + query_module, + "process_evaluation_source_slice", + process_source_slice, + ) + + await query_module.process_query_source_run( + project_id=project_id, + user_id=user_id, + run_id=run_id, + ) + + process_source_slice.assert_not_awaited() diff --git a/api/oss/tests/pytest/unit/evaluations/test_queue_dao_serialization.py b/api/oss/tests/pytest/unit/evaluations/test_queue_dao_serialization.py index cad44f6ee8..c7444a2b71 100644 --- a/api/oss/tests/pytest/unit/evaluations/test_queue_dao_serialization.py +++ b/api/oss/tests/pytest/unit/evaluations/test_queue_dao_serialization.py @@ -51,10 +51,14 @@ async def test_create_queue_serializes_queue_data_without_uuid_warnings(monkeypa "_get_run_flags", AsyncMock(return_value={}), ) + # Mock get_transactions_engine to return an engine with session method + mock_engine = type( + "MockEngine", (), {"session": lambda self: _DummySessionContext(session)} + )() monkeypatch.setattr( - dao_module.engine, - "core_session", - lambda: _DummySessionContext(session), + dao_module, + "get_transactions_engine", + lambda: mock_engine, ) def fake_create_dto_from_dbe(*, DTO, dbe): diff --git a/api/oss/tests/pytest/unit/evaluations/test_run_flags.py b/api/oss/tests/pytest/unit/evaluations/test_run_flags.py index 03bb40cb1a..a3b87b3350 100644 --- a/api/oss/tests/pytest/unit/evaluations/test_run_flags.py +++ b/api/oss/tests/pytest/unit/evaluations/test_run_flags.py @@ -56,3 +56,56 @@ def test_evaluation_run_query_flags_include_cache_and_split_when_explicit(): "is_cached": False, "is_split": False, } + + +def test_create_run_flags_keeps_direct_source_families_distinct_from_backed_sources(): + direct_run = EvaluationRun( + data=EvaluationRunData( + steps=[ + EvaluationRunDataStep( + key="traces", + type="input", + origin="custom", + references={}, + ), + EvaluationRunDataStep( + key="testcases", + type="input", + origin="custom", + references={}, + ), + ] + ) + ) + backed_run = EvaluationRun( + data=EvaluationRunData( + steps=[ + EvaluationRunDataStep( + key="query-main", + type="input", + origin="custom", + references={"query_revision": {"id": str(uuid4())}}, + ), + EvaluationRunDataStep( + key="testset-main", + type="input", + origin="custom", + references={"testset_revision": {"id": str(uuid4())}}, + ), + ] + ) + ) + + direct_flags = create_run_flags(direct_run) + backed_flags = create_run_flags(backed_run) + + assert direct_flags is not None + assert direct_flags.has_traces is True + assert direct_flags.has_testcases is True + assert direct_flags.has_queries is False + assert direct_flags.has_testsets is False + assert backed_flags is not None + assert backed_flags.has_queries is True + assert backed_flags.has_testsets is True + assert backed_flags.has_traces is False + assert backed_flags.has_testcases is False diff --git a/api/oss/tests/pytest/unit/evaluations/test_runtime_topology_planner.py b/api/oss/tests/pytest/unit/evaluations/test_runtime_topology_planner.py new file mode 100644 index 0000000000..4b2a52679e --- /dev/null +++ b/api/oss/tests/pytest/unit/evaluations/test_runtime_topology_planner.py @@ -0,0 +1,2065 @@ +from types import SimpleNamespace +from unittest.mock import AsyncMock, call +from uuid import uuid4 + +import pytest + +from oss.src.core.evaluations.runtime.adapters import ( + BackendCachedRunner, + BackendEvaluatorRunner, + BackendWorkflowRunner, + BackendWorkflowServiceRunner, +) +from oss.src.core.evaluations.runtime.cache import RunnableCacheResolver +from oss.src.core.evaluations.runtime.executor import ( + ApplicationBatchRunnableStepExecutor, + WorkflowRunnableStepExecutor, +) +from oss.src.core.evaluations.runtime.models import ( + ProcessSummary, + ResolvedSourceItem, + TensorSlice, + TensorProbeSummary, +) +from oss.src.core.evaluations.runtime.planner import ( + EvaluationPlanner, + make_scenario_bindings, + plan_source_input_result_creates, + planned_cells_to_result_creates, +) +from oss.src.core.evaluations.runtime.sources import ( + resolve_direct_source_items, + resolve_live_query_traces, + resolve_queue_source_batches, + resolve_testset_input_specs, +) +from oss.src.core.evaluations.runtime.tensor import TensorSliceOperations +from oss.src.core.evaluations.runtime.task_runner import TaskiqEvaluationTaskRunner +from oss.src.core.evaluations.runtime.topology import classify_run_topology +from agenta.sdk.evaluations.runtime.models import ( + EvaluationStep as SdkEvaluationStep, + PlannedCell as SdkPlannedCell, + ResolvedSourceItem as SdkResolvedSourceItem, + WorkflowExecutionRequest, + WorkflowExecutionResult, +) +from agenta.sdk.evaluations.runtime.source_slice import ( + ProcessedScenario as SdkProcessedScenario, +) +from agenta.sdk.models.evaluations import EvaluationStatus as SdkEvaluationStatus +from oss.src.core.evaluations.types import ( + EvaluationResult, + EvaluationResultCreate, + EvaluationRun, + EvaluationRunData, + EvaluationRunDataStep, + EvaluationRunFlags, + EvaluationStatus, + SimpleEvaluation, + SimpleEvaluationData, + SimpleEvaluationFlags, +) +from oss.src.core.shared.dtos import Reference +from oss.src.core.tracing.dtos import Windowing +from oss.src.core.evaluations.service import SimpleEvaluationsService +from oss.src.core.evaluations.tasks import source_slice as source_slice_tasks +from oss.src.core.evaluations.tasks import run as run_tasks + + +def _run(*, steps, flags=None, repeats=1): + return EvaluationRun( + id=uuid4(), + flags=flags or EvaluationRunFlags(), + data=EvaluationRunData(steps=list(steps), repeats=repeats), + ) + + +def _step(key, type_, origin="custom", references=None): + return EvaluationRunDataStep( + key=key, + type=type_, + origin=origin, + references=references or {}, + ) + + +def test_topology_classifier_preserves_current_batch_dispatch_shapes(): + query_eval = _run( + steps=[ + _step( + "query-main", + "input", + references={"query_revision": Reference(id=uuid4())}, + ), + _step( + "evaluator-auto", + "annotation", + origin="auto", + references={"evaluator_revision": Reference(id=uuid4())}, + ), + ] + ) + testset_eval = _run( + steps=[ + _step( + "testset-main", + "input", + references={"testset_revision": Reference(id=uuid4())}, + ), + _step( + "application-main", + "invocation", + references={"application_revision": Reference(id=uuid4())}, + ), + _step( + "evaluator-auto", + "annotation", + origin="auto", + references={"evaluator_revision": Reference(id=uuid4())}, + ), + ] + ) + batch_inference = _run( + steps=[ + _step( + "testset-main", + "input", + references={"testset_revision": Reference(id=uuid4())}, + ), + _step( + "application-main", + "invocation", + references={"application_revision": Reference(id=uuid4())}, + ), + ] + ) + + assert classify_run_topology(query_eval).dispatch == "batch_query" + assert classify_run_topology(testset_eval).dispatch == "batch_testset" + assert classify_run_topology(batch_inference).dispatch == "batch_invocation" + + +def test_topology_classifier_names_deferred_shapes(): + query_to_app = _run( + steps=[ + _step( + "query-main", + "input", + references={"query_revision": Reference(id=uuid4())}, + ), + _step( + "application-main", + "invocation", + references={"application_revision": Reference(id=uuid4())}, + ), + ] + ) + testset_to_eval = _run( + steps=[ + _step( + "testset-main", + "input", + references={"testset_revision": Reference(id=uuid4())}, + ), + _step( + "evaluator-auto", + "annotation", + origin="auto", + references={"evaluator_revision": Reference(id=uuid4())}, + ), + ] + ) + multi_app = _run( + steps=[ + _step( + "testset-main", + "input", + references={"testset_revision": Reference(id=uuid4())}, + ), + _step("application-a", "invocation"), + _step("application-b", "invocation"), + ] + ) + + assert classify_run_topology(query_to_app).status == "potential" + assert classify_run_topology(testset_to_eval).status == "potential" + assert classify_run_topology(multi_app).status == "not_planned" + + +def test_topology_classifier_names_not_planned_source_shapes(): + mixed_sources = _run( + steps=[ + _step( + "query-main", + "input", + references={"query_revision": Reference(id=uuid4())}, + ), + _step( + "testset-main", + "input", + references={"testset_revision": Reference(id=uuid4())}, + ), + _step("evaluator-auto", "annotation", origin="auto"), + ] + ) + live_testset = _run( + flags=EvaluationRunFlags(is_live=True), + steps=[ + _step( + "testset-main", + "input", + references={"testset_revision": Reference(id=uuid4())}, + ), + _step("evaluator-auto", "annotation", origin="auto"), + ], + ) + + assert classify_run_topology(mixed_sources).status == "not_planned" + assert classify_run_topology(live_testset).status == "not_planned" + + +def test_planner_creates_repeat_aware_slots_and_keeps_manual_annotations_pending(): + run = _run( + repeats=3, + flags=EvaluationRunFlags(is_split=False), + steps=[ + _step( + "testset-main", + "input", + references={"testset_revision": Reference(id=uuid4())}, + ), + _step( + "application-main", + "invocation", + references={"application_revision": Reference(id=uuid4())}, + ), + _step( + "evaluator-auto", + "annotation", + origin="auto", + references={"evaluator_revision": Reference(id=uuid4())}, + ), + _step( + "evaluator-human", + "annotation", + origin="human", + references={"evaluator_revision": Reference(id=uuid4())}, + ), + ], + ) + scenario_id = uuid4() + bindings = make_scenario_bindings( + scenario_ids=[scenario_id], + source_items=[ + ResolvedSourceItem( + kind="testcase", + step_key="testset-main", + testcase_id=uuid4(), + ) + ], + ) + + plan = EvaluationPlanner().plan(run=run, bindings=bindings) + cells_by_step = {} + for cell in plan.cells: + cells_by_step.setdefault(cell.step_key, []).append(cell) + + assert [cell.repeat_idx for cell in cells_by_step["testset-main"]] == [0, 1, 2] + assert [cell.repeat_idx for cell in cells_by_step["application-main"]] == [0] + assert [cell.repeat_idx for cell in cells_by_step["evaluator-auto"]] == [0, 1, 2] + assert [cell.status for cell in cells_by_step["evaluator-human"]] == [ + EvaluationStatus.PENDING, + EvaluationStatus.PENDING, + EvaluationStatus.PENDING, + ] + assert {(cell.step_key, cell.repeat_idx) for cell in plan.executable_cells} == { + ("application-main", 0), + ("evaluator-auto", 0), + ("evaluator-auto", 1), + ("evaluator-auto", 2), + } + + +def test_planner_fans_out_application_for_batch_inference_without_evaluators(): + run = _run( + repeats=2, + steps=[ + _step( + "testset-main", + "input", + references={"testset_revision": Reference(id=uuid4())}, + ), + _step( + "application-main", + "invocation", + references={"application_revision": Reference(id=uuid4())}, + ), + ], + ) + bindings = make_scenario_bindings( + scenario_ids=[uuid4()], + source_items=[ + ResolvedSourceItem( + kind="testcase", + step_key="testset-main", + testcase_id=uuid4(), + ) + ], + ) + + plan = EvaluationPlanner().plan(run=run, bindings=bindings) + + assert [ + cell.repeat_idx for cell in plan.cells if cell.step_key == "application-main" + ] == [0, 1] + + +def test_planner_fans_out_application_when_split_is_enabled(): + run = _run( + repeats=3, + flags=EvaluationRunFlags(is_split=True), + steps=[ + _step( + "testset-main", + "input", + references={"testset_revision": Reference(id=uuid4())}, + ), + _step( + "application-main", + "invocation", + references={"application_revision": Reference(id=uuid4())}, + ), + _step("evaluator-auto", "annotation", origin="auto"), + ], + ) + plan = EvaluationPlanner().plan( + run=run, + bindings=make_scenario_bindings( + scenario_ids=[uuid4()], + source_items=[ + ResolvedSourceItem( + kind="testcase", + step_key="testset-main", + testcase_id=uuid4(), + ) + ], + ), + ) + + assert [ + cell.repeat_idx for cell in plan.cells if cell.step_key == "application-main" + ] == [0, 1, 2] + + +def test_planned_cells_convert_to_result_create_payloads(): + run = _run( + repeats=1, + steps=[ + _step( + "query-main", + "input", + references={"query_revision": Reference(id=uuid4())}, + ), + _step( + "evaluator-human", + "annotation", + origin="human", + references={"evaluator_revision": Reference(id=uuid4())}, + ), + ], + ) + trace_id = "trace-1" + scenario_id = uuid4() + plan = EvaluationPlanner().plan( + run=run, + bindings=make_scenario_bindings( + scenario_ids=[scenario_id], + source_items=[ + ResolvedSourceItem( + kind="trace", + step_key="query-main", + trace_id=trace_id, + ) + ], + ), + ) + + result_creates = planned_cells_to_result_creates(plan.cells) + + assert [(result.step_key, result.status) for result in result_creates] == [ + ("query-main", EvaluationStatus.SUCCESS), + ("evaluator-human", EvaluationStatus.PENDING), + ] + assert result_creates[0].trace_id == trace_id + assert result_creates[1].trace_id is None + + +def test_plan_source_input_result_creates_filters_to_source_step(): + run = _run( + repeats=2, + steps=[ + _step("query-other", "input"), + _step( + "query-main", + "input", + references={"query_revision": Reference(id=uuid4())}, + ), + _step("evaluator-auto", "annotation", origin="auto"), + ], + ) + scenario_id = uuid4() + + result_creates = plan_source_input_result_creates( + run=run, + scenario_id=scenario_id, + source_item=ResolvedSourceItem( + kind="trace", + step_key="query-main", + trace_id="trace-main", + ), + ) + + assert [result.step_key for result in result_creates] == [ + "query-main", + "query-main", + ] + assert [result.repeat_idx for result in result_creates] == [0, 1] + assert [result.trace_id for result in result_creates] == [ + "trace-main", + "trace-main", + ] + + +@pytest.mark.asyncio +async def test_cache_resolver_skips_lookup_when_disabled_and_fetches_when_enabled(): + project_id = uuid4() + + class DummyTracingService: + async def query_traces(self, *, project_id, query): + return [SimpleNamespace(trace_id="trace-1"), SimpleNamespace(trace_id=None)] + + disabled = await RunnableCacheResolver().resolve( + tracing_service=DummyTracingService(), + project_id=project_id, + enabled=False, + references={"evaluator_revision": Reference(id=uuid4())}, + required_count=2, + ) + enabled = await RunnableCacheResolver().resolve( + tracing_service=DummyTracingService(), + project_id=project_id, + enabled=True, + references={"evaluator_revision": Reference(id=uuid4())}, + required_count=2, + ) + + assert disabled.reusable_traces == [] + assert disabled.missing_count == 2 + assert [trace.trace_id for trace in enabled.reusable_traces] == ["trace-1"] + assert enabled.missing_count == 1 + + +@pytest.mark.asyncio +async def test_cache_resolver_zero_required_count_does_not_query_traces(): + tracing_service = SimpleNamespace(query_traces=AsyncMock()) + + resolution = await RunnableCacheResolver().resolve( + tracing_service=tracing_service, + project_id=uuid4(), + enabled=True, + references={"evaluator_revision": Reference(id=uuid4())}, + required_count=0, + ) + + assert resolution.reusable_traces == [] + assert resolution.missing_count == 0 + tracing_service.query_traces.assert_not_awaited() + + +@pytest.mark.asyncio +async def test_queue_source_resolver_resolves_query_and_testset_batches(): + project_id = uuid4() + query_revision_id = uuid4() + testset_revision_id = uuid4() + testcase_id_1 = uuid4() + testcase_id_2 = uuid4() + run = _run( + steps=[ + _step( + "query-source", + "input", + references={"query_revision": Reference(id=query_revision_id)}, + ), + _step( + "testset-source", + "input", + references={"testset_revision": Reference(id=testset_revision_id)}, + ), + _step("evaluator-human", "annotation", origin="human"), + ], + ) + queries_service = SimpleNamespace( + fetch_query_revision=AsyncMock( + return_value=SimpleNamespace( + data=SimpleNamespace(trace_ids=["trace-1", "trace-2"]) + ) + ) + ) + testsets_service = SimpleNamespace( + fetch_testset_revision=AsyncMock( + return_value=SimpleNamespace( + data=SimpleNamespace(testcase_ids=[testcase_id_1, testcase_id_2]) + ) + ) + ) + + batches = await resolve_queue_source_batches( + project_id=project_id, + run=run, + queries_service=queries_service, + testsets_service=testsets_service, + ) + + assert [batch.kind for batch in batches] == ["traces", "testcases"] + assert batches[0].step_key == "query-source" + assert batches[0].trace_ids == ["trace-1", "trace-2"] + assert batches[1].step_key == "testset-source" + assert batches[1].testcase_ids == [testcase_id_1, testcase_id_2] + queries_service.fetch_query_revision.assert_awaited_once() + testsets_service.fetch_testset_revision.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_queue_source_resolver_skips_empty_sources(): + run = _run( + steps=[ + _step( + "query-source", + "input", + references={"query_revision": Reference(id=uuid4())}, + ), + _step( + "testset-source", + "input", + references={"testset_revision": Reference(id=uuid4())}, + ), + ], + ) + + batches = await resolve_queue_source_batches( + project_id=uuid4(), + run=run, + queries_service=SimpleNamespace( + fetch_query_revision=AsyncMock( + return_value=SimpleNamespace(data=SimpleNamespace(trace_ids=[])) + ) + ), + testsets_service=SimpleNamespace( + fetch_testset_revision=AsyncMock( + return_value=SimpleNamespace(data=SimpleNamespace(testcase_ids=[])) + ) + ), + ) + + assert batches == [] + + +@pytest.mark.asyncio +async def test_testset_payload_source_resolver_preserves_testcase_payloads(): + project_id = uuid4() + testset_id = uuid4() + testset_variant_id = uuid4() + testset_revision_id = uuid4() + testcase_id = uuid4() + testcase = SimpleNamespace(id=testcase_id, data={"prompt": "hello"}) + testsets_service = SimpleNamespace( + fetch_testset_revision=AsyncMock( + return_value=SimpleNamespace( + id=testset_revision_id, + variant_id=testset_variant_id, + data=SimpleNamespace(testcases=[testcase]), + ) + ), + fetch_testset_variant=AsyncMock( + return_value=SimpleNamespace( + id=testset_variant_id, + testset_id=testset_id, + ) + ), + fetch_testset=AsyncMock( + return_value=SimpleNamespace( + id=testset_id, + slug="testset-main", + ) + ), + ) + + specs = await resolve_testset_input_specs( + project_id=project_id, + input_steps=[ + _step( + "testset-main", + "input", + references={"testset_revision": Reference(id=testset_revision_id)}, + ) + ], + testsets_service=testsets_service, + ) + + assert len(specs) == 1 + assert specs[0].step_key == "testset-main" + assert specs[0].testcases == [testcase] + assert specs[0].testcases_data == [ + {"prompt": "hello", "testcase_id": str(testcase_id)} + ] + + +@pytest.mark.asyncio +async def test_direct_source_resolver_preserves_order_and_missing_testcases(): + project_id = uuid4() + testcase_id_1 = uuid4() + testcase_id_2 = uuid4() + testcase = SimpleNamespace(id=testcase_id_1, data={"input": "a"}) + testcases_service = SimpleNamespace( + fetch_testcases=AsyncMock(return_value=[testcase]) + ) + + source_items = await resolve_direct_source_items( + project_id=project_id, + testcase_ids=[testcase_id_1, testcase_id_2], + trace_ids=["trace-1"], + testcases_service=testcases_service, + ) + + assert [source_item.kind for source_item in source_items] == [ + "testcase", + "testcase", + "trace", + ] + assert source_items[0].testcase == testcase + assert source_items[1].testcase is None + assert source_items[2].trace_id == "trace-1" + + +@pytest.mark.asyncio +async def test_direct_source_resolver_loads_trace_context(): + project_id = uuid4() + trace_id = "trace-1" + span_id = "span-1" + trace_payload = { + "trace_id": trace_id, + "spans": { + span_id: { + "trace_id": trace_id, + "span_id": span_id, + "attributes": { + "ag": { + "data": { + "inputs": {"prompt": "hello"}, + "outputs": {"answer": "world"}, + } + } + }, + } + }, + } + trace = SimpleNamespace( + trace_id=trace_id, + spans={ + span_id: SimpleNamespace( + trace_id=trace_id, + span_id=span_id, + attributes=trace_payload["spans"][span_id]["attributes"], + ) + }, + model_dump=lambda **_: trace_payload, + ) + tracing_service = SimpleNamespace(fetch_trace=AsyncMock(return_value=trace)) + + source_items = await resolve_direct_source_items( + project_id=project_id, + trace_ids=[trace_id], + tracing_service=tracing_service, + ) + + assert len(source_items) == 1 + assert source_items[0].kind == "trace" + assert source_items[0].trace_id == trace_id + assert source_items[0].span_id == span_id + assert source_items[0].trace is not None + assert source_items[0].inputs == {"prompt": "hello"} + assert source_items[0].outputs == {"answer": "world"} + + +@pytest.mark.asyncio +async def test_live_query_trace_resolver_applies_default_windowing(): + project_id = uuid4() + traces = [SimpleNamespace(trace_id="trace-1")] + + class DummyTracingService: + def __init__(self): + self.query = None + + async def query_traces(self, *, project_id, query): + self.query = query + return traces + + tracing_service = DummyTracingService() + + resolved = await resolve_live_query_traces( + project_id=project_id, + query_revisions={ + "query-main": SimpleNamespace(data=SimpleNamespace()), + }, + tracing_service=tracing_service, + ) + + assert resolved == {"query-main": traces} + assert tracing_service.query.windowing.order == "ascending" + assert tracing_service.query.windowing.limit is None + + +@pytest.mark.asyncio +async def test_live_query_trace_resolver_uses_revision_windowing_when_requested(): + class DummyTracingService: + def __init__(self): + self.query = None + + async def query_traces(self, *, project_id, query): + self.query = query + return [] + + tracing_service = DummyTracingService() + revision_windowing = Windowing( + oldest=None, + newest=None, + limit=25, + order="descending", + rate=0.5, + ) + + await resolve_live_query_traces( + project_id=uuid4(), + query_revisions={ + "query-main": SimpleNamespace( + data=SimpleNamespace(filtering=None, windowing=revision_windowing) + ), + }, + tracing_service=tracing_service, + use_windowing=True, + ) + + assert tracing_service.query.windowing.limit == 25 + assert tracing_service.query.windowing.order == "descending" + assert tracing_service.query.windowing.rate == 0.5 + + +@pytest.mark.asyncio +async def test_tensor_slice_operations_probe_populate_prune_and_process(): + project_id = uuid4() + user_id = uuid4() + run_id = uuid4() + scenario_id = uuid4() + result_id = uuid4() + result = EvaluationResult( + id=result_id, + run_id=run_id, + scenario_id=scenario_id, + step_key="evaluator-auto", + repeat_idx=0, + status=EvaluationStatus.SUCCESS, + ) + evaluations_service = SimpleNamespace( + query_results=AsyncMock(return_value=[result]), + create_results=AsyncMock(return_value=[result]), + delete_results=AsyncMock(return_value=[result_id]), + refresh_metrics=AsyncMock(return_value=[]), + ) + operations = TensorSliceOperations(evaluations_service=evaluations_service) + tensor_slice = TensorSlice( + run_id=run_id, + scenario_ids=[scenario_id], + step_keys=["evaluator-auto"], + repeat_idxs=[0], + ) + + probed = await operations.probe( + project_id=project_id, + tensor_slice=tensor_slice, + ) + populated = await operations.populate( + project_id=project_id, + user_id=user_id, + results=[ + EvaluationResultCreate( + run_id=run_id, + scenario_id=scenario_id, + step_key="evaluator-auto", + repeat_idx=0, + status=EvaluationStatus.SUCCESS, + ) + ], + ) + pruned = await operations.prune( + project_id=project_id, + user_id=user_id, + tensor_slice=tensor_slice, + ) + summary = await operations.process( + project_id=project_id, + user_id=user_id, + tensor_slice=tensor_slice, + ) + + assert probed == [result] + assert populated == [result] + assert pruned == [result_id] + assert summary == ProcessSummary() + assert evaluations_service.query_results.await_count == 2 + assert evaluations_service.create_results.await_count == 1 + assert evaluations_service.delete_results.await_count == 1 + assert evaluations_service.refresh_metrics.await_count == 3 + + +@pytest.mark.asyncio +async def test_tensor_slice_probe_summary_counts_statuses_and_missing_cells(): + project_id = uuid4() + run_id = uuid4() + scenario_id = uuid4() + evaluations_service = SimpleNamespace( + query_results=AsyncMock( + return_value=[ + EvaluationResult( + id=uuid4(), + run_id=run_id, + scenario_id=scenario_id, + step_key="step-success", + repeat_idx=0, + status=EvaluationStatus.SUCCESS, + ), + EvaluationResult( + id=uuid4(), + run_id=run_id, + scenario_id=scenario_id, + step_key="step-failure", + repeat_idx=0, + status=EvaluationStatus.FAILURE, + ), + EvaluationResult( + id=uuid4(), + run_id=run_id, + scenario_id=scenario_id, + step_key="step-pending", + repeat_idx=0, + status=EvaluationStatus.PENDING, + ), + ] + ) + ) + + summary = await TensorSliceOperations( + evaluations_service=evaluations_service + ).probe_summary( + project_id=project_id, + tensor_slice=TensorSlice(run_id=run_id), + expected_count=5, + ) + + assert summary == TensorProbeSummary( + existing_count=3, + missing_count=2, + success_count=1, + failure_count=1, + pending_count=1, + any_count=3, + ) + + +@pytest.mark.asyncio +async def test_tensor_slice_empty_dimension_short_circuits_probe_and_process(): + project_id = uuid4() + user_id = uuid4() + operations = TensorSliceOperations( + evaluations_service=SimpleNamespace( + query_results=AsyncMock(), + refresh_metrics=AsyncMock(), + ) + ) + tensor_slice = TensorSlice(run_id=uuid4(), scenario_ids=[]) + + assert ( + await operations.probe( + project_id=project_id, + tensor_slice=tensor_slice, + ) + == [] + ) + assert ( + await operations.process( + project_id=project_id, + user_id=user_id, + tensor_slice=tensor_slice, + ) + == ProcessSummary() + ) + operations.evaluations_service.query_results.assert_not_awaited() + operations.evaluations_service.refresh_metrics.assert_not_awaited() + + +@pytest.mark.asyncio +async def test_workflow_runnable_executor_normalizes_success_and_failure(): + success_response = SimpleNamespace( + status=SimpleNamespace(code=200), + trace_id="trace-success", + outputs={"score": 1}, + ) + failure_status = SimpleNamespace( + code=500, + model_dump=lambda **kwargs: {"code": 500, "message": "failed"}, + ) + failure_response = SimpleNamespace( + status=failure_status, + trace_id="trace-failure", + outputs=None, + ) + workflows_service = SimpleNamespace( + invoke_workflow=AsyncMock(side_effect=[success_response, failure_response]) + ) + executor = WorkflowRunnableStepExecutor(workflows_service=workflows_service) + + success = await executor.execute(project_id=uuid4(), user_id=uuid4(), request={}) + failure = await executor.execute(project_id=uuid4(), user_id=uuid4(), request={}) + + assert success.status == EvaluationStatus.SUCCESS + assert success.trace_id == "trace-success" + assert success.error is None + assert failure.status == EvaluationStatus.FAILURE + assert failure.error == {"code": 500, "message": "failed"} + assert workflows_service.invoke_workflow.await_count == 2 + + +@pytest.mark.asyncio +async def test_backend_workflow_service_runner_adapts_sdk_runtime_request(): + workflows_service = SimpleNamespace( + invoke_workflow=AsyncMock( + return_value=SimpleNamespace( + status=SimpleNamespace(code=200), + trace_id="trace-success", + span_id="span-success", + outputs={"score": 1}, + ) + ) + ) + runner = BackendWorkflowServiceRunner( + workflows_service=workflows_service, + request_builder=lambda request: { + "project_id": "project", + "step_key": request.step.key, + }, + ) + request = WorkflowExecutionRequest( + step=SdkEvaluationStep(key="evaluator-auto", type="annotation", origin="auto"), + cell=SdkPlannedCell( + run_id=uuid4(), + scenario_id=uuid4(), + step_key="evaluator-auto", + step_type="annotation", + origin="auto", + repeat_idx=0, + status=SdkEvaluationStatus.QUEUED, + ), + source=SdkResolvedSourceItem(kind="trace", step_key="query-main"), + revision={"slug": "evaluator-auto"}, + ) + + result = await runner.execute(request) + + assert result.status == SdkEvaluationStatus.SUCCESS + assert result.trace_id == "trace-success" + assert result.span_id == "span-success" + workflows_service.invoke_workflow.assert_awaited_once_with( + project_id="project", + step_key="evaluator-auto", + ) + + +@pytest.mark.asyncio +async def test_taskiq_evaluation_task_runner_omits_empty_optional_kwargs(): + project_id = uuid4() + user_id = uuid4() + run_id = uuid4() + worker = SimpleNamespace( + process_run=SimpleNamespace(kiq=AsyncMock(return_value="run-task")), + process_slice=SimpleNamespace(kiq=AsyncMock(return_value="slice-task")), + ) + runner = TaskiqEvaluationTaskRunner(worker=worker) + + assert ( + await runner.process_run( + project_id=project_id, + user_id=user_id, + run_id=run_id, + ) + == "run-task" + ) + assert ( + await runner.process_slice( + project_id=project_id, + user_id=user_id, + run_id=run_id, + source_kind="traces", + trace_ids=["trace-1"], + ) + == "slice-task" + ) + + worker.process_run.kiq.assert_awaited_once_with( + project_id=project_id, + user_id=user_id, + run_id=run_id, + ) + worker.process_slice.kiq.assert_awaited_once_with( + project_id=project_id, + user_id=user_id, + run_id=run_id, + source_kind="traces", + trace_ids=["trace-1"], + ) + + +@pytest.mark.asyncio +async def test_backend_workflow_runner_invokes_application_through_workflow_service(): + project_id = uuid4() + user_id = uuid4() + application_revision_id = uuid4() + workflows_service = SimpleNamespace( + invoke_workflow=AsyncMock( + return_value=SimpleNamespace( + status=SimpleNamespace(code=200), + trace_id="app-trace", + span_id="app-span", + outputs={"answer": "world"}, + ) + ) + ) + runner = BackendWorkflowRunner( + project_id=project_id, + user_id=user_id, + workflows_service=workflows_service, + ) + revision = { + "id": str(application_revision_id), + "data": { + "uri": "http://application", + "schemas": { + "inputs": { + "type": "object", + "properties": {"input": {"type": "string"}}, + } + }, + "parameters": {"temperature": 0.1}, + }, + "flags": {"is_chat": True}, + } + request = WorkflowExecutionRequest( + step=SdkEvaluationStep(key="application-main", type="invocation"), + cell=SdkPlannedCell( + run_id=uuid4(), + scenario_id=uuid4(), + step_key="application-main", + step_type="invocation", + origin="custom", + repeat_idx=0, + status=SdkEvaluationStatus.QUEUED, + ), + source=SdkResolvedSourceItem( + kind="testcase", + step_key="testset-main", + inputs={ + "input": "hello", + "correct_answer": "world", + "testcase_id": "testcase-id", + "testcase_dedup_id": "dedup-id", + }, + ), + revision=revision, + references={"application_revision": {"id": str(application_revision_id)}}, + ) + + result = await runner.execute(request) + + assert result.status == SdkEvaluationStatus.SUCCESS + assert result.trace_id == "app-trace" + workflows_service.invoke_workflow.assert_awaited_once() + kwargs = workflows_service.invoke_workflow.await_args.kwargs + assert kwargs["project_id"] == project_id + assert kwargs["user_id"] == user_id + assert "annotate" not in kwargs + workflow_request = kwargs["request"] + assert workflow_request.flags == {"is_chat": True} + assert workflow_request.data.interface["uri"] == "http://application" + assert workflow_request.data.interface["schemas"] == { + "inputs": { + "type": "object", + "properties": {"input": {"type": "string"}}, + } + } + assert workflow_request.data.configuration["parameters"] == {"temperature": 0.1} + assert workflow_request.data.revision == revision + assert workflow_request.data.parameters == {"temperature": 0.1} + assert workflow_request.data.inputs == {"input": "hello"} + assert ( + workflow_request.references["application_revision"].id + == application_revision_id + ) + + +@pytest.mark.asyncio +async def test_backend_evaluator_runner_sends_normalized_workflow_request(): + project_id = uuid4() + user_id = uuid4() + workflow_revision_id = uuid4() + workflows_service = SimpleNamespace( + invoke_workflow=AsyncMock( + return_value=SimpleNamespace( + status=SimpleNamespace(code=200), + trace_id="eval-trace", + span_id="eval-span", + outputs={"score": 1}, + ) + ) + ) + runner = BackendEvaluatorRunner( + project_id=project_id, + user_id=user_id, + workflows_service=workflows_service, + ) + revision = SimpleNamespace( + id=workflow_revision_id, + data=SimpleNamespace( + uri="http://evaluator", + url=None, + headers={"authorization": "secret"}, + schemas={"outputs": {"type": "object"}}, + script="return score", + parameters={"threshold": 0.5}, + ), + flags=SimpleNamespace(model_dump=lambda **kwargs: {"is_custom": True}), + model_dump=lambda **kwargs: {"id": str(workflow_revision_id)}, + ) + request = WorkflowExecutionRequest( + step=SdkEvaluationStep(key="evaluator-auto", type="annotation", origin="auto"), + cell=SdkPlannedCell( + run_id=uuid4(), + scenario_id=uuid4(), + step_key="evaluator-auto", + step_type="annotation", + origin="auto", + repeat_idx=0, + status=SdkEvaluationStatus.QUEUED, + ), + source=SdkResolvedSourceItem( + kind="testcase", + step_key="testset-main", + inputs={"input": "hello"}, + outputs={"answer": "world"}, + ), + revision=revision, + references={"evaluator_revision": {"id": str(workflow_revision_id)}}, + links={"invocation": {"trace_id": "app-trace", "span_id": "app-span"}}, + upstream_trace={"trace_id": "app-trace"}, + upstream_outputs={"answer": "world"}, + ) + + result = await runner.execute(request) + + assert result.status == SdkEvaluationStatus.SUCCESS + assert result.trace_id == "eval-trace" + workflows_service.invoke_workflow.assert_awaited_once() + kwargs = workflows_service.invoke_workflow.await_args.kwargs + assert kwargs["project_id"] == project_id + assert kwargs["user_id"] == user_id + assert "annotate" not in kwargs + workflow_request = kwargs["request"] + assert workflow_request.flags == {"is_custom": True} + assert workflow_request.data.revision == {"id": str(workflow_revision_id)} + assert workflow_request.data.parameters == {"threshold": 0.5} + assert workflow_request.data.inputs == {"input": "hello"} + assert workflow_request.data.outputs == {"answer": "world"} + assert workflow_request.links["invocation"].trace_id == "app-trace" + assert workflow_request.links["invocation"].span_id == "app-span" + + +@pytest.mark.asyncio +async def test_backend_evaluator_runner_preserves_dict_revision_data(): + project_id = uuid4() + user_id = uuid4() + workflow_revision_id = uuid4() + workflows_service = SimpleNamespace( + invoke_workflow=AsyncMock( + return_value=SimpleNamespace( + status=SimpleNamespace(code=200), + trace_id="eval-trace", + span_id="eval-span", + outputs={"score": 1}, + ) + ) + ) + runner = BackendEvaluatorRunner( + project_id=project_id, + user_id=user_id, + workflows_service=workflows_service, + ) + revision = { + "id": str(workflow_revision_id), + "data": { + "uri": "http://evaluator", + "url": None, + "headers": {"authorization": "secret"}, + "schemas": {"outputs": {"type": "object"}}, + "script": "return score", + "parameters": {"threshold": 0.5}, + }, + "flags": {"is_custom": True}, + } + request = WorkflowExecutionRequest( + step=SdkEvaluationStep(key="evaluator-auto", type="annotation", origin="auto"), + cell=SdkPlannedCell( + run_id=uuid4(), + scenario_id=uuid4(), + step_key="evaluator-auto", + step_type="annotation", + origin="auto", + repeat_idx=0, + status=SdkEvaluationStatus.QUEUED, + ), + source=SdkResolvedSourceItem( + kind="testcase", + step_key="testset-main", + inputs={"input": "hello"}, + ), + revision=revision, + ) + + result = await runner.execute(request) + + assert result.status == SdkEvaluationStatus.SUCCESS + workflows_service.invoke_workflow.assert_awaited_once() + workflow_request = workflows_service.invoke_workflow.await_args.kwargs["request"] + assert workflow_request.flags == {"is_custom": True} + assert workflow_request.data.revision["data"]["uri"] == "http://evaluator" + assert workflow_request.data.revision["data"]["headers"] == { + "authorization": "secret" + } + assert workflow_request.data.revision["data"]["schemas"] == { + "outputs": {"type": "object"} + } + assert workflow_request.data.revision["data"]["script"] == "return score" + assert workflow_request.data.revision["data"]["parameters"] == {"threshold": 0.5} + assert workflow_request.data.parameters == {"threshold": 0.5} + + +@pytest.mark.asyncio +async def test_backend_cached_runner_preserves_partial_hit_order(): + project_id = uuid4() + cached_trace = SimpleNamespace(trace_id="cached-trace") + tracing_service = SimpleNamespace( + query_traces=AsyncMock(side_effect=[[cached_trace], []]) + ) + + class BatchRunner: + def __init__(self): + self.requests = [] + + async def execute_batch(self, requests): + self.requests.append(requests) + return [ + WorkflowExecutionResult( + status=SdkEvaluationStatus.SUCCESS, + trace_id="fresh-trace", + ) + ] + + batch_runner = BatchRunner() + runner = BackendCachedRunner( + runner=batch_runner, + tracing_service=tracing_service, + project_id=project_id, + enabled=True, + ) + requests = [ + WorkflowExecutionRequest( + step=SdkEvaluationStep(key="evaluator-auto", type="annotation"), + cell=SdkPlannedCell( + run_id=uuid4(), + scenario_id=uuid4(), + step_key="evaluator-auto", + step_type="annotation", + origin="auto", + repeat_idx=idx, + status=SdkEvaluationStatus.QUEUED, + ), + source=SdkResolvedSourceItem(kind="trace", step_key="query-main"), + revision={"id": "evaluator-revision"}, + references={"evaluator_revision": {"id": f"revision-{idx}"}}, + ) + for idx in range(2) + ] + + results = await runner.execute_batch(requests) + + assert [result.trace_id for result in results] == ["cached-trace", "fresh-trace"] + assert len(batch_runner.requests) == 1 + assert [request.cell.repeat_idx for request in batch_runner.requests[0]] == [1] + + +@pytest.mark.asyncio +async def test_application_batch_runnable_executor_delegates_batch_invocation(): + batch_invoke = AsyncMock(return_value=["invocation-1", "invocation-2"]) + executor = ApplicationBatchRunnableStepExecutor(batch_invoke=batch_invoke) + + invocations = await executor.execute_batch( + project_id="project", + user_id="user", + testset_data=[{"input": 1}], + ) + + assert invocations == ["invocation-1", "invocation-2"] + batch_invoke.assert_awaited_once_with( + project_id="project", + user_id="user", + testset_data=[{"input": 1}], + ) + + +@pytest.mark.asyncio +async def test_simple_evaluation_start_dispatches_batch_invocation_by_topology(): + project_id = uuid4() + user_id = uuid4() + run_id = uuid4() + run = _run( + steps=[ + _step( + "testset-main", + "input", + references={"testset_revision": Reference(id=uuid4())}, + ), + _step( + "application-main", + "invocation", + references={"application_revision": Reference(id=uuid4())}, + ), + ], + ) + run.id = run_id + worker = SimpleNamespace( + process_run=SimpleNamespace(kiq=AsyncMock()), + ) + service = SimpleEvaluationsService( + testsets_service=None, # type: ignore[arg-type] + queries_service=None, # type: ignore[arg-type] + applications_service=None, # type: ignore[arg-type] + evaluators_service=None, # type: ignore[arg-type] + evaluations_service=None, # type: ignore[arg-type] + evaluations_worker=worker, + ) + service.fetch = AsyncMock( + return_value=SimpleEvaluation( + id=run_id, + flags=SimpleEvaluationFlags(is_live=False), + data=SimpleEvaluationData( + status=None, + testset_steps={uuid4(): "custom"}, + application_steps={uuid4(): "custom"}, + ), + ) + ) + service._activate_evaluation_run = AsyncMock(return_value=run) + + await service.start( + project_id=project_id, + user_id=user_id, + evaluation_id=run_id, + ) + + worker.process_run.kiq.assert_awaited_once_with( + project_id=project_id, + user_id=user_id, + run_id=run_id, + ) + + +@pytest.mark.asyncio +async def test_simple_evaluation_start_does_not_dispatch_potential_topology(): + project_id = uuid4() + user_id = uuid4() + run_id = uuid4() + run = _run( + steps=[ + _step( + "query-main", + "input", + references={"query_revision": Reference(id=uuid4())}, + ), + _step( + "application-main", + "invocation", + references={"application_revision": Reference(id=uuid4())}, + ), + ], + ) + run.id = run_id + worker = SimpleNamespace( + process_run=SimpleNamespace(kiq=AsyncMock()), + ) + service = SimpleEvaluationsService( + testsets_service=None, # type: ignore[arg-type] + queries_service=None, # type: ignore[arg-type] + applications_service=None, # type: ignore[arg-type] + evaluators_service=None, # type: ignore[arg-type] + evaluations_service=None, # type: ignore[arg-type] + evaluations_worker=worker, + ) + service.fetch = AsyncMock( + return_value=SimpleEvaluation( + id=run_id, + flags=SimpleEvaluationFlags(is_live=False), + data=SimpleEvaluationData( + status=None, + query_steps={uuid4(): "custom"}, + application_steps={uuid4(): "custom"}, + ), + ) + ) + service._activate_evaluation_run = AsyncMock(return_value=run) + + await service.start( + project_id=project_id, + user_id=user_id, + evaluation_id=run_id, + ) + + worker.process_run.kiq.assert_not_awaited() + + +@pytest.mark.asyncio +async def test_simple_evaluation_queue_batches_dispatch_through_slice_processor(): + project_id = uuid4() + user_id = uuid4() + run_id = uuid4() + testcase_id = uuid4() + run = _run( + flags=EvaluationRunFlags(is_queue=True), + steps=[ + _step( + "query-main", + "input", + references={"query_revision": Reference(id=uuid4())}, + ), + _step("evaluator-human", "annotation", origin="human"), + ], + ) + run.id = run_id + worker = SimpleNamespace( + process_slice=SimpleNamespace(kiq=AsyncMock()), + ) + evaluations_service = SimpleNamespace(fetch_run=AsyncMock(return_value=run)) + service = SimpleEvaluationsService( + testsets_service=None, # type: ignore[arg-type] + queries_service=None, # type: ignore[arg-type] + applications_service=None, # type: ignore[arg-type] + evaluators_service=None, # type: ignore[arg-type] + evaluations_service=evaluations_service, # type: ignore[arg-type] + evaluations_worker=worker, + ) + service._ensure_human_annotation_queue = AsyncMock() + + traces_ok = await service.dispatch_trace_slice( + project_id=project_id, + user_id=user_id, + run_id=run_id, + trace_ids=["trace-1"], + input_step_key="query-main", + ) + testcases_ok = await service.dispatch_testcase_slice( + project_id=project_id, + user_id=user_id, + run_id=run_id, + testcase_ids=[testcase_id], + input_step_key="testset-main", + ) + + assert traces_ok is True + assert testcases_ok is True + assert worker.process_slice.kiq.await_args_list == [ + call( + project_id=project_id, + user_id=user_id, + run_id=run_id, + source_kind="traces", + trace_ids=["trace-1"], + input_step_key="query-main", + ), + call( + project_id=project_id, + user_id=user_id, + run_id=run_id, + source_kind="testcases", + testcase_ids=[testcase_id], + input_step_key="testset-main", + ), + ] + + +@pytest.mark.asyncio +async def test_slice_processor_calls_source_slice_loop_directly(monkeypatch): + project_id = uuid4() + user_id = uuid4() + run_id = uuid4() + testcase_id = uuid4() + tracing_service = object() + testcases_service = object() + workflows_service = object() + evaluations_service = object() + process_source_slice = AsyncMock() + monkeypatch.setattr( + run_tasks, + "process_evaluation_source_slice", + process_source_slice, + ) + + traces_ok = await run_tasks.process_evaluation_slice( + project_id=project_id, + user_id=user_id, + run_id=run_id, + source_kind="traces", + trace_ids=["trace-1"], + input_step_key="query-main", + tracing_service=tracing_service, # type: ignore[arg-type] + testcases_service=testcases_service, # type: ignore[arg-type] + workflows_service=workflows_service, # type: ignore[arg-type] + evaluations_service=evaluations_service, # type: ignore[arg-type] + ) + testcases_ok = await run_tasks.process_evaluation_slice( + project_id=project_id, + user_id=user_id, + run_id=run_id, + source_kind="testcases", + testcase_ids=[testcase_id], + input_step_key="testset-main", + tracing_service=tracing_service, # type: ignore[arg-type] + testcases_service=testcases_service, # type: ignore[arg-type] + workflows_service=workflows_service, # type: ignore[arg-type] + evaluations_service=evaluations_service, # type: ignore[arg-type] + ) + + assert traces_ok is True + assert testcases_ok is True + assert process_source_slice.await_args_list == [ + call( + project_id=project_id, + user_id=user_id, + run_id=run_id, + trace_ids=["trace-1"], + input_step_key="query-main", + tracing_service=tracing_service, + workflows_service=workflows_service, + evaluations_service=evaluations_service, + ), + call( + project_id=project_id, + user_id=user_id, + run_id=run_id, + testcase_ids=[testcase_id], + input_step_key="testset-main", + tracing_service=tracing_service, + testcases_service=testcases_service, + workflows_service=workflows_service, + evaluations_service=evaluations_service, + ), + ] + + +@pytest.mark.asyncio +async def test_run_processor_routes_batch_inference_through_testset_application_loop( + monkeypatch, +): + project_id = uuid4() + user_id = uuid4() + run_id = uuid4() + run = _run( + steps=[ + _step( + "testset-main", + "input", + references={"testset_revision": Reference(id=uuid4())}, + ), + _step( + "application-main", + "invocation", + references={"application_revision": Reference(id=uuid4())}, + ), + ], + ) + run.id = run_id + process_testset_source_run = AsyncMock() + monkeypatch.setattr( + run_tasks, + "process_testset_source_run", + process_testset_source_run, + ) + + processed = await run_tasks.process_evaluation_run( + project_id=project_id, + user_id=user_id, + run_id=run_id, + tracing_service=object(), # type: ignore[arg-type] + testsets_service=object(), # type: ignore[arg-type] + queries_service=object(), # type: ignore[arg-type] + workflows_service=object(), # type: ignore[arg-type] + applications_service=object(), # type: ignore[arg-type] + evaluations_service=SimpleNamespace(fetch_run=AsyncMock(return_value=run)), + simple_evaluators_service=object(), # type: ignore[arg-type] + ) + + assert processed is True + process_testset_source_run.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_testset_source_run_resolves_rows_and_uses_source_slice_processor( + monkeypatch, +): + project_id = uuid4() + user_id = uuid4() + run_id = uuid4() + testcase_id = uuid4() + testset_id = uuid4() + testset_variant_id = uuid4() + testset_revision_id = uuid4() + run = _run( + steps=[ + _step( + "testset-main", + "input", + references={"testset_revision": Reference(id=testset_revision_id)}, + ), + _step( + "application-main", + "invocation", + references={"application_revision": Reference(id=uuid4())}, + ), + _step( + "evaluator-auto", + "annotation", + origin="auto", + references={"evaluator_revision": Reference(id=uuid4())}, + ), + ], + ) + run.id = run_id + testcase = SimpleNamespace(id=testcase_id, data={"prompt": "hello"}) + resolved_specs = [ + { + "step_key": "testset-main", + "testset": SimpleNamespace(id=testset_id), + "testset_revision": SimpleNamespace( + id=testset_revision_id, + variant_id=testset_variant_id, + ), + "testcases": [testcase], + "testcases_data": [{"prompt": "hello"}], + } + ] + process_source_slice = AsyncMock() + monkeypatch.setattr( + source_slice_tasks, + "_resolve_testset_input_specs", + AsyncMock(return_value=resolved_specs), + ) + monkeypatch.setattr( + source_slice_tasks, + "process_evaluation_source_slice", + process_source_slice, + ) + + await source_slice_tasks.process_testset_source_run( + project_id=project_id, + user_id=user_id, + run_id=run_id, + tracing_service=object(), # type: ignore[arg-type] + testsets_service=object(), # type: ignore[arg-type] + workflows_service=object(), # type: ignore[arg-type] + applications_service=object(), # type: ignore[arg-type] + evaluations_service=SimpleNamespace(fetch_run=AsyncMock(return_value=run)), + ) + + process_source_slice.assert_awaited_once() + kwargs = process_source_slice.await_args.kwargs + assert kwargs["project_id"] == project_id + assert kwargs["user_id"] == user_id + assert kwargs["run_id"] == run_id + assert kwargs["require_queue"] is False + source_item = kwargs["source_items"][0] + assert source_item.kind == "testcase" + assert source_item.step_key == "testset-main" + assert source_item.testcase_id == testcase_id + assert source_item.testcase is testcase + assert source_item.inputs == {"prompt": "hello"} + assert source_item.references == { + "testcase": {"id": str(testcase_id)}, + "testset": {"id": str(testset_id)}, + "testset_variant": {"id": str(testset_variant_id)}, + "testset_revision": {"id": str(testset_revision_id)}, + } + + +@pytest.mark.asyncio +async def test_source_slice_processor_maps_scenario_and_run_statuses(monkeypatch): + project_id = uuid4() + user_id = uuid4() + run_id = uuid4() + scenario_success = SimpleNamespace(id=uuid4(), tags=None, meta=None) + scenario_pending = SimpleNamespace(id=uuid4(), tags=None, meta=None) + scenario_errors = SimpleNamespace(id=uuid4(), tags=None, meta=None) + run = _run( + flags=EvaluationRunFlags(is_queue=True), + steps=[ + _step( + "query-main", + "input", + references={"query_revision": Reference(id=uuid4())}, + ), + _step("evaluator-human", "annotation", origin="human"), + ], + ) + run.id = run_id + evaluations_service = SimpleNamespace( + fetch_run=AsyncMock(return_value=run), + edit_scenario=AsyncMock(), + edit_run=AsyncMock(), + ) + monkeypatch.setattr( + source_slice_tasks, + "sdk_process_evaluation_source_slice", + AsyncMock( + return_value=[ + SdkProcessedScenario(scenario=scenario_success), + SdkProcessedScenario( + scenario=scenario_pending, + has_pending=True, + ), + SdkProcessedScenario( + scenario=scenario_errors, + has_errors=True, + ), + ] + ), + ) + + await source_slice_tasks.process_evaluation_source_slice( + project_id=project_id, + user_id=user_id, + run_id=run_id, + source_items=[ + ResolvedSourceItem( + kind="trace", + step_key="query-main", + trace_id="trace-1", + ) + ], + tracing_service=SimpleNamespace(), + workflows_service=SimpleNamespace(), + evaluations_service=evaluations_service, + ) + + assert [ + call.kwargs["scenario"].status + for call in evaluations_service.edit_scenario.await_args_list + ] == [ + EvaluationStatus.SUCCESS, + EvaluationStatus.PENDING, + EvaluationStatus.ERRORS, + ] + assert evaluations_service.edit_run.await_args.kwargs["run"].status == ( + EvaluationStatus.ERRORS + ) + + +@pytest.mark.asyncio +async def test_source_slice_processor_hydrates_direct_trace_batches(monkeypatch): + project_id = uuid4() + user_id = uuid4() + run_id = uuid4() + trace_id = "trace-1" + span_id = "span-1" + scenario = SimpleNamespace(id=uuid4(), tags=None, meta=None) + run = _run( + flags=EvaluationRunFlags(is_queue=True), + steps=[ + _step( + "query-main", + "input", + references={"query_revision": Reference(id=uuid4())}, + ), + _step("evaluator-human", "annotation", origin="human"), + ], + ) + run.id = run_id + trace_payload = { + "trace_id": trace_id, + "spans": { + span_id: { + "trace_id": trace_id, + "span_id": span_id, + "attributes": { + "ag": { + "data": { + "inputs": {"prompt": "hello"}, + "outputs": {"answer": "world"}, + } + } + }, + } + }, + } + trace = SimpleNamespace( + trace_id=trace_id, + spans={ + span_id: SimpleNamespace( + trace_id=trace_id, + span_id=span_id, + attributes=trace_payload["spans"][span_id]["attributes"], + ) + }, + model_dump=lambda **_: trace_payload, + ) + evaluations_service = SimpleNamespace( + fetch_run=AsyncMock(side_effect=[run, run]), + edit_scenario=AsyncMock(), + edit_run=AsyncMock(), + ) + sdk_process = AsyncMock(return_value=[SdkProcessedScenario(scenario=scenario)]) + monkeypatch.setattr( + source_slice_tasks, + "sdk_process_evaluation_source_slice", + sdk_process, + ) + + await source_slice_tasks.process_evaluation_source_slice( + project_id=project_id, + user_id=user_id, + run_id=run_id, + trace_ids=[trace_id], + tracing_service=SimpleNamespace(fetch_trace=AsyncMock(return_value=trace)), + workflows_service=SimpleNamespace(), + evaluations_service=evaluations_service, + ) + + sdk_source_item = sdk_process.await_args.kwargs["source_items"][0] + assert sdk_source_item.trace_id == trace_id + assert sdk_source_item.span_id == span_id + assert sdk_source_item.trace is not None + assert sdk_source_item.inputs == {"prompt": "hello"} + assert sdk_source_item.outputs == {"answer": "world"} + + +@pytest.mark.asyncio +async def test_source_slice_processor_preserves_higher_queue_status(monkeypatch): + project_id = uuid4() + user_id = uuid4() + run_id = uuid4() + scenario = SimpleNamespace(id=uuid4(), tags=None, meta=None) + run = _run( + flags=EvaluationRunFlags(is_queue=True), + steps=[ + _step( + "query-main", + "input", + references={"query_revision": Reference(id=uuid4())}, + ), + _step("evaluator-human", "annotation", origin="human"), + ], + ) + run.id = run_id + current_run = run.model_copy(update={"status": EvaluationStatus.ERRORS}) + evaluations_service = SimpleNamespace( + fetch_run=AsyncMock(side_effect=[run, current_run]), + edit_scenario=AsyncMock(), + edit_run=AsyncMock(), + ) + monkeypatch.setattr( + source_slice_tasks, + "sdk_process_evaluation_source_slice", + AsyncMock(return_value=[SdkProcessedScenario(scenario=scenario)]), + ) + + await source_slice_tasks.process_evaluation_source_slice( + project_id=project_id, + user_id=user_id, + run_id=run_id, + source_items=[ + ResolvedSourceItem( + kind="trace", + step_key="query-main", + trace_id="trace-1", + ) + ], + tracing_service=SimpleNamespace(), + workflows_service=SimpleNamespace(), + evaluations_service=evaluations_service, + ) + + assert evaluations_service.edit_run.await_args.kwargs["run"].status == ( + EvaluationStatus.ERRORS + ) + + +@pytest.mark.asyncio +async def test_source_slice_processor_marks_run_failure_on_invalid_batch(): + project_id = uuid4() + user_id = uuid4() + run_id = uuid4() + run = _run( + flags=EvaluationRunFlags(is_queue=True), + steps=[ + _step( + "query-main", + "input", + references={"query_revision": Reference(id=uuid4())}, + ), + _step("evaluator-human", "annotation", origin="human"), + ], + ) + run.id = run_id + evaluations_service = SimpleNamespace( + fetch_run=AsyncMock(return_value=run), + edit_run=AsyncMock(), + ) + + await source_slice_tasks.process_evaluation_source_slice( + project_id=project_id, + user_id=user_id, + run_id=run_id, + tracing_service=SimpleNamespace(), + workflows_service=SimpleNamespace(), + evaluations_service=evaluations_service, + ) + + assert evaluations_service.edit_run.await_args.kwargs["run"].status == ( + EvaluationStatus.FAILURE + ) + + +@pytest.mark.asyncio +async def test_run_processor_routes_query_topologies_with_windowing(monkeypatch): + project_id = uuid4() + user_id = uuid4() + run_id = uuid4() + query_revision_id = uuid4() + evaluator_revision_id = uuid4() + newest = object() + oldest = object() + process_query_source_run = AsyncMock() + monkeypatch.setattr( + run_tasks, + "process_query_source_run", + process_query_source_run, + ) + live_run = _run( + flags=EvaluationRunFlags(is_live=True), + steps=[ + _step( + "query-main", + "input", + references={"query_revision": Reference(id=query_revision_id)}, + ), + _step( + "evaluator-auto", + "annotation", + origin="auto", + references={"evaluator_revision": Reference(id=evaluator_revision_id)}, + ), + ], + ) + live_run.id = run_id + batch_run = live_run.model_copy(update={"flags": EvaluationRunFlags(is_live=False)}) + + for run, expected_use_windowing in [(live_run, False), (batch_run, True)]: + process_query_source_run.reset_mock() + + processed = await run_tasks.process_evaluation_run( + project_id=project_id, + user_id=user_id, + run_id=run_id, + newest=newest, # type: ignore[arg-type] + oldest=oldest, # type: ignore[arg-type] + tracing_service=object(), # type: ignore[arg-type] + testsets_service=object(), # type: ignore[arg-type] + queries_service=object(), # type: ignore[arg-type] + workflows_service=object(), # type: ignore[arg-type] + applications_service=object(), # type: ignore[arg-type] + evaluations_service=SimpleNamespace(fetch_run=AsyncMock(return_value=run)), + simple_evaluators_service=object(), # type: ignore[arg-type] + ) + + assert processed is True + kwargs = process_query_source_run.await_args.kwargs + assert kwargs["use_windowing"] is expected_use_windowing + if expected_use_windowing: + assert kwargs["newest"] is None + assert kwargs["oldest"] is None + else: + assert kwargs["newest"] is newest + assert kwargs["oldest"] is oldest + + +@pytest.mark.asyncio +async def test_run_processor_returns_false_for_missing_or_unsupported_run(): + project_id = uuid4() + user_id = uuid4() + run_id = uuid4() + unsupported_run = _run( + steps=[ + _step( + "query-main", + "input", + references={"query_revision": Reference(id=uuid4())}, + ), + _step( + "testset-main", + "input", + references={"testset_revision": Reference(id=uuid4())}, + ), + ], + ) + unsupported_run.id = run_id + + common_kwargs = dict( + project_id=project_id, + user_id=user_id, + run_id=run_id, + tracing_service=object(), + testsets_service=object(), + queries_service=object(), + workflows_service=object(), + applications_service=object(), + simple_evaluators_service=object(), + ) + + assert ( + await run_tasks.process_evaluation_run( + **common_kwargs, # type: ignore[arg-type] + evaluations_service=SimpleNamespace(fetch_run=AsyncMock(return_value=None)), + ) + is False + ) + assert ( + await run_tasks.process_evaluation_run( + **common_kwargs, # type: ignore[arg-type] + evaluations_service=SimpleNamespace( + fetch_run=AsyncMock(return_value=unsupported_run) + ), + ) + is False + ) diff --git a/api/oss/tests/pytest/unit/test_evaluation_runtime_locks.py b/api/oss/tests/pytest/unit/test_evaluation_runtime_locks.py index 8e739cc170..a1f77a75b2 100644 --- a/api/oss/tests/pytest/unit/test_evaluation_runtime_locks.py +++ b/api/oss/tests/pytest/unit/test_evaluation_runtime_locks.py @@ -85,8 +85,10 @@ async def _release_lock_for_tests( return False return bool(await client.delete(lock_key)) + cache_engine = pytest.importorskip("oss.src.dbs.redis.shared.engine") + with ( - patch("oss.src.utils.caching.r_lock", client), + patch.object(cache_engine._cache_engine, "get_r_lock", return_value=client), patch( "oss.src.utils.caching.renew_lock", _renew_lock_for_tests, @@ -116,22 +118,22 @@ def _job_id() -> str: def _genson_patch(): module = types.ModuleType("genson") - live_module = types.ModuleType("oss.src.core.evaluations.tasks.live") + query_module = types.ModuleType("oss.src.core.evaluations.tasks.query") class SchemaBuilder: ... - async def evaluate_live_query(*args, **kwargs): + async def process_query_source_run(*args, **kwargs): return None module.SchemaBuilder = SchemaBuilder - live_module.evaluate_live_query = evaluate_live_query + query_module.process_query_source_run = process_query_source_run stack = ExitStack() stack.enter_context( patch.dict( sys.modules, { "genson": module, - "oss.src.core.evaluations.tasks.live": live_module, + "oss.src.core.evaluations.tasks.query": query_module, }, ) ) @@ -473,7 +475,7 @@ async def _failing_coro(): @pytest.mark.asyncio async def test_refresh_worker_heartbeat_preserves_created_at_without_fakeredis(): - from oss.src.core.evaluations.runtime import locks + import oss.src.core.evaluations.runtime.locks as locks class DummyRedis: def __init__(self): @@ -487,9 +489,10 @@ async def set(self, key, value, ex=None): return True dummy = DummyRedis() + cache_engine = pytest.importorskip("oss.src.dbs.redis.shared.engine") with ( - patch("oss.src.utils.caching.r_lock", dummy), + patch.object(cache_engine._cache_engine, "get_r_lock", return_value=dummy), patch( "oss.src.core.evaluations.runtime.locks._now_iso", side_effect=["2026-03-25T10:00:00Z", "2026-03-25T10:01:00Z"], @@ -506,7 +509,7 @@ async def set(self, key, value, ex=None): @pytest.mark.asyncio async def test_run_job_heartbeat_fails_after_missing_renew_deadline(): - from oss.src.core.evaluations.runtime import locks + import oss.src.core.evaluations.runtime.locks as locks clock = {"now": 0.0} diff --git a/docs/designs/unified-eval-loops/findings.md b/docs/designs/unified-eval-loops/findings.md new file mode 100644 index 0000000000..333e175480 --- /dev/null +++ b/docs/designs/unified-eval-loops/findings.md @@ -0,0 +1,169 @@ +# Unified Eval Loops Findings + +Review scope: last two commits on the current `application` checkout: + +- `a114ab369 initial design` +- `747502df1 initial implementation` + +Sources: + +- Code scan of `a114ab369^..HEAD`, focused on unified evaluation runtime, worker dispatch, SDK preview runtime, engine initialization, and adjacent tests/docs. +- Staged-area scan of the current `application` checkout. +- User-provided validation output from `cd api && poetry run python run-tests.py`: `1040 passed, 11 skipped in 78.94s`. + +## Notes + +- No local test execution was performed for these findings. The full suite result above is recorded from user-provided validation. +- Findings were resolved through code review and focused patching; no end-to-end evaluation run was performed locally in this pass. + +## Open Findings + +No open findings recorded after the requested fix pass. + +## Closed Findings + +### [CLOSED] UEL-004: Runnable batch length mismatches can silently drop planned cells + +- ID: `UEL-004` +- Origin: `scan` +- Lens: `verification` +- Severity: `P1` +- Confidence: `high` +- Status: `fixed` +- Category: `Correctness` +- Summary: The shared source-slice loop zipped planned cells with runner results and did not verify that the runner returned one execution per requested cell. +- Files: + - `sdk/agenta/sdk/evaluations/runtime/source_slice.py` + - `sdk/tests/pytest/unit/test_evaluations_runtime.py` +- Resolution: + - Fixed by making `process_evaluation_source_slice` treat runner result-count mismatches as explicit scenario errors. + - Missing trailing planned cells are now logged as failed result cells with a contract-violation message instead of disappearing from persistence. + - Added focused SDK unit coverage for a two-repeat auto evaluator batch where the runner returns only one execution. + +### [CLOSED] UEL-005: Trace-backed queue slices do not load trace context before evaluator execution + +- ID: `UEL-005` +- Origin: `scan` +- Lens: `verification` +- Severity: `P1` +- Confidence: `high` +- Status: `fixed` +- Category: `Correctness` +- Summary: Direct trace batches entered the unified runtime as `ResolvedSourceItem(trace_id=...)` only, so auto evaluators could receive no source trace, inputs, outputs, or span link. +- Files: + - `api/oss/src/core/evaluations/runtime/sources.py` + - `api/oss/src/core/evaluations/tasks/source_slice.py` + - `api/oss/tests/pytest/unit/evaluations/test_runtime_topology_planner.py` +- Resolution: + - Fixed by hydrating direct trace source items through `tracing_service` before converting them to SDK source items. + - The resolver now populates `trace`, root `span_id`, `inputs`, and `outputs` from `ag.data` when the source trace is available. + - Added focused unit coverage for direct source resolution and for `process_evaluation_source_slice(trace_ids=[...])` forwarding hydrated source context to the SDK runtime. + +### [CLOSED] UEL-006: Source-trace links are hard-coded as `invocation` + +- ID: `UEL-006` +- Origin: `scan` +- Lens: `verification` +- Severity: `P2` +- Confidence: `medium` +- Status: `wontfix` +- Category: `Consistency` +- Summary: The SDK runtime emits upstream links under the key `invocation`. +- Files: + - `sdk/agenta/sdk/evaluations/runtime/source_slice.py` + - `api/oss/src/core/evaluations/runtime/adapters.py` +- Resolution: + - Wontfix per user decision: the invocation link key is the workflow contract, and the key for the invocation step should be `invocation`. + +### [CLOSED] UEL-003: Dict-revision regression test asserts fields that the request model drops + +- ID: `UEL-003` +- Origin: `mixed` +- Lens: `verification` +- Severity: `P1` +- Confidence: `high` +- Status: `fixed` +- Category: `Testing` +- Summary: The staged unit test for dict-shaped evaluator revisions fails because it asserts `workflow_request.interface` and `workflow_request.configuration`, but the active `WorkflowServiceRequest` alias is `WorkflowInvokeRequest`, whose declared payload surface is `data`. +- Impact: This previously made the full API pytest suite red and blocked using the regression test as validation for the staged adapter fix. The latest user-provided validation is now green. +- Evidence: + - User-provided validation fails at `oss/tests/pytest/unit/evaluations/test_runtime_topology_planner.py::test_backend_evaluator_runner_preserves_dict_revision_data` with `AttributeError: 'WorkflowInvokeRequest' object has no attribute 'interface'`. + - `api/oss/tests/pytest/unit/evaluations/test_runtime_topology_planner.py:1187-1191` asserts `workflow_request.interface.*` and `workflow_request.configuration.*`. + - `sdk/agenta/sdk/models/workflows.py:255-256` defines `WorkflowInvokeRequest` with `data: Optional[WorkflowRequestData] = None`; it does not declare `interface` or `configuration`. + - `api/oss/src/core/evaluations/runtime/adapters.py:355-362` still passes `interface=` and `configuration=` while constructing `WorkflowServiceRequest`, but those are not accessible as model attributes under the current request model. The preserved evaluator details should be verified through `workflow_request.data.revision["data"]` and `workflow_request.data.parameters`, or the request model should explicitly regain those top-level fields if that is the intended contract. +- Files: + - `api/oss/tests/pytest/unit/evaluations/test_runtime_topology_planner.py` + - `api/oss/src/core/evaluations/runtime/adapters.py` + - `sdk/agenta/sdk/models/workflows.py` +- Cause: The staged regression test was written against an older or assumed request shape instead of the current SDK `WorkflowServiceRequest` alias. +- Explanation: The adapter now reads nested dict data with `_read_field`, which addresses the original dict/DTO mismatch. However, the test checks `interface` and `configuration` directly on the Pydantic request object. Since the model only declares `data`, Pydantic does not expose those names, producing the exact `AttributeError` seen in the full-suite output before the test can assert the actual preserved revision data. +- Suggested Fix: + - Update the regression test to assert the current contract: `workflow_request.data.revision["data"]["uri"]`, `headers`, `schemas`, `script`, and `workflow_request.data.parameters`. + - If top-level `interface` and `configuration` are still required for downstream workflow-service invocation, add them to the SDK `WorkflowInvokeRequest` model and add assertions on `workflow_request.model_dump(mode="json", exclude_none=True)`. +- Alternatives: + - Remove the top-level `interface` and `configuration` constructor arguments from `BackendEvaluatorRunner` if they are intentionally obsolete, to avoid suggesting that they are part of the request contract. +- Sources: + - `api/oss/tests/pytest/unit/evaluations/test_runtime_topology_planner.py:1187` + - `api/oss/src/core/evaluations/runtime/adapters.py:355` + - `sdk/agenta/sdk/models/workflows.py:255` +- Resolution: + - Fixed by updating the regression test to assert preserved evaluator metadata through `workflow_request.data.revision["data"]` and `workflow_request.data.parameters`, matching the current SDK request model. + +### [CLOSED] UEL-001: Backend evaluator runner receives dumped revisions but reads them like DTOs + +- ID: `UEL-001` +- Origin: `scan` +- Lens: `verification` +- Severity: `P1` +- Confidence: `high` +- Status: `fixed` +- Category: `Correctness` +- Summary: Backend auto-annotation execution can invoke evaluators with an empty `interface` and `configuration` because the shared runtime dumps revisions to dictionaries before handing them to `BackendEvaluatorRunner`. +- Evidence: + - `sdk/agenta/sdk/evaluations/runtime/source_slice.py:274-279` builds `WorkflowExecutionRequest` with `revision=_dump_revision(revision)`, so Pydantic DTOs become plain dictionaries. + - `api/oss/src/core/evaluations/runtime/adapters.py:302-328` detects `dict` revisions but then reads `data` with `getattr(data, "uri")`, `getattr(data, "script")`, and `getattr(data, "parameters")`. When `data` is a dict, these all return `None`. + - The resulting `WorkflowServiceRequest` for backend evaluator steps loses the evaluator script, parameters, URI, URL, headers, and schemas. +- Files: + - `sdk/agenta/sdk/evaluations/runtime/source_slice.py` + - `api/oss/src/core/evaluations/runtime/adapters.py` +- Cause: The shared runtime normalizes revisions for transport with `model_dump`, but the backend evaluator adapter was written for object-shaped revision data and only partially handles dictionary-shaped revisions. +- Explanation: Auto evaluator steps are routed through `process_evaluation_source_slice`, which passes evaluator revisions into the SDK runtime. The SDK runtime dumps the revision before storing it on the request. On the backend side, `BackendEvaluatorRunner._execute_one` switches to dict handling for the top-level revision but not for nested `data`, so evaluator metadata becomes empty. This can make backend auto annotations fail or execute without the intended evaluator configuration while planner tests still pass. +- Suggested Fix: + - Preserve the revision object in `WorkflowExecutionRequest` when running in-process, or teach `BackendEvaluatorRunner` to read both dict and DTO shapes for `data`. + - Add a focused unit test where an evaluator revision dict with `data.script`, `data.parameters`, and interface fields reaches `workflows_service.invoke_workflow` intact. +- Alternatives: + - Move backend request construction before the SDK runtime boundary and pass a backend-native execution payload to the runner. +- Sources: + - `sdk/agenta/sdk/evaluations/runtime/source_slice.py:274` + - `api/oss/src/core/evaluations/runtime/adapters.py:302` +- Resolution: + - Fixed by making `BackendEvaluatorRunner` read revision, nested `data`, and `flags` from both dict-shaped and DTO-shaped objects. + - Added a focused unit case for dict-shaped evaluator revisions in `api/oss/tests/pytest/unit/evaluations/test_runtime_topology_planner.py`. + +### [CLOSED] UEL-002: Startup instrumentation uses raw prints in the FastAPI module + +- ID: `UEL-002` +- Origin: `scan` +- Lens: `verification` +- Severity: `P3` +- Confidence: `high` +- Status: `wontfix` +- Category: `Compatibility` +- Summary: `api/entrypoints/routers.py` now emits startup timing with top-level `print()` calls during module import. +- Evidence: + - `api/entrypoints/routers.py:158-171` prints SDK import and `ag.init()` timing at import time. + - `api/entrypoints/routers.py:176-180` prints EE import timing at import time. +- Files: + - `api/entrypoints/routers.py` +- Cause: Debug startup timing instrumentation was committed directly into the application entrypoint instead of using the structured logger or a debug-gated startup probe. +- Explanation: These prints run whenever the module is imported, including tests, scripts, workers, and production ASGI startup. That bypasses log formatting, severity controls, JSON log aggregation, and normal logger configuration. It is not blocking, but it adds noisy side effects to a central import path. +- Suggested Fix: + - Replace the `print()` calls with `log.debug` or `log.info` after logger initialization, gated by an explicit startup profiling flag if the timings are still needed. + - Keep import-time side effects limited to required initialization. +- Alternatives: + - Move startup timing into the FastAPI lifespan handler and emit one structured summary log. +- Sources: + - `api/entrypoints/routers.py:158` + - `api/entrypoints/routers.py:176` +- Resolution: + - Wontfix per user decision. diff --git a/docs/designs/unified-eval-loops/gap.md b/docs/designs/unified-eval-loops/gap.md new file mode 100644 index 0000000000..1e2977bc5e --- /dev/null +++ b/docs/designs/unified-eval-loops/gap.md @@ -0,0 +1,401 @@ +# Gap Analysis + +## Summary + +The current system has many pieces of a unified evaluation model, but they are split across setup surfaces, worker loops, SDK code, queue code, and frontend assumptions. + +The largest gaps are: + +- no first-class planner that turns run graph + sources + flags into tensor cells +- source resolution exists in several paths but not behind one resolver interface +- repeat-aware execution exists in current backend loops but not behind one execution planner +- pending/manual lifecycle exists in key loops but is still duplicated and topology-specific +- no slice-aware `process` operation +- no single slice-shaped operation boundary across process/probe/populate/prune +- source-family classification and simple-queue eligibility are still entangled through `is_queue` +- destructive step removal and archival step lifecycle are not yet separated + +## Already Implemented + +Current code already includes several capabilities that older design docs listed as missing or speculative: + +- `EvaluationRunFlags.is_cached` +- `EvaluationRunFlags.is_split` +- `EvaluationRunFlags.is_queue` +- `EvaluationRunData.repeats` +- `EvaluationResult.repeat_idx` +- `EvaluationResultQuery.repeat_idx` / `repeat_idxs` +- repeat helpers: `build_repeat_indices`, `required_traces_for_step`, `effective_is_split` +- cache helpers: `make_hash`, `fetch_traces_by_hash`, `select_traces_for_reuse`, `plan_missing_traces` +- source-aware queue creation from query/testset-backed sources +- source-backed queue dispatch to concrete trace/testcase batches +- human/custom evaluator pending behavior in live query and batch item paths +- repeat-aware input/evaluator result creation in live query +- repeat-aware input/application/evaluator result creation in batch testset +- repeat-aware input/application result creation in batch inference / batch invocation +- repeat-aware source/evaluator result creation in batch trace/testcase items + +## Setup Gaps + +Current setup is fragmented: + +- auto testset evaluation setup builds one specific graph shape +- human evaluation setup builds a related but separate testset shape +- live query setup has separate query semantics +- queue setup accepts direct trace/testcase IDs but not source revisions +- SDK/local setup has its own assumptions + +Still missing or incomplete: + +- canonical graph-oriented create request +- shared validation for input source combinations +- one canonical setup request model used by all wrappers +- wrapper-to-canonical translation for every existing setup API +- one place to enforce step origin semantics +- annotation queue convenience APIs that hide backing run/scenario/result setup while using the same canonical setup path + +## Source Resolution Gaps + +There is no shared abstraction for resolving source descriptors into concrete scenario items, even though source resolution now exists in several code paths. + +Resolver behavior that exists but should be extracted: + +- query revision -> trace refs for live windows +- query revision -> trace refs for source-backed queues +- testset revision -> testcase refs for source-backed queues +- testset revision -> testcase payloads for batch testset/invocation +- direct trace IDs -> trace refs +- direct testcase IDs -> testcase refs + +Current consequences: + +- scenario creation is repeated in each loop +- each loop owns part of source resolution itself +- live and batch query semantics are harder to compare +- unsupported mixed-source cases fail implicitly rather than through clear validation + + +## Source-Family Flag Gaps + +The current runtime still overloads `is_queue` in places where the concern is source family or queue-style ingestion. + +Missing from the target model: + +- explicit inferred `has_traces` +- explicit inferred `has_testcases` +- one place to prevent mixed source families using the source-family flags +- one topology contract that does not need synthetic step-name inspection to distinguish direct traces/testcases from query/testset sources + +## Default Queue Integration Gaps + +The queue layer should be a human-work view over the tensor, but the current runtime/docs still blur several concerns. + +Missing from the target model: + +- default queue as the canonical persisted view over active human work +- `queue.flags.is_default` +- queue lifecycle semantics that can drive simple-queue eligibility +- redefined persisted `run.flags.is_queue` as “active default queue + active human work” +- explicit separation between queue view semantics and source-family classification + +## Step Lifecycle Gaps + +The current mutation model still leans toward: + +```text +remove_step -> prune tensor cells +``` + +That is incomplete if product semantics require evaluator/step archival and retention of historical results. + +Needs explicit decision: + +- whether ordinary evaluator removal is archival/deactivation rather than destructive deletion +- whether active versus archived step state belongs in the graph model +- whether `process` defaults to active steps only +- whether queue eligibility is based on active human steps only +- when hard remove/prune is appropriate + +## Planner Gaps + +The system lacks an execution planner that can materialize these concepts once: + +- scenario cells +- input result cells +- auto executable cells +- human/custom pending cells +- repeat slots +- cache reuse plans +- upstream bindings between steps + +Current loops encode planning inline. This makes it difficult to support new combinations or change semantics consistently. + +- multiple input steps with consistent result slots +- repeat fan-out at different graph boundaries +- partial retries by tensor slice + +## Execution Gaps + +Current execution was specialized: + +- SDK preview evaluation had its own nested loop +- backend legacy batch testset had another loop +- backend live query had another loop +- queue batch evaluation had another loop + +Current backend implementation direction: + +- live query, batch query, direct trace queues, direct testcase queues, batch inference, and testset -> application -> evaluator resolve source items and call one backend source-slice processor +- batch inference is the application-only testset application graph shape +- API-internal task handlers have been collapsed to run and slice processors +- trace/testcase batch task helpers are no longer needed because the slice processor can call the source-slice processor directly +- specialized helper names may remain as wrappers while web/API compatibility is preserved + +Current SDK implementation direction: + +- SDK preview/local evaluation now routes through SDK-owned `process_evaluation_source_slice` +- SDK runner, result logging, trace loading, and metrics work are adapters around the shared SDK runtime contract +- backend execution now delegates to the SDK processor through backend-specific scenario, result, cache, status, trace, and workflow adapters + +Still missing: + +- unified `process(run, slice)` role exposed as a public API or service operation +- topological execution over planned cells +- idempotent probe-before-write behavior +- consistent error-as-result behavior +- shared metrics refresh policy after processing +- clear separation between execution and persistence adapters +- public API/service operation shape for invoking the SDK-owned source-slice processor by tensor slice + +## Runnable Execution Gaps + +The current worker loops still call step-specific invocation helpers directly. + +Known debt: + +- application execution still uses legacy batch invocation helper paths +- evaluator execution assembles workflow invocation requests separately +- cache lookup, trace validation, link/reference construction, and error-to-result conversion are repeated around those calls +- there is no single runnable-step executor that can handle application and evaluator steps through the same contract + +Needed: + +- `RunnableStepExecutor` interface for any auto runnable step +- application-step adapter that can initially wrap the current application invocation path +- evaluator-step adapter that can initially wrap workflow invocation +- shared request/context builder for references, links, inputs, trace, outputs, and parameters +- shared trace validation and result normalization +- migration path to deprecate legacy LLM app batch helper functions after parity is proven + +This should be treated as part of unification, not as a later cleanup. Otherwise the new loop would only centralize iteration while preserving the most brittle execution boundary. + +## Tensor Operation Gaps + +The intended tensor identity exists in storage and query models, but the operation model is incomplete. + +Missing or partial: + +- `TensorSlice` model across backend, SDK, and frontend +- slice-aware `probe` +- slice-aware `populate` +- slice-aware `prune` +- slice-aware `process` +- partial retry/fill-missing workflows +- repeat slot materialization as a shared planner primitive + +Existing APIs are mostly per-entity or full-run oriented. + +## Repeat And Fan-Out Gaps + +`repeat_idx` exists in result identity and current backend loops now expand it in the inspected paths. + +Current gaps: + +- repeat expansion is duplicated across loops +- queue repeats still also carry assignment semantics, so execution repeats and assignment lanes need an explicit shared contract +- `is_split` is enforced through helpers in some paths, but topology validation is still dispatch-specific +- no shared planner decides whether application or evaluator steps fan out + +Needed: + +- repeat-aware result-slot planner +- topology-specific fan-out validation +- deterministic repeat-slot binding for reused traces +- tests for full, partial, and zero cache hits under repeats + +## Cache Gaps + +Hash-based trace reuse is implemented in current backend worker paths, but not centralized. + +Still missing: + +- one shared cache-resolution stage used by every runnable step +- one explicit per-slot cache binding object +- parity tests proving all loops use the same cache rules +- a documented project-scoped cross-run reuse policy + +The current code uses `is_cached`; older docs that say `reuse_traces` should be treated as stale terminology unless an external compatibility need exists. + +## Origin Gaps + +`auto`, `human`, and `custom` origins exist in the current backend model, and human/custom pending behavior is present in several loops. + +Still missing or incomplete: + +- common pending/manual result lifecycle +- consistent frontend/backend origin naming +- external custom-populate contract for custom steps +- annotation queue progress/status semantics layered over evaluation results without duplicating task state + +Verify frontend/generated-client naming before changing UI code; backend type truth is `auto`. + +## Annotation Queue Layer Gaps + +Annotation Queue v2 identifies product/API gaps adjacent to unified loop execution: + +- convenience API for queue creation from traces and testsets +- annotator inbox/list view across queues +- per-item progress computed from evaluation results +- explicit export/write-back flow for testset-sourced annotation queues +- clear distinction between backing infrastructure status and consumer-facing task status +- UI that uses queue assignment instead of allowing annotators to annotate any scenario + +These should be built on top of the unified planner/source resolver/tensor result model, not as a separate runtime. + +## Graph Mutation Gaps + +Steps are stored in run data, but graph operations are not first-class enough. + +Missing: + +- `add_step` endpoint/service operation +- `remove_step` endpoint/service operation +- graph validation outside setup functions +- tensor pruning cascade when removing a step +- explicit immutable-reference policy in code paths +- UI/API affordances for managing steps after creation + +Without these, graph changes require specialized setup edits or recreation. + +## Flag Gaps + +Flags are consistently modeled in the current backend types, but old docs and possibly callers still use stale names. + +Current canonical backend flags include: + +- `is_live` +- `is_active` +- `is_cached` +- `is_split` +- `is_queue` +- `repeats` +- `is_closed` + +Compatibility names to reconcile: + +- `reuse_traces` vs `is_cached` +- `repeat_target` vs `is_split` +- `allow_decrease_repeats` if repeat count becomes mutable + +Still missing: + +- first-class constrained `set_flag` +- validation when flags conflict with topology +- end-to-end propagation through setup, run fetch, queue creation, SDK/local execution, and frontend state + +## Topology Gaps + +Supported topologies are implicit in dispatch logic. + +Missing: + +- explicit topology validation table in code +- structured error messages for unsupported combinations +- explicit rejection for not-planned shapes such as multiple application steps, mixed-source queues, and live testset runs +- future extension point for query -> application flows +- future extension point for testset -> evaluator flows + +Potentially useful future shapes: + +- `query -> application -> evaluator`, with query traces adapted as input data rather than application links +- `testset -> evaluator`, with an explicit evaluator testcase-only input contract + +Not planned for now: + +- multiple application steps in one worker-dispatched run +- mixed query/testset source families in one queue +- live testset evaluation + +The immediate goal should not be to support every theoretical graph. It should be to reject unsupported graphs through one planner, and make adding support localized. + +## API Gaps + +Missing or incomplete API surface: + +- graph-oriented create request +- `process(slice)` +- `probe(slice)` +- `prune(slice)` +- `populate(slice, results)` for bulk/slice writes +- `set_flag` +- response payloads that expose resolved source items and pending cells consistently + +Existing APIs should remain as compatibility wrappers while the canonical surface is introduced. + +## SDK Gaps + +The SDK should own the shared runtime contract. The preview loop can keep its +public setup API, but orchestration should move behind SDK runtime planning and +SDK-specific execution/persistence adapters. Backend workers should consume the +same SDK runtime models through backend-specific adapters. + +Missing: + +- remote API persistence adapter +- slice-aware processing +- probe-before-write +- cache parity with backend +- stable step key strategy aligned with backend graph steps +- removal of duplicate backend planner/topology logic once migration coverage is sufficient + +The desired state is not "SDK calls backend worker for everything." It is "SDK and backend share the same loop contract with different persistence/execution adapters." + +## Frontend Gaps + +Missing: + +- explicit graph builder/step management model +- TensorSlice UI concepts for retry, prune, and fill missing +- unified origin naming +- flag editing beyond current implicit flows +- display of pending human/custom cells across query, testset, and queue runs +- source-aware queue creation UI if that product path is enabled + +Frontend work can follow backend planner/API stabilization. + +## Testing Gaps + +Needed test coverage: + +- source resolver outputs for query, testset, trace IDs, testcase IDs +- topology validation success and failure cases +- repeat slot materialization for every supported topology +- `is_split=true` and `is_split=false` on testset -> application -> evaluator +- evaluator-only repeat fan-out for query and queue runs +- cache full hit, partial hit, and miss +- cross-run trace reuse +- human/custom pending cells in query-backed runs +- source-aware queues preserving query/testset revision references +- existing direct queue behavior remains unchanged +- SDK/backend parity for the same planned graph + +## Documentation Gaps + +Needed docs after implementation starts: + +- canonical source matrix +- topology validation matrix +- flag semantics and compatibility names +- manual/custom origin lifecycle +- cache and repeat behavior +- migration guide from specialized setup APIs to canonical graph creation diff --git a/docs/designs/unified-eval-loops/plan.md b/docs/designs/unified-eval-loops/plan.md new file mode 100644 index 0000000000..2ccb00ab5e --- /dev/null +++ b/docs/designs/unified-eval-loops/plan.md @@ -0,0 +1,199 @@ +# Plan + +## Goal + +Move from multiple specialized setup and execution functions to unified evaluation loop(s) built around: + +- source resolvers +- run graph steps +- tensor slices +- repeat-aware planning +- origin-aware execution +- runnable-step execution +- adapter-based persistence + +This plan describes required work, not phases or timeline. + +## Baseline Inventory + +1. Lock down the current behavior with code references and tests before changing semantics. +2. Treat these as implemented baseline behavior: + - `is_cached` + - `is_split` + - current `is_queue` + - `repeats` + - repeat-indexed result creation in current backend workers + - hash-based cache helpers and worker integration + - source-aware queue creation and source batch dispatch + - human/custom pending behavior in query/queue-related paths +3. Maintain a parity matrix from current tests: + - `test_cache_split_utils.py` + - `test_query_eval_loops.py` + - `test_run_flags.py` + - queue assignment and queue DAO tests + - acceptance tests for evaluation steps/runs/queues/results + +## Vocabulary And Flags + +1. Use current backend names as canonical: + - `is_cached` + - `is_split` + - `repeats` +2. Normalize origin values: + - `auto` + - `human` + - `custom` +3. Add explicit source-family flags: + - `has_queries` + - `has_testsets` + - `has_traces` + - `has_testcases` +4. Separate source-family classification from simple-queue eligibility. +5. Redefine target `run.flags.is_queue` as: + +```text +active default queue exists + active human evaluator work exists +``` + +6. Document topology validation rules in one table used by implementation and tests. + +## Shared Runtime Models + +Introduce or consolidate shared internal models: + +1. `InputSourceSpec` +2. `ResolvedSourceItem` +3. `ScenarioBinding` +4. `EvaluationStep` +5. `TensorSlice` +6. `PlannedCell` +7. `ExecutionPlan` +8. `ProcessSummary` + +The common runtime contract should live in the SDK so SDK-local evaluations and API workers share the same planner/topology/result-cell model. Backend code should keep API-specific source, DAO, workflow-service, and worker-dispatch adapters in backend modules. + +## Source Resolution + +Create resolver interfaces that cover: + +1. query revision -> trace refs for live windows +2. query revision -> trace refs for source-backed queues +3. testset revision -> testcase refs for source-backed queues +4. testset revision -> testcase payloads for batch testset/invocation +5. direct trace IDs -> trace refs +6. direct testcase IDs -> testcase refs + +Resolver requirements: + +- preserve existing source behavior +- own live query windowing +- preserve original source references in input steps +- reject unsupported mixed-source combinations explicitly +- expose source-family flags consistently + +## Tensor Slice Operations + +Add or adapt backend service operations around existing CRUD: + +1. `probe(slice)` +2. `populate(slice, results)` +3. `prune(slice)` +4. `process(slice)` + +Requirements: + +- slice dimensions support all/none/explicit selections +- `probe` identifies missing, success, failure, and any cells +- `populate` writes by `scenario_id + step_key + repeat_idx` +- `prune` deletes by slice and refreshes affected metrics + +## Planner + +Implement planner logic that produces result slots before execution. + +Planner responsibilities: + +1. validate topology +2. order steps +3. materialize scenario bindings +4. create input cells +5. expand repeat slots +6. decide fan-out boundary +7. bind upstream context +8. mark `human` and `custom` cells pending +9. select `auto` cells for execution + +Planner requirements: + +- one planned cell exists for every required repeat slot +- unsupported topologies fail with structured validation errors +- human/custom steps are planned as pending rather than silently skipped + +## Cache Resolution + +Reuse the existing cache helpers: + +1. `make_hash(...)` +2. `fetch_traces_by_hash(...)` +3. `select_traces_for_reuse(...)` +4. `plan_missing_traces(...)` +5. per-slot trace binding + +Requirements: + +- cache lookup is skipped when `is_cached=false` +- full cache hit invokes nothing +- partial cache hit invokes only missing slots +- misses invoke all required slots +- reused and newly generated traces populate identical tensor cells + +## Runnable-Step Execution + +Add a runnable execution boundary for any auto step whose type maps to a runnable. + +Initial adapters: + +1. SDK workflow-runner protocols for application/evaluator execution +2. SDK/local adapters wrapping decorator/service endpoint execution +3. API adapters wrapping the current backend workflow invocation path +4. API application adapter wrapping the current legacy batch invocation path + +The interface should own: + +- request construction from step references and upstream bindings +- cache resolver integration +- invocation +- trace fetch/validation +- normalized `StepExecutionResult` + +## Queue Integration + +1. Treat default queues as persisted human-work views over the tensor, not orchestration. +2. Add `queue.flags.is_default` to identify the canonical queue. +3. Keep default queues open over scenarios, steps, and assignments. +4. Let source-family flags describe where scenarios come from. +5. Let `run.flags.is_queue` describe simple-queue eligibility. +6. Ensure queue eligibility depends on active human steps and active default queue lifecycle. + +## Mutation Semantics + +1. Decide whether ordinary evaluator removal is archival/deactivation rather than destructive deletion. +2. If history must remain visible, represent active versus archived step state in the graph model. +3. Make planner defaults operate on active steps. +4. Reserve hard remove/prune for explicit destructive cleanup. +5. Keep queue eligibility tied to active human steps. + +## Verification + +Add or preserve coverage for: + +1. topology classification +2. resolver behavior +3. repeat slot creation +4. cache reuse +5. human/custom pending behavior +6. query/testset/direct trace/direct testcase source families +7. source-family validation +8. tensor slice probe/populate/prune/process behavior +9. queue/default-queue integration semantics +10. active-versus-archived step behavior once chosen diff --git a/docs/designs/unified-eval-loops/proposal.md b/docs/designs/unified-eval-loops/proposal.md new file mode 100644 index 0000000000..ac5e2bcca1 --- /dev/null +++ b/docs/designs/unified-eval-loops/proposal.md @@ -0,0 +1,476 @@ +# Proposal + +## Goal + +Introduce unified evaluation loop(s) that avoid separate setup and execution functions for every evaluation shape while preserving the capabilities already implemented in the current backend: + +- input steps +- application and evaluator origins +- evaluation run flags +- input sources +- evaluation graph steps +- repeat and cache behavior through `repeats`, `is_cached`, and `is_split` +- live, batch, queue, and SDK/local execution contexts + +This proposal does not require every source to become the same thing. It requires every source to enter the same planning and tensor execution contract. + +The current code has already implemented several parts that older docs described as missing: + +- `EvaluationRunFlags.is_cached` +- `EvaluationRunFlags.is_split` +- `EvaluationRunFlags.is_queue` +- `EvaluationRunData.repeats` +- repeat helper functions in `evaluations/utils.py` +- hash/cache helper functions in `evaluations/utils.py` +- source-aware queue creation from query/testset-backed sources +- live and batch query human/custom pending behavior +- repeat-aware result creation in the inspected backend worker loops + +The proposal is therefore a unification/refactor proposal, not a first implementation of those behaviors. + +## Design Principle + +Separate source resolution, graph planning, execution, and tensor persistence. + +```text +setup request + -> source resolver + -> run graph + -> scenario materializer + -> execution planner + -> step executor + -> tensor writer + -> metrics refresh +``` + +Each current loop family becomes a configuration of this pipeline instead of a separate handwritten loop. + +## Canonical Concepts + +### Run Graph + +A run graph is a list of immutable step definitions: + +```python +class EvaluationStep: + key: str + type: Literal["input", "invocation", "annotation"] + origin: Literal["auto", "human", "custom"] + references: dict + inputs: list[StepInput] | None +``` + +Step references point to concrete revisions or direct-source descriptors. Editing a step means removing it and adding a new step. + +### Tensor Cell + +Every produced or pending result targets: + +```text +run_id + scenario_id + step_key + repeat_idx +``` + +All execution, retry, prune, cache binding, and manual annotation work should address this coordinate. + +### Tensor Slice + +Use one slice model for read, write, delete, and processing operations: + +```python +class TensorSlice: + scenarios: Literal["all", "none"] | list[UUID] + steps: Literal["all", "none"] | list[str] + repeats: Literal["all", "none"] | list[int] +``` + +The same slice shape should power: + +- `probe` +- `populate` +- `prune` +- `process` +- retry failed cells +- fill missing cells +- re-run a single evaluator +- materialize new repeat slots + +## Source Resolver Layer + +Input sources should be modeled as descriptors that resolve into concrete source items. + +| Descriptor | Resolver output | Scenario source | +|---|---|---| +| query revision | trace refs | queried traces | +| testset revision | testcase refs | testcases | +| direct trace source | trace refs | queued traces | +| direct testcase source | testcase refs | queued testcases | + +The resolver is responsible for source-specific rules: + +- live query windows +- batch query snapshots +- testset revision loading +- direct item validation +- source-aware queue expansion +- preserving original source references in input steps + +The executor should not care whether a trace came from live query, batch query, or a queue. It should receive concrete scenario bindings. + +## Annotation Queue Convenience Layer + +Annotation Queue v2 should be treated as a consumer-facing layer over the same unified evaluation infrastructure. + +Principles: + +- `EvaluationRun`, `EvaluationScenario`, `EvaluationResult`, and `EvaluationQueue` remain the backing entities. +- The annotation queue API hides run/scenario/result setup for trace and testset annotation use cases. +- Queue assignment remains based on `EvaluationQueue.data.user_ids`, optional `scenario_ids`, optional `step_keys`, and result `repeat_idx`. +- Queue creation from traces/testsets should translate into canonical source specs and graph steps, then use the same source resolver and planner as evaluation runs. +- Annotation submission can continue to create annotation traces and link them to evaluation results. + +Unified eval loops should provide the infrastructure contract for this layer. They should not replace the annotation queue convenience API. + +## Planner Layer + +The planner converts a graph and concrete scenarios into execution slots. + +```python +class PlannedCell: + scenario_id: UUID + step_key: str + repeat_idx: int + action: Literal["bind_input", "invoke", "pending", "skip"] + upstream: dict +``` + +Planner responsibilities: + +- validate topology +- derive step order +- materialize input cells +- decide repeat fan-out point +- compute required result slots +- bind upstream trace/testcase/application output context +- mark human/custom annotation cells as pending +- skip cells that are already successful when requested + +The planner is where topology-specific behavior belongs. + +## Execution Layer + +`process(slice)` executes planned `auto` cells only. + +```text +process(run, slice): + resolve concrete scenarios for input steps + plan cells for the requested slice + probe existing cells if skip-success is enabled + for each executable auto cell in dependency order: + resolve cache/reuse if enabled + invoke only missing work + populate result cells + create pending cells for human/custom work + refresh metrics for affected scope +``` + +The same executor can run in different contexts with different adapters: + +| Context | Adapter | +|---|---| +| backend worker | API source/DAO/workflow-service adapters | +| SDK/local | local decorator or remote service adapters | +| tests | in-memory adapter | +| frontend human annotation | direct `populate` adapter for submitted cells | + +The planner, topology classifier, and result-cell models should be SDK-owned so +SDK-local evaluation and backend workers use the same runtime contract. API code +should not fork the runtime; it should translate backend DTOs into SDK runtime +models and keep only backend-specific adapters beside the worker/service code. + +The backend implementation should have one scenario execution loop. Source +wrappers may still differ because live queries, query snapshots, direct queue +items, and testset rows resolve differently, but after resolution they should +all call one source-slice processor. That processor owns input cell creation, +application invocation, evaluator invocation, pending manual/custom cells, +cache resolution, metrics refresh, and run/scenario status updates. Batch +inference is therefore just the application-only graph shape, not a separate +loop. Task-level trace/testcase batch helpers are unnecessary once the slice +worker calls the source-slice processor directly; service/API wrapper methods +can remain for compatibility. + +The SDK should own the generic source-slice contract. SDK preview/local +execution can run through SDK-owned `process_evaluation_source_slice` now using +local decorator runners, SDK result logging, and SDK trace loading. The backend +processor should use that same contract with backend adapters for scenario +creation, result persistence, cache reuse, status updates, trace loading, and +workflow service execution. + +## Runnable Step Executor + +The unified loop should introduce a new runnable-step execution boundary rather than directly preserving the current helper calls inside each loop. + +Current application execution is still routed through legacy helper paths such as batch LLM app invocation. Those paths have accumulated patches and are not the right long-term abstraction. Evaluator execution is also assembled separately even though it has the same core shape: prepare a runnable request, bind upstream context, invoke or reuse a trace, validate the trace, and produce a result cell. + +Proposed contract: + +```python +class WorkflowRunner: + async def execute( + self, + request: WorkflowExecutionRequest, + ) -> WorkflowExecutionResult: + ... +``` + +`WorkflowExecutionResult` should be independent of whether the runnable was an application or evaluator: + +```python +class WorkflowExecutionResult: + status: EvaluationStatus + trace_id: str | None + span_id: str | None + hash_id: str | None + error: dict | None + outputs: dict | None +``` + +Responsibilities: + +- build the service request from step references and upstream bindings +- apply cache lookup/reuse when enabled +- invoke missing work +- fetch and validate traces where required +- normalize failures into result payloads +- return enough context for downstream steps + +The first implementation can wrap existing application and workflow services. +The SDK should expose the runner protocol and shared models. The API should +provide backend workflow-service and legacy batch-invocation adapters. The SDK +should provide local decorator and remote service adapters. This makes those +wrappers replaceable so the legacy batch helpers can be deprecated without +changing the planner, tensor operations, or queue APIs. + +## Origin Semantics + +`origin` controls who can populate a step: + +| Origin | Planner behavior | Executor behavior | +|---|---|---| +| `auto` | create executable cells | invoke and populate | +| `human` | create pending cells | do not invoke | +| `custom` | create pending cells or external-awaiting cells | do not invoke | + +This gives query-backed, testset-backed, and queue-backed runs the same pending/manual semantics. Current behavior where query-backed runs lack human/custom pending branches should become a topology limitation only until the planner supports those cells. + +## Flag Model + +Use the current backend flag set as canonical. Bridge older design names only where old clients or docs still mention them. + +| Canonical flag | Legacy design name | Purpose | +|---|---|---| +| `is_live` | `is_live` | periodic windowed source resolution | +| `is_active` | `is_active` | pause/resume live processing | +| `is_cached` | `reuse_traces` in older docs | enable hash-based trace reuse | +| `is_split` | `repeat_target` in older docs | select fan-out location where meaningful | +| `repeats` | `repeats` | number of repeat slots | +| `is_closed` | `is_closed` | block structural and tensor mutations | + +Recommended compatibility mapping: + +```text +repeat_target = "application" <=> is_split = true +repeat_target = "evaluator" <=> is_split = false +reuse_traces = true <=> is_cached = true +``` + +Do not introduce `reuse_traces` or `repeat_target` as new model fields unless there is a compatibility requirement. The current code already uses `is_cached` and `is_split`. + +## Topology Validation + +The unified planner should support every valid topology explicitly and reject invalid combinations before execution. + +| Topology | Valid? | Notes | +|---|---:|---| +| query -> evaluator | yes | live or batch; evaluator fan-out only | +| query -> human/custom evaluator | yes target | creates pending cells | +| query -> application -> evaluator | potentially useful | Requires query trace to application input adapter. Do not pass query traces as application `links`; that can make application traces look like annotations rather than invocations. | +| testset -> application -> evaluator | yes | app or evaluator fan-out | +| testset -> application | yes | Batch inference / batch invocation. Application fan-out only; no evaluator execution or evaluator metrics. | +| testset -> evaluator | potentially useful | Requires evaluator testcase-only contract. | +| direct trace -> evaluator | yes | queue trace shape | +| direct testcase -> evaluator | yes | queue testcase shape | +| mixed query + testset in one queue | not planned | Keep queues single-source-family for now. | +| multiple application steps | not planned | Use separate evaluations for A/B comparison for the foreseeable future. | +| live testset | not planned | Static sources do not make sense for live periodic evaluation. | + +The key shift is that unsupported shapes should fail through planner validation, not because there is no matching handwritten function. Potentially useful shapes should be explicitly modeled when implemented; not-planned shapes should stay rejected with clear errors. + +## Repeat Semantics + +Repeats are always represented as `repeat_idx` result slots. Fan-out determines which runnable step produces multiple traces. + +Rules: + +- query/queue trace/testcase evaluator-only runs fan out at evaluator steps. +- application-only runs, also called batch inference or batch invocation, fan out at application steps. +- testset -> application -> evaluator runs use `is_split`: + - `true`: application produces one trace per repeat, evaluators consume matching repeat traces + - `false`: application produces one trace, evaluators produce one trace per repeat +- if a topology has no application/evaluator boundary, `is_split` is ignored or rejected according to validation policy. + +## Cache Semantics + +Cache reuse is explicit through `is_cached`. + +At each runnable step: + +1. Compute the expected hash from step references and upstream links. +2. Fetch all candidate traces by hash. +3. Select deterministic traces for requested repeat slots. +4. Invoke missing slots only. +5. Populate the same tensor cells whether the trace was reused or newly generated. + +Cache lookup is step-local and already exists in the current backend loops inspected: + +- application steps reuse application traces +- evaluator steps reuse evaluator traces + +Cross-run reuse is structurally supported by project-scoped trace lookup by hash. The unified planner should reuse the existing helper functions instead of reimplementing this per loop. + +The cache resolver should sit inside or immediately beside the runnable-step executor so applications and evaluators use the same reuse semantics. + +## Setup API Direction + +Consolidate specialized setup functions behind one graph-oriented creation path plus convenience wrappers. + +Canonical create request: + +```python +class EvaluationCreate: + inputs: list[InputSourceSpec] + steps: list[ExecutableStepSpec] + flags: EvaluationFlags +``` + +Convenience wrappers may remain: + +- create auto testset evaluation +- create live query evaluation +- create annotation queue from traces +- create annotation queue from testcases +- create source-aware queue from query/testset +- create Annotation Queue v2 convenience flows from traces or testsets + +Several wrappers already use `_make_evaluation_run_data()`. The next step is to make that builder and its validations explicit enough that wrappers only translate into canonical graph/source specs and do not own separate graph semantics. + +## Operation API Direction + +Expose or normalize first-class operations: + +- `add_step` +- `remove_step` +- `add_scenario` +- `remove_scenario` +- `probe(slice)` +- `populate(slice, results)` +- `prune(slice)` +- `process(slice)` +- `refresh_metrics(scope)` +- `set_flag` + +This lets setup, retry, queue assignment, manual annotation, live ticks, and SDK/local runs share the same tensor contract. + +Some CRUD operations already exist in service/router form (`create_results`, `query_results`, `delete_results`, `refresh_metrics`, run start/stop, queue creation). The missing piece is a slice-shaped operation boundary and a shared `process(slice)` planner/executor. + +## Migration Strategy + +Do not rewrite all loops at once. Introduce the unified planner and adapters beside existing loops, then move topologies one at a time while preserving current behavior. + +Recommended order: + +1. Inventory current behavior and lock it with parity tests. +2. Define shared models: source descriptor, scenario binding, tensor slice, planned cell. +3. Extract current source resolution behavior into resolver interfaces. +4. Extract current repeat/cache planning into shared planner functions. +5. Introduce a runnable-step executor that initially wraps existing invocation services. +6. Route one simple topology through the planner, likely batch query or queue traces. +7. Move pending human/custom planning into the shared planner. +8. Move batch testset after repeat, cache, and runnable-executor parity are proven. +9. Move live query once windowed source resolution and idempotency are stable. +10. Collapse API-internal worker handlers to run/slice processors. +11. Share one backend source-slice processor across live query, batch query, queue slices, batch inference, and testset application evaluation. +12. Route SDK preview/local evaluation through SDK-owned source-slice processing with SDK-specific adapters. +13. Move backend execution onto the SDK source-slice contract through backend adapters that preserve current cache/result/status behavior. +14. Treat batch inference as the application-only shape of the testset application graph. +15. Retire specialized setup/execution branches after parity tests pass, leaving compatibility wrappers around the canonical processor. + +## Success Criteria + +The design succeeds when adding a new valid combination requires: + +- adding or extending a source resolver if the source is new +- adding or extending a step executor if the runnable is new +- adding planner validation if the topology is new + +It should not require creating a new end-to-end setup function and a new end-to-end execution loop. + +## Default Queue Integration + +Unified eval loops should treat default queues as a consumer-facing layer over the tensor, not as part of orchestration. + +```text +default queue = canonical persisted human-work view over the run tensor +``` + +A default queue is open over the run by default: + +```text +scenario_ids = None +step_keys = None +user_ids = None +``` + +The runtime should continue to own: + +- source resolution +- topology validation +- planning +- auto-step execution +- tensor persistence + +The queue layer should own: + +- human-work visibility +- assignment +- queue lifecycle +- simple queue interaction + +### Run flags + +The unified model should distinguish source family from queue eligibility. + +Source-family flags: + +- `has_queries` +- `has_testsets` +- `has_traces` +- `has_testcases` + +Queue eligibility flag: + +```text +run.flags.is_queue = active default queue exists + active human evaluator work exists +``` + +That keeps query-backed, testset-backed, trace-backed, and testcase-backed runs expressible through the same planner while allowing any human-bearing run with an active default queue to participate in the simple queue surface. + +### Step lifecycle + +The mutation model should distinguish lifecycle changes from destructive cleanup. + +If the product needs historical evaluator results to remain visible, then: + +- archive/deactivate should be the normal operation for steps that stop participating in future work +- remove/prune should remain available only for explicit destructive cleanup +- planner defaults should target active steps +- default queue eligibility should depend on active human steps diff --git a/docs/designs/unified-eval-loops/research.md b/docs/designs/unified-eval-loops/research.md new file mode 100644 index 0000000000..44e834b97c --- /dev/null +++ b/docs/designs/unified-eval-loops/research.md @@ -0,0 +1,344 @@ +# Research + +## Scope + +This document consolidates the existing evaluation-loop design notes and the +current implementation state from: + +- `application/docs/designs/eval-loops` +- `application/docs/designs/loops` +- `application/docs/designs/query-eval-loops` +- `application/docs/design/annotation-queue-v2` +- `application/api/oss/src/core/evaluations/types.py` +- `application/api/oss/src/core/evaluations/utils.py` +- `application/api/oss/src/core/evaluations/service.py` +- `application/api/oss/src/core/evaluations/tasks/source_slice.py` +- `application/api/oss/src/core/evaluations/tasks/query.py` +- `application/api/oss/tests/pytest/unit/evaluations/*` + +The goal is to identify the common execution model behind the current loop families and the places where setup and execution still diverge. + +## Current Loop Families + +The runtime historically had several explicit evaluation loop families: + +| Loop family | Source unit | Input steps | Application steps | Evaluator steps | Scenario represents | +|---|---|---:|---:|---:|---| +| Live query | trace returned by query | `1..N` query | `0` | `1..N` | queried trace | +| Batch query | trace returned by query | `1..N` query | `0` | `1..N` | queried trace | +| Batch testset | testcase | `1..N` testset | `1` | `1..N` | testcase | +| Batch inference / batch invocation | testcase | `1..N` testset | `1` | `0` | testcase | +| Queue traces | trace ID | `1` synthetic source | `0` | `1..N` | provided trace | +| Queue testcases | testcase ID | `1` synthetic source | `0` | `1..N` | provided testcase | +| SDK/local | runner-defined | run-defined | run-defined | run-defined | runner-defined | + +## Related Design: Annotation Queue v2 + +`application/docs/design/annotation-queue-v2` matters because annotation queues are one of the main consumers of unified evaluation loop infrastructure. + +The durable direction from that design is: + +- keep `EvaluationRun`, `EvaluationScenario`, `EvaluationResult`, and `EvaluationQueue` as backing infrastructure +- expose a simpler annotation queue API/UI that hides backing run/scenario/result setup +- do not introduce a separate annotation task runtime unless the existing entities prove insufficient +- map assignment/repeats through `EvaluationQueue.data.user_ids` and `EvaluationResult.repeat_idx` +- support trace and testset annotation as consumer-facing queue creation flows + +Some current-state claims in that older design are stale. In current code, source-aware queue creation and human/custom pending behavior are already partially implemented. The useful takeaway is the layering principle: annotation queues are a convenience layer over evaluation entities, not a separate execution model. + +The current worker dispatch only supports a subset of possible graphs: + +- `query(1..N) -> evaluator(1..N)` +- `testset(1..N) -> application(1) -> evaluator(1..N)` +- `testset(1..N) -> application(1)` +- `queue source(1) -> evaluator(1..N)` + +Unsupported by the current simple-evaluation worker dispatch, with product priority: + +| Unsupported shape | Priority | Notes | +|---|---|---| +| multiple application steps in one worker-dispatched run | not planned | A/B comparison can remain separate evaluations for the foreseeable future. | +| query inputs followed by application steps | potentially useful | The planner must treat query traces as input data, not as invocation links for the application step. If query trace IDs are placed in application `links`, the resulting application traces may be classified as annotations rather than invocations. | +| testset inputs followed directly by evaluator steps in non-queue mode | potentially useful | Useful for evaluators that can score testcase payloads without first invoking an application. Requires an explicit evaluator input contract. | +| mixed query and testset source families in one queue | not planned | Keep queues single-source-family for now. | +| live testset evaluation | not planned | Static testsets do not make sense as periodic live sources. | + +## Shared Runtime Model + +All loop families can be described with the same conceptual entities: + +- input source descriptors +- materialized scenarios +- executable steps +- result cells +- repeat slots +- execution flags + +The intended result identity is already visible in the persistence model: + +```text +scenario_id + step_key + repeat_idx +``` + +That identity is the core tensor coordinate. A unified loop should treat every execution as filling, probing, or pruning cells in this coordinate system. + +The current code already models this directly: + +- `EvaluationResult` has `scenario_id`, `step_key`, and `repeat_idx`. +- `EvaluationResultQuery` can filter by `scenario_ids`, `step_keys`, and `repeat_idxs`. +- worker loops now create repeat-indexed result rows in the main batch, queue, and live paths. + +## Steps + +Current run data already carries step definitions with: + +- `key` +- `type` +- `origin` +- `references` +- optional input links + +The shared step types are: + +| Step type | Meaning | Typical references | +|---|---|---| +| `input` | Source materialization | query revision, testset revision, direct trace/testcase source | +| `invocation` | Application/workflow execution | application revision, variant, workflow revision | +| `annotation` | Evaluator/judge/manual annotation | evaluator revision, annotation task | + +The shared origins are: + +| Origin | Populated by | Execution behavior | +|---|---|---| +| `auto` | Backend/SDK runner | invoked by `process` | +| `human` | UI/user annotation | runner creates or leaves pending work | +| `custom` | External/programmatic actor | runner creates or leaves pending work | + +Backend types use `auto`, `human`, and `custom`. Any frontend or generated-client naming drift should be treated as compatibility debt and verified before changing. + +## Input Sources + +The current code distinguishes source descriptors from concrete execution items. + +| Source descriptor | Concrete item | Current usage | +|---|---|---| +| query revision | trace | live query, batch query | +| testset revision | testcase | batch testset, batch inference / batch invocation | +| direct trace IDs | trace | queue traces | +| direct testcase IDs | testcase | queue testcases | + +Source-aware queue creation has been partially implemented. `SimpleQueuesService.create()` can accept query/testset-backed queue sources, builds run data through `_make_evaluation_run_data()`, preserves source revision references in input steps, and dispatches concrete trace/testcase batches through `_dispatch_source_batches()`. Direct trace/testcase queue additions remain supported. + +Annotation Queue v2 frames this as a consumer-facing convenience layer: users should be able to create annotation queues from traces or testsets without manually constructing the backing evaluation run, scenarios, results, and queue. + +## Scenario Semantics + +A scenario is a concrete source item inside a run. + +Depending on the source family, a scenario may represent: + +- a trace returned by a query +- a testcase from a testset revision +- a direct trace queue item +- a direct testcase queue item + +Live query scenarios additionally need temporal metadata such as timestamp and interval. Testset-backed online evaluation is intentionally unsupported because the same static testcases would be reprocessed every interval. + +## Application And Evaluator Boundaries + +Application steps produce application traces and outputs. Evaluator steps consume either: + +- an existing source trace from query/queue trace inputs +- an application trace/output from an invocation step +- testcase payload where the evaluator supports testcase-only input + +The current loop families differ mostly in which upstream object exists before evaluator execution. + +| Shape | Evaluator input | +|---|---| +| query -> evaluator | source trace | +| queue trace -> evaluator | source trace | +| testset -> application | no evaluator; output is application trace/result | +| testset -> application -> evaluator | application trace and outputs | +| queue testcase -> evaluator | testcase item | + +This difference is real and should be modeled as planning data, not hidden in separate handwritten loops. + +## Repeats And Fan-Out + +Older docs used two naming schemes for the same underlying concern: + +| Older eval-loop name | Current code name | Meaning | +|---|---|---| +| `repeat_target = "application"` | `is_split = true` | fan out at the application step | +| `repeat_target = "evaluator"` | `is_split = false` | fan out at evaluator steps | +| `reuse_traces` | `is_cached` | enable hash-based trace reuse | + +The current backend model uses `is_cached`, `is_split`, and `repeats`. `EvaluationRunFlags` contains `is_cached` and `is_split`; `EvaluationRunData.repeats` defaults to `1`. + +The worker loops now expand repeat slots in the core paths inspected: + +- batch testset creates input, invocation, and evaluator results per `repeat_idx` +- batch inference / batch invocation creates input and invocation results per `repeat_idx` +- batch trace/testcase queue items create input/source and evaluator results per `repeat_idx` +- live query creates query and evaluator results per `repeat_idx` + +The remaining issue is not absence of repeat support. It is that repeat planning is still duplicated inside specialized loops rather than centralized in one planner. + +Fan-out validity depends on topology: + +| Topology | Valid fan-out | +|---|---| +| query -> evaluator | evaluator only | +| queue source -> evaluator | evaluator only | +| testset -> application -> evaluator | application or evaluator | +| testset -> application | application only; this is batch inference / batch invocation | + +## Trace Reuse + +Hash-based trace reuse is explicit through `is_cached`. + +Reuse flow: + +1. Compute a stable hash for the runnable node from canonical references and upstream links. +2. Fetch matching traces by hash at project scope. +3. Select deterministic reusable traces for the requested repeat slots. +4. Invoke only the missing slots. +5. Populate result cells with reused or newly produced trace IDs. + +The lookup is already plural in `fetch_traces_by_hash(...)`, and helper tests cover selection and missing-count behavior. Cache lookup is now wired into the inspected application and evaluator worker boundaries. The remaining issue is duplicated per-loop cache resolution logic. + +## Setup Fragmentation + +Current setup is not one universal flow. It is split across: + +- auto evaluation creation for app + variant + testset + evaluators +- human evaluation creation for testset + single variant + evaluators +- live evaluation setup for query-backed trace sampling +- queue creation from trace IDs or testcase IDs +- SDK/local setup +- annotation queue convenience setup from traces/testsets, backed by evaluation entities + +These setup paths build similar run-data concepts through `_make_evaluation_run_data()` in several paths, but they still apply different validation and dispatch rules. That is why new combinations still tend to require setup and execution changes in multiple places. + +## Execution Consolidation + +The SDK and backend now route concrete source items through the SDK-owned +source-slice processor. Backend task modules are source/dispatch shells: + +- `run.py` classifies a run and dispatches to the right source resolver +- `query.py` resolves live/batch query source traces +- `source_slice.py` resolves direct/testset source items and builds backend adapters + +The remaining backend-specific work lives behind adapters for scenario creation, +result persistence, metrics refresh, trace loading, cache reuse, and workflow +execution. +- application invocation +- evaluator invocation +- human/custom pending behavior +- repeat handling +- cache lookup +- metrics refresh + +This duplication is now the main remaining problem. Some formerly missing capabilities are implemented, but they are implemented repeatedly across specialized loops. + +## Runnable Execution Debt + +Unifying the loop should not mean preserving every current invocation helper as-is. + +The current application execution path is especially legacy. Batch testset and batch inference paths still rely on older application invocation helpers such as the LLM app service batch invocation path, which has been patched repeatedly over time. Evaluator execution uses workflow invocation paths with similar but not identical request assembly, links, reference handling, cache handling, trace fetch handling, and error handling. + +The repeated pattern is broader than "application vs evaluator": + +- build runnable request from step references, upstream bindings, inputs, trace, and outputs +- optionally compute hash and reuse existing traces +- invoke a runnable when cache does not satisfy the slot +- fetch/validate the resulting trace +- convert response or failure into an evaluation result cell + +That should become a shared runnable-step execution contract. Application and evaluator steps can then be two runnable kinds handled by the same boundary, rather than separate loop-local helper stacks. + +## Research Conclusion + +The product does not need one flattened source type, but it does need one loop contract. + +The common contract should be: + +```text +resolve sources -> materialize scenarios -> plan result slots -> execute auto steps -> leave human/custom slots pending -> populate tensor cells -> refresh metrics +``` + +Current code has many pieces of that contract, including flags, repeat helpers, cache helpers, source-aware queue dispatch, and pending human/custom behavior in key loops. The missing layer is a shared planner that owns these decisions once. + +Source-specific behavior should live in resolvers and planners. Step execution should be generic over: + +- scenario +- step +- repeat slot +- upstream bindings +- origin +- cache policy +- fan-out policy +- runnable invocation policy + +## Relationship To Default Queues + +The queue-unification work clarifies that annotation queues are not another execution runtime. They are a persisted human-work view over the same evaluation tensor described in this document. + +The useful separation is: + +```text +evaluation runtime + = graph + tensor + process(slice) + +default queue + = canonical persisted human-work view over that tensor +``` + +The queue dimensions align with tensor dimensions: + +- scenario selection maps to scenarios +- step selection maps to steps +- repeat assignment maps to scenario × repeat lanes + +A default queue leaves those dimensions open: + +```text +scenario_ids = None +step_keys = None +user_ids = None +``` + +The queue layer decides how human work is exposed and assigned. It does not decide how auto steps are planned or executed. + +### Source-family flags versus queue eligibility + +The current runtime still uses `is_queue` in places where it is really distinguishing queue-style source ingestion. The cleaner target model is to expose source family directly through inferred flags: + +- `has_queries` +- `has_testsets` +- `has_traces` +- `has_testcases` + +Those flags should describe where scenarios come from and should drive topology validation and mixed-source prevention. + +Separately, `run.flags.is_queue` should answer the product question: + +```text +active default queue exists +and active human evaluator work exists +``` + +That makes source classification and simple-queue eligibility separate facts rather than overloading one flag. + +### Shared mutation question + +The queue-unification work also exposes a question this design must answer explicitly: whether step removal is usually destructive or archival. + +If historical results should remain visible after an evaluator is no longer active, then the graph model needs an active-versus-archived distinction. In that world: + +- `archive_step` is the common lifecycle operation +- hard `remove_step` and `prune` are stronger cleanup operations +- queue eligibility should depend on active human steps, not merely historical human steps + +This needs to be settled before hardening the mutation contract around remove/prune behavior. diff --git a/docs/designs/unified-eval-loops/step-removal-semantics.md b/docs/designs/unified-eval-loops/step-removal-semantics.md new file mode 100644 index 0000000000..585e520f62 --- /dev/null +++ b/docs/designs/unified-eval-loops/step-removal-semantics.md @@ -0,0 +1,448 @@ +# Step Removal Semantics + +## Decision + +For now, evaluation step removal is **destructive**: + +```text +remove_step -> prune the removed step's tensor cells +``` + +Removing a step means: + +1. remove it from the active run graph +2. delete result cells for that step across scenarios and repeats +3. refresh/flush metrics that depended on that step +4. if the removed step is an input step, also remove scenarios that are sourced only from that step + +This keeps the stored graph and stored tensor aligned with the current evaluation definition. + +The alternative — archiving/deactivating steps while retaining historical cells — remains a valid future model, but it is **not** the model chosen for the current design. + +## Why This Decision Exists + +There are two coherent models for step lifecycle. + +### Model A — Destructive removal + +```text +stored graph = current active graph +stored tensor = cells for the current active graph +``` + +A removed step no longer exists in the graph, and its cells are pruned. + +### Model B — Archival lifecycle + +```text +stored graph = historical graph +active execution = projection over active steps +stored tensor = historical cells, including archived steps +``` + +An archived step remains historically present, but no longer participates in future work. + +Both models are internally coherent. The current design chooses **Model A** because it is simpler, cleaner, and matches the existing unified-loop operation model. + +## Existing Design Rationale For Remove + Prune + +The existing eval-loop documents already leaned toward destructive removal for good reasons. + +### Steps are immutable by reference + +A step points to a concrete referenced revision. Changing a reference should not mutate the step in place. + +Instead: + +```text +change evaluator revision = remove old step + add new step +``` + +That preserves step identity semantics and avoids silently rewriting what a historical step meant. + +### The graph defines tensor shape + +The design treats graph steps as tensor dimensions: + +- add a step -> add a tensor column dimension +- remove a step -> remove that tensor column's cells + +This creates a simple invariant: + +```text +current graph and current tensor have the same shape +``` + +### Remove + prune prevents stale state + +If a step disappears but its result cells remain: + +- cells exist for steps no longer in the graph +- metrics may still refer to retired steps +- UI needs to distinguish active from historical columns +- planner and topology logic need lifecycle-aware filtering + +Pruning avoids all of that in the default path. + +### The mutation model stays symmetric + +The lower-level operation model remains clean: + +```text +graph: add_step / remove_step +tensor: populate / prune +``` + +That symmetry is useful for reasoning, implementation, and testing. + +## Why Archival Was Considered + +Archival has one major product advantage: + +> it preserves auditability. + +If a human evaluator or automatic evaluator is no longer active, retaining the old step and its cells would preserve: + +- who evaluated what +- what outputs existed before the step was retired +- historical metric context +- a full explanation of past evaluation state + +That is especially attractive if evaluations are treated as long-lived collaborative records rather than disposable execution definitions. + +## Cost Of Destructive Removal + +The chosen model deliberately gives up some history. + +When a step is removed: + +- its result cells are deleted +- metrics derived from it disappear from the active run +- prior human work for that step is no longer represented in the run tensor +- the run no longer explains that the step ever existed + +If auditability becomes a product requirement later, destructive removal will not satisfy it by itself. + +## Cost Of Archival + +Archival avoids data loss, but it has broad implications across every layer of the system. + +The rest of this document records those implications so the tradeoff remains explicit. + +# Archival Implications + +## 1. Model implications + +Archival requires step lifecycle state, for example: + +```python +archived_at: datetime | None +archived_by_id: UUID | None +``` + +A run would then contain two conceptual graphs: + +```text +historical graph = all steps ever attached to the run +active graph = historical graph minus archived steps +``` + +Any presence-style flags would need explicit semantics: + +- `has_evaluators` +- `has_human` +- `has_auto` +- `has_custom` + +For most product behavior, they would likely need to mean **active presence**, not historical presence. + +If historical presence also matters, that would require separate query behavior or additional flags. + +## 2. Data implications + +Archived steps retain their tensor cells: + +```text +scenario_id + step_key + repeat_idx +``` + +That preserves history, but results now divide into: + +- active-step results +- archived-step results + +Queries and APIs would need to decide whether they default to: + +- active-only results +- all historical results +- or support explicit `include_archived_steps` + +If archived steps remain embedded in JSON run data, active/historical filtering is service-derived and less relationally natural. If steps become first-class rows, lifecycle handling becomes cleaner but requires a larger schema refactor. + +Archival also increases retained data volume over time because old cells remain instead of being pruned. + +## 3. Metrics implications + +Archival makes metric meaning more complex. + +At minimum, the system would need to distinguish: + +### Active metrics + +Metrics over the current active graph, used for: + +- current dashboards +- current summary views +- present-tense evaluation interpretation + +### Historical metrics + +Metrics including archived steps, used for: + +- audit +- history +- lineage + +Without that distinction, archived evaluators would continue to affect current dashboards. + +Metric refresh would need to know whether it is computing over active steps only or over historical steps as well. Run mappings may also need lifecycle awareness so archived step mappings do not keep contributing to current aggregates. + +## 4. Compute and planner implications + +The planner would need to operate on **active steps only** by default. + +Every execution path would need a shared helper such as: + +```python +active_steps(run) +has_active_human_steps(run) +``` + +If archived steps remained in `run.data.steps`, raw iteration over `run.data.steps` would become unsafe. + +`process(slice)` would need explicit semantics: + +- `steps="all"` likely means all **active** steps +- archived steps require explicit inclusion for any historical replay or audit operation + +Planner complexity would remain manageable if active filtering is centralized, but every planner, topology classifier, queue reconciler, and flag refresher would need to use the same lifecycle-aware projection. + +## 5. Queue implications + +Default queue eligibility would need to depend on **active** human steps: + +```text +active default queue exists +and active human evaluator work exists +``` + +For default queues: + +```text +step_keys=None +``` + +would need to mean all **active** queue-relevant steps, not all historical steps. + +Custom queues that explicitly reference later-archived steps would need a policy, such as: + +- retain the queue row +- stop generating active work for archived steps +- surface that the queue references inactive steps +- perhaps mark the queue degraded/inactive if all included steps are archived + +## 6. API implications + +Archival would require new lifecycle operations: + +- `archive_step` +- `unarchive_step` + +or equivalent run-mutation semantics. + +Any response that exposes step definitions would need archival metadata so clients can distinguish active from historical steps. + +The API would also need explicit lifecycle-aware query semantics, likely including some form of: + +- active-only default behavior +- optional archived inclusion for audit/history views + +Backward compatibility becomes non-trivial because older clients may assume every returned step is active. + +## 7. UI implications + +The UI would need an explicit active-versus-archived presentation model. + +Likely implications: + +- active steps shown normally +- archived steps grouped under a collapsed historical section +- current results tables show active columns by default +- archived result columns appear only in audit/history contexts or behind an explicit toggle +- current metric charts exclude archived steps by default +- historical metric views expose archived-step data intentionally +- queue screens show only active human work +- archived human work remains visible in evaluation history but not as new actionable queue work + +Action labels would also need to change: + +- ordinary user action: `Archive evaluator` +- stronger destructive action: `Delete step and results` + +Without UI support for archived state, archival would preserve data technically but create user confusion. + +## 8. Controller and service implications + +Archival requires centralized lifecycle orchestration. + +A step archive/unarchive transition would need to coordinate: + +- active graph projection +- run flag recomputation +- queue reconciliation +- metric refresh +- possible custom-queue invalidation/degradation + +Those changes should not be scattered across ad hoc call sites. They require one authoritative lifecycle path. + +## 9. Conceptual implication + +Archival changes the core invariant from: + +```text +stored tensor = current graph +``` + +to: + +```text +stored tensor = historical graph +active execution = projection over active graph +``` + +That is a richer but more expensive model. + +# Destructive Removal Implications + +## 1. Model implications + +No additional step lifecycle state is required. + +The run graph remains: + +```text +run.data.steps = active graph +``` + +Presence flags continue to reflect the graph directly. + +## 2. Data implications + +Removed step cells are deleted. + +This avoids: + +- stale cells +- historical-vs-active result interpretation +- extra retained data volume for removed steps + +But it sacrifices historical traceability inside the run. + +## 3. Metrics implications + +Metric handling stays simple: + +- prune step cells +- refresh/flush dependent metrics +- current metrics remain aligned with the current graph + +No separate active/historical metric families are required. + +## 4. Compute and planner implications + +Planner logic remains simpler: + +- every step in the graph is active +- `steps="all"` means literally every stored step +- topology validation does not need step lifecycle filtering + +## 5. Queue implications + +Queue eligibility can be computed from the current graph without active/historical distinction. + +A removed human step no longer contributes to queue eligibility because it no longer exists. + +## 6. API implications + +Only destructive graph operations are needed: + +- `add_step` +- `remove_step` + +No step archive/unarchive surface is required. + +## 7. UI implications + +The UI stays much simpler: + +- no archived-step sections +- no archived-result toggles +- no historical metric mode +- “remove” means the thing is gone + +The downside is that users cannot inspect retired step history through the run afterward. + +## 8. Controller implications + +Mutation side effects remain narrow: + +- remove step +- prune cells +- refresh metrics +- if needed, reconcile queue flags from the new active graph + +No long-lived archival state needs to remain synchronized. + +# Comparison + +| Concern | Destructive remove + prune | Archive/deactivate | +|---|---|---| +| Auditability | weak | strong | +| Current-state simplicity | strong | weaker | +| Storage growth | lower | higher | +| Planner complexity | lower | higher | +| Metric semantics | simple | active vs historical required | +| UI complexity | lower | higher | +| Queue semantics | simpler | must ignore archived steps | +| API lifecycle surface | smaller | larger | +| Graph/tensor invariant | identical current graph/tensor | historical storage + active projection | + +# Current Choice + +The current unified-eval-loop design chooses: + +```text +remove + prune +``` + +as the normal behavior. + +This is intentionally destructive, and the tradeoff is accepted for now because it provides: + +- a clean graph/tensor invariant +- simpler planning and topology logic +- simpler metrics +- simpler UI/API behavior +- direct alignment with the existing operation model + +If auditability becomes a product requirement later, the design should be revisited explicitly rather than approximated halfway. A future archival model would need full support across: + +- step lifecycle metadata +- active/historical result semantics +- metric semantics +- queue eligibility +- APIs +- UI +- planner defaults + +Until then, retaining removed-step cells without modeling archival everywhere is not acceptable because it would introduce ambiguity without delivering coherent auditability. diff --git a/docs/designs/unify-evals-and-queues/gap.md b/docs/designs/unify-evals-and-queues/gap.md new file mode 100644 index 0000000000..81b0efa296 --- /dev/null +++ b/docs/designs/unify-evals-and-queues/gap.md @@ -0,0 +1,92 @@ +# Gap Analysis + +## Queue Semantics + +Missing from current state: + +- no explicit meaning that `step_keys=None` is the open/default step scope +- current auto-created human queues snapshot human step keys instead of leaving step scope open +- no canonical default-queue marker distinct from arbitrary custom queues + +Already present: + +- `scenario_ids=None` already leaves scenario scope open over the run +- `user_ids=None` already means unassigned +- repeats are already run-owned rather than queue-owned + +## Default Queue Lifecycle + +Missing from current state: + +- no default-queue reconciliation tied to run creation/editing +- current helper is path-dependent and only reached from selected execution flows +- no two-policy model separating: + - human-step structural condition + - unconditional default-queue global setting +- no logic to archive/unarchive the default queue as human evaluator availability changes + +## Queue Archival + +Missing from current state: + +- no queue archive endpoint +- no queue unarchive endpoint +- no queue service/DAO archive lifecycle path +- no `include_archived` support on queue query/fetch surfaces +- default-queue lookup cannot currently search archived queues for restoration + +Present but underused: + +- queue DTOs already inherit lifecycle fields such as `deleted_at` and `deleted_by_id` + +## Queue Identity + +Missing from current state: + +- no reliable way to distinguish the canonical default queue from a custom queue with the same open shape +- current ensure logic stops if any queue exists for the run, which is insufficient once default and custom queues coexist + +## Run Semantics + +Needs clarification or adjustment: + +- `is_queue` currently distinguishes simple queue-created runs from simple evaluations +- linked default queues should not require ordinary evaluation runs to become queue-ingest runs +- the old meaning of `is_queue` must be replaced by persisted simple-queue eligibility + +## Configuration + +Missing from current state: + +- no global policy toggle for unconditional default queues +- no shared policy helper for deciding default-queue lifecycle mode + +## Tests + +Missing from current state: + +- open default queue behavior with `step_keys=None` +- default queue creation for simple evaluations under unconditional mode +- conditional creation when human evaluator steps exist +- no creation / archive when conditional mode has no active human evaluator steps +- unarchive of an existing archived default queue instead of duplicate creation +- coexistence of default and custom queues +- archived-inclusive queue query behavior +- regression tests for existing queue assignment and scenario selection behavior + +## API Surface + +Missing from current state: + +- archive/unarchive queue endpoints +- `include_archived` request/query support for queues +- response behavior that lets callers distinguish active from archived queues where relevant + +## UI Surface + +Not covered by this backend design: + +- whether auto-only evaluations show an empty queue +- whether users are nudged to add human evaluators +- how the default queue appears inside evaluation details versus Queues +- any migration of frontend terminology from “human evaluation” to “evaluation with human evaluators” diff --git a/docs/designs/unify-evals-and-queues/plan.md b/docs/designs/unify-evals-and-queues/plan.md new file mode 100644 index 0000000000..2f62a34e2c --- /dev/null +++ b/docs/designs/unify-evals-and-queues/plan.md @@ -0,0 +1,16 @@ +# Plan + +1. Define the canonical default queue shape with open filters: `scenario_ids=None`, `step_keys=None`, `user_ids=None`, and no default batching restrictions. +2. Add a durable default-queue identifier, preferably an explicit queue flag or role that distinguishes default queues from custom queues independently of shape. +3. Add queue archival support across DTOs, service methods, DAO methods, and API endpoints, including archive and unarchive operations. +4. Extend queue query/fetch paths with `include_archived` support and ensure archived default queues can be found during reconciliation. +5. Add global policy toggle for unconditional default queues, e.g. `EVALUATIONS_DEFAULT_QUEUES_FOR_ALL_RUNS`, as a module-level global. +6. Implement shared policy helpers for: + - whether a run has active human evaluator steps + - whether default queues are unconditional for all runs +7. Replace the current path-specific human-queue helper with a default-queue reconciliation operation that can create, unarchive, no-op, or archive according to the two-policy model. +8. Invoke default-queue reconciliation from simple evaluation run creation and run-editing flows so queue lifecycle follows evaluation lifecycle rather than dispatch timing. +9. Use source-family flags for ingestion semantics; persist `is_queue` as active default queue + active human evaluator work. +10. Update simple queue/default queue creation paths so the default queue leaves `step_keys` open instead of snapshotting human step keys. +11. Add backend tests for unconditional mode, conditional mode, archive/unarchive behavior, coexistence with custom queues, open step scope, and existing queue regressions. +12. Update design/API documentation to describe the default queue model, the two policies, queue archival semantics, and the frontend decisions intentionally left outside this backend work. diff --git a/docs/designs/unify-evals-and-queues/proposal.md b/docs/designs/unify-evals-and-queues/proposal.md new file mode 100644 index 0000000000..63666386a1 --- /dev/null +++ b/docs/designs/unify-evals-and-queues/proposal.md @@ -0,0 +1,144 @@ +# Proposal + +## Goal + +Unify human evaluation and annotation queues at the backend model level by making the queue a default companion of evaluation runs rather than a separately created product concept. + +This proposal covers API and service semantics only. Frontend behavior, copy, and product nudges can vary later without requiring a different backend model. + +## Proposed Model + +Keep the current evaluation substrate: + +- runs define evaluation structure and repeats +- scenarios are concrete work items +- results are step × repeat outputs +- queues overlay a run to expose and distribute human work + +Add one canonical **default queue** concept for evaluation runs. + +A default queue has: + +- `scenario_ids=None` +- `step_keys=None` +- `user_ids=None` +- no queue-specific batching restriction + +Those open fields mean: + +- all scenarios in the run are eligible +- all queue-relevant steps are eligible +- no users are assigned by default +- run repeats remain fully covered because repeats belong to the run + +## Queue Axes + +The queue has three independent axes: + +| Axis | Governs | +|---|---| +| scenario selection | which scenarios belong to the queue | +| repeat assignment | which scenario × repeat lanes a user receives | +| step selection | which steps must be completed for each assigned scenario × repeat | + +Default queues leave all three axes open except for the run boundary itself. + +## Default Queue Policies + +### Structural policy + +The structural condition is simple: + +```text +has_human_evaluator_steps(run) +``` + +This says whether a run warrants a default queue when queues are conditional. + +### Global lifecycle policy + +A configuration value controls whether default queues exist for all runs regardless of human steps, for example: + +```text +EVALUATIONS_DEFAULT_QUEUES_FOR_ALL_RUNS +``` + +When enabled: + +- every run gets a default queue at creation +- the default queue is never archived merely because no active human evaluator steps remain + +When disabled: + +- a default queue exists only while active human evaluator steps exist +- adding/restoring human evaluator work creates or unarchives the default queue +- removing/archiving the last active human evaluator archives the default queue + +## Default Queue Lifecycle + +Default queue reconciliation should use durable identity: + +- missing + required -> create +- archived + required -> unarchive +- active + required -> no-op +- active + not required -> archive + +Default queues should not be hard-deleted as part of normal reconciliation. + +## Queue Lifecycle Support + +Queues should gain product-level soft-delete support: + +- archive endpoint/service/DAO path +- unarchive endpoint/service/DAO path +- `include_archived` query support +- archived-inclusive lookup for default-queue reconciliation + +Hard delete may remain available for existing low-level semantics, but default-queue lifecycle should use archive/unarchive. + +## Canonical Queue Identity + +The system needs a reliable way to identify the default queue independently of shape. A custom queue may coincidentally have no scenario filter, no step filter, and no assignments. + +The proposal requires one of: + +- an explicit queue role/flag such as `is_default` +- or another canonical linkage that uniquely identifies the default queue for a run + +An explicit marker is the clearer fit. + +## Service Placement + +Default-queue reconciliation belongs with run creation and run mutation, not only dispatch flows. + +The current `_ensure_human_annotation_queue(...)` seam should evolve into a more general lifecycle operation such as: + +```text +reconcile_default_queue(run) +``` + +It should evaluate the global lifecycle policy and, when needed, the structural human-step policy. + +## Compatibility + +This proposal preserves: + +- existing evaluation-run primitives +- existing custom queue behavior +- existing queue-backed execution paths +- hard-delete support where still needed + +It changes the default composition: + +- default queue existence becomes managed by run lifecycle +- open `step_keys` become a supported queue shape rather than a snapshot omission +- simple evaluations can participate in the same queue model as simple queues + +## Product Boundary + +The backend supports both product postures: + +- default queues for every evaluation, including auto-only evaluations +- default queues only when human evaluator work exists + +The frontend can later decide whether to expose empty queues, nudge users toward adding human evaluators, or hide queues until human work appears. The API does not need to change again for that choice. diff --git a/docs/designs/unify-evals-and-queues/research.md b/docs/designs/unify-evals-and-queues/research.md new file mode 100644 index 0000000000..fcfbc2cf36 --- /dev/null +++ b/docs/designs/unify-evals-and-queues/research.md @@ -0,0 +1,500 @@ +# Research: Unifying Evaluations and Queues + +## Scope + +This note maps the evaluation/queue model that exists today and answers one concrete exploration question: + +> What would it mean, in the current architecture, for a regular evaluation to always have a default linked queue that behaves like a simple queue when human evaluators are present? + +The focus is backend behavior in: + +- `api/oss/src/core/evaluations/*` +- `api/oss/src/apis/fastapi/evaluations/*` +- `api/oss/src/dbs/postgres/evaluations/*` +- the neighboring annotations layer where it clarifies the boundary + +## Executive Summary + +The system already has a single low-level evaluation substrate: + +- an **evaluation run** defines the workflow graph and repeats +- **scenarios** are concrete work items within a run +- **results** are per-step, per-repeat outputs for a scenario +- **metrics** summarize results +- an **evaluation queue** is an overlay over a run that selects which scenarios and annotation steps are visible to which users + +The split users see today is mostly created by wrapper layers: + +- `SimpleEvaluationsService` wraps runs without queue-centric defaults +- `SimpleQueuesService` wraps runs plus a queue with queue-centric defaults + +That means the proposed product direction is structurally plausible: it does **not** require inventing a new primitive. It mostly requires deciding what the canonical/default queue attached to a run means and when it should be created or updated. + +The codebase is also already partway toward that direction: + +- `EvaluationsService._ensure_human_annotation_queue(...)` creates a queue for a run with human annotation steps when none exists. +- Human-bearing live runs call this during refresh. +- Queue-backed batch dispatch paths call it before processing traces/testcases. + +However, that helper currently creates a **narrow, snapshot-style** queue: + +- only when human steps exist +- only if the run has no queue at all +- with `step_keys` captured from the run at that moment +- with no assignments +- with no explicit scenario restriction + +It is a useful seam, but not yet the full default-queue model. + +The sharper target model is simpler than the current helper: + +- `scenario_ids=None` means all scenarios in the run +- `step_keys=None` means all steps included by the queue policy +- `user_ids=None` means unassigned +- repeats remain owned by the run, while assignments distribute scenario × repeat work + +## Current Domain Model + +### 1. Evaluation runs are the canonical execution object + +`EvaluationRun` stores the durable definition of an evaluation: + +- `data.steps`: input, invocation, and annotation steps +- `data.repeats`: repeat count for the run +- `data.mappings`: metric/result extraction mappings +- flags such as `is_live`, `is_queue`, `has_human`, `has_auto`, `has_testsets`, `has_queries` + +A run is therefore already capable of representing: + +- automatic-only evaluations +- human-only evaluations +- mixed human + automatic evaluations +- queue-backed and non-queue-backed flows + +The evaluator origin is not a separate resource type. It is encoded on annotation steps as `origin in {custom, human, auto}`. + +### 2. Queues are overlays over runs, not separate executions + +`EvaluationQueue` points to a `run_id` and stores queue-specific selection/distribution state in `EvaluationQueueData`: + +- `user_ids: List[List[UUID]] | None` +- `scenario_ids: List[UUID] | None` +- `step_keys: List[str] | None` +- optional batching controls + +The queue does **not** own scenarios or results. It derives visible scenarios from the underlying run and optionally filters them. + +This is important for the proposed default queue: + +- if `scenario_ids is None`, the queue automatically covers **all current scenarios in the run** +- if `user_ids is None`, the queue is effectively **unassigned** +- if the queue has no user filter, the scenario query path returns the run scenarios directly + +So the desired “default queue that follows future scenarios” already matches existing semantics **if** we leave `scenario_ids=None`. + +### 3. Scenario assignment is derived, not persisted per scenario + +Assignment behavior is computed from queue data at read time: + +- no `user_ids` -> everyone sees the run’s scenarios +- with `user_ids` -> `filter_scenario_ids(...)` deterministically partitions scenarios per repeat/user lane +- sequential vs randomized distribution is controlled by queue flags/settings + +That means the queue model already supports: + +- no assignees +- assignees per repeat lane +- repeated review lanes +- deterministic re-computation as scenarios are added later + +The subtle point is that **repeats live on the run**, while **assignment lanes live on the queue**. `SimpleQueuesService.create(...)` currently reconciles the two by setting run repeats to at least the number of assignment lanes. + +## Current Public/Service Surfaces + +### `SimpleEvaluationsService` + +This is a convenience wrapper over runs. It builds run steps from query/testset/application/evaluator revision IDs and exposes CRUD/lifecycle operations as “simple evaluations.” + +Notably: + +- evaluator inputs can be lists or explicit origin maps +- run flags are inferred from step origins +- it is run-first, not queue-first + +### `SimpleQueuesService` + +This is a different convenience wrapper over the same substrate. It: + +1. builds or reuses run data +2. creates a run with `is_queue=True` +3. creates one linked `EvaluationQueue` +4. stores queue-specific behavior such as assignments, step keys, and batching + +It is effectively a preset constructor for “evaluation run + annotation queue.” + +### Low-level evaluation endpoints + +The main evaluations API exposes separate resources for: + +- runs +- scenarios +- results +- metrics +- queues + +This exposes the true underlying shape more directly than either simple wrapper. + +### Annotations + +The annotations module is adjacent but distinct. It creates/edit annotations as trace-linked artifacts and may provision evaluators, but it is not the queue abstraction itself. The queue system is still implemented in evaluations. + +## How Simple Queues Work Today + +A simple queue is not a separate backend domain. It is a prescribed composition: + +1. Create an evaluation run whose input is either: + - direct traces/testcases, or + - source-backed queries/testsets +2. Add evaluator annotation steps. +3. Create one queue against that run. +4. Store only the annotation `step_keys` in the queue. +5. Optionally store assignments and batching settings. + +The queue then queries scenarios by: + +- starting from scenarios belonging to the run +- optionally applying `queue.data.scenario_ids` +- optionally applying user/repeat distribution + +That is why a queue with: + +- `scenario_ids=None` +- `step_keys=None` +- `user_ids=None` + +is the natural shape of the default queue. + +These are three independent axes: + +- scenario selection decides which scenarios are in the queue +- repeat assignment decides which scenario × repeat lanes a user gets +- step selection decides which steps must be completed for each assigned scenario × repeat + +`step_keys` do not participate in scenario or repeat selection, and they do not need to. Leaving them open is still the correct queue-level analogue of leaving `scenario_ids` open. + +## The Existing Proto-Unification Seam + +`EvaluationsService._ensure_human_annotation_queue(...)` currently does this: + +1. inspect run steps +2. collect human annotation step keys +3. if there are no human steps, do nothing +4. if any queue already exists for the run, do nothing +5. otherwise create an `EvaluationQueue` with: + - `run_id=run.id` + - `status=RUNNING` + - `step_keys=` + - no assignments + - no explicit scenario IDs + +This already gives the queue open scenario coverage and no default assignments, but it freezes step membership instead of leaving the queue open over the run’s steps. + +Today it is invoked from: + +- live run refresh before dispatch +- queue-backed trace/testcase batch evaluation dispatch + +That tells us two things: + +1. The architecture already treats queues as a natural companion to human annotation work. +2. The current behavior is still opportunistic and path-dependent, not a universal invariant of evaluation creation/editing. + +## What the Desired Default Queue Maps To in Current Terms + +| Desired behavior | Current primitive that already supports it | +|---|---| +| Queue linked to an evaluation | `EvaluationQueue.run_id` | +| No scenario selection; include all current/future scenarios | `queue.data.scenario_ids = None` | +| No assigned users by default | `queue.data.user_ids = None` | +| Cover all repeats | run-level `data.repeats`; queue assignment lanes can be absent | +| Step scope is not frozen | `queue.data.step_keys = None` | +| New scenarios added later become visible | queue scenario lookup derives from `run_id`, not a frozen list | + +So the cleanest first interpretation of a **default queue** is: + +```text +one canonical queue per evaluation run, +with no scenario restriction, +no step-key restriction, +and no assignees. +``` + +## Default Queue Policy + +The target model has two separate policies. + +### Structural policy + +This is the run-level condition: + +```text +has_human_evaluator_steps(run) +``` + +When default queues are conditional, this decides whether a run should currently have one. + +### Global lifecycle policy + +This is a configuration choice, for example: + +```text +EVALUATIONS_DEFAULT_QUEUES_FOR_ALL_RUNS +``` + +When enabled: + +- create a default queue for every run +- never archive it merely because the run has no active human evaluators + +When disabled: + +- create or unarchive the default queue when the run has human evaluator steps +- archive the default queue when the run has no active human evaluator steps + +These policies are related but not interchangeable. The global setting defines whether default queues are unconditional. The structural rule only governs lifecycle when default queues are conditional. + +## Default Queue Lifecycle + +The desired queue identity is durable: + +- if the default queue does not exist and policy requires one, create it +- if it exists and is archived, unarchive it +- if it exists and is active, leave it alone +- if policy no longer requires it, archive it rather than hard-delete it + +This fits the broader evaluator model if evaluators are archived rather than removed. A queue can disappear from normal views while retaining identity and later return if human evaluator work becomes active again. + +The current queue API is not yet aligned with that lifecycle: + +- queues have lifecycle fields +- queue endpoints currently expose hard deletion +- queue queries do not yet expose `include_archived` +- queue lookup does not currently distinguish active from archived queues + +If default queues become durable linked objects, queue archive/unarchive operations and archived-aware lookup become part of the needed foundation. + +## Remaining Model Gaps + +### 1. `step_keys` are currently stored as a snapshot + +The current helper captures the human step keys that exist at creation time. The default queue should instead leave step scope open with `step_keys=None`, so later step changes do not require queue rewrites. + +### 2. There is no first-class distinction between default and custom queues + +Currently, a queue is just a queue. `_ensure_human_annotation_queue(...)` only checks whether **any** queue exists for the run. + +That creates ambiguity: + +- if a custom filtered queue exists, should it suppress creation of the evaluation’s default queue? +- if multiple queues exist, which one is the queue shown inside the evaluation? +- which archived queue should be restored when the default-queue invariant becomes true again? + +A stable default-queue marker or equivalent canonical linkage is needed once default queues and custom queues can coexist. + +### 3. `is_queue` still encodes a product distinction on the run + +Simple queues create runs with `is_queue=True`; simple evaluations generally do not. This flag is used for querying and queue-specific dispatch guards. + +If ordinary evaluations can have linked queues, `is_queue` should remain a technical execution flag rather than the signal that a run has a queue companion. + +### 4. Queue lifecycle is currently path-dependent + +Today automatic queue creation is reached from execution paths, not from the run mutation paths that define whether human work exists. + +Default-queue reconciliation belongs next to run creation/editing so it can enforce either: + +- unconditional queue existence, or +- conditional existence based on active human evaluator steps. + +## Multiple Human Evaluators, Assignments, and Repeats + +### Multiple human evaluators + +The run model supports many human evaluator steps already. A queue can target multiple annotation steps through `step_keys`. + +So there is no fundamental blocker to one default queue covering multiple human evaluators. The real question is product semantics: + +- should one task card represent one scenario with multiple human fields? +- or one scenario × evaluator step as separate queue work? + +The current queue primitive points at multiple step keys but scenario listing is scenario-oriented, not step-oriented. That suggests today’s model is closer to “one scenario can carry several annotation steps” than “each evaluator creates a separate queue item.” + +### Assignments + +The queue model already supports repeat-lane assignments: + +```text +user_ids = [[repeat_0 users], [repeat_1 users], ...] +``` + +A default queue with `user_ids=None` naturally means “unassigned.” + +What still needs product clarification is how assignment should behave when: + +- the evaluation repeat count increases later +- a human evaluator is added after assignments exist +- different human evaluators should have different assignee pools + +The current queue data model has one assignment matrix for the whole queue, not per evaluator step. If evaluators need separate assignment rules, one shared default queue may be insufficient or the model must evolve. + +### Repeats + +Repeats are owned by `EvaluationRunData.repeats`, not by the queue. Simple queue creation enforces: + +```text +run.repeats >= number of assignment lanes +``` + +That is compatible with a default queue that “covers all existing repeats,” provided the queue derives from the run rather than freezing repeat-specific scope. + +The open design question is what happens if repeats are later edited downward/upward after a queue already has assignments. The current storage model permits temporary mismatch. + +## Likely Design Direction + +### Recommended conceptual model + +Treat the queue as a **view/controller over human annotation work for an evaluation run**, not as a sibling product object that users must create manually. + +A practical backend direction would be: + +1. Every evaluation run may have one **canonical/default queue**. +2. The default queue is created automatically when the run first contains human annotation steps. +3. The default queue has: + - no scenario filter + - no assignments + - derived human-step membership, or managed synchronization +4. Additional custom queues may still exist for advanced filtered/assigned workflows. +5. The evaluation API should expose the canonical queue link directly so the UI can render the same queue both inside the evaluation and on the Queues surface. + +### Why this fits the current code well + +It reuses what already exists: + +- run/scenario/result storage +- evaluation queue storage +- scenario derivation by `run_id` +- assignment logic by repeat lane +- the already-present `_ensure_human_annotation_queue(...)` seam + +The largest required addition is not storage volume; it is **semantics**: + +- how to mark the canonical queue +- how default queue step membership stays current +- where invariant enforcement lives + +## Concrete Gaps to Resolve Before Implementation + +### Product/behavior questions + +1. Is the default queue only for **human** steps, or for all annotation steps? + - The current helper chooses human-only. + - The product statement sounds human-focused. + +2. With multiple human evaluators, is assignment shared across them or evaluator-specific? + - Current queue data supports shared assignment only. + +3. When a run already has a custom queue, should the default queue still exist? + - Current helper says no because it stops if *any* queue exists. + +4. If human evaluators are removed, should the default queue remain, archive, or disappear? + +5. Should all evaluation runs have a default queue immediately, or only runs with human work? + - The latter better matches current semantics and avoids empty queues for automatic-only runs. + +### Technical questions + +1. Should default queue membership be derived with `step_keys=None`, or synchronized explicitly? +2. How is “default queue” represented? +3. Should queue creation/update happen in service-layer run mutation methods rather than dispatch flows? +4. Does `is_queue` remain meaningful once ordinary evaluations can have linked queues? +5. What migration/backfill is needed for existing runs with human steps but no queue? + +## Candidate Implementation Shapes + +### Option A — Minimal evolution + +Keep explicit `step_keys`, add a default queue marker, and update `_ensure_human_annotation_queue(...)` plus run-edit paths to keep the default queue synced. + +**Pros** + +- smallest delta from current code +- easiest to reason about with existing queue reads +- preserves explicit custom queue semantics + +**Cons** + +- sync bugs remain possible +- every run edit touching human steps must remember to update the queue + +### Option B — Derived default queues + +Define `step_keys=None` on a default queue as “all current human annotation steps for this run.” Custom queues continue to use explicit step keys. + +**Pros** + +- best match for “follows the evaluation as it changes” +- new human evaluators appear automatically +- fewer synchronization paths + +**Cons** + +- requires read-time logic to distinguish default/derived queues from unconstrained queues +- needs crisp semantics for historical behavior and custom queues + +### Option C — No persisted default queue; virtualize it + +Do not persist a canonical queue. Derive one virtually from the run whenever human steps exist. + +**Pros** + +- no sync issue +- cleanest conceptual model + +**Cons** + +- weaker fit with “visible in Queues” unless the Queues API/UI also supports virtual resources +- assignments and future edits become awkward because there is no row to mutate + +**Current recommendation:** Option B looks like the strongest long-term fit, with Option A as the lower-risk incremental step if we want a small migration first. + +## Key References + +- `api/oss/src/core/evaluations/types.py` + - `EvaluationRun*` + - `EvaluationQueue*` + - `SimpleEvaluation*` + - `SimpleQueue*` +- `api/oss/src/core/evaluations/service.py` + - `EvaluationsService._ensure_human_annotation_queue(...)` + - `EvaluationsService.fetch_queue_scenarios(...)` + - `SimpleEvaluationsService` + - `SimpleQueuesService` +- `api/oss/src/core/evaluations/utils.py` + - `filter_scenario_ids(...)` +- `api/oss/src/dbs/postgres/evaluations/dao.py` + - queue CRUD and user filtering +- `api/oss/src/apis/fastapi/evaluations/router.py` + - low-level evaluation endpoints + - simple evaluation endpoints + - simple queue endpoints +- `api/oss/src/core/annotations/service.py` + - neighboring annotation abstraction, distinct from queueing + +## Bottom Line + +The backend already has the right primitive split for unification: + +- **evaluation run** = what is being evaluated +- **queue** = how human annotation work over that run is exposed/distributed + +The exploration should therefore avoid introducing a new “human evaluation” abstraction. The more promising path is to make a linked default queue a first-class, automatically maintained aspect of evaluation runs that contain human steps, while keeping custom queues as an advanced overlay when users need narrower assignment or filtering behavior. diff --git a/docs/designs/unify-evals-and-queues/unify-evals-extension-synthesis.md b/docs/designs/unify-evals-and-queues/unify-evals-extension-synthesis.md new file mode 100644 index 0000000000..832b12eae5 --- /dev/null +++ b/docs/designs/unify-evals-and-queues/unify-evals-extension-synthesis.md @@ -0,0 +1,263 @@ +# Unify Evals Extension Synthesis + +## Purpose + +This note captures the refined model that emerged after relating the queue-unification work to the parallel eval-loop unification work. + +The central clarification is that several concepts currently overloaded into “queue” should be separated: + +- source family +- default queue identity +- simple-queue eligibility +- queue lifecycle + +## Final Vocabulary + +### Source-family flags on runs + +Runs should expose distinct inferred flags for each source family: + +- `has_queries` +- `has_testsets` +- `has_traces` +- `has_testcases` + +These flags answer where scenarios come from and should drive: + +- validation +- topology classification +- source-family filtering +- mixed-input prevention + +They should replace the current tendency to infer direct trace/testcase behavior indirectly through `is_queue` plus synthetic step-key inspection. + +### `run.flags.is_queue` + +`is_queue` should become the persisted derived flag that answers: + +> Can this evaluation currently be interacted with through the simple annotation queue surface? + +The intended condition is: + +```text +active default queue exists +and active human evaluator work exists +``` + +This aligns the name with the product meaning and makes the flag directly useful for querying. + +It should be maintained eagerly, like the other persisted run flags, whenever: + +- the default queue is created +- the default queue is archived or unarchived +- active human evaluator work appears or disappears + +### `queue.flags.is_default` + +Queues need an explicit canonical-default marker: + +```text +queue.flags.is_default = true +``` + +Shape alone is not enough to identify the canonical queue, because a custom queue may coincidentally have the same open filters. + +## Default Queue Model + +A default queue is the canonical persisted queue view for a run. + +Its invariant shape is: + +```text +scenario_ids = None +step_keys = None +user_ids = None +``` + +and no queue-specific batching constraints. + +Interpretation: + +- no scenario filter -> all run scenarios +- no step filter -> all included steps +- no user assignments -> unassigned +- repeat coverage remains run-owned + +### Uniqueness + +There must be at most one default queue per run, including archived queues. + +Because queue flags are persisted JSONB, this can be enforced with a partial unique index over the materialized JSONB flag: + +```sql +CREATE UNIQUE INDEX ux_evaluation_queues_default_per_run +ON evaluation_queues (project_id, run_id) +WHERE (flags ->> 'is_default')::boolean = true; +``` + +An archived default queue still occupies the uniqueness slot. Reconciliation should unarchive it rather than create a duplicate. + +### Edit restrictions + +When `is_default=true`, editing must not allow: + +- scenario filters +- step-key filters +- assignments +- batching settings + +Default queues are canonical open views, not user-customizable slices. + +## Default Queue Lifecycle + +There are two policies. + +### Structural policy + +```text +has_active_human_evaluator_steps(run) +``` + +This says whether a run warrants a default queue when queues are conditional. + +### Global lifecycle policy + +A global policy toggle determines whether default queues are unconditional for all runs, for example: + +```text +EVALUATIONS_DEFAULT_QUEUES_FOR_ALL_RUNS +``` + +When enabled: + +- every run gets a default queue +- the default queue is not archived merely because active human work disappears + +When disabled: + +- active human work requires a default queue +- absence of active human work archives the default queue + +### Reconciliation behavior + +```text +required + missing -> create +required + archived -> unarchive +required + active -> no-op +not required + active -> archive +``` + +Queue lifecycle should use soft deletion for this behavior. Hard deletion may remain available separately where still needed. + +Queue queries need archived-aware support so reconciliation can restore the existing canonical row. + +## Simple Queue Semantics + +A `SimpleQueue` should be understood as: + +```text +a simplified human-work projection of an evaluation's default queue +``` + +not as a wrapper around runs that happen to use a special ingestion mode. + +### Eligibility + +A run is simple-queue eligible when: + +```text +run.flags.is_queue == true +``` + +under the redefined meaning above. + +That means all of these can appear through the simple queue surface when they have active human work and an active default queue: + +- query-backed evaluations +- testset-backed evaluations +- direct trace-backed evaluations +- direct testcase-backed evaluations + +Auto-only evaluations with an eager but empty default queue are not simple-queue eligible unless product later decides otherwise. + +### Identifiers + +Simple queue endpoints should remain queue-ID based. + +If a caller starts from a run, add one small lookup endpoint: + +```http +GET /evaluations/runs/{run_id}/default-queue +``` + +This returns the canonical queue resource or ID, after which existing simple queue endpoints can continue using queue IDs. + +There is no need for run-scoped archive/unarchive endpoints. + +## Relationship to Unified Eval Loops + +The parallel eval-loop work formalizes: + +```text +evaluation = graph + tensor + process(slice) +``` + +The queue model fits above that as: + +```text +default queue = canonical persisted human-work view over the tensor +``` + +The queue axes align naturally with tensor dimensions: + +- scenarios +- steps +- repeats + +A default queue is the open/default view over those dimensions. + +### Boundary + +The default queue is not orchestration. + +- eval runtime owns planning, processing, and tensor population +- queues own visibility, assignment, queue lifecycle, and user workflow + +## Shared Design Tension: Step Lifecycle + +The unified-loop design currently leans toward: + +```text +remove_step -> prune tensor cells +``` + +But if product semantics require evaluators or steps to be archived rather than hard-removed, the graph model needs to support active versus archived steps. + +That affects queue logic directly: + +- queue eligibility should depend on active human steps +- archived human steps may remain visible historically without keeping the queue active +- old tensor cells may remain instead of being pruned + +This needs explicit alignment with the eval-loop mutation model. + +## Recommended Backend Changes + +1. Add inferred run flags: + - `has_traces` + - `has_testcases` +2. Redefine persisted `run.flags.is_queue` as simple-queue eligibility. +3. Add `queue.flags.is_default`. +4. Enforce one default queue per run with a partial unique index over `flags.is_default`. +5. Reject filters/assignments/batching edits on default queues. +6. Add queue archive/unarchive support and archived-aware queries. +7. Add default-queue reconciliation tied to run creation/editing and queue lifecycle changes. +8. Persist and eagerly refresh `run.flags.is_queue` when the default queue or active human-work state changes. +9. Keep simple queue endpoints queue-ID based. +10. Add a small run-scoped lookup endpoint for the default queue: + +```http +GET /evaluations/runs/{run_id}/default-queue +``` + +11. Align eval-loop mutation semantics around active vs archived steps before hardening remove/prune behavior. diff --git a/docs/designs/unify-evals-and-queues/unify-evals-extension-verbatim.md b/docs/designs/unify-evals-and-queues/unify-evals-extension-verbatim.md new file mode 100644 index 0000000000..66e41f27e7 --- /dev/null +++ b/docs/designs/unify-evals-and-queues/unify-evals-extension-verbatim.md @@ -0,0 +1,533 @@ +Yes — I think this is a better model than the one I was circling around before. + +It gives each fact its own place instead of forcing `is_queue` to carry a hidden runtime meaning that does not match its name. + +## The revised flag model + +### Add inferred source flags + +Alongside: + +- `has_queries` +- `has_testsets` +- `has_evaluators` + +add: + +- `has_traces` +- `has_testcases` + +Then the source family becomes explicit: + +| Source shape | Flags | +|---|---| +| query-backed | `has_queries=True` | +| testset-backed | `has_testsets=True` | +| externally supplied traces | `has_traces=True` | +| externally supplied testcases | `has_testcases=True` | + +That gives us the right vocabulary for: + +- validation +- topology classification +- query filtering +- avoiding mixed source families +- distinguishing declared sources from externally supplied concrete items + +And importantly: + +> that source distinction does not need to be smuggled through `is_queue`. + +I like that a lot. + +## Then redefine `is_queue` + +Under your proposal: + +```text +is_queue = there is an active default queue + and at least one active human evaluator step +``` + +or more plainly: + +```text +this evaluation is currently interactable as a simple annotation queue +``` + +That is much more intuitive. + +It answers the product-relevant question: + +> Can this run be interacted with through the simple queue / annotation queue surface? + +And it can turn off when: +- the default queue is archived +- there are no active human evaluator steps + +That means `is_queue` is no longer an execution/input-source flag. It becomes a derived **human-work availability** flag. + +I think that is actually the right use of the name. + +## The clean separation becomes + +| Fact | Where represented | +|---|---| +| declared query source | `has_queries` | +| declared testset source | `has_testsets` | +| external trace source | `has_traces` | +| external testcase source | `has_testcases` | +| has evaluators | `has_evaluators` | +| has active human evaluators | `has_human` or eventually `has_active_human` if step archival becomes explicit | +| can be used via simple queues | `is_queue` | +| this particular queue is canonical | `EvaluationQueue.flags.is_default` | + +That is much clearer. + +--- + +# A few important consequences + +## 1. `is_queue` can no longer be derived from the run alone + +Today run flags are derived only from run data. + +Under this model, `is_queue` depends on: + +- run graph state +- queue existence +- queue archival state + +So we should stop thinking of all run flags as purely “derived from steps.” + +Some are graph-derived: +- `has_queries` +- `has_testsets` +- `has_traces` +- `has_testcases` +- `has_evaluators` +- `has_human` +- `has_auto` + +But `is_queue` is relationship/lifecycle-derived: +- active default queue exists +- active human work exists + +That is fine, but it means the reconciliation path needs to update the run flag whenever the default queue is created, archived, or unarchived. + +## 2. `has_traces` / `has_testcases` solve the current synthetic-source hack + +Right now the code has this awkward logic: + +```python +if flags.is_queue and not _references: + inspect step_key text for "query" or "testset" +``` + +That exists because direct queues use synthetic input steps and the system still wants to classify them as trace/testcase-ish through the old flags. + +With explicit: + +- `has_traces` +- `has_testcases` + +we can stop overloading: +- `has_queries` to mean “trace-like” +- `has_testsets` to mean “testcase-like” +- `is_queue` to help infer source family from missing references + +That should make validation and topology much more honest. + +## 3. Source-backed human evaluations can naturally be simple queues + +Then your desired product rule becomes straightforward: + +```text +simple queue eligible +iff +run.is_queue == True +``` + +because `is_queue` now means exactly: + +```text +active default queue + active human evaluator work +``` + +So: + +| Run | `is_queue` | +|---|---:| +| query-backed + human evaluator + active default queue | true | +| testset-backed + human evaluator + active default queue | true | +| direct traces + human evaluator + active default queue | true | +| direct testcases + human evaluator + active default queue | true | +| auto-only eval with eager default queue | false | +| human eval whose default queue is archived | false | + +That is much more useful than the current meaning. + +--- + +# On the default queue itself + +Yes: I think we need an explicit queue flag: + +```python +EvaluationQueueFlags.is_default: bool +``` + +because shape alone is not enough once custom queues exist. A custom queue could happen to also have: + +- no scenario filter +- no step filter +- no assignments + +and still not be the canonical default queue. + +## Default queue invariants + +For `is_default=True`, enforce: + +```text +scenario_ids is None +step_keys is None +user_ids is None +``` + +You mentioned “no repeat index”; today queues do not have a repeat-index filter field, so in current schema that translates to **no assignments**, because assignment lanes are the queue-level way repeats are partitioned. If repeat-index filtering is later added explicitly, then yes, default queues should also forbid that. + +And I agree with the behavior: + +- default queue archived → it still exists +- creating another default queue for same run should fail +- reconciliation should unarchive the existing one rather than create a second row + +So the invariant is: + +```text +at most one default queue per run, including archived queues +``` + +That likely wants a real uniqueness guarantee, not just service logic. Since `is_default` inside JSONB flags is awkward for constraints, I would at least consider making `is_default` a real DB column or adding an explicit indexed field rather than burying it only in JSON flags. If we keep it only inside JSONB, enforcement is still possible but less pleasant and less obvious. + +--- + +# On simple queue identifiers + +Yes — today the simple queue endpoints are queue-ID-centric. + +If after unification the simple queue is fundamentally: + +```text +the simple projection of the run’s default queue +``` + +then using the **run ID** is probably the better canonical identifier for that surface. + +Because the user-facing thing is increasingly “the evaluation’s queue,” not “some arbitrary queue row.” + +## Compatibility shape + +Your compatibility idea is reasonable: + +- if the identifier resolves to a queue ID, use that queue +- otherwise, if it resolves to a run ID, fetch that run’s default queue +- eventually prefer run IDs in the new API/UI paths + +Since both are UUIDs, this is a pragmatic compatibility shim rather than a semantically clean contract. It works, but the implementation should be careful and deterministic: + +1. try queue ID lookup first +2. if no queue exists, try default queue by run ID +3. if both could theoretically exist, queue ID wins + +UUID collision across tables is extraordinarily unlikely, but the precedence should still be explicit. + +Longer term, I would probably prefer: +- queue-specific endpoints remain queue-ID based +- simple queue endpoints become run-ID based +- maybe add explicit routes like `/simple/queues/by-run/{run_id}` during transition + +But as an internal compatibility move, your suggestion is sound. + +--- + +# The resulting model is much stronger + +I would now summarize it this way: + +```text +Source flags describe where scenarios come from: + has_queries + has_testsets + has_traces + has_testcases + +Queue flags describe queue identity: + is_default + +Run.is_queue describes product eligibility: + active default queue + active human evaluator work + +SimpleQueue is the simplified API view over: + the run's default queue +``` + +That gives each concern one home. + +It also removes the awkwardness where a run could “have a queue” but `is_queue=False`, or have `is_queue=True` mostly because it accepts source batches. Under your model, the name and the meaning line up again. + +## One thing I would still decide explicitly + +Whether `is_queue` should be persisted at all or always recomputed. + +Because it depends on both: +- run state +- queue state + +If it is persisted, reconciliation must update it transactionally whenever: +- default queue created +- default queue archived/unarchived +- active human evaluator steps appear/disappear + +If it is recomputed, queries become more expensive or need joins. + +Given that you explicitly want it queryable, persisting it as a denormalized derived flag is reasonable — but then I would treat **default queue reconciliation + run flag refresh** as one logical operation. + +This is a much cleaner design than the earlier alternatives. + +--- + +Yes — on both points. + +## 1. Uniqueness for the default queue + +`scenario_ids`, `step_keys`, and `user_ids` all currently live inside `EvaluationQueueData`, which is stored in the queue’s JSON `data` column. So if we define the default queue only by the shape: + +```text +scenario_ids is None +step_keys is None +user_ids is None +``` + +then a normal relational unique index over those fields is not straightforward. + +You could technically build expression indexes over JSON paths, but that would be a poor foundation for a first-class invariant like: + +```text +at most one default queue per run +``` + +I think your instinct is right: if **defaultness** matters enough to enforce uniqueness, it wants a first-class indexed column. + +### Recommended shape + +Add something like: + +```python +is_default: bool +``` + +as a real queue column, even if it is also mirrored into DTO flags for API symmetry. + +Then enforce: + +```text +unique(project_id, run_id) where is_default = true +``` + +That gives you the exact invariant you want: + +- only one default queue per run +- archived default queue still occupies the slot +- trying to create another default queue fails +- reconciliation can find and unarchive the existing one + +And separately, enforce in service validation: + +```text +if is_default: + scenario_ids must be None + step_keys must be None + user_ids must be None + batch settings must be None +``` + +The uniqueness invariant should key on **identity** (`is_default`), not on a coincidental combination of empty filters. + +That is much safer than saying “whichever queue happens to have no filters is default.” + +## 2. Simple queues can stay queue-ID based if default queue lookup is easy + +I agree with your revised take. + +If we add a very small direct way to fetch the default queue for a run, then we do **not** need to overload simple queue identifiers with run IDs. + +Something like: + +```http +GET /evaluations/runs/{run_id}/queue +``` + +or more explicitly: + +```http +GET /evaluations/runs/{run_id}/default-queue +``` + +would be enough. + +Then the flow becomes: + +1. evaluation screen has `run_id` +2. fetch default queue in one cheap call +3. receive `queue_id` +4. use existing simple queue endpoints with `queue_id` + +That preserves the clean existing identity model: +- queue resources use queue IDs +- runs use run IDs +- no ambiguous UUID interpretation +- no compatibility shim needed unless product wants one for some other reason + +I’d favor: + +```http +GET /evaluations/runs/{run_id}/default-queue +``` + +because it is unambiguous and leaves room for other queue relationships later. + +Possibly also: + +```http +POST /evaluations/runs/{run_id}/default-queue/archive +POST /evaluations/runs/{run_id}/default-queue/unarchive +``` + +if we want run-scoped convenience operations, though the generic queue archive/unarchive endpoints may already be enough. + +## 3. Persisting `run.flags.is_queue` + +I agree with you. + +Because we want: + +- cheap filtering +- parity with the other persisted inferred run flags +- no joins for the common read paths +- and queue archival/unarchival is a meaningful state transition anyway + +then `is_queue` should be persisted and eagerly maintained. + +Under the new meaning: + +```text +run.flags.is_queue = + active default queue exists + and active human evaluator work exists +``` + +And yes, that means queue lifecycle operations must update the run flags eagerly: + +| Event | `run.flags.is_queue` effect | +|---|---| +| default queue created and human work exists | `true` | +| default queue unarchived and human work exists | `true` | +| default queue archived | `false` | +| last active human evaluator archived | `false` | +| human evaluator added/restored while default queue active | `true` | + +This is analogous to the other denormalized flags: +- they are persisted +- they are inferred/maintained eagerly +- query paths stay cheap + +The only thing I would be careful about is **where** that synchronization lives. I would want one narrow service-level operation that owns the invariant, rather than scattered manual edits: + +```text +reconcile_default_queue_and_run_flags(...) +``` + +or maybe two lower-level pieces used transactionally: + +```text +reconcile_default_queue(...) +refresh_run_flags(...) +``` + +But conceptually, yes: persist it and keep it in sync. + +## Updated model after your refinements + +```text +EvaluationRunFlags + has_queries + has_testsets + has_traces + has_testcases + has_evaluators + has_human + has_auto + is_queue # can be used via simple queue surface + +EvaluationQueue + is_default # first-class indexed identity field + +Default queue invariant + at most one per run + scenario_ids=None + step_keys=None + user_ids=None + no batching constraints + +Simple queue endpoints + remain queue-ID based + +Default queue lookup + add a direct run-scoped endpoint to get queue_id cheaply +``` + +That feels quite settled to me now. + +--- + +Yes — if `is_default` is materialized inside the queue `flags` JSONB, then Postgres can enforce uniqueness from that directly. We do **not** necessarily need a dedicated physical column just for the index. + +For example, conceptually: + +```sql +CREATE UNIQUE INDEX ux_evaluation_queues_default_per_run +ON evaluation_queues (project_id, run_id) +WHERE (flags ->> 'is_default')::boolean = true; +``` + +That would give us the important invariant: + +```text +at most one default queue per run, including archived rows +``` + +because archived rows still satisfy the partial-index predicate. + +So yes: if we are comfortable treating `flags.is_default` as a materialized persisted field, a partial unique index over the JSONB expression is enough. + +The tradeoff is mainly ergonomics: +- **JSONB expression index**: no extra column, keeps queue flags grouped together +- **real column**: easier to inspect/query/index conventionally + +But technically, your approach is sound, and in this codebase it may be the more consistent choice if other queue flags already live in JSONB. + +And agreed: **no** to run-scoped archive/unarchive endpoints. + +We only need: +- generic queue archive/unarchive operations +- plus a simple way to fetch the default queue for a run + +So the useful addition is something like: + +```http +GET /evaluations/runs/{run_id}/default-queue +``` + +but not special lifecycle endpoints hanging off the run. diff --git a/hosting/docker-compose/ee/docker-compose.dev.yml b/hosting/docker-compose/ee/docker-compose.dev.yml index d74d0126d3..3578a1155e 100644 --- a/hosting/docker-compose/ee/docker-compose.dev.yml +++ b/hosting/docker-compose/ee/docker-compose.dev.yml @@ -77,9 +77,17 @@ services: "8000", "--reload", "--reload-dir", + "/app/oss", + "--reload-dir", + "/app/ee", + "--reload-dir", + "/app/entrypoints", + "--reload-dir", "/sdks/python", "--reload-dir", "/clients/python", + "--reload-exclude", + "**/tests/**", "--root-path", "/api", "--loop", @@ -133,7 +141,7 @@ services: image: agenta-ee-dev-api:latest # === EXECUTION ============================================ # command: > - watchmedo auto-restart --directory=/app/ --pattern=*.py --recursive -- + watchmedo auto-restart --directory=/app/ --pattern=*.py --recursive --ignore-patterns=*/tests/* -- python -m entrypoints.worker_evaluations # === STORAGE ============================================== # volumes: @@ -170,7 +178,7 @@ services: image: agenta-ee-dev-api:latest # === EXECUTION ============================================ # command: > - watchmedo auto-restart --directory=/app/ --pattern=*.py --recursive -- + watchmedo auto-restart --directory=/app/ --pattern=*.py --recursive --ignore-patterns=*/tests/* -- python -m entrypoints.worker_tracing # === STORAGE ============================================== # volumes: @@ -207,7 +215,7 @@ services: image: agenta-ee-dev-api:latest # === EXECUTION ============================================ # command: > - watchmedo auto-restart --directory=/app/ --pattern=*.py --recursive -- + watchmedo auto-restart --directory=/app/ --pattern=*.py --recursive --ignore-patterns=*/tests/* -- python -m entrypoints.worker_webhooks # === STORAGE ============================================== # volumes: @@ -250,7 +258,7 @@ services: image: agenta-ee-dev-api:latest # === EXECUTION ============================================ # command: > - watchmedo auto-restart --directory=/app/ --pattern=*.py --recursive -- + watchmedo auto-restart --directory=/app/ --pattern=*.py --recursive --ignore-patterns=*/tests/* -- python -m entrypoints.worker_events # === STORAGE ============================================== # volumes: @@ -355,9 +363,13 @@ services: "8080", "--reload", "--reload-dir", + "/app", + "--reload-dir", "/sdks/python", "--reload-dir", "/clients/python", + "--reload-exclude", + "**/tests/**", "--root-path", "/services", "--loop", diff --git a/hosting/docker-compose/oss/docker-compose.dev.yml b/hosting/docker-compose/oss/docker-compose.dev.yml index 750ea6c8be..95c4aab381 100644 --- a/hosting/docker-compose/oss/docker-compose.dev.yml +++ b/hosting/docker-compose/oss/docker-compose.dev.yml @@ -77,9 +77,15 @@ services: "8000", "--reload", "--reload-dir", + "/app/oss", + "--reload-dir", + "/app/entrypoints", + "--reload-dir", "/sdks/python", "--reload-dir", "/clients/python", + "--reload-exclude", + "**/tests/**", "--root-path", "/api", "--loop", @@ -133,7 +139,7 @@ services: image: agenta-oss-dev-api:latest # === EXECUTION ============================================ # command: > - watchmedo auto-restart --directory=/app/ --pattern=*.py --recursive -- + watchmedo auto-restart --directory=/app/ --pattern=*.py --recursive --ignore-patterns=*/tests/* -- python -m entrypoints.worker_evaluations # === STORAGE ============================================== # volumes: @@ -170,7 +176,7 @@ services: image: agenta-oss-dev-api:latest # === EXECUTION ============================================ # command: > - watchmedo auto-restart --directory=/app/ --pattern=*.py --recursive -- + watchmedo auto-restart --directory=/app/ --pattern=*.py --recursive --ignore-patterns=*/tests/* -- python -m entrypoints.worker_tracing # === STORAGE ============================================== # volumes: @@ -207,7 +213,7 @@ services: image: agenta-oss-dev-api:latest # === EXECUTION ============================================ # command: > - watchmedo auto-restart --directory=/app/ --pattern=*.py --recursive -- + watchmedo auto-restart --directory=/app/ --pattern=*.py --recursive --ignore-patterns=*/tests/* -- python -m entrypoints.worker_webhooks # === STORAGE ============================================== # volumes: @@ -250,7 +256,7 @@ services: image: agenta-oss-dev-api:latest # === EXECUTION ============================================ # command: > - watchmedo auto-restart --directory=/app/ --pattern=*.py --recursive -- + watchmedo auto-restart --directory=/app/ --pattern=*.py --recursive --ignore-patterns=*/tests/* -- python -m entrypoints.worker_events # === STORAGE ============================================== # volumes: @@ -354,9 +360,13 @@ services: "8080", "--reload", "--reload-dir", + "/app", + "--reload-dir", "/sdks/python", "--reload-dir", "/clients/python", + "--reload-exclude", + "**/tests/**", "--root-path", "/services", "--loop", diff --git a/sdk/agenta/sdk/evaluations/runtime/__init__.py b/sdk/agenta/sdk/evaluations/runtime/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/sdk/agenta/sdk/evaluations/runtime/adapters.py b/sdk/agenta/sdk/evaluations/runtime/adapters.py new file mode 100644 index 0000000000..60a0c868ad --- /dev/null +++ b/sdk/agenta/sdk/evaluations/runtime/adapters.py @@ -0,0 +1,129 @@ +from typing import Any, Dict, Optional + +from agenta.sdk.decorators.running import invoke_application, invoke_evaluator +from agenta.sdk.evaluations.preview.utils import fetch_trace_data +from agenta.sdk.evaluations.results import acreate as alog_result +from agenta.sdk.evaluations.runtime.models import ( + ResultLogRequest, + WorkflowExecutionRequest, + WorkflowExecutionResult, +) +from agenta.sdk.models.evaluations import EvaluationStatus +from agenta.sdk.models.workflows import ( + ApplicationServiceRequest, + EvaluatorServiceRequest, + WorkflowServiceRequestData, +) + + +class SdkLocalApplicationRunner: + """SDK adapter for executing application steps through local decorators.""" + + async def execute( + self, + request: WorkflowExecutionRequest, + ) -> WorkflowExecutionResult: + response = await invoke_application( + request=ApplicationServiceRequest( + data=WorkflowServiceRequestData( + revision=request.revision, + parameters=request.parameters, + testcase=request.source.testcase, + inputs=request.source.inputs, + trace=request.upstream_trace, + outputs=request.upstream_outputs, + ), + references=request.references, # type: ignore[arg-type] + links=request.links, # type: ignore[arg-type] + ) + ) + return _normalize_service_response(response) + + async def execute_batch( + self, + requests: list[WorkflowExecutionRequest], + ) -> list[WorkflowExecutionResult]: + return [await self.execute(request) for request in requests] + + +class SdkLocalEvaluatorRunner: + """SDK adapter for executing evaluator steps through local decorators.""" + + async def execute( + self, + request: WorkflowExecutionRequest, + ) -> WorkflowExecutionResult: + response = await invoke_evaluator( + request=EvaluatorServiceRequest( + version="2025.07.14", + data=WorkflowServiceRequestData( + revision=request.revision, + parameters=request.parameters, + testcase=request.source.testcase, + inputs=request.source.inputs, + trace=request.upstream_trace, + outputs=request.upstream_outputs, + ), + references=request.references, # type: ignore[arg-type] + links=request.links, # type: ignore[arg-type] + ) + ) + return _normalize_service_response(response) + + async def execute_batch( + self, + requests: list[WorkflowExecutionRequest], + ) -> list[WorkflowExecutionResult]: + return [await self.execute(request) for request in requests] + + +class SdkResultLogger: + """SDK adapter for persisting evaluation result cells.""" + + async def log(self, request: ResultLogRequest) -> Any: + cell = request.cell + return await alog_result( + run_id=cell.run_id, + scenario_id=cell.scenario_id, + step_key=cell.step_key, + repeat_idx=cell.repeat_idx, + trace_id=request.trace_id + if request.trace_id is not None + else cell.trace_id, + testcase_id=( + request.testcase_id + if request.testcase_id is not None + else cell.testcase_id + ), + error=request.error if request.error is not None else cell.error, + ) + + +class SdkTraceLoader: + """SDK adapter for loading traces after local workflow execution.""" + + def __init__(self, *, max_retries: int = 30, delay: float = 1.0): + self.max_retries = max_retries + self.delay = delay + + async def load(self, trace_id: str) -> Optional[Dict[str, Any]]: + return await fetch_trace_data( + trace_id, + max_retries=self.max_retries, + delay=self.delay, + ) + + +def _normalize_service_response(response: Any) -> WorkflowExecutionResult: + if not response or not getattr(response, "data", None) or not response.trace_id: + return WorkflowExecutionResult( + status=EvaluationStatus.FAILURE, + error={"message": "Missing or invalid workflow response"}, + ) + + return WorkflowExecutionResult( + status=EvaluationStatus.SUCCESS, + trace_id=response.trace_id, + span_id=getattr(response, "span_id", None), + outputs=getattr(response.data, "outputs", None), + ) diff --git a/sdk/agenta/sdk/evaluations/runtime/execution.py b/sdk/agenta/sdk/evaluations/runtime/execution.py new file mode 100644 index 0000000000..f49620ce2b --- /dev/null +++ b/sdk/agenta/sdk/evaluations/runtime/execution.py @@ -0,0 +1,152 @@ +from asyncio import Semaphore, gather +from datetime import datetime +from inspect import signature +from typing import Any, Awaitable, Callable, Dict, List, Optional, Protocol +from uuid import UUID + +from agenta.sdk.evaluations.runtime.models import ( + PlannedCell, + ResultLogRequest, + WorkflowExecutionRequest, + WorkflowExecutionResult, +) + + +class WorkflowRunner(Protocol): + """Adapter boundary for application/evaluator execution. + + SDK-local evaluation, API service execution, and backend-internal workflow + invocation should each implement this protocol instead of changing the + planner or topology classifier. + """ + + async def execute( + self, + request: WorkflowExecutionRequest, + ) -> WorkflowExecutionResult: ... + + +class WorkflowBatchRunner(WorkflowRunner, Protocol): + """Optional batch execution boundary for any runnable workflow step.""" + + async def execute_batch( + self, + requests: List[WorkflowExecutionRequest], + ) -> List[WorkflowExecutionResult]: ... + + +async def execute_workflow_batch( + *, + runner: WorkflowRunner, + requests: List[WorkflowExecutionRequest], + semaphore: Optional[Semaphore] = None, +) -> List[WorkflowExecutionResult]: + execute_batch = getattr(runner, "execute_batch", None) + + async def _guarded(request: WorkflowExecutionRequest) -> WorkflowExecutionResult: + if semaphore is not None: + async with semaphore: + return await runner.execute(request) + return await runner.execute(request) + + if execute_batch is not None: + try: + params = signature(execute_batch).parameters + accepts_semaphore = "semaphore" in params or any( + p.kind == p.VAR_KEYWORD for p in params.values() + ) + except (ValueError, TypeError): + accepts_semaphore = False + if accepts_semaphore: + return await execute_batch(requests, semaphore=semaphore) + return await execute_batch(requests) + + return list(await gather(*(_guarded(request) for request in requests))) + + +class EvaluationTaskRunner(Protocol): + """Generic evaluation task dispatch boundary. + + SDK/local code should use an in-process asyncio implementation. API code can + adapt this protocol to Taskiq without Taskiq leaking into SDK runtime code. + """ + + async def process_run( + self, + *, + project_id: UUID, + user_id: UUID, + run_id: UUID, + newest: Optional[datetime] = None, + oldest: Optional[datetime] = None, + ) -> Any: ... + + async def process_slice( + self, + *, + project_id: UUID, + user_id: UUID, + run_id: UUID, + source_kind: str, + trace_ids: Optional[List[str]] = None, + testcase_ids: Optional[List[UUID]] = None, + input_step_key: Optional[str] = None, + ) -> Any: ... + + +class AsyncioEvaluationTaskRunner: + """In-process task runner adapter for SDK/local evaluation execution.""" + + def __init__( + self, + *, + process_run: Optional[Callable[..., Awaitable[Any]]] = None, + process_slice: Optional[Callable[..., Awaitable[Any]]] = None, + ): + self._process_run = process_run + self._process_slice = process_slice + + async def process_run(self, **kwargs: Any) -> Any: + if self._process_run is None: + raise RuntimeError("process_run handler is not configured") + return await self._process_run(**kwargs) + + async def process_slice(self, **kwargs: Any) -> Any: + if self._process_slice is None: + raise RuntimeError("process_slice handler is not configured") + return await self._process_slice(**kwargs) + + +class ResultLogger(Protocol): + """Adapter boundary for persisting planned result cells.""" + + async def log(self, request: ResultLogRequest) -> Any: ... + + +class TraceLoader(Protocol): + """Adapter boundary for loading runner traces after a step executes.""" + + async def load(self, trace_id: str) -> Optional[Any]: ... + + +class RuntimeExecutionContext: + """Small mutable context shared by runner adapters while processing a scenario.""" + + def __init__(self) -> None: + self.results: Dict[str, Any] = {} + self.traces: Dict[str, Any] = {} + self.outputs: Dict[str, Any] = {} + + def remember_result(self, *, cell: PlannedCell, result: Any) -> None: + self.results[cell.step_key] = result + + def remember_execution( + self, + *, + cell: PlannedCell, + execution: WorkflowExecutionResult, + ) -> None: + if execution.trace is not None: + self.traces[cell.step_key] = execution.trace + if execution.outputs is not None: + self.outputs[cell.step_key] = execution.outputs diff --git a/sdk/agenta/sdk/evaluations/runtime/models.py b/sdk/agenta/sdk/evaluations/runtime/models.py new file mode 100644 index 0000000000..99a526dfaf --- /dev/null +++ b/sdk/agenta/sdk/evaluations/runtime/models.py @@ -0,0 +1,129 @@ +from typing import Any, Dict, List, Literal, Optional +from uuid import UUID + +from pydantic import BaseModel, ConfigDict, Field + +from agenta.sdk.models.evaluations import EvaluationStatus, Origin + +StepType = Literal["input", "invocation", "annotation"] +SourceKind = Literal["query", "testset", "trace", "testcase", "direct"] +TopologyStatus = Literal["supported", "potential", "not_planned", "unsupported"] +DispatchKind = Literal[ + "batch_query", + "batch_testset", + "batch_invocation", + "queue_traces", + "queue_testcases", + "live_query", +] + + +class EvaluationStep(BaseModel): + model_config = ConfigDict(arbitrary_types_allowed=True) + + key: str + type: StepType + origin: Origin = "custom" + references: Dict[str, Any] = Field(default_factory=dict) + inputs: List[str] = Field(default_factory=list) + + +class ResolvedSourceItem(BaseModel): + model_config = ConfigDict(arbitrary_types_allowed=True) + + kind: SourceKind + step_key: str + references: Dict[str, Any] = Field(default_factory=dict) + trace_id: Optional[str] = None + span_id: Optional[str] = None + testcase_id: Optional[UUID] = None + testcase: Optional[Any] = None + trace: Optional[Any] = None + inputs: Optional[Any] = None + outputs: Optional[Any] = None + + +class ScenarioBinding(BaseModel): + model_config = ConfigDict(arbitrary_types_allowed=True) + + scenario_id: UUID + source: ResolvedSourceItem + interval: Optional[int] = None + timestamp: Optional[Any] = None + + +class PlannedCell(BaseModel): + model_config = ConfigDict(arbitrary_types_allowed=True) + + run_id: UUID + scenario_id: UUID + step_key: str + step_type: StepType + origin: Origin + repeat_idx: int + status: EvaluationStatus + should_execute: bool = False + trace_id: Optional[str] = None + span_id: Optional[str] = None + testcase_id: Optional[UUID] = None + error: Optional[Dict[str, Any]] = None + + +class ExecutionPlan(BaseModel): + model_config = ConfigDict(arbitrary_types_allowed=True) + + run_id: UUID + cells: List[PlannedCell] + + @property + def executable_cells(self) -> List[PlannedCell]: + return [cell for cell in self.cells if cell.should_execute] + + +class TopologyDecision(BaseModel): + status: TopologyStatus + label: str + reason: str + dispatch: Optional[DispatchKind] = None + + +class WorkflowExecutionRequest(BaseModel): + """Runner-agnostic request for an application or evaluator step.""" + + model_config = ConfigDict(arbitrary_types_allowed=True) + + step: EvaluationStep + cell: PlannedCell + source: ResolvedSourceItem + revision: Any + parameters: Optional[Any] = None + references: Dict[str, Any] = Field(default_factory=dict) + links: Optional[Dict[str, Any]] = None + upstream_trace: Optional[Any] = None + upstream_outputs: Optional[Any] = None + + +class WorkflowExecutionResult(BaseModel): + """Normalized result produced by any workflow runner adapter.""" + + model_config = ConfigDict(arbitrary_types_allowed=True) + + status: EvaluationStatus + trace_id: Optional[str] = None + span_id: Optional[str] = None + hash_id: Optional[str] = None + outputs: Optional[Any] = None + trace: Optional[Any] = None + error: Optional[Dict[str, Any]] = None + + +class ResultLogRequest(BaseModel): + """Runner-agnostic request for persisting a planned result cell.""" + + model_config = ConfigDict(arbitrary_types_allowed=True) + + cell: PlannedCell + trace_id: Optional[str] = None + span_id: Optional[str] = None + testcase_id: Optional[UUID] = None + error: Optional[Dict[str, Any]] = None diff --git a/sdk/agenta/sdk/evaluations/runtime/planner.py b/sdk/agenta/sdk/evaluations/runtime/planner.py new file mode 100644 index 0000000000..87d48867ca --- /dev/null +++ b/sdk/agenta/sdk/evaluations/runtime/planner.py @@ -0,0 +1,202 @@ +from typing import List, Optional +from uuid import UUID + +from agenta.sdk.evaluations.runtime.models import ( + EvaluationStep, + ExecutionPlan, + PlannedCell, + ResolvedSourceItem, + ScenarioBinding, +) +from agenta.sdk.models.evaluations import EvaluationStatus + + +def build_repeat_indices(repeats: Optional[int]) -> List[int]: + count = repeats or 1 + if count < 1: + count = 1 + return list(range(count)) + + +def effective_is_split( + *, + is_split: bool, + is_live: bool = False, + has_traces: bool = False, + has_testcases: bool = False, + has_application_steps: bool = False, + has_evaluator_steps: bool = False, +) -> bool: + if is_live or has_traces or has_testcases: + return False + if not has_application_steps or not has_evaluator_steps: + return False + return is_split + + +class EvaluationPlanner: + """Build the evaluation result tensor without knowing how steps execute.""" + + def plan( + self, + *, + run_id: UUID, + scenario_id: UUID, + source: ResolvedSourceItem, + steps: List[EvaluationStep], + repeats: Optional[int] = None, + is_split: bool = False, + is_live: bool = False, + has_traces: bool = False, + has_testcases: bool = False, + ) -> ExecutionPlan: + return self.plan_bindings( + run_id=run_id, + bindings=[ + ScenarioBinding( + scenario_id=scenario_id, + source=source, + ) + ], + steps=steps, + repeats=repeats, + is_split=is_split, + is_live=is_live, + has_traces=has_traces, + has_testcases=has_testcases, + ) + + def plan_bindings( + self, + *, + run_id: UUID, + bindings: List[ScenarioBinding], + steps: List[EvaluationStep], + repeats: Optional[int] = None, + is_split: bool = False, + is_live: bool = False, + has_traces: bool = False, + has_testcases: bool = False, + ) -> ExecutionPlan: + repeat_indices = build_repeat_indices(repeats) + + input_steps = [step for step in steps if step.type == "input"] + application_steps = [step for step in steps if step.type == "invocation"] + evaluator_steps = [step for step in steps if step.type == "annotation"] + app_repeat_indices = self._application_repeat_indices( + repeat_indices=repeat_indices, + is_split=is_split, + is_live=is_live, + has_traces=has_traces, + has_testcases=has_testcases, + has_application_steps=bool(application_steps), + has_evaluator_steps=bool(evaluator_steps), + ) + + cells: List[PlannedCell] = [] + + for binding in bindings: + source = binding.source + + for step in input_steps: + cells.extend( + PlannedCell( + run_id=run_id, + scenario_id=binding.scenario_id, + step_key=step.key, + step_type=step.type, + origin=step.origin, + repeat_idx=repeat_idx, + status=EvaluationStatus.SUCCESS, + trace_id=source.trace_id, + span_id=source.span_id, + testcase_id=source.testcase_id, + ) + for repeat_idx in repeat_indices + ) + + for step in application_steps: + cells.extend( + self._runnable_cells( + run_id=run_id, + scenario_id=binding.scenario_id, + source=source, + step=step, + repeat_indices=app_repeat_indices, + ) + ) + + for step in evaluator_steps: + cells.extend( + self._runnable_cells( + run_id=run_id, + scenario_id=binding.scenario_id, + source=source, + step=step, + repeat_indices=repeat_indices, + ) + ) + + return ExecutionPlan(run_id=run_id, cells=cells) + + def _application_repeat_indices( + self, + *, + repeat_indices: List[int], + is_split: bool, + is_live: bool, + has_traces: bool, + has_testcases: bool, + has_application_steps: bool, + has_evaluator_steps: bool, + ) -> List[int]: + split = effective_is_split( + is_split=is_split, + is_live=is_live, + has_traces=has_traces, + has_testcases=has_testcases, + has_application_steps=has_application_steps, + has_evaluator_steps=has_evaluator_steps, + ) + + if not has_application_steps: + return [] + if not has_evaluator_steps: + return repeat_indices + if split: + return repeat_indices + return [0] + + def _runnable_cells( + self, + *, + run_id: UUID, + scenario_id: UUID, + source: ResolvedSourceItem, + step: EvaluationStep, + repeat_indices: List[int], + ) -> List[PlannedCell]: + is_manual_annotation = step.type == "annotation" and step.origin in { + "human", + "custom", + } + status = ( + EvaluationStatus.PENDING + if is_manual_annotation + else EvaluationStatus.QUEUED + ) + + return [ + PlannedCell( + run_id=run_id, + scenario_id=scenario_id, + step_key=step.key, + step_type=step.type, + origin=step.origin, + repeat_idx=repeat_idx, + status=status, + should_execute=not is_manual_annotation, + testcase_id=source.testcase_id, + ) + for repeat_idx in repeat_indices + ] diff --git a/sdk/agenta/sdk/evaluations/runtime/source_slice.py b/sdk/agenta/sdk/evaluations/runtime/source_slice.py new file mode 100644 index 0000000000..731eacc8cb --- /dev/null +++ b/sdk/agenta/sdk/evaluations/runtime/source_slice.py @@ -0,0 +1,504 @@ +import asyncio +from typing import Any, Awaitable, Callable, Dict, List, Mapping, Optional +from uuid import UUID + +from pydantic import BaseModel, ConfigDict, Field + +from agenta.sdk.evaluations.runtime.execution import ( + ResultLogger, + TraceLoader, + execute_workflow_batch, +) +from agenta.sdk.evaluations.runtime.models import ( + EvaluationStep, + PlannedCell, + ResolvedSourceItem, + ResultLogRequest, + WorkflowExecutionRequest, + WorkflowExecutionResult, +) +from agenta.sdk.evaluations.runtime.planner import EvaluationPlanner +from agenta.sdk.models.evaluations import EvaluationStatus +from agenta.sdk.utils.logging import get_logger + +logger = get_logger(__name__) + + +class ProcessedScenario(BaseModel): + model_config = ConfigDict(arbitrary_types_allowed=True) + + scenario: Any + results: Dict[str, Any] = Field(default_factory=dict) + metrics: Optional[Any] = None + has_pending: bool = False + has_errors: bool = False + auto_results_created: bool = False + + +CreateScenario = Callable[[UUID], Awaitable[Any]] +RefreshMetrics = Callable[[UUID, Optional[UUID]], Awaitable[Any]] + + +async def process_evaluation_source_slice( + *, + run_id: UUID, + source_items: List[ResolvedSourceItem], + steps: List[EvaluationStep], + repeats: Optional[int], + create_scenario: CreateScenario, + result_logger: ResultLogger, + refresh_metrics: RefreshMetrics, + runners: Mapping[str, Any], + revisions: Mapping[str, Any], + trace_loader: Optional[TraceLoader] = None, + is_split: bool = False, + log_pending: bool = True, + refresh_metrics_without_auto_results: bool = True, + batch_size: Optional[int] = None, + max_retries: Optional[int] = None, + retry_delay: Optional[float] = None, +) -> List[ProcessedScenario]: + """Process concrete source items through the SDK-owned runtime contract. + + The function is runner/persistence agnostic. SDK preview uses local + decorator runners and API result logging; backend code can move to this + shape by supplying backend DAO/workflow adapters. + + batch_size controls the maximum number of concurrent invoke_workflow calls + across all scenarios and repeats. A single asyncio.Semaphore is shared by + both the scenario-level gather and the per-step repeat batch so that peak + concurrency equals exactly batch_size regardless of how repeats are split. + """ + semaphore = asyncio.Semaphore(batch_size) if batch_size else None + processed_lock = asyncio.Lock() + processed: List[ProcessedScenario] = [] + + logger.info( + "[SLICE] Starting", + run_id=str(run_id), + scenarios=len(source_items), + batch_size=batch_size, + max_retries=max_retries, + retry_delay=retry_delay, + ) + + async def _process_one(source_item: ResolvedSourceItem) -> None: + scenario = await create_scenario(run_id) + scenario_id = scenario.id + + plan = EvaluationPlanner().plan( + run_id=run_id, + scenario_id=scenario_id, + source=source_item, + steps=steps, + repeats=repeats, + is_split=is_split, + ) + results: Dict[str, Any] = {} + context_by_repeat = _initial_context_by_repeat( + source_item=source_item, + repeats=repeats, + ) + scenario_has_pending = False + scenario_has_errors = False + scenario_auto_results_created = False + + idx = 0 + while idx < len(plan.cells): + cell = plan.cells[idx] + step = _step_by_key(steps, cell.step_key) + if step is None: + idx += 1 + continue + + if cell.step_type == "input": + results[cell.step_key] = await result_logger.log( + ResultLogRequest( + cell=cell, + testcase_id=source_item.testcase_id, + trace_id=source_item.trace_id, + ) + ) + idx += 1 + continue + + if not cell.should_execute: + scenario_has_pending = True + if log_pending: + results[cell.step_key] = await result_logger.log( + ResultLogRequest(cell=cell) + ) + idx += 1 + continue + + batch_cells = _next_runnable_batch( + cells=plan.cells, + start_idx=idx, + step_key=cell.step_key, + ) + runner = runners.get(cell.step_key) + revision = revisions.get(cell.step_key) + if runner is None or revision is None: + for batch_cell in batch_cells: + scenario_has_errors = True + results[batch_cell.step_key] = await result_logger.log( + ResultLogRequest( + cell=_failed_cell( + batch_cell, + message=( + f"Missing runner or revision for " + f"{batch_cell.step_key}" + ), + ), + error={ + "message": ( + f"Missing runner or revision for " + f"{batch_cell.step_key}" + ) + }, + ) + ) + idx += len(batch_cells) + continue + + requests = [ + _build_execution_request( + cell=batch_cell, + step=step, + source_item=source_item, + revision=revision, + context_by_repeat=context_by_repeat, + ) + for batch_cell in batch_cells + ] + + executions = await _execute_with_retry( + runner=runner, + requests=requests, + semaphore=semaphore, + max_retries=max_retries, + retry_delay=retry_delay, + ) + for batch_cell, execution in zip(batch_cells, executions): + if trace_loader and execution.trace_id and execution.trace is None: + execution.trace = await trace_loader.load(str(execution.trace_id)) + if execution.outputs is None and execution.trace is not None: + execution.outputs = _extract_outputs(execution.trace) + + results[batch_cell.step_key] = await result_logger.log( + ResultLogRequest( + cell=batch_cell, + trace_id=execution.trace_id, + span_id=execution.span_id, + testcase_id=source_item.testcase_id, + error=execution.error, + ) + ) + scenario_auto_results_created = True + if execution.error or str(execution.status) in { + "failure", + "EvaluationStatus.FAILURE", + "errors", + "EvaluationStatus.ERRORS", + }: + scenario_has_errors = True + + if execution.trace_id: + _remember_context( + cell=batch_cell, + context_by_repeat=context_by_repeat, + trace=execution.trace, + trace_id=str(execution.trace_id), + span_id=execution.span_id, + outputs=execution.outputs, + ) + + if len(executions) != len(batch_cells): + scenario_has_errors = True + message = ( + f"Runner for {cell.step_key} returned {len(executions)} " + f"execution(s) for {len(batch_cells)} planned cell(s)." + ) + for batch_cell in batch_cells[len(executions) :]: + results[batch_cell.step_key] = await result_logger.log( + ResultLogRequest( + cell=_failed_cell(batch_cell, message=message), + testcase_id=source_item.testcase_id, + error={"message": message}, + ) + ) + scenario_auto_results_created = True + + idx += len(batch_cells) + + metrics = None + if refresh_metrics_without_auto_results or scenario_auto_results_created: + metrics = await refresh_metrics(run_id, scenario_id) + + async with processed_lock: + processed.append( + ProcessedScenario( + scenario=scenario, + results=results, + metrics=metrics, + has_pending=scenario_has_pending, + has_errors=scenario_has_errors, + auto_results_created=scenario_auto_results_created, + ) + ) + + await asyncio.gather(*(_process_one(item) for item in source_items)) + + logger.info( + "[SLICE] Complete", + run_id=str(run_id), + processed=len(processed), + has_errors=any(item.has_errors for item in processed), + ) + + if processed and ( + refresh_metrics_without_auto_results + or any(item.auto_results_created for item in processed) + ): + await refresh_metrics(run_id, None) + + return processed + + +async def _execute_with_retry( + *, + runner: Any, + requests: List[WorkflowExecutionRequest], + semaphore: Optional[asyncio.Semaphore], + max_retries: Optional[int], + retry_delay: Optional[float], +) -> List[WorkflowExecutionResult]: + attempts = max(1, (max_retries or 0) + 1) + delay = retry_delay or 0.0 + results: List[WorkflowExecutionResult] = await execute_workflow_batch( + runner=runner, + requests=requests, + semaphore=semaphore, + ) + for attempt in range(attempts - 1): + failed_indices = [ + i + for i, r in enumerate(results) + if r.error + or str(r.status) + in { + "failure", + "EvaluationStatus.FAILURE", + "errors", + "EvaluationStatus.ERRORS", + } + ] + if not failed_indices: + break + logger.warning( + "[RETRY] Retrying failed requests", + attempt=attempt + 1, + failed=len(failed_indices), + total=len(requests), + delay=delay, + ) + if delay > 0: + await asyncio.sleep(delay) + retried = await execute_workflow_batch( + runner=runner, + requests=[requests[i] for i in failed_indices], + semaphore=semaphore, + ) + for idx, result in zip(failed_indices, retried): + results[idx] = result + return results + + +def _step_by_key( + steps: List[EvaluationStep], + step_key: str, +) -> Optional[EvaluationStep]: + for step in steps: + if step.key == step_key: + return step + return None + + +def _initial_context_by_repeat( + *, + source_item: ResolvedSourceItem, + repeats: Optional[int], +) -> Dict[int, Dict[str, Any]]: + if not source_item.trace and not source_item.trace_id: + return {} + + trace = source_item.trace + trace_id = source_item.trace_id or _get_trace_id(trace) + root_span = _extract_root_span(trace) + span_id = source_item.span_id or _get_span_id(root_span) + outputs = source_item.outputs or _extract_outputs(trace) + if not trace_id: + return {} + + context = { + "trace": trace, + "trace_id": str(trace_id), + "span_id": span_id, + "outputs": outputs, + } + count = repeats or 1 + return {repeat_idx: context for repeat_idx in range(max(count, 1))} + + +def _next_runnable_batch( + *, + cells: List[PlannedCell], + start_idx: int, + step_key: str, +) -> List[PlannedCell]: + batch = [] + for cell in cells[start_idx:]: + if not cell.should_execute or cell.step_key != step_key: + break + batch.append(cell) + return batch + + +def _build_execution_request( + *, + cell: PlannedCell, + step: EvaluationStep, + source_item: ResolvedSourceItem, + revision: Any, + context_by_repeat: Dict[int, Dict[str, Any]], +) -> WorkflowExecutionRequest: + upstream = _upstream_for_cell( + cell=cell, + context_by_repeat=context_by_repeat, + ) + return WorkflowExecutionRequest( + step=step, + cell=cell, + source=source_item, + revision=_dump_revision(revision), + parameters=_revision_parameters(revision), + references={ + **(source_item.references or {}), + **(step.references or {}), + }, + links=upstream.get("links"), + upstream_trace=upstream.get("trace"), + upstream_outputs=upstream.get("outputs"), + ) + + +def _failed_cell(cell: PlannedCell, *, message: str) -> PlannedCell: + return cell.model_copy( + update={ + "status": EvaluationStatus.FAILURE, + "error": {"message": message}, + } + ) + + +def _dump_revision(revision: Any) -> Any: + if hasattr(revision, "model_dump"): + return revision.model_dump(mode="json", exclude_none=True) + return revision + + +def _revision_parameters(revision: Any) -> Optional[Any]: + data = getattr(revision, "data", None) + return getattr(data, "parameters", None) if data else None + + +def _upstream_for_cell( + *, + cell: PlannedCell, + context_by_repeat: Dict[int, Dict[str, Any]], +) -> Dict[str, Any]: + context = context_by_repeat.get(cell.repeat_idx) or context_by_repeat.get(0) or {} + if not context: + return {} + + trace_id = context.get("trace_id") + span_id = context.get("span_id") + links = ( + { + "invocation": { + "trace_id": trace_id, + "span_id": span_id, + } + } + if trace_id and span_id + else None + ) + return { + "links": links, + "trace": context.get("trace"), + "outputs": context.get("outputs"), + } + + +def _remember_context( + *, + cell: PlannedCell, + context_by_repeat: Dict[int, Dict[str, Any]], + trace: Optional[Any], + trace_id: str, + span_id: Optional[str], + outputs: Optional[Any], +) -> None: + context = { + "trace": trace, + "trace_id": trace_id, + "span_id": span_id, + "outputs": outputs, + } + context_by_repeat[cell.repeat_idx] = context + if cell.step_type == "invocation" and 0 not in context_by_repeat: + context_by_repeat[0] = context + + +def _extract_outputs(trace: Any) -> Optional[Any]: + root_span = _extract_root_span(trace) + if root_span is None: + return None + attributes = ( + root_span.get("attributes", {}) + if isinstance(root_span, dict) + else getattr(root_span, "attributes", {}) + ) + if hasattr(attributes, "model_dump"): + attributes = attributes.model_dump(mode="json", exclude_none=True) + return attributes.get("ag", {}).get("data", {}).get("outputs") + + +def _extract_root_span(trace: Any) -> Optional[Any]: + spans = ( + trace.get("spans") if isinstance(trace, dict) else getattr(trace, "spans", None) + ) + if not spans: + return None + root_span = next(iter(spans.values()), None) if isinstance(spans, dict) else None + if isinstance(root_span, list): + return None + return root_span + + +def _get_trace_id(trace: Any) -> Optional[str]: + if isinstance(trace, dict): + return trace.get("trace_id") + trace_id = getattr(trace, "trace_id", None) + return str(trace_id) if trace_id else None + + +def _get_span_id(root_span: Any) -> Optional[str]: + if root_span is None: + return None + span_id = ( + root_span.get("span_id") + if isinstance(root_span, dict) + else getattr(root_span, "span_id", None) + ) + return str(span_id) if span_id else None diff --git a/sdk/agenta/sdk/evaluations/runtime/topology.py b/sdk/agenta/sdk/evaluations/runtime/topology.py new file mode 100644 index 0000000000..d06d08512e --- /dev/null +++ b/sdk/agenta/sdk/evaluations/runtime/topology.py @@ -0,0 +1,145 @@ +from typing import Iterable, List, Optional + +from agenta.sdk.evaluations.runtime.models import EvaluationStep, TopologyDecision + + +def _has_reference(step: EvaluationStep, token: str) -> bool: + if any(token in str(key).lower() for key in step.references.keys()): + return True + return token in step.key.lower() + + +def _input_family(step: EvaluationStep) -> Optional[str]: + if _has_reference(step, "query"): + return "query" + if _has_reference(step, "testset"): + return "testset" + if _has_reference(step, "trace"): + return "trace" + if _has_reference(step, "testcase"): + return "testcase" + return None + + +def _steps_of_type( + steps: Iterable[EvaluationStep], step_type: str +) -> List[EvaluationStep]: + return [step for step in steps if step.type == step_type] + + +def classify_steps_topology( + *, + steps: List[EvaluationStep], + is_live: bool = False, + has_queries: bool = False, + has_testsets: bool = False, + has_traces: bool = False, + has_testcases: bool = False, + has_evaluators: bool = False, +) -> TopologyDecision: + input_steps = _steps_of_type(steps, "input") + application_steps = _steps_of_type(steps, "invocation") + evaluator_steps = _steps_of_type(steps, "annotation") + + input_families = { + family for family in (_input_family(step) for step in input_steps) if family + } + has_queries = has_queries or "query" in input_families + has_testsets = has_testsets or "testset" in input_families + has_traces = has_traces or "trace" in input_families + has_testcases = has_testcases or "testcase" in input_families + has_applications = bool(application_steps) + has_evaluators = has_evaluators or bool(evaluator_steps) + + if has_queries and has_testsets: + return TopologyDecision( + status="not_planned", + label="mixed query and testset sources", + reason="mixed query and testset source families in one run are not planned", + ) + + if is_live and has_testsets: + return TopologyDecision( + status="not_planned", + label="live testset evaluation", + reason="live testset evaluation is not a meaningful product shape", + ) + + if len(application_steps) > 1: + return TopologyDecision( + status="not_planned", + label="multiple application steps", + reason="A/B application comparisons should use separate evaluations", + ) + + if is_live and has_queries and has_evaluators and not has_applications: + return TopologyDecision( + status="supported", + label="live query -> evaluator", + reason="live query evaluator runs keep scheduler/windowing behavior", + dispatch="live_query", + ) + + if has_evaluators and not has_applications: + if has_testcases: + return TopologyDecision( + status="supported", + label="direct testcases -> evaluator", + reason="direct testcase batches are worker-dispatched", + dispatch="queue_testcases", + ) + if has_traces: + return TopologyDecision( + status="supported", + label="direct traces -> evaluator", + reason="direct trace batches are worker-dispatched", + dispatch="queue_traces", + ) + + if has_queries and has_applications: + return TopologyDecision( + status="potential", + label="query -> application", + reason=( + "query traces can seed application calls, but source trace links must " + "not be attached as application links because that would classify the " + "new application traces as annotations" + ), + ) + + if has_testsets and has_evaluators and not has_applications: + return TopologyDecision( + status="potential", + label="testset -> evaluator", + reason="non-queue testcase-only evaluator execution needs an explicit evaluator contract", + ) + + if has_queries and has_evaluators and not has_applications: + return TopologyDecision( + status="supported", + label="batch query -> evaluator", + reason="batch query evaluator runs are worker-dispatched", + dispatch="batch_query", + ) + + if has_testsets and has_applications and has_evaluators: + return TopologyDecision( + status="supported", + label="testset -> application -> evaluator", + reason="batch testset evaluation is worker-dispatched", + dispatch="batch_testset", + ) + + if has_testsets and has_applications and not has_evaluators and not has_queries: + return TopologyDecision( + status="supported", + label="testset -> application", + reason="batch inference / batch invocation is worker-dispatched", + dispatch="batch_invocation", + ) + + return TopologyDecision( + status="unsupported", + label="unsupported evaluation topology", + reason="no current worker dispatch path matches this evaluation graph", + ) diff --git a/sdk/tests/pytest/unit/test_evaluations_runtime.py b/sdk/tests/pytest/unit/test_evaluations_runtime.py new file mode 100644 index 0000000000..6c43ddca23 --- /dev/null +++ b/sdk/tests/pytest/unit/test_evaluations_runtime.py @@ -0,0 +1,1102 @@ +from types import SimpleNamespace +from uuid import uuid4 + +import pytest + +import agenta.sdk.evaluations.preview.evaluate as preview_evaluate +import agenta.sdk.evaluations.runtime.adapters as runtime_adapters +from agenta.sdk.evaluations.runtime.execution import execute_workflow_batch +from agenta.sdk.evaluations.runtime.models import ( + EvaluationStep, + PlannedCell, + ResolvedSourceItem, + ResultLogRequest, + ScenarioBinding, +) +from agenta.sdk.evaluations.runtime.planner import EvaluationPlanner +from agenta.sdk.evaluations.runtime.source_slice import ( + process_evaluation_source_slice, +) +from agenta.sdk.evaluations.runtime.topology import classify_steps_topology +from agenta.sdk.evaluations.runtime.models import WorkflowExecutionResult +from agenta.sdk.models.evaluations import EvaluationStatus + + +def test_sdk_runtime_planner_matches_split_repeat_rules(): + run_id = uuid4() + scenario_id = uuid4() + plan = EvaluationPlanner().plan( + run_id=run_id, + scenario_id=scenario_id, + source=ResolvedSourceItem( + kind="testcase", + step_key="testset-main", + testcase_id=uuid4(), + ), + steps=[ + EvaluationStep(key="testset-main", type="input"), + EvaluationStep(key="application-main", type="invocation"), + EvaluationStep(key="evaluator-auto", type="annotation", origin="auto"), + EvaluationStep(key="evaluator-human", type="annotation", origin="human"), + ], + repeats=3, + is_split=False, + ) + + assert [ + cell.repeat_idx for cell in plan.cells if cell.step_key == "application-main" + ] == [0] + assert [ + cell.status for cell in plan.cells if cell.step_key == "evaluator-human" + ] == [ + EvaluationStatus.PENDING, + EvaluationStatus.PENDING, + EvaluationStatus.PENDING, + ] + assert {(cell.step_key, cell.repeat_idx) for cell in plan.executable_cells} == { + ("application-main", 0), + ("evaluator-auto", 0), + ("evaluator-auto", 1), + ("evaluator-auto", 2), + } + + +def test_sdk_runtime_planner_handles_multiple_scenario_bindings(): + run_id = uuid4() + first_scenario_id = uuid4() + second_scenario_id = uuid4() + + plan = EvaluationPlanner().plan_bindings( + run_id=run_id, + bindings=[ + ScenarioBinding( + scenario_id=first_scenario_id, + source=ResolvedSourceItem( + kind="testcase", + step_key="testset-main", + testcase_id=uuid4(), + ), + ), + ScenarioBinding( + scenario_id=second_scenario_id, + source=ResolvedSourceItem( + kind="testcase", + step_key="testset-main", + testcase_id=uuid4(), + ), + ), + ], + steps=[ + EvaluationStep(key="testset-main", type="input"), + EvaluationStep(key="application-main", type="invocation"), + ], + repeats=2, + ) + + assert [cell.scenario_id for cell in plan.cells] == [ + first_scenario_id, + first_scenario_id, + first_scenario_id, + first_scenario_id, + second_scenario_id, + second_scenario_id, + second_scenario_id, + second_scenario_id, + ] + assert [ + (cell.step_key, cell.repeat_idx) + for cell in plan.cells + if cell.scenario_id == first_scenario_id + ] == [ + ("testset-main", 0), + ("testset-main", 1), + ("application-main", 0), + ("application-main", 1), + ] + + +def test_sdk_runtime_topology_classifier_matches_batch_inference_shape(): + decision = classify_steps_topology( + steps=[ + EvaluationStep( + key="testset-main", + type="input", + references={"testset_revision": {"id": str(uuid4())}}, + ), + EvaluationStep( + key="application-main", + type="invocation", + references={"application_revision": {"id": str(uuid4())}}, + ), + ], + ) + + assert decision.status == "supported" + assert decision.dispatch == "batch_invocation" + + +def test_sdk_runtime_topology_classifier_distinguishes_direct_testcases_from_testsets(): + decision = classify_steps_topology( + steps=[ + EvaluationStep(key="testcases", type="input"), + EvaluationStep(key="evaluator-human", type="annotation", origin="human"), + ], + has_testcases=True, + has_evaluators=True, + ) + + assert decision.status == "supported" + assert decision.dispatch == "queue_testcases" + + +def test_sdk_runtime_topology_classifier_keeps_deferred_query_to_application_shape(): + decision = classify_steps_topology( + steps=[ + EvaluationStep( + key="query-main", + type="input", + references={"query_revision": {"id": str(uuid4())}}, + ), + EvaluationStep( + key="application-main", + type="invocation", + references={"application_revision": {"id": str(uuid4())}}, + ), + ], + ) + + assert decision.status == "potential" + + +@pytest.mark.asyncio +async def test_sdk_workflow_batch_falls_back_to_single_execute(): + calls = [] + + class SingleRunner: + async def execute(self, request): + calls.append(request.cell.repeat_idx) + return WorkflowExecutionResult( + status=EvaluationStatus.SUCCESS, + trace_id=f"trace-{request.cell.repeat_idx}", + ) + + requests = [ + SimpleNamespace( + cell=SimpleNamespace(repeat_idx=0), + ), + SimpleNamespace( + cell=SimpleNamespace(repeat_idx=1), + ), + ] + + results = await execute_workflow_batch( + runner=SingleRunner(), + requests=requests, + ) + + assert calls == [0, 1] + assert [result.trace_id for result in results] == ["trace-0", "trace-1"] + + +@pytest.mark.asyncio +async def test_sdk_source_slice_batches_runnable_cells(): + run_id = uuid4() + scenario_id = uuid4() + logged = [] + + class BatchRunner: + def __init__(self): + self.requests = [] + + async def execute_batch(self, requests): + self.requests.append(requests) + return [ + WorkflowExecutionResult( + status=EvaluationStatus.SUCCESS, + trace_id=f"trace-{request.cell.repeat_idx}", + span_id=f"span-{request.cell.repeat_idx}", + ) + for request in requests + ] + + class Logger: + async def log(self, request): + logged.append((request.cell.step_key, request.cell.repeat_idx)) + return SimpleNamespace(id=uuid4()) + + async def create_scenario(run_id): + return SimpleNamespace(id=scenario_id) + + async def refresh_metrics(run_id, scenario_id): + return SimpleNamespace(id=uuid4()) + + runner = BatchRunner() + + await process_evaluation_source_slice( + run_id=run_id, + source_items=[ + ResolvedSourceItem( + kind="testcase", + step_key="testset-main", + testcase_id=uuid4(), + inputs={"prompt": "hello"}, + ) + ], + steps=[ + EvaluationStep(key="testset-main", type="input"), + EvaluationStep(key="evaluator-auto", type="annotation", origin="auto"), + ], + repeats=3, + create_scenario=create_scenario, + result_logger=Logger(), + refresh_metrics=refresh_metrics, + runners={"evaluator-auto": runner}, + revisions={"evaluator-auto": {"id": "revision"}}, + ) + + assert len(runner.requests) == 1 + assert [request.cell.repeat_idx for request in runner.requests[0]] == [0, 1, 2] + assert logged == [ + ("testset-main", 0), + ("testset-main", 1), + ("testset-main", 2), + ("evaluator-auto", 0), + ("evaluator-auto", 1), + ("evaluator-auto", 2), + ] + + +@pytest.mark.asyncio +async def test_sdk_source_slice_marks_short_runner_batch_as_error(): + run_id = uuid4() + scenario_id = uuid4() + logged = [] + + class ShortRunner: + async def execute_batch(self, requests): + return [ + WorkflowExecutionResult( + status=EvaluationStatus.SUCCESS, + trace_id="trace-0", + span_id="span-0", + ) + ] + + class Logger: + async def log(self, request): + logged.append(request) + return SimpleNamespace(id=uuid4()) + + async def create_scenario(run_id): + return SimpleNamespace(id=scenario_id) + + async def refresh_metrics(run_id, scenario_id): + return SimpleNamespace(id=uuid4()) + + processed = await process_evaluation_source_slice( + run_id=run_id, + source_items=[ + ResolvedSourceItem( + kind="testcase", + step_key="testset-main", + testcase_id=uuid4(), + inputs={"prompt": "hello"}, + ) + ], + steps=[ + EvaluationStep(key="testset-main", type="input"), + EvaluationStep(key="evaluator-auto", type="annotation", origin="auto"), + ], + repeats=2, + create_scenario=create_scenario, + result_logger=Logger(), + refresh_metrics=refresh_metrics, + runners={"evaluator-auto": ShortRunner()}, + revisions={"evaluator-auto": {"id": "revision"}}, + ) + + assert processed[0].has_errors is True + failed_log = logged[-1] + assert failed_log.cell.step_key == "evaluator-auto" + assert failed_log.cell.repeat_idx == 1 + assert failed_log.cell.status == EvaluationStatus.FAILURE + assert failed_log.error == { + "message": ( + "Runner for evaluator-auto returned 1 execution(s) for 2 planned cell(s)." + ) + } + + +@pytest.mark.asyncio +async def test_sdk_source_slice_marks_missing_runner_as_error(): + run_id = uuid4() + scenario_id = uuid4() + logged = [] + + class Logger: + async def log(self, request): + logged.append(request) + return SimpleNamespace(id=uuid4()) + + async def create_scenario(run_id): + return SimpleNamespace(id=scenario_id) + + async def refresh_metrics(run_id, scenario_id): + return SimpleNamespace(id=uuid4()) + + processed = await process_evaluation_source_slice( + run_id=run_id, + source_items=[ + ResolvedSourceItem( + kind="trace", + step_key="query-main", + trace_id="query-trace", + ) + ], + steps=[ + EvaluationStep(key="query-main", type="input"), + EvaluationStep(key="evaluator-auto", type="annotation", origin="auto"), + ], + repeats=1, + create_scenario=create_scenario, + result_logger=Logger(), + refresh_metrics=refresh_metrics, + runners={}, + revisions={}, + ) + + assert processed[0].has_errors is True + assert [(item.cell.step_key, item.error) for item in logged] == [ + ("query-main", None), + ( + "evaluator-auto", + {"message": "Missing runner or revision for evaluator-auto"}, + ), + ] + + +@pytest.mark.asyncio +async def test_sdk_source_slice_can_defer_manual_results_without_metric_refresh(): + run_id = uuid4() + scenario_id = uuid4() + logged = [] + + class Logger: + async def log(self, request): + logged.append(request.cell.step_key) + return SimpleNamespace(id=uuid4()) + + refresh_metrics = pytest.fail + + async def create_scenario(run_id): + return SimpleNamespace(id=scenario_id) + + processed = await process_evaluation_source_slice( + run_id=run_id, + source_items=[ + ResolvedSourceItem( + kind="trace", + step_key="query-main", + trace_id="query-trace", + ) + ], + steps=[ + EvaluationStep(key="query-main", type="input"), + EvaluationStep(key="evaluator-human", type="annotation", origin="human"), + ], + repeats=1, + create_scenario=create_scenario, + result_logger=Logger(), + refresh_metrics=refresh_metrics, + runners={}, + revisions={}, + log_pending=False, + refresh_metrics_without_auto_results=False, + ) + + assert processed[0].has_pending is True + assert processed[0].auto_results_created is False + assert logged == ["query-main"] + + +@pytest.mark.asyncio +async def test_sdk_source_slice_links_evaluators_to_application_traces(): + run_id = uuid4() + scenario_id = uuid4() + evaluator_requests = [] + + class ApplicationRunner: + async def execute_batch(self, requests): + return [ + WorkflowExecutionResult( + status=EvaluationStatus.SUCCESS, + trace_id="app-trace", + span_id="app-span", + outputs={"answer": "hello"}, + trace={ + "trace_id": "app-trace", + "spans": { + "root": { + "span_id": "app-span", + "attributes": { + "ag": { + "data": { + "outputs": {"answer": "hello"}, + } + } + }, + } + }, + }, + ) + for _ in requests + ] + + class EvaluatorRunner: + async def execute_batch(self, requests): + evaluator_requests.extend(requests) + return [ + WorkflowExecutionResult( + status=EvaluationStatus.SUCCESS, + trace_id=f"eval-trace-{request.cell.repeat_idx}", + ) + for request in requests + ] + + class Logger: + async def log(self, request): + return SimpleNamespace(id=uuid4()) + + async def create_scenario(run_id): + return SimpleNamespace(id=scenario_id) + + async def refresh_metrics(run_id, scenario_id): + return SimpleNamespace(id=uuid4()) + + await process_evaluation_source_slice( + run_id=run_id, + source_items=[ + ResolvedSourceItem( + kind="testcase", + step_key="testset-main", + testcase_id=uuid4(), + inputs={"prompt": "hello"}, + ) + ], + steps=[ + EvaluationStep(key="testset-main", type="input"), + EvaluationStep(key="application-main", type="invocation"), + EvaluationStep(key="evaluator-auto", type="annotation", origin="auto"), + ], + repeats=2, + create_scenario=create_scenario, + result_logger=Logger(), + refresh_metrics=refresh_metrics, + runners={ + "application-main": ApplicationRunner(), + "evaluator-auto": EvaluatorRunner(), + }, + revisions={ + "application-main": {"id": "application-revision"}, + "evaluator-auto": {"id": "evaluator-revision"}, + }, + is_split=False, + ) + + assert [request.cell.repeat_idx for request in evaluator_requests] == [0, 1] + assert [request.links for request in evaluator_requests] == [ + {"invocation": {"trace_id": "app-trace", "span_id": "app-span"}}, + {"invocation": {"trace_id": "app-trace", "span_id": "app-span"}}, + ] + assert [request.upstream_outputs for request in evaluator_requests] == [ + {"answer": "hello"}, + {"answer": "hello"}, + ] + + +@pytest.mark.asyncio +async def test_sdk_result_logger_adapter_preserves_repeat_idx(monkeypatch): + calls = [] + + async def fake_log_result(**kwargs): + calls.append(kwargs) + return {"id": "result"} + + monkeypatch.setattr(runtime_adapters, "alog_result", fake_log_result) + cell = PlannedCell( + run_id=uuid4(), + scenario_id=uuid4(), + step_key="evaluator-auto", + step_type="annotation", + origin="auto", + repeat_idx=2, + status=EvaluationStatus.SUCCESS, + testcase_id=uuid4(), + ) + + result = await runtime_adapters.SdkResultLogger().log( + ResultLogRequest( + cell=cell, + trace_id="trace-repeat", + ) + ) + + assert result == {"id": "result"} + assert calls == [ + { + "run_id": cell.run_id, + "scenario_id": cell.scenario_id, + "step_key": "evaluator-auto", + "repeat_idx": 2, + "trace_id": "trace-repeat", + "testcase_id": cell.testcase_id, + "error": None, + } + ] + + +@pytest.mark.asyncio +async def test_sdk_preview_evaluate_logs_repeat_aware_results(monkeypatch): + run_id = uuid4() + scenario_id = uuid4() + testset_id = uuid4() + testset_variant_id = uuid4() + testset_revision_id = uuid4() + application_revision_id = uuid4() + evaluator_revision_id = uuid4() + testcase_id = uuid4() + + testcase = SimpleNamespace( + id=testcase_id, + data={"prompt": "hello"}, + model_dump=lambda **kwargs: { + "id": str(testcase_id), + "data": {"prompt": "hello"}, + }, + ) + testset_revision = SimpleNamespace( + id=testset_revision_id, + testset_id=testset_id, + testset_variant_id=testset_variant_id, + slug="main", + version="1", + data=SimpleNamespace(testcases=[testcase]), + ) + application_revision = SimpleNamespace( + id=application_revision_id, + application_id=uuid4(), + application_variant_id=uuid4(), + slug="app", + version="1", + data=SimpleNamespace(parameters={"temperature": 0}), + model_dump=lambda **kwargs: {"id": str(application_revision_id)}, + ) + evaluator_revision = SimpleNamespace( + id=evaluator_revision_id, + evaluator_id=uuid4(), + evaluator_variant_id=uuid4(), + slug="eval", + version="1", + data=SimpleNamespace(parameters={"threshold": 1}), + model_dump=lambda **kwargs: {"id": str(evaluator_revision_id)}, + ) + + async def fake_retrieve_testset(**kwargs): + return testset_revision + + async def fake_retrieve_application(**kwargs): + return application_revision + + async def fake_retrieve_evaluator(**kwargs): + return evaluator_revision + + async def fake_create_run(**kwargs): + return SimpleNamespace(id=run_id) + + async def fake_add_scenario(**kwargs): + return SimpleNamespace(id=scenario_id) + + logged_results = [] + + async def fake_log_result(**kwargs): + logged_results.append(kwargs) + return SimpleNamespace(id=uuid4()) + + async def fake_invoke_application(**kwargs): + return SimpleNamespace( + data=SimpleNamespace(), + trace_id="app-trace", + span_id="app-span", + ) + + evaluator_trace_ids = iter(["eval-trace-0", "eval-trace-1"]) + + async def fake_invoke_evaluator(**kwargs): + return SimpleNamespace( + data=SimpleNamespace(), + trace_id=next(evaluator_trace_ids), + span_id="eval-span", + ) + + async def fake_fetch_trace_data(trace_id, **kwargs): + return { + "spans": { + "root": { + "attributes": { + "ag": { + "data": { + "inputs": {"prompt": "hello"}, + "outputs": {"answer": trace_id}, + } + } + } + } + } + } + + async def fake_compute_metrics(**kwargs): + return SimpleNamespace(id=uuid4()) + + async def fake_close_run(**kwargs): + return SimpleNamespace(id=run_id) + + async def fake_get_url(**kwargs): + return "" + + monkeypatch.setattr(preview_evaluate, "aretrieve_testset", fake_retrieve_testset) + monkeypatch.setattr( + preview_evaluate, "aretrieve_application", fake_retrieve_application + ) + monkeypatch.setattr( + preview_evaluate, "aretrieve_evaluator", fake_retrieve_evaluator + ) + monkeypatch.setattr(preview_evaluate, "acreate_run", fake_create_run) + monkeypatch.setattr(preview_evaluate, "aadd_scenario", fake_add_scenario) + monkeypatch.setattr(runtime_adapters, "alog_result", fake_log_result) + monkeypatch.setattr(runtime_adapters, "invoke_application", fake_invoke_application) + monkeypatch.setattr(runtime_adapters, "invoke_evaluator", fake_invoke_evaluator) + monkeypatch.setattr(runtime_adapters, "fetch_trace_data", fake_fetch_trace_data) + monkeypatch.setattr(preview_evaluate, "acompute_metrics", fake_compute_metrics) + monkeypatch.setattr(preview_evaluate, "aclose_run", fake_close_run) + monkeypatch.setattr(preview_evaluate, "aget_url", fake_get_url) + + result = await preview_evaluate.aevaluate( + testsets={testset_revision_id: "custom"}, + applications={application_revision_id: "custom"}, + evaluators={evaluator_revision_id: "auto"}, + repeats=2, + ) + + assert result["run"].id == run_id + assert [ + (logged_result["step_key"], logged_result["repeat_idx"]) + for logged_result in logged_results + ] == [ + ("testset-main", 0), + ("testset-main", 1), + ("application-app", 0), + ("evaluator-eval", 0), + ("evaluator-eval", 1), + ] + assert [ + logged_result["trace_id"] + for logged_result in logged_results + if logged_result["step_key"] == "evaluator-eval" + ] == ["eval-trace-0", "eval-trace-1"] + + +# --------------------------------------------------------------------------- +# Concurrency and retry +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_sdk_source_slice_runs_scenarios_concurrently_up_to_batch_size(): + """batch_size=2 with 4 scenarios: at most 2 invoke_workflow calls in flight at once.""" + import asyncio + + run_id = uuid4() + in_flight = 0 + peak = 0 + + class ConcurrentRunner: + async def execute_batch(self, requests, semaphore=None): + results = [] + for request in requests: + + async def _one(req): + nonlocal in_flight, peak + in_flight += 1 + peak = max(peak, in_flight) + await asyncio.sleep(0) + in_flight -= 1 + return WorkflowExecutionResult( + status=EvaluationStatus.SUCCESS, + trace_id=f"trace-{req.cell.repeat_idx}", + ) + + if semaphore is not None: + async with semaphore: + results.append(await _one(request)) + else: + results.append(await _one(request)) + return results + + class Logger: + async def log(self, request): + return SimpleNamespace(id=uuid4()) + + scenarios_created = [] + + async def create_scenario(run_id): + sid = uuid4() + scenarios_created.append(sid) + return SimpleNamespace(id=sid) + + async def refresh_metrics(run_id, scenario_id): + return None + + source_items = [ + ResolvedSourceItem( + kind="testcase", + step_key="testset-main", + testcase_id=uuid4(), + inputs={"x": str(i)}, + ) + for i in range(4) + ] + + await process_evaluation_source_slice( + run_id=run_id, + source_items=source_items, + steps=[ + EvaluationStep(key="testset-main", type="input"), + EvaluationStep(key="evaluator-auto", type="annotation", origin="auto"), + ], + repeats=1, + create_scenario=create_scenario, + result_logger=Logger(), + refresh_metrics=refresh_metrics, + runners={"evaluator-auto": ConcurrentRunner()}, + revisions={"evaluator-auto": {"id": "rev"}}, + batch_size=2, + ) + + assert len(scenarios_created) == 4 + assert peak <= 2 + + +@pytest.mark.asyncio +async def test_sdk_source_slice_semaphore_shared_across_repeats(): + """batch_size=2 with 1 scenario and 4 repeats: peak concurrency stays ≤ 2.""" + import asyncio + + run_id = uuid4() + scenario_id = uuid4() + in_flight = 0 + peak = 0 + + class ConcurrentRunner: + async def execute_batch(self, requests, semaphore=None): + async def _one(req): + nonlocal in_flight, peak + in_flight += 1 + peak = max(peak, in_flight) + await asyncio.sleep(0) + in_flight -= 1 + return WorkflowExecutionResult( + status=EvaluationStatus.SUCCESS, + trace_id=f"trace-{req.cell.repeat_idx}", + ) + + if semaphore is not None: + results = [] + for req in requests: + async with semaphore: + results.append(await _one(req)) + return results + return [await _one(req) for req in requests] + + class Logger: + async def log(self, request): + return SimpleNamespace(id=uuid4()) + + async def create_scenario(run_id): + return SimpleNamespace(id=scenario_id) + + async def refresh_metrics(run_id, scenario_id): + return None + + await process_evaluation_source_slice( + run_id=run_id, + source_items=[ + ResolvedSourceItem( + kind="testcase", + step_key="testset-main", + testcase_id=uuid4(), + inputs={"x": "0"}, + ) + ], + steps=[ + EvaluationStep(key="testset-main", type="input"), + EvaluationStep(key="evaluator-auto", type="annotation", origin="auto"), + ], + repeats=4, + create_scenario=create_scenario, + result_logger=Logger(), + refresh_metrics=refresh_metrics, + runners={"evaluator-auto": ConcurrentRunner()}, + revisions={"evaluator-auto": {"id": "rev"}}, + batch_size=2, + ) + + assert peak <= 2 + + +@pytest.mark.asyncio +async def test_sdk_source_slice_no_batch_size_runs_all_concurrently(): + """When batch_size=None the semaphore is absent and all scenarios run freely.""" + run_id = uuid4() + scenarios_created = [] + + class Runner: + async def execute_batch(self, requests, semaphore=None): + return [ + WorkflowExecutionResult( + status=EvaluationStatus.SUCCESS, + trace_id="t", + ) + for _ in requests + ] + + class Logger: + async def log(self, request): + return SimpleNamespace(id=uuid4()) + + async def create_scenario(run_id): + sid = uuid4() + scenarios_created.append(sid) + return SimpleNamespace(id=sid) + + async def refresh_metrics(run_id, scenario_id): + return None + + source_items = [ + ResolvedSourceItem( + kind="testcase", + step_key="testset-main", + testcase_id=uuid4(), + ) + for _ in range(5) + ] + + processed = await process_evaluation_source_slice( + run_id=run_id, + source_items=source_items, + steps=[ + EvaluationStep(key="testset-main", type="input"), + EvaluationStep(key="evaluator-auto", type="annotation", origin="auto"), + ], + repeats=1, + create_scenario=create_scenario, + result_logger=Logger(), + refresh_metrics=refresh_metrics, + runners={"evaluator-auto": Runner()}, + revisions={"evaluator-auto": {"id": "rev"}}, + batch_size=None, + ) + + assert len(processed) == 5 + + +@pytest.mark.asyncio +async def test_sdk_source_slice_retries_failed_cells_and_succeeds(): + """max_retries=1: first attempt fails, retry succeeds; result is success.""" + run_id = uuid4() + scenario_id = uuid4() + call_count = 0 + + class FlakyRunner: + async def execute_batch(self, requests, semaphore=None): + nonlocal call_count + call_count += 1 + if call_count == 1: + return [ + WorkflowExecutionResult( + status=EvaluationStatus.FAILURE, + error={"message": "transient"}, + ) + for _ in requests + ] + return [ + WorkflowExecutionResult( + status=EvaluationStatus.SUCCESS, + trace_id="recovered", + ) + for _ in requests + ] + + logged = [] + + class Logger: + async def log(self, request): + logged.append((request.cell.step_key, request.trace_id, request.error)) + return SimpleNamespace(id=uuid4()) + + async def create_scenario(run_id): + return SimpleNamespace(id=scenario_id) + + async def refresh_metrics(run_id, scenario_id): + return None + + processed = await process_evaluation_source_slice( + run_id=run_id, + source_items=[ + ResolvedSourceItem( + kind="testcase", + step_key="testset-main", + testcase_id=uuid4(), + inputs={"x": "1"}, + ) + ], + steps=[ + EvaluationStep(key="testset-main", type="input"), + EvaluationStep(key="evaluator-auto", type="annotation", origin="auto"), + ], + repeats=1, + create_scenario=create_scenario, + result_logger=Logger(), + refresh_metrics=refresh_metrics, + runners={"evaluator-auto": FlakyRunner()}, + revisions={"evaluator-auto": {"id": "rev"}}, + max_retries=1, + ) + + assert call_count == 2 + assert processed[0].has_errors is False + eval_log = next(entry for entry in logged if entry[0] == "evaluator-auto") + assert eval_log[1] == "recovered" + assert eval_log[2] is None + + +@pytest.mark.asyncio +async def test_sdk_source_slice_exhausts_retries_and_marks_error(): + """max_retries=1 with persistent failure: result is still an error.""" + run_id = uuid4() + scenario_id = uuid4() + call_count = 0 + + class AlwaysFailRunner: + async def execute_batch(self, requests, semaphore=None): + nonlocal call_count + call_count += 1 + return [ + WorkflowExecutionResult( + status=EvaluationStatus.FAILURE, + error={"message": "always fails"}, + ) + for _ in requests + ] + + class Logger: + async def log(self, request): + return SimpleNamespace(id=uuid4()) + + async def create_scenario(run_id): + return SimpleNamespace(id=scenario_id) + + async def refresh_metrics(run_id, scenario_id): + return None + + processed = await process_evaluation_source_slice( + run_id=run_id, + source_items=[ + ResolvedSourceItem( + kind="testcase", + step_key="testset-main", + testcase_id=uuid4(), + inputs={"x": "1"}, + ) + ], + steps=[ + EvaluationStep(key="testset-main", type="input"), + EvaluationStep(key="evaluator-auto", type="annotation", origin="auto"), + ], + repeats=1, + create_scenario=create_scenario, + result_logger=Logger(), + refresh_metrics=refresh_metrics, + runners={"evaluator-auto": AlwaysFailRunner()}, + revisions={"evaluator-auto": {"id": "rev"}}, + max_retries=1, + ) + + assert call_count == 2 + assert processed[0].has_errors is True + + +@pytest.mark.asyncio +async def test_sdk_source_slice_retries_only_failed_cells_in_batch(): + """With repeats=2, only the failing repeat is retried, not the successful one.""" + run_id = uuid4() + scenario_id = uuid4() + attempt_by_repeat: dict = {} + + class SelectiveFlakyRunner: + async def execute_batch(self, requests, semaphore=None): + results = [] + for req in requests: + idx = req.cell.repeat_idx + attempt_by_repeat[idx] = attempt_by_repeat.get(idx, 0) + 1 + if idx == 1 and attempt_by_repeat[idx] == 1: + results.append( + WorkflowExecutionResult( + status=EvaluationStatus.FAILURE, + error={"message": "fail repeat 1 first time"}, + ) + ) + else: + results.append( + WorkflowExecutionResult( + status=EvaluationStatus.SUCCESS, + trace_id=f"trace-{idx}", + ) + ) + return results + + class Logger: + async def log(self, request): + return SimpleNamespace(id=uuid4()) + + async def create_scenario(run_id): + return SimpleNamespace(id=scenario_id) + + async def refresh_metrics(run_id, scenario_id): + return None + + processed = await process_evaluation_source_slice( + run_id=run_id, + source_items=[ + ResolvedSourceItem( + kind="testcase", + step_key="testset-main", + testcase_id=uuid4(), + inputs={"x": "1"}, + ) + ], + steps=[ + EvaluationStep(key="testset-main", type="input"), + EvaluationStep(key="evaluator-auto", type="annotation", origin="auto"), + ], + repeats=2, + create_scenario=create_scenario, + result_logger=Logger(), + refresh_metrics=refresh_metrics, + runners={"evaluator-auto": SelectiveFlakyRunner()}, + revisions={"evaluator-auto": {"id": "rev"}}, + max_retries=1, + ) + + assert processed[0].has_errors is False + assert attempt_by_repeat[0] == 1 + assert attempt_by_repeat[1] == 2 diff --git a/sdks/python/agenta/sdk/engines/running/handlers.py b/sdks/python/agenta/sdk/engines/running/handlers.py index 70891ddbff..5bdb0d699c 100644 --- a/sdks/python/agenta/sdk/engines/running/handlers.py +++ b/sdks/python/agenta/sdk/engines/running/handlers.py @@ -316,13 +316,9 @@ def auto_exact_match_v0( Returns: Evaluation result with success flag (True for match, False for mismatch) """ - if parameters is None or not isinstance(parameters, dict): - raise InvalidConfigurationParametersV0Error(expected="dict", got=parameters) - - if "correct_answer_key" not in parameters: - raise MissingConfigurationParameterV0Error(path="correct_answer_key") + parameters = parameters or {} - correct_answer_key = str(parameters["correct_answer_key"]) + correct_answer_key = str(parameters.get("correct_answer_key", "correct_answer")) if inputs is None or not isinstance(inputs, dict): raise InvalidInputsV0Error(expected="dict", got=inputs) @@ -361,8 +357,7 @@ def auto_regex_test_v0( Returns: Evaluation result with success flag """ - if parameters is None or not isinstance(parameters, dict): - raise InvalidConfigurationParametersV0Error(expected="dict", got=parameters) + parameters = parameters or {} if "regex_pattern" not in parameters: raise MissingConfigurationParameterV0Error(path="regex_pattern") @@ -419,18 +414,14 @@ def field_match_test_v0( Returns: Evaluation result with success flag """ - if parameters is None or not isinstance(parameters, dict): - raise InvalidConfigurationParametersV0Error(expected="dict", got=parameters) + parameters = parameters or {} if "json_field" not in parameters: raise MissingConfigurationParameterV0Error(path="json_field") json_field = str(parameters["json_field"]) - if "correct_answer_key" not in parameters: - raise MissingConfigurationParameterV0Error(path="correct_answer_key") - - correct_answer_key = str(parameters["correct_answer_key"]) + correct_answer_key = str(parameters.get("correct_answer_key", "correct_answer")) if inputs is None or not isinstance(inputs, dict): raise InvalidInputsV0Error(expected="dict", got=inputs) @@ -515,8 +506,7 @@ def json_multi_field_match_v0( Dict with per-field scores and aggregate_score, e.g.: {"name": 1.0, "email": 0.0, "aggregate_score": 0.5} """ - if parameters is None or not isinstance(parameters, dict): - raise InvalidConfigurationParametersV0Error(expected="dict", got=parameters) + parameters = parameters or {} if "fields" not in parameters: raise MissingConfigurationParameterV0Error(path="fields") @@ -530,10 +520,7 @@ def json_multi_field_match_v0( got=fields, ) - if "correct_answer_key" not in parameters: - raise MissingConfigurationParameterV0Error(path="correct_answer_key") - - correct_answer_key = str(parameters["correct_answer_key"]) + correct_answer_key = str(parameters.get("correct_answer_key", "correct_answer")) if inputs is None or not isinstance(inputs, dict): raise InvalidInputsV0Error(expected="dict", got=inputs) @@ -626,8 +613,7 @@ async def auto_webhook_test_v0( Returns: Evaluation result with score from the webhook """ - if parameters is None or not isinstance(parameters, dict): - raise InvalidConfigurationParametersV0Error(expected="dict", got=parameters) + parameters = parameters or {} if "webhook_url" not in parameters: raise MissingConfigurationParameterV0Error(path="webhook_url") @@ -745,8 +731,7 @@ async def auto_custom_code_run_v0( Returns: Evaluation result with score from the custom code """ - if parameters is None or not isinstance(parameters, dict): - raise InvalidConfigurationParametersV0Error(expected="dict", got=parameters) + parameters = parameters or {} if "code" not in parameters: raise MissingConfigurationParameterV0Error(path="code") @@ -812,10 +797,7 @@ def _run_v2() -> Any: ) from e def _run_v1() -> Any: - if "correct_answer_key" not in parameters: - raise MissingConfigurationParameterV0Error(path="correct_answer_key") - - correct_answer_key = str(parameters["correct_answer_key"]) + correct_answer_key = str(parameters.get("correct_answer_key", "correct_answer")) if inputs is None or not isinstance(inputs, dict): raise InvalidInputsV0Error(expected="dict", got=inputs) @@ -877,8 +859,7 @@ async def auto_ai_critique_v0( Returns: Evaluation result with score from the AI """ - if parameters is None or not isinstance(parameters, dict): - raise InvalidConfigurationParametersV0Error(expected="dict", got=parameters) + parameters = parameters or {} correct_answer_key = parameters.get("correct_answer_key") @@ -1088,8 +1069,7 @@ def auto_starts_with_v0( Returns: Evaluation result with success flag """ - if parameters is None or not isinstance(parameters, dict): - raise InvalidConfigurationParametersV0Error(expected="dict", got=parameters) + parameters = parameters or {} if "prefix" not in parameters: raise MissingConfigurationParameterV0Error(path="prefix") @@ -1137,8 +1117,7 @@ def auto_ends_with_v0( Returns: Evaluation result with success flag """ - if parameters is None or not isinstance(parameters, dict): - raise InvalidConfigurationParametersV0Error(expected="dict", got=parameters) + parameters = parameters or {} if "suffix" not in parameters: raise MissingConfigurationParameterV0Error(path="suffix") @@ -1186,8 +1165,7 @@ def auto_contains_v0( Returns: Evaluation result with success flag """ - if parameters is None or not isinstance(parameters, dict): - raise InvalidConfigurationParametersV0Error(expected="dict", got=parameters) + parameters = parameters or {} if "substring" not in parameters: raise MissingConfigurationParameterV0Error(path="substring") @@ -1235,8 +1213,7 @@ def auto_contains_any_v0( Returns: Evaluation result with success flag """ - if parameters is None or not isinstance(parameters, dict): - raise InvalidConfigurationParametersV0Error(expected="dict", got=parameters) + parameters = parameters or {} if "substrings" not in parameters: raise MissingConfigurationParameterV0Error(path="substrings") @@ -1293,8 +1270,7 @@ def auto_contains_all_v0( Returns: Evaluation result with success flag """ - if parameters is None or not isinstance(parameters, dict): - raise InvalidConfigurationParametersV0Error(expected="dict", got=parameters) + parameters = parameters or {} if "substrings" not in parameters: raise MissingConfigurationParameterV0Error(path="substrings") @@ -1393,13 +1369,9 @@ def auto_json_diff_v0( Returns: Evaluation result with score only (no diff explanation) """ - if parameters is None or not isinstance(parameters, dict): - raise InvalidConfigurationParametersV0Error(expected="dict", got=parameters) - - if "correct_answer_key" not in parameters: - raise MissingConfigurationParameterV0Error(path="correct_answer_key") + parameters = parameters or {} - correct_answer_key = str(parameters["correct_answer_key"]) + correct_answer_key = str(parameters.get("correct_answer_key", "correct_answer")) if inputs is None or not isinstance(inputs, dict): raise InvalidInputsV0Error(expected="dict", got=inputs) @@ -1485,13 +1457,9 @@ def auto_levenshtein_distance_v0( Dictionary with normalized similarity score (0 to 1), or error message if evaluation fails. """ - if parameters is None or not isinstance(parameters, dict): - raise InvalidConfigurationParametersV0Error(expected="dict", got=parameters) - - if "correct_answer_key" not in parameters: - raise MissingConfigurationParameterV0Error(path="correct_answer_key") + parameters = parameters or {} - correct_answer_key = str(parameters["correct_answer_key"]) + correct_answer_key = str(parameters.get("correct_answer_key", "correct_answer")) case_sensitive = parameters.get("case_sensitive", True) is True @@ -1590,13 +1558,9 @@ def auto_similarity_match_v0( Returns: Evaluation result with similarity score """ - if parameters is None or not isinstance(parameters, dict): - raise InvalidConfigurationParametersV0Error(expected="dict", got=parameters) + parameters = parameters or {} - if "correct_answer_key" not in parameters: - raise MissingConfigurationParameterV0Error(path="correct_answer_key") - - correct_answer_key = str(parameters["correct_answer_key"]) + correct_answer_key = str(parameters.get("correct_answer_key", "correct_answer")) case_sensitive = parameters.get("case_sensitive", True) is True @@ -1683,13 +1647,9 @@ async def auto_semantic_similarity_v0( Returns: Evaluation result with cosine similarity score """ - if parameters is None or not isinstance(parameters, dict): - raise InvalidConfigurationParametersV0Error(expected="dict", got=parameters) - - if "correct_answer_key" not in parameters: - raise MissingConfigurationParameterV0Error(path="correct_answer_key") + parameters = parameters or {} - correct_answer_key = str(parameters["correct_answer_key"]) + correct_answer_key = str(parameters.get("correct_answer_key", "correct_answer")) embedding_model = parameters.get("embedding_model", "text-embedding-3-small") @@ -2117,7 +2077,7 @@ async def completion_v0( required_keys = set(config.prompt.input_keys) provided_keys = set(_variables.keys()) - if required_keys != provided_keys: + if not required_keys.issubset(provided_keys): raise InvalidInputsV0Error( expected=sorted(required_keys), got=sorted(provided_keys), @@ -2191,7 +2151,7 @@ async def chat_v0( required_keys = set(config.prompt.input_keys) - {"messages"} provided_keys = set(_variables.keys()) - if required_keys != provided_keys: + if not required_keys.issubset(provided_keys): raise InvalidInputsV0Error( expected=sorted(required_keys), got=sorted(provided_keys), @@ -2830,8 +2790,7 @@ async def code_v0( ``{"success": bool}`` when code returns a bool. The raw dict / str when code returns one of those. """ - if parameters is None or not isinstance(parameters, dict): - raise InvalidConfigurationParametersV0Error(expected="dict", got=parameters) + parameters = parameters or {} if "code" not in parameters: raise MissingConfigurationParameterV0Error(path="code") @@ -2944,8 +2903,7 @@ async def match_v0( Returns: {key: result_node, ..., "score": float, "success": bool} — flat result dict """ - if parameters is None or not isinstance(parameters, dict): - raise InvalidConfigurationParametersV0Error(expected="dict", got=parameters) + parameters = parameters or {} if "matchers" not in parameters: raise MissingConfigurationParameterV0Error(path="matchers") @@ -3434,8 +3392,7 @@ async def llm_v0( from agenta.sdk.engines.running.errors import LLMUnavailableV0Error # --- Validate parameters - if parameters is None or not isinstance(parameters, dict): - raise InvalidConfigurationParametersV0Error(expected="dict", got=parameters) + parameters = parameters or {} llms = parameters.get("llms") if not llms or not isinstance(llms, list): diff --git a/sdks/python/agenta/sdk/evaluations/preview/evaluate.py b/sdks/python/agenta/sdk/evaluations/preview/evaluate.py index 8747b294fa..28e885ac18 100644 --- a/sdks/python/agenta/sdk/evaluations/preview/evaluate.py +++ b/sdks/python/agenta/sdk/evaluations/preview/evaluate.py @@ -10,18 +10,13 @@ Target, SimpleEvaluationData, ) -from agenta.sdk.models.shared import Link, Reference +from agenta.sdk.models.shared import Reference from agenta.sdk.models.workflows import ( ApplicationRevision, EvaluatorRevision, - WorkflowServiceRequestData, - ApplicationServiceRequest, - EvaluatorServiceRequest, ) from agenta.sdk.models.testsets import TestsetRevision -from agenta.sdk.evaluations.preview.utils import fetch_trace_data - from agenta.sdk.managers.testsets import ( acreate as acreate_testset, aretrieve as aretrieve_testset, @@ -42,18 +37,19 @@ from agenta.sdk.evaluations.scenarios import ( acreate as aadd_scenario, ) -from agenta.sdk.evaluations.results import ( - acreate as alog_result, -) from agenta.sdk.evaluations.metrics import ( arefresh as acompute_metrics, ) +from agenta.sdk.evaluations.runtime.models import EvaluationStep, ResolvedSourceItem +from agenta.sdk.evaluations.runtime.source_slice import process_evaluation_source_slice +from agenta.sdk.evaluations.runtime.adapters import ( + SdkLocalApplicationRunner, + SdkLocalEvaluatorRunner, + SdkResultLogger, + SdkTraceLoader, +) -from agenta.sdk.decorators.running import ( - invoke_application, - invoke_evaluator, -) from agenta.sdk.utils.logging import get_module_logger @@ -456,14 +452,25 @@ async def aevaluate( ) scenarios = list() - metrics = dict() + async def create_scenario(run_id: UUID): + return await aadd_scenario(run_id=run_id) + + async def refresh_metrics(run_id: UUID, scenario_id: Optional[UUID]): + if scenario_id: + return await acompute_metrics(run_id=run_id, scenario_id=scenario_id) + return await acompute_metrics(run_id=run_id) + + result_logger = SdkResultLogger() + trace_loader = SdkTraceLoader(max_retries=30, delay=1.0) + for testset_revision in testset_revisions.values(): if not testset_revision.data or not testset_revision.data.testcases: continue testcases = testset_revision.data.testcases + input_step_key = "testset-" + testset_revision.slug # type: ignore print( f"{UNICODE['next']}" @@ -474,352 +481,144 @@ async def aevaluate( f" testset_id={str(testset_revision.testset_id)}", ) - for testcase_idx, testcase in enumerate(testcases): - print( - f"{UNICODE['pipe']}" - f"{UNICODE['pipe']}" - f"{UNICODE['skip']}" - f"{UNICODE['skip']}" - f"{UNICODE['skip']}" - "-----------------------" - "--------------------------------------" - ) - - print( - f"{UNICODE['pipe']}" - f"{UNICODE['next' if testcase_idx < len(testcases) - 1 else 'last']}" - f"{UNICODE['here']}" - f"{UNICODE['skip']}" - f"{UNICODE['skip']}" - f"testcase_id={str(testcase.id)}", - ) - - scenario = await aadd_scenario( - run_id=run.id, - ) - - print( - f"{UNICODE['pipe']}" - f"{UNICODE['pipe' if testcase_idx < len(testcases) - 1 else 'skip']}" - f"{UNICODE['next']}" - f"{UNICODE['here']}" - f"{UNICODE['skip']}" - f"scenario_id={str(scenario.id)}", - ) - - results = dict() - - result = await alog_result( - run_id=run.id, - scenario_id=scenario.id, - step_key="testset-" + testset_revision.slug, # type: ignore - testcase_id=testcase.id, - ) - - print( - f"{UNICODE['pipe']}" - f"{UNICODE['pipe' if testcase_idx < len(testcases) - 1 else 'skip']}" - f"{UNICODE['pipe']}" - f"{UNICODE['next']}" - f"{UNICODE['here']}" - f" result_id={str(result.id)} (testcase)", - ) - - results[testset_revision.slug] = result - - _testcase = testcase.model_dump( - mode="json", - exclude_none=True, - ) # type: ignore - inputs = testcase.data - if isinstance(inputs, dict): - if "testcase_dedup_id" in inputs: - del inputs["testcase_dedup_id"] - - for application_revision in application_revisions.values(): - if not application_revision or not application_revision.data: - print("Missing or invalid application revision") - if application_revision: - print(application_revision.model_dump(exclude_none=True)) - continue - - # print(f" Application {application_revision.model_dump(exclude_none=True)}") # type: ignore - - references = dict( - testset=Reference( - id=testset_revision.testset_id, + steps = [ + EvaluationStep( + key=input_step_key, + type="input", + origin="custom", + references={ + "testset": Reference(id=testset_revision.testset_id), + "testset_variant": Reference( + id=testset_revision.testset_variant_id ), - testset_variant=Reference( - id=testset_revision.testset_variant_id, - ), - testset_revision=Reference( + "testset_revision": Reference( id=testset_revision.id, slug=testset_revision.slug, version=testset_revision.version, ), - application=Reference( - id=application_revision.application_id, - ), - application_variant=Reference( - id=application_revision.application_variant_id, - ), - application_revision=Reference( - id=application_revision.id, - slug=application_revision.slug, - version=application_revision.version, - ), - ) - links = None - - _revision = application_revision.model_dump( - mode="json", - exclude_none=True, - ) - parameters = ( - application_revision.data.parameters - if application_revision.data - else None - ) - - _trace = None - outputs = None - - workflow_service_request_data = WorkflowServiceRequestData( - revision=_revision, - parameters=parameters, - # - testcase=_testcase, - inputs=inputs, - # - trace=_trace, - outputs=outputs, - ) - - application_request = ApplicationServiceRequest( - data=workflow_service_request_data, - # - references=references, # type: ignore - links=links, # type: ignore - ) - - application_response = await invoke_application( - request=application_request, - ) - - if ( - not application_response - or not application_response.data - or not application_response.trace_id - ): - print("Missing or invalid application response") - if application_response: - print(application_response.model_dump(exclude_none=True)) - continue - - trace_id = application_response.trace_id - - if not application_revision.slug: - print("Missing application revision slug") - continue - - application_slug = application_revision.slug - - trace = fetch_trace_data(trace_id, max_retries=30, delay=1.0) - - result = await alog_result( - run_id=run.id, - scenario_id=scenario.id, - step_key="application-" + application_slug, # type: ignore - trace_id=trace_id, - ) - - print( - f"{UNICODE['pipe']}" - f"{UNICODE['pipe' if testcase_idx < len(testcases) - 1 else 'skip']}" - f"{UNICODE['pipe']}" - f"{UNICODE['next']}" - f"{UNICODE['here']}" - f" result_id={str(result.id)} (invocation)", - ) - - results[application_slug] = result - - trace = await trace - - if not trace: - print("Failed to fetch trace data for application") - continue - - root_span = list(trace.get("spans", {}).values())[0] - trace_attributes: dict = root_span.get("attributes", {}) - trace_attributes_ag: dict = trace_attributes.get("ag", {}) - trace_attributes_ag_data: dict = trace_attributes_ag.get("data", {}) - outputs = trace_attributes_ag_data.get("outputs") - inputs = inputs or trace_attributes_ag_data.get("inputs") - - for i, evaluator_revision in enumerate(evaluator_revisions.values()): - if not evaluator_revision or not evaluator_revision.data: - print("Missing or invalid evaluator revision") - if evaluator_revision: - print(evaluator_revision.model_dump(exclude_none=True)) - continue + }, + ) + ] + runners: Dict[str, Any] = {} + revisions: Dict[str, Any] = {} + + for application_revision in application_revisions.values(): + if not application_revision or not application_revision.data: + print("Missing or invalid application revision") + if application_revision: + print(application_revision.model_dump(exclude_none=True)) + continue - references = dict( - testset=Reference( - id=testset_revision.testset_id, + application_step_key = "application-" + application_revision.slug # type: ignore + steps.append( + EvaluationStep( + key=application_step_key, + type="invocation", + origin="auto", + references={ + "application": Reference( + id=application_revision.application_id ), - testset_variant=Reference( - id=testset_revision.testset_variant_id, - ), - testset_revision=Reference( - id=testset_revision.id, - slug=testset_revision.slug, - version=testset_revision.version, + "application_variant": Reference( + id=application_revision.application_variant_id, ), - evaluator=Reference( - id=evaluator_revision.evaluator_id, + "application_revision": Reference( + id=application_revision.id, + slug=application_revision.slug, + version=application_revision.version, ), - evaluator_variant=Reference( + }, + ) + ) + runners[application_step_key] = SdkLocalApplicationRunner() + revisions[application_step_key] = application_revision + + for ( + evaluator_revision_id, + origin, + ) in simple_evaluation_data.evaluator_steps.items(): + evaluator_revision = evaluator_revisions.get( + evaluator_revision_id + ) or evaluator_revisions.get(UUID(str(evaluator_revision_id))) + if not evaluator_revision or not evaluator_revision.data: + print("Missing or invalid evaluator revision") + if evaluator_revision: + print(evaluator_revision.model_dump(exclude_none=True)) + continue + + evaluator_step_key = "evaluator-" + evaluator_revision.slug # type: ignore + steps.append( + EvaluationStep( + key=evaluator_step_key, + type="annotation", + origin=origin, + references={ + "evaluator": Reference(id=evaluator_revision.evaluator_id), + "evaluator_variant": Reference( id=evaluator_revision.evaluator_variant_id, ), - evaluator_revision=Reference( + "evaluator_revision": Reference( id=evaluator_revision.id, slug=evaluator_revision.slug, version=evaluator_revision.version, ), - ) - links = ( - dict( - invocation=Link( - trace_id=application_response.trace_id, - span_id=application_response.span_id, - ) - ) - if application_response.trace_id - and application_response.span_id - else None - ) - - _revision = evaluator_revision.model_dump( - mode="json", - exclude_none=True, - ) - parameters = ( - evaluator_revision.data.parameters - if evaluator_revision.data - else None - ) - - workflow_service_request_data = WorkflowServiceRequestData( - revision=_revision, - parameters=parameters, - # - testcase=_testcase, - inputs=inputs, - # - trace=trace, - outputs=outputs, - ) - - evaluator_request = EvaluatorServiceRequest( - version="2025.07.14", - # - data=workflow_service_request_data, - # - references=references, # type: ignore - links=links, # type: ignore - ) - - evaluator_response = await invoke_evaluator( - request=evaluator_request, - ) - - if ( - not evaluator_response - or not evaluator_response.data - or not evaluator_response.trace_id - ): - print("Missing or invalid evaluator response") - if evaluator_response: - print(evaluator_response.model_dump(exclude_none=True)) - continue - - trace_id = evaluator_response.trace_id - - trace = fetch_trace_data(trace_id, max_retries=30, delay=1.0) - - result = await alog_result( - run_id=run.id, - scenario_id=scenario.id, - step_key="evaluator-" + evaluator_revision.slug, # type: ignore - trace_id=trace_id, - ) - - print( - f"{UNICODE['pipe']}" - f"{UNICODE['pipe' if testcase_idx < len(testcases) - 1 else 'skip']}" - f"{UNICODE['pipe']}" - f"{UNICODE['last' if (i == len(evaluator_revisions) - 1) else 'next']}" - f"{UNICODE['here']}" - f" result_id={str(result.id)} (annotation)", - ) - - results[evaluator_revision.slug] = result - - trace = await trace - - if not trace: - print("Failed to fetch trace data for evaluator") - continue - - metrics = await acompute_metrics( - run_id=run.id, - scenario_id=scenario.id, - ) - - print( - f"{UNICODE['pipe']}" - f"{UNICODE['pipe' if testcase_idx < len(testcases) - 1 else 'skip']}" - f"{UNICODE['last']}" - f"{UNICODE['here']}" - f"{UNICODE['skip']}" - f" metrics_id={str(metrics.id)}", + }, + ) ) - - scenarios.append( - { - "scenario": scenario, - "results": results, - "metrics": metrics, - }, + if origin == "auto": + runners[evaluator_step_key] = SdkLocalEvaluatorRunner() + revisions[evaluator_step_key] = evaluator_revision + + source_items = [] + for testcase in testcases: + inputs = dict(testcase.data or {}) + inputs.pop("testcase_dedup_id", None) + source_items.append( + ResolvedSourceItem( + kind="testcase", + step_key=input_step_key, + references={ + "testcase": Reference(id=testcase.id), + "testset": Reference(id=testset_revision.testset_id), + "testset_variant": Reference( + id=testset_revision.testset_variant_id, + ), + "testset_revision": Reference( + id=testset_revision.id, + slug=testset_revision.slug, + version=testset_revision.version, + ), + }, + testcase_id=testcase.id, + testcase=testcase.model_dump(mode="json", exclude_none=True), + inputs=inputs, + ) ) - print( - f"{UNICODE['pipe']}" - f"{UNICODE['skip']}" - f"{UNICODE['skip']}" - f"{UNICODE['skip']}" - f"{UNICODE['skip']}" - "-----------------------" - "--------------------------------------" - ) - - metrics = dict() - - if len(scenarios) > 0: - metrics = await acompute_metrics( + processed = await process_evaluation_source_slice( run_id=run.id, + source_items=source_items, + steps=steps, + repeats=simple_evaluation_data.repeats, + create_scenario=create_scenario, + result_logger=result_logger, + refresh_metrics=refresh_metrics, + runners=runners, + revisions=revisions, + trace_loader=trace_loader, ) - - print( - f"{UNICODE['last']}" - f"{UNICODE['here']}" - f"{UNICODE['skip']}" - f"{UNICODE['skip']}" - f"{UNICODE['skip']}" - f" metrics_id={str(metrics.id)}", + scenarios.extend( + { + "scenario": item.scenario, + "results": item.results, + "metrics": item.metrics, + } + for item in processed ) + if len(scenarios) > 0: + metrics = await acompute_metrics(run_id=run.id) + run = await aclose_run( run_id=run.id, ) diff --git a/sdks/python/agenta/sdk/evaluations/results.py b/sdks/python/agenta/sdk/evaluations/results.py index ca0aebb170..34abb215f7 100644 --- a/sdks/python/agenta/sdk/evaluations/results.py +++ b/sdks/python/agenta/sdk/evaluations/results.py @@ -12,7 +12,7 @@ async def acreate( run_id: UUID, scenario_id: UUID, step_key: str, - # repeat_idx: str, + repeat_idx: Optional[int] = 0, # timestamp: datetime, # interval: float, # @@ -37,7 +37,7 @@ async def acreate( # # interval=interval, # timestamp=timestamp, - # repeat_idx=repeat_idx, + repeat_idx=repeat_idx, step_key=step_key, run_id=str(run_id), scenario_id=str(scenario_id), diff --git a/sdks/python/agenta/sdk/models/evaluations.py b/sdks/python/agenta/sdk/models/evaluations.py index 267f996fab..a25020580d 100644 --- a/sdks/python/agenta/sdk/models/evaluations.py +++ b/sdks/python/agenta/sdk/models/evaluations.py @@ -42,6 +42,9 @@ class EvaluationRunFlags(BaseModel): is_live: bool = False # Indicates if the run has live queries is_active: bool = False # Indicates if the run is currently active is_closed: bool = False # Indicates if the run is modifiable + is_queue: bool = False # Indicates this run belongs to an annotation queue + is_cached: bool = False # Indicates the run should reuse traces by hash + is_split: bool = False # Indicates repeats fan out at the application step # has_queries: bool = False # Indicates if the run has queries has_testsets: bool = False # Indicates if the run has testsets @@ -86,9 +89,10 @@ class EvaluationResult(BaseModel): run_id: UUID scenario_id: UUID step_key: str + repeat_idx: Optional[int] = 0 testcase_id: Optional[UUID] = None - trace_id: Optional[UUID] = None + trace_id: Optional[Union[UUID, str]] = None error: Optional[dict] = None flags: Optional[Dict[str, Any]] = None diff --git a/services/entrypoints/main.py b/services/entrypoints/main.py index 919698c5ae..29fc7309d6 100644 --- a/services/entrypoints/main.py +++ b/services/entrypoints/main.py @@ -177,4 +177,5 @@ async def health(): port=8080, reload=True, reload_dirs=[".", "/sdk"], + reload_excludes=["**/tests/**", "/sdk/tests"], ) diff --git a/web/ee/package.json b/web/ee/package.json index 3075819a19..22735484c3 100644 --- a/web/ee/package.json +++ b/web/ee/package.json @@ -43,7 +43,6 @@ "@types/react-dom": "^19.0.4", "@types/react-resizable": "^3.0.7", "@types/react-window": "^1.8.8", - "@types/recharts": "^2.0.1", "antd": "^6.1.3", "autoprefixer": "10.4.20", "axios": "1.16.0", diff --git a/web/oss/src/components/pages/evaluations/NewEvaluation/Components/AdvancedSettings.tsx b/web/oss/src/components/pages/evaluations/NewEvaluation/Components/AdvancedSettings.tsx index 4edb5403b0..c825e9af04 100644 --- a/web/oss/src/components/pages/evaluations/NewEvaluation/Components/AdvancedSettings.tsx +++ b/web/oss/src/components/pages/evaluations/NewEvaluation/Components/AdvancedSettings.tsx @@ -1,14 +1,26 @@ import {memo, useCallback, useMemo} from "react" import {QuestionCircleOutlined} from "@ant-design/icons" -import {Button, Col, Flex, Form, Input, InputNumber, Row, Tooltip, Typography} from "antd" +import {Button, Flex, Form, InputNumber, Tooltip, Typography} from "antd" import deepEqual from "fast-deep-equal" import {DEFAULT_ADVANCE_SETTINGS} from "../assets/constants" -import {AdvancedSettingsProps} from "../types" +import {AdvancedSettingsProps, EvaluationConcurrencySettings} from "../types" + +const FIELD_LABELS: Record = { + batch_size: "Batch Size", + max_retries: "Max Retries", + retry_delay: "Retry Delay (s)", +} + +const FIELD_TOOLTIPS: Record = { + batch_size: "Maximum number of concurrent invocations", + max_retries: "How many times to retry a failed invocation", + retry_delay: "Seconds to wait before retrying a failed invocation", +} const AdvancedSettings = ({advanceSettings, setAdvanceSettings}: AdvancedSettingsProps) => { - const handleChange = (key: string, value: any) => { + const handleChange = (key: keyof EvaluationConcurrencySettings, value: number | null) => { setAdvanceSettings((prev) => ({ ...prev, [key]: value, @@ -24,17 +36,14 @@ const AdvancedSettings = ({advanceSettings, setAdvanceSettings}: AdvancedSetting [advanceSettings], ) - const {correct_answer_column, ...rateLimitConfig} = advanceSettings - return (
- Rate Limit Configuration + Concurrency {isAdvancedSettingsChanged && (