diff --git a/src/flyte_mcp/tools/runs.py b/src/flyte_mcp/tools/runs.py index ebef736..1875de5 100644 --- a/src/flyte_mcp/tools/runs.py +++ b/src/flyte_mcp/tools/runs.py @@ -1,4 +1,7 @@ -from typing import Any, Literal +from collections.abc import Sequence +from typing import Annotated, Any, Literal, cast + +from pydantic import Field from flyte.remote import Run, TimeFilter from flyte.remote._run import RunDetails @@ -8,6 +11,9 @@ from flyte_mcp.models import run_status_payload, to_mcp_payload from flyte_mcp.runtime import get_execution_domain, get_execution_project +SortDirection = Literal["asc", "desc"] +SortByParam = Annotated[list[str], Field(min_length=2, max_length=2)] + def _normalize_in_phase( in_phase: list[str] | tuple[str, ...] | None, @@ -23,6 +29,21 @@ def _normalize_in_phase( return normalized or None +def _normalize_sort_by(sort_by: Sequence[str] | None) -> tuple[str, SortDirection]: + if sort_by is None: + return ("created_at", "desc") + + if len(sort_by) != 2: + raise ValueError("sort_by must be [field, direction]") + + field, direction = sort_by + normalized_direction = direction.strip().lower() + if normalized_direction not in {"asc", "desc"}: + raise ValueError("sort_by direction must be 'asc' or 'desc'") + + return (field, cast(SortDirection, normalized_direction)) + + async def _get_run( run_name: str, *, @@ -71,7 +92,7 @@ async def list_runs( task_name: str | None = None, task_version: str | None = None, created_by_subject: str | None = None, - sort_by: tuple[str, Literal["asc", "desc"]] | None = None, + sort_by: SortByParam | None = None, project: str | None = None, domain: str | None = None, created_at: TimeFilter | None = None, @@ -92,8 +113,9 @@ async def list_runs( Optional Flyte task version to filter by. created_by_subject : str or None, default=None Optional creator subject to filter by. - sort_by : tuple[str, Literal["asc", "desc"]] or None, default=None - Optional sort tuple. Defaults to ``("created_at", "desc")`` when omitted. + sort_by : list[str] or None, default=None + Optional two-item sort array ``[field, direction]``. Defaults to + ``["created_at", "desc"]`` when omitted. project : str or None, default=None Optional Flyte project override. Falls back to the configured execution project when omitted. @@ -114,7 +136,7 @@ async def list_runs( resolved_project = get_execution_project(project) resolved_domain = get_execution_domain(domain) normalized_in_phase = _normalize_in_phase(in_phase) - resolved_sort_by = sort_by or ("created_at", "desc") + resolved_sort_by = _normalize_sort_by(sort_by) runs: list[dict[str, Any]] = [] async for run in Run.listall.aio( diff --git a/tests/unit/test_server_core.py b/tests/unit/test_server_core.py index 7c90bc6..7b24da9 100644 --- a/tests/unit/test_server_core.py +++ b/tests/unit/test_server_core.py @@ -1,6 +1,10 @@ import asyncio +import inspect import flyte_mcp.server as server +from pydantic import create_model + +from flyte_mcp.tools.runs import list_runs def test_initialize_server_calls_flyte_init(monkeypatch) -> None: @@ -28,3 +32,19 @@ def test_static_tools_are_registered() -> None: assert "list_project_versions" in tool_names assert "list_task_versions" in tool_names assert "list_runs" in tool_names + + +def test_list_runs_schema_uses_array_items_for_sort_by() -> None: + fields = {} + for name, param in inspect.signature(list_runs).parameters.items(): + default = ... if param.default is inspect._empty else param.default + fields[name] = (param.annotation, default) + + schema = create_model("list_runs_params", **fields).model_json_schema() + sort_by_schema = schema["properties"]["sort_by"]["anyOf"][0] + + assert sort_by_schema["type"] == "array" + assert sort_by_schema["items"] == {"type": "string"} + assert sort_by_schema["minItems"] == 2 + assert sort_by_schema["maxItems"] == 2 + assert "prefixItems" not in sort_by_schema diff --git a/tests/unit/tools/test_runs.py b/tests/unit/tools/test_runs.py index 8055554..ee906ab 100644 --- a/tests/unit/tools/test_runs.py +++ b/tests/unit/tools/test_runs.py @@ -144,7 +144,7 @@ async def fake_listall(**kwargs): task_name="root_env.task", task_version="v123", created_by_subject="user-123", - sort_by=("updated_at", "asc"), + sort_by=["updated_at", "asc"], project="other-project", domain="development", created_at=created_at, @@ -165,6 +165,53 @@ async def fake_listall(**kwargs): } +def test_list_runs_normalizes_uppercase_sort_direction(monkeypatch): + async def fake_details_aio(): + return SimpleNamespace( + action_details=SimpleNamespace( + task_name="root_env.task", + error_info=None, + ) + ) + + fake_run = SimpleNamespace( + name="run-11", + phase="succeeded", + url="https://console/run-11", + done=lambda: True, + details=SimpleNamespace(aio=fake_details_aio), + ) + + async def fake_listall(**kwargs): + assert kwargs["sort_by"] == ("updated_at", "desc") + yield fake_run + + monkeypatch.setattr( + run_tools, + "get_execution_project", + lambda project=None, task=None: "isolated-project", + ) + monkeypatch.setattr( + run_tools, + "get_execution_domain", + lambda domain=None, task=None: "isolated-domain", + ) + monkeypatch.setattr(run_tools.Run, "listall", SimpleNamespace(aio=fake_listall)) + + payload = asyncio.run(list_runs(sort_by=["updated_at", "DESC"])) + + assert payload["sort_by"] == ["updated_at", "desc"] + + +def test_list_runs_rejects_invalid_sort_direction(): + try: + run_tools._normalize_sort_by(["updated_at", "newest"]) + except ValueError as exc: + assert str(exc) == "sort_by direction must be 'asc' or 'desc'" + else: + raise AssertionError("Expected ValueError") + + def test_get_run_status_returns_phase_payload(monkeypatch): async def fake_details(): return SimpleNamespace(