Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 27 additions & 5 deletions src/flyte_mcp/tools/runs.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
from typing import Any, Literal
from collections.abc import Sequence
from typing import Annotated, Any, Literal, cast

from pydantic import Field

from flyte.remote import Run, TimeFilter
from flyte.remote._run import RunDetails
Expand All @@ -8,6 +11,9 @@
from flyte_mcp.models import run_status_payload, to_mcp_payload
from flyte_mcp.runtime import get_execution_domain, get_execution_project

SortDirection = Literal["asc", "desc"]
SortByParam = Annotated[list[str], Field(min_length=2, max_length=2)]


def _normalize_in_phase(
in_phase: list[str] | tuple[str, ...] | None,
Expand All @@ -23,6 +29,21 @@ def _normalize_in_phase(
return normalized or None


def _normalize_sort_by(sort_by: Sequence[str] | None) -> tuple[str, SortDirection]:
if sort_by is None:
return ("created_at", "desc")

if len(sort_by) != 2:
raise ValueError("sort_by must be [field, direction]")

field, direction = sort_by
normalized_direction = direction.strip().lower()
if normalized_direction not in {"asc", "desc"}:
raise ValueError("sort_by direction must be 'asc' or 'desc'")

return (field, cast(SortDirection, normalized_direction))


async def _get_run(
run_name: str,
*,
Expand Down Expand Up @@ -71,7 +92,7 @@ async def list_runs(
task_name: str | None = None,
task_version: str | None = None,
created_by_subject: str | None = None,
sort_by: tuple[str, Literal["asc", "desc"]] | None = None,
sort_by: SortByParam | None = None,
project: str | None = None,
domain: str | None = None,
created_at: TimeFilter | None = None,
Expand All @@ -92,8 +113,9 @@ async def list_runs(
Optional Flyte task version to filter by.
created_by_subject : str or None, default=None
Optional creator subject to filter by.
sort_by : tuple[str, Literal["asc", "desc"]] or None, default=None
Optional sort tuple. Defaults to ``("created_at", "desc")`` when omitted.
sort_by : list[str] or None, default=None
Optional two-item sort array ``[field, direction]``. Defaults to
``["created_at", "desc"]`` when omitted.
project : str or None, default=None
Optional Flyte project override. Falls back to the configured execution
project when omitted.
Expand All @@ -114,7 +136,7 @@ async def list_runs(
resolved_project = get_execution_project(project)
resolved_domain = get_execution_domain(domain)
normalized_in_phase = _normalize_in_phase(in_phase)
resolved_sort_by = sort_by or ("created_at", "desc")
resolved_sort_by = _normalize_sort_by(sort_by)
runs: list[dict[str, Any]] = []

async for run in Run.listall.aio(
Expand Down
20 changes: 20 additions & 0 deletions tests/unit/test_server_core.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
import asyncio
import inspect

import flyte_mcp.server as server
from pydantic import create_model

from flyte_mcp.tools.runs import list_runs


def test_initialize_server_calls_flyte_init(monkeypatch) -> None:
Expand Down Expand Up @@ -28,3 +32,19 @@ def test_static_tools_are_registered() -> None:
assert "list_project_versions" in tool_names
assert "list_task_versions" in tool_names
assert "list_runs" in tool_names


def test_list_runs_schema_uses_array_items_for_sort_by() -> None:
fields = {}
for name, param in inspect.signature(list_runs).parameters.items():
default = ... if param.default is inspect._empty else param.default
fields[name] = (param.annotation, default)

schema = create_model("list_runs_params", **fields).model_json_schema()
sort_by_schema = schema["properties"]["sort_by"]["anyOf"][0]

assert sort_by_schema["type"] == "array"
assert sort_by_schema["items"] == {"type": "string"}
assert sort_by_schema["minItems"] == 2
assert sort_by_schema["maxItems"] == 2
assert "prefixItems" not in sort_by_schema
49 changes: 48 additions & 1 deletion tests/unit/tools/test_runs.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ async def fake_listall(**kwargs):
task_name="root_env.task",
task_version="v123",
created_by_subject="user-123",
sort_by=("updated_at", "asc"),
sort_by=["updated_at", "asc"],
project="other-project",
domain="development",
created_at=created_at,
Expand All @@ -165,6 +165,53 @@ async def fake_listall(**kwargs):
}


def test_list_runs_normalizes_uppercase_sort_direction(monkeypatch):
async def fake_details_aio():
return SimpleNamespace(
action_details=SimpleNamespace(
task_name="root_env.task",
error_info=None,
)
)

fake_run = SimpleNamespace(
name="run-11",
phase="succeeded",
url="https://console/run-11",
done=lambda: True,
details=SimpleNamespace(aio=fake_details_aio),
)

async def fake_listall(**kwargs):
assert kwargs["sort_by"] == ("updated_at", "desc")
yield fake_run

monkeypatch.setattr(
run_tools,
"get_execution_project",
lambda project=None, task=None: "isolated-project",
)
monkeypatch.setattr(
run_tools,
"get_execution_domain",
lambda domain=None, task=None: "isolated-domain",
)
monkeypatch.setattr(run_tools.Run, "listall", SimpleNamespace(aio=fake_listall))

payload = asyncio.run(list_runs(sort_by=["updated_at", "DESC"]))

assert payload["sort_by"] == ["updated_at", "desc"]


def test_list_runs_rejects_invalid_sort_direction():
try:
run_tools._normalize_sort_by(["updated_at", "newest"])
except ValueError as exc:
assert str(exc) == "sort_by direction must be 'asc' or 'desc'"
else:
raise AssertionError("Expected ValueError")


def test_get_run_status_returns_phase_payload(monkeypatch):
async def fake_details():
return SimpleNamespace(
Expand Down
Loading