Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/flyte_mcp/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,11 @@ async def run_remote_task(
*,
execution_project: str | None = None,
execution_domain: str | None = None,
overwrite_cache: bool = False,
**inputs: Any,
) -> Any:
return await flyte.with_runcontext(
project=execution_project or get_settings().execution_project,
domain=execution_domain or get_settings().execution_domain,
overwrite_cache=overwrite_cache,
).run.aio(task, **inputs)
5 changes: 5 additions & 0 deletions src/flyte_mcp/tools/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -460,6 +460,7 @@ async def run_task(
execution_project: str | None = None,
execution_domain: str | None = None,
wait: bool = False,
overwrite_cache: bool = False,
) -> dict[str, Any]:
"""
Run one Flyte task with explicit inputs.
Expand Down Expand Up @@ -487,6 +488,9 @@ async def run_task(
execution domain is used, falling back to the runtime default resolution.
wait : bool, default=False
When ``True``, wait for the run to complete and include outputs when available.
overwrite_cache : bool, default=False
When ``True``, ignore existing cached results and force re-execution of the task.
The new results will overwrite any previously cached outputs.

Returns
-------
Expand Down Expand Up @@ -523,6 +527,7 @@ async def run_task(
lazy_task,
execution_project=resolved_exec_project,
execution_domain=resolved_exec_domain,
overwrite_cache=overwrite_cache,
**coerced_inputs,
)
payload = {
Expand Down
19 changes: 17 additions & 2 deletions tests/unit/tools/test_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,11 +417,17 @@ async def _outputs(self) -> dict[str, Any]:
captured: dict[str, Any] = {}

async def fake_run_remote_task(
task_obj, *, execution_project, execution_domain, **inputs
task_obj,
*,
execution_project,
execution_domain,
overwrite_cache=False,
**inputs,
):
captured["task"] = task_obj
captured["execution_project"] = execution_project
captured["execution_domain"] = execution_domain
captured["overwrite_cache"] = overwrite_cache
captured["inputs"] = inputs
return FakeRun()

Expand Down Expand Up @@ -472,6 +478,7 @@ async def fake_run_remote_task(
}
assert captured["execution_project"] == "my-project"
assert captured["execution_domain"] == "development"
assert captured["overwrite_cache"] is False
assert isinstance(captured["inputs"]["e"], DataFrame)
assert isinstance(captured["inputs"]["h"], File)
assert captured["task"]["lazy_task"]["version"] == "v1"
Expand All @@ -492,11 +499,17 @@ def done(self) -> bool:
captured: dict[str, Any] = {}

async def fake_run_remote_task(
task_obj, *, execution_project, execution_domain, **inputs
task_obj,
*,
execution_project,
execution_domain,
overwrite_cache=False,
**inputs,
):
captured["task"] = task_obj
captured["execution_project"] = execution_project
captured["execution_domain"] = execution_domain
captured["overwrite_cache"] = overwrite_cache
captured["inputs"] = inputs
return FakeRun()

Expand Down Expand Up @@ -525,11 +538,13 @@ async def fake_run_remote_task(
),
execution_project="custom-project",
execution_domain="staging",
overwrite_cache=True,
)
)

assert captured["execution_project"] == "resolved:custom-project"
assert captured["execution_domain"] == "resolved:staging"
assert captured["overwrite_cache"] is True
assert payload["run_scope"] == {
"project": "resolved:custom-project",
"domain": "resolved:staging",
Expand Down
Loading