diff --git a/src/flyte_mcp/runtime.py b/src/flyte_mcp/runtime.py index c0cd27e..cc20f2d 100644 --- a/src/flyte_mcp/runtime.py +++ b/src/flyte_mcp/runtime.py @@ -43,9 +43,11 @@ async def run_remote_task( *, execution_project: str | None = None, execution_domain: str | None = None, + overwrite_cache: bool = False, **inputs: Any, ) -> Any: return await flyte.with_runcontext( project=execution_project or get_settings().execution_project, domain=execution_domain or get_settings().execution_domain, + overwrite_cache=overwrite_cache, ).run.aio(task, **inputs) diff --git a/src/flyte_mcp/tools/tasks.py b/src/flyte_mcp/tools/tasks.py index 2a9a125..d981b59 100644 --- a/src/flyte_mcp/tools/tasks.py +++ b/src/flyte_mcp/tools/tasks.py @@ -460,6 +460,7 @@ async def run_task( execution_project: str | None = None, execution_domain: str | None = None, wait: bool = False, + overwrite_cache: bool = False, ) -> dict[str, Any]: """ Run one Flyte task with explicit inputs. @@ -487,6 +488,9 @@ async def run_task( execution domain is used, falling back to the runtime default resolution. wait : bool, default=False When ``True``, wait for the run to complete and include outputs when available. + overwrite_cache : bool, default=False + When ``True``, ignore existing cached results and force re-execution of the task. + The new results will overwrite any previously cached outputs. Returns ------- @@ -523,6 +527,7 @@ async def run_task( lazy_task, execution_project=resolved_exec_project, execution_domain=resolved_exec_domain, + overwrite_cache=overwrite_cache, **coerced_inputs, ) payload = { diff --git a/tests/unit/tools/test_tasks.py b/tests/unit/tools/test_tasks.py index 24dd15d..78b6fd5 100644 --- a/tests/unit/tools/test_tasks.py +++ b/tests/unit/tools/test_tasks.py @@ -417,11 +417,17 @@ async def _outputs(self) -> dict[str, Any]: captured: dict[str, Any] = {} async def fake_run_remote_task( - task_obj, *, execution_project, execution_domain, **inputs + task_obj, + *, + execution_project, + execution_domain, + overwrite_cache=False, + **inputs, ): captured["task"] = task_obj captured["execution_project"] = execution_project captured["execution_domain"] = execution_domain + captured["overwrite_cache"] = overwrite_cache captured["inputs"] = inputs return FakeRun() @@ -472,6 +478,7 @@ async def fake_run_remote_task( } assert captured["execution_project"] == "my-project" assert captured["execution_domain"] == "development" + assert captured["overwrite_cache"] is False assert isinstance(captured["inputs"]["e"], DataFrame) assert isinstance(captured["inputs"]["h"], File) assert captured["task"]["lazy_task"]["version"] == "v1" @@ -492,11 +499,17 @@ def done(self) -> bool: captured: dict[str, Any] = {} async def fake_run_remote_task( - task_obj, *, execution_project, execution_domain, **inputs + task_obj, + *, + execution_project, + execution_domain, + overwrite_cache=False, + **inputs, ): captured["task"] = task_obj captured["execution_project"] = execution_project captured["execution_domain"] = execution_domain + captured["overwrite_cache"] = overwrite_cache captured["inputs"] = inputs return FakeRun() @@ -525,11 +538,13 @@ async def fake_run_remote_task( ), execution_project="custom-project", execution_domain="staging", + overwrite_cache=True, ) ) assert captured["execution_project"] == "resolved:custom-project" assert captured["execution_domain"] == "resolved:staging" + assert captured["overwrite_cache"] is True assert payload["run_scope"] == { "project": "resolved:custom-project", "domain": "resolved:staging",