From 78fc95ecf38031fa7064eee67a837fab35f29b48 Mon Sep 17 00:00:00 2001 From: David <3dgiordano@gmail.com> Date: Tue, 28 Apr 2026 15:27:31 -0300 Subject: [PATCH 1/5] Initial support of Tasks and Dataframes --- formatters/test.py | 14 + main.py | 70 +- models/result.py | 77 ++ pyproject.toml | 1 + server.py | 2 + tests/test_async_task_manager_ids.py | 75 ++ tests/test_batch_controls.py | 16 +- tests/test_dataframe_manager.py | 367 +++++++++ tests/test_failure_criteria.py | 4 +- tests/test_required_args_tools.py | 128 ++- tests/test_skills_manager_security.py | 6 +- tests/test_test_formatter_minimal.py | 23 + tests/test_test_manager_list_project_list.py | 38 + tests/test_tool_result_wrapper.py | 20 + ...s_manager_dataframe_query_result_format.py | 66 ++ tests/test_tools_manager_dataframes_remove.py | 48 ++ tests/test_tools_manager_polling_format.py | 129 +++ .../test_tools_manager_schema_groups_info.py | 37 + tests/test_utils_batch.py | 118 +++ tests/test_utils_normalize_action_args.py | 57 ++ tests/test_utils_required_args.py | 15 + tests/test_utils_ttl_cache.py | 167 ++++ tools/account_manager.py | 40 +- tools/async_task_manager.py | 254 ++++++ tools/billing_manager.py | 25 +- tools/bridge.py | 11 +- tools/dataframe_manager.py | 777 ++++++++++++++++++ tools/execution_manager.py | 102 ++- tools/help_manager.py | 127 +-- tools/project_manager.py | 50 +- tools/report_manager.py | 21 +- tools/skills_manager.py | 168 ++-- tools/test_manager.py | 219 +++-- tools/tools_manager.py | 736 +++++++++++++++++ tools/user_manager.py | 14 +- tools/utils.py | 774 +++++++++++++++-- tools/workspace_manager.py | 71 +- uv.lock | 30 + 38 files changed, 4470 insertions(+), 427 deletions(-) create mode 100644 tests/test_async_task_manager_ids.py create mode 100644 tests/test_dataframe_manager.py create mode 100644 tests/test_test_formatter_minimal.py create mode 100644 tests/test_test_manager_list_project_list.py create mode 100644 tests/test_tool_result_wrapper.py create mode 100644 tests/test_tools_manager_dataframe_query_result_format.py create mode 100644 tests/test_tools_manager_dataframes_remove.py create mode 100644 tests/test_tools_manager_polling_format.py create mode 100644 tests/test_tools_manager_schema_groups_info.py create mode 100644 tests/test_utils_batch.py create mode 100644 tests/test_utils_normalize_action_args.py create mode 100644 tests/test_utils_required_args.py create mode 100644 tests/test_utils_ttl_cache.py create mode 100644 tools/async_task_manager.py create mode 100644 tools/dataframe_manager.py create mode 100644 tools/tools_manager.py diff --git a/formatters/test.py b/formatters/test.py index 89e334e..a850384 100644 --- a/formatters/test.py +++ b/formatters/test.py @@ -93,3 +93,17 @@ def format_tests(tests: List[Any], params: Optional[dict] = None) -> List[Test]: ) ) return formatted_tests + + +def format_tests_minimal(tests: List[Any], params: Optional[dict] = None) -> List[dict]: + formatted_tests = [] + for test in tests: + formatted_tests.append({ + "test_id": test.get("id"), + "test_name": test.get("name", "Unknown"), + "description": test.get("description", ""), + "created": get_date_time_iso(test.get("created")), + "updated": get_date_time_iso(test.get("updated")), + "project_id": test.get("projectId"), + }) + return formatted_tests diff --git a/main.py b/main.py index d9dab9b..45a8bbe 100644 --- a/main.py +++ b/main.py @@ -27,6 +27,25 @@ from pathlib import Path from typing import Literal, cast +# Patch MCP ArgModelBase so tools with an "arguments" param receive the full payload +# when the client sends {"action": "x", "key": "value"} instead of {"arguments": {...}} +from mcp.server.fastmcp.utilities import func_metadata +from pydantic import model_validator + +_OriginalArgModelBase = func_metadata.ArgModelBase + + +class _PatchedArgModelBase(_OriginalArgModelBase): + @model_validator(mode="before") + @classmethod + def _wrap_root_as_arguments(cls, data: object) -> object: + if isinstance(data, dict) and "arguments" not in data: + return {"arguments": data} + return data + + +func_metadata.ArgModelBase = _PatchedArgModelBase + from mcp.server.fastmcp import FastMCP from config.token import BzmToken, BzmTokenError @@ -372,7 +391,10 @@ def get_token(): return token -def run(log_level: str = "CRITICAL", confirm_mode: ConfirmMode = ConfirmMode.DELETE): +def run( + log_level: str = "CRITICAL", + confirm_mode: ConfirmMode = ConfirmMode.DELETE +): token = get_token() instructions = """ # BlazeMeter MCP Server @@ -384,6 +406,11 @@ def run(log_level: str = "CRITICAL", confirm_mode: ConfirmMode = ConfirmMode.DEL - **Read action always gets more information** about a particular item than the list action. List only displays minimal information. - **Read the current user information at startup** to learn the username, default account, workspace and project, and other important information. - **Links anchors**: Never invent or add anchors to links if they do not originally have them. +- **BlazeMeter tasks must use BlazeMeter MCP only**. +- **No direct BlazeMeter API access** (including tokens/keys/curl/scripts/custom HTTP); do not bypass **BlazeMeter MCP** for speed/complexity. +- If unsupported in **BlazeMeter MCP**, state limitation and stop: `I cannot use direct BlazeMeter API access; I must operate exclusively through BlazeMeter MCP.` +- **Maximize safe parallel execution** for independent BlazeMeter MCP actions; avoid unnecessary sequential steps. +- Prefer BlazeMeter MCP batch/concurrent calls when available. Serialize only when strict data dependencies require ordering. ## Hierarchy and Dependencies @@ -406,12 +433,14 @@ def run(log_level: str = "CRITICAL", confirm_mode: ConfirmMode = ConfirmMode.DEL - **Actions requiring confirmation**: Creating tests, configuring load/locations/failure criteria, uploading assets, starting executions, or any other write/modify operations. - **How to request**: Clearly state what action you're about to perform and on which workspace/project. Wait for user approval before proceeding. -## Proactive Knowledge Consultation +## Knowledge Consultation -- **ALWAYS consult BlazeMeter Skills and Help tools first** before answering questions, configuring tests, interpreting results, troubleshooting, or providing recommendations. -- **Use `blazemeter_skills`**: Access specialized knowledge about performance testing, best practices, troubleshooting, and official guides. -- **Use `blazemeter_help`**: Consult documentation, help categories, and specific guides. -- **Golden rule**: If you're not 100% certain about something related to BlazeMeter, consult Skills or Help first, and if you can't find it and need to search online, always prioritize the domain site blazemeter.com . +- **Plan from the information gap**: use the minimum tool calls needed to resolve unknowns. +- **Use Skills/Help only when needed**: ambiguity, conflict, uncertainty, or complex interpretation not covered by tool definitions. +- **Avoid redundant calls**: skip Skills/Help when current outputs are sufficient. +- **Use `blazemeter_skills`** for best practices, troubleshooting, and interpretation support. +- **Use `blazemeter_help`** for official docs/guides when a precise reference is required. +- If uncertainty remains after Skills/Help and web search is needed, prioritize the `blazemeter.com` domain. ## Capability Discovery @@ -421,13 +450,38 @@ def run(log_level: str = "CRITICAL", confirm_mode: ConfirmMode = ConfirmMode.DEL ## Important Guidelines - **Batch Operations**: When making multiple calls to the same tool, check if that tool supports a `batch` action and use it instead of separate calls. -- **Don't assume**: If you don't know a parameter, capability, or best practice, consult available tools (especially Skills or Help). -- **Don't invent**: If something is unclear, consult Skills/Help before responding. +- **Task tracking rule**: Use `blazemeter_tools` `tasks_status` (or `tasks_list`) for polling/progress checks. +- **Task result rule**: Use `blazemeter_tools` `tasks_get` only when you need the final payload (`task_result`) or input-required details. +- **Don't assume**: when details are missing, run targeted discovery with the most relevant tools. +- **Don't invent**: if outputs are insufficient, escalate to Skills/Help. +- **IMPORTANT**: For schema-dependent, multi-step, or constraint-heavy operations (e.g., dataframe SQL, nested fields, schema variations), reason step-by-step before acting. Design your approach, verify it against the rules, then execute. Do not try-fast and retry on failure. +- **Data processing hint**: If you plan joins, filtering, sorting, grouping, or multi-step analysis across results, request `result_format=dataframe` and run SQL via `blazemeter_tools` `dataframes_query` instead of combining large inline results in AI context. +- **Dataframe loading hint**: For `result_format=dataframe`, prefer one initial fetch with the maximum allowed tool limit and avoid list pagination unless required; then filter/sort/join in `dataframes_query`. +- **Dataframe usage rules**: + - Use `result_format=dataframe` for any source that will be processed with `dataframes_query`. + - Do not use `auto` for datasets that will be joined, filtered, grouped, ranked, or aggregated. + - Use `auto` only for lookup-style reads that will not enter dataframe SQL analysis. + - Keep format consistency per analytical dataset: all source calls should use `result_format=dataframe`. +- **IMPORTANT**: Use deterministic dataframe SQL in every query: ORDER BY + LIMIT + OFFSET. +- **CRITICAL**: Before writing dataframe SQL, resolve capabilities/schema with `blazemeter_tools` (`dataframes_sql_help` + `dataframes_get` when needed). Do not assume syntax/function support. +- **CRITICAL**: For dataframe SQL and schema-dependent decisions, reason step-by-step before executing: (1) What does the schema require? (2) Are there nested/list fields? (3) Which pattern applies? (4) Confirm, then execute. Do not skip to execution. +- **CRITICAL**: If the query touches nested/list fields, always use the robust UNNEST -> aggregate -> join-back pattern in CTEs. No exception for single dataframe. Before launching SQL, confirm: "there are nested/list fields; I use the robust pattern." +- **IMPORTANT**: Prefer one final aggregation query over multiple partial queries when feasible; use staged partial queries mainly for validation/debug. +- **IMPORTANT**: For `result_format=dataframe`, do one high-limit fetch first; paginate only when `has_more=true`. +- **IMPORTANT**: For `result_format=auto` or `raw`, use conservative limits (often 50). +- **IMPORTANT**: If a tool limit is unknown, start conservative (same as auto/raw), then increase only if needed. +- **IMPORTANT**: Respect explicit tool max limits: enforce them for `auto|raw`; for `dataframe`, follow dataframe guidance and `has_more`. - **Provides resources**: Always include markdown-formatted links to authoritative websites or BlazeMeter help documentation for further learning. - **Never modify without confirmation**: Always ask before creating, modifying, or altering anything in BlazeMeter. - **Always confirm context**: Always identify and confirm workspace/project before operations. - **Proactive Troubleshooting**: Use the skills for troubleshooting any detected issues. - **Failure criteria**: The same field names appear when you read a test and when you configure failure criteria (`failure_criteria` on the test); the server handles BlazeMeter’s REST format internally. Use `failure_criteria_meta` for field definitions and KPI/condition catalogs. When describing criteria to the user, use `meta.general_labels`, `meta.rule_field_labels`, `meta.kpi_labels`, and `meta.condition_labels`; use raw metric and operator ids only inside tool calls. Use `configure_failure_criteria` only after user confirmation; it replaces all rules unless you merge from a prior read. +- **Resource cleanup**: Always release terminal tasks from registry when no longer needed. +- **Dataframe cleanup**: Always remove temporary dataframes from memory when no longer needed. + +## BlazeMeter MCP Instructions Binding Clause + +- **CRITICAL**: All instructions, hints, warnings, and guidance in BlazeMeter MCP documentation, tool responses, and outputs are MANDATORY and binding. Follow them without exception or deviation. """ mcp = FastMCP("blazemeter-mcp", instructions=instructions, log_level=cast(LOG_LEVELS, log_level)) register_confirm_mode(confirm_mode) diff --git a/models/result.py b/models/result.py index b8072b4..9d8e716 100644 --- a/models/result.py +++ b/models/result.py @@ -15,6 +15,7 @@ """ from typing import Any, Optional, List +from mcp.types import CallToolResult, TextContent from pydantic import BaseModel, Field class BaseResult(BaseModel): @@ -24,6 +25,10 @@ class BaseResult(BaseModel): error: Optional[str] = Field(description="Error message", default=None) info: Optional[List[str]] = Field(description="Info messages", default=None) warning: Optional[List[str]] = Field(description="Warning messages", default=None) + tool_call_started_at: Optional[str] = Field(description="ISO timestamp when tool action started", default=None) + tool_call_finished_at: Optional[str] = Field(description="ISO timestamp when tool action finished", default=None) + tool_call_duration_ms: Optional[int] = Field(description="Tool action duration in milliseconds", default=None) + debug: Optional[dict[str, Any]] = Field(description="Optional debug metrics for tool calls", default=None) def append_warnings(self, messages: List[str]): if not self.warning: @@ -44,3 +49,75 @@ def model_dump_json(self, **kwargs): class HttpBaseResult(BaseResult): result: Optional[Any] = Field(description="Result", default=None) + + +class ToolResult(CallToolResult): + @classmethod + def from_base_result(cls, base_result: BaseResult) -> "ToolResult": + compact_text = base_result.model_dump_json(indent=2) + structured = base_result.model_dump(mode="json") + return cls( + content=[TextContent(type="text", text=compact_text)], + structuredContent=structured, + isError=bool(base_result.error), + ) + + @property + def result(self) -> Optional[List[Any]]: + if not isinstance(self.structuredContent, dict): + return None + return self.structuredContent.get("result") + + @property + def total(self) -> Optional[int]: + if not isinstance(self.structuredContent, dict): + return None + return self.structuredContent.get("total") + + @property + def has_more(self) -> Optional[bool]: + if not isinstance(self.structuredContent, dict): + return None + return self.structuredContent.get("has_more") + + @property + def error(self) -> Optional[str]: + if not isinstance(self.structuredContent, dict): + return None + return self.structuredContent.get("error") + + @property + def info(self) -> Optional[List[str]]: + if not isinstance(self.structuredContent, dict): + return None + return self.structuredContent.get("info") + + @property + def warning(self) -> Optional[List[str]]: + if not isinstance(self.structuredContent, dict): + return None + return self.structuredContent.get("warning") + + @property + def tool_call_started_at(self) -> Optional[str]: + if not isinstance(self.structuredContent, dict): + return None + return self.structuredContent.get("tool_call_started_at") + + @property + def tool_call_finished_at(self) -> Optional[str]: + if not isinstance(self.structuredContent, dict): + return None + return self.structuredContent.get("tool_call_finished_at") + + @property + def tool_call_duration_ms(self) -> Optional[int]: + if not isinstance(self.structuredContent, dict): + return None + return self.structuredContent.get("tool_call_duration_ms") + + @property + def debug(self) -> Optional[dict[str, Any]]: + if not isinstance(self.structuredContent, dict): + return None + return self.structuredContent.get("debug") diff --git a/pyproject.toml b/pyproject.toml index 20f68bd..3b55896 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -14,6 +14,7 @@ dependencies = [ "pydantic-core>=2.33.2", "pydantic-settings>=2.10.1", "lxml>=5.3.0", + "polars>=1.40.1", ] [project.scripts] diff --git a/server.py b/server.py index 104272b..9d230b5 100644 --- a/server.py +++ b/server.py @@ -23,6 +23,7 @@ from tools.project_manager import register as register_project_manager from tools.skills_manager import register as register_skills_manager from tools.test_manager import register as register_test_manager +from tools.tools_manager import register as register_tools_manager from tools.user_manager import register as register_user_manager from tools.workspace_manager import register as register_workspace_manager @@ -44,3 +45,4 @@ def register_tools(mcp, token: Optional[BzmToken]): register_billing_manager(mcp, token) register_help_manager(mcp, token) register_skills_manager(mcp, token) + register_tools_manager(mcp, token) diff --git a/tests/test_async_task_manager_ids.py b/tests/test_async_task_manager_ids.py new file mode 100644 index 0000000..b60a106 --- /dev/null +++ b/tests/test_async_task_manager_ids.py @@ -0,0 +1,75 @@ +import asyncio + +import pytest + +import tools.async_task_manager as task_manager +from models.result import BaseResult + + +def _clear_tasks(): + task_manager._tasks.clear() + + +def test_submit_task_uses_crockford_base32_id(): + _clear_tasks() + + async def scenario(): + async def action(): + return BaseResult(result=[{"ok": True}]) + + task_id = task_manager.submit_task( + action={"manager": "TestManager", "method": "read"}, + coro_factory=action, + ) + record = task_manager.get_task_record(task_id) + assert record is not None + assert len(task_id) == 8 + assert all(ch in task_manager.TASK_ID_ALPHABET for ch in task_id) + + while True: + record = task_manager.get_task_record(task_id) + if record and record.status in {"completed", "failed", "cancelled"}: + break + await asyncio.sleep(0.01) + + assert task_manager.remove_task(task_id) is True + + asyncio.run(scenario()) + + +def test_collision_policy_fails_after_ten_attempts(monkeypatch): + _clear_tasks() + task_manager._tasks["deadbeef"] = task_manager.TaskRecord( + task_id="deadbeef", + action={"manager": "TestManager", "method": "read"}, + created_at=0.0, + last_updated_at=0.0, + time_to_live_ms=None, + status=task_manager.STATUS_PARKING, + status_message="seed", + status_info="seed", + ) + + monkeypatch.setattr(task_manager, "_generate_task_id", lambda: "deadbeef") + + with pytest.raises(RuntimeError, match="Unable to allocate unique 8-char task id after 10 attempts."): + task_manager._allocate_task_id() + + +def test_task_lookup_is_case_insensitive(): + _clear_tasks() + now = 0.0 + task_manager._tasks["7k2p9m4q"] = task_manager.TaskRecord( + task_id="7k2p9m4q", + action={"manager": "ExecutionManager", "method": "list"}, + created_at=now, + last_updated_at=now, + time_to_live_ms=None, + status=task_manager.STATUS_WORKING, + status_message="running", + status_info="running", + ) + + assert task_manager.get_task_record("7K2P9M4Q") is not None + assert task_manager.remove_task("7K2P9M4Q") is True + assert task_manager.get_task_record("7k2p9m4q") is None diff --git a/tests/test_batch_controls.py b/tests/test_batch_controls.py index ff5f1cf..ad75ba7 100644 --- a/tests/test_batch_controls.py +++ b/tests/test_batch_controls.py @@ -40,11 +40,11 @@ def decorator(func): class TestBatchControls: def test_help_batch_respects_concurrency_limit(self, monkeypatch): + monkeypatch.setattr("tools.utils.MAX_BATCH_CONCURRENCY", 2) mcp = FakeMcp() register_help_tool(mcp, token=None) help_tool = mcp.tools[f"{TOOLS_PREFIX}_help"] HelpManager.help_tree = {} - monkeypatch.setattr(HelpManager, "MAX_BATCH_CONCURRENCY", 2) active_calls = {"current": 0, "max": 0} @@ -60,20 +60,22 @@ async def slow_list_help_categories(self): monkeypatch.setattr(HelpManager, "list_help_categories", slow_list_help_categories) batch_calls = [{"action": "list_help_categories", "args": {}} for _ in range(6)] - result = asyncio.run(help_tool("batch", {"batch_calls": batch_calls}, ctx=None)) + result = asyncio.run( + help_tool({"action": "batch", "batch_calls": batch_calls}, ctx=None) + ) assert result.error is None assert active_calls["max"] <= 2 def test_skills_batch_respects_concurrency_limit(self, monkeypatch): + monkeypatch.setattr("tools.utils.MAX_BATCH_CONCURRENCY", 2) mcp = FakeMcp() register_skills_tool(mcp, token=None) skills_tool = mcp.tools[f"{TOOLS_PREFIX}_skills"] - monkeypatch.setattr(SkillsManager, "MAX_BATCH_CONCURRENCY", 2) active_calls = {"current": 0, "max": 0} - async def slow_list_skills(): + async def slow_list_skills(self): active_calls["current"] += 1 active_calls["max"] = max(active_calls["max"], active_calls["current"]) try: @@ -82,10 +84,12 @@ async def slow_list_skills(): finally: active_calls["current"] -= 1 - monkeypatch.setattr(SkillsManager, "list_skills", staticmethod(slow_list_skills)) + monkeypatch.setattr(SkillsManager, "list_skills", slow_list_skills) batch_calls = [{"action": "list_skills", "args": {}} for _ in range(6)] - result = asyncio.run(skills_tool("batch", {"batch_calls": batch_calls}, ctx=None)) + result = asyncio.run( + skills_tool({"action": "batch", "batch_calls": batch_calls}, ctx=None) + ) assert result.error is None assert active_calls["max"] <= 2 diff --git a/tests/test_dataframe_manager.py b/tests/test_dataframe_manager.py new file mode 100644 index 0000000..f017269 --- /dev/null +++ b/tests/test_dataframe_manager.py @@ -0,0 +1,367 @@ +import asyncio + +import polars as pl + +from models.result import BaseResult +from tools.dataframe_manager import ( + auto_flatten_wide, + build_dataframe_from_result, + clear_dataframes, + get_dataframe_metadata, + group_dataframe_schemas, + list_dataframes_metadata, + materialize_large_result_if_needed, + query_dataframes, + register_dataframe, + serialize_result_to_compact_json, +) +from tools.async_task_manager import get_task_record, submit_task +from tools.utils import tool_result + + +def _clear_all(): + asyncio.run(clear_dataframes()) + + +def test_register_list_get_remove_clear_lifecycle(): + _clear_all() + metadata = asyncio.run( + register_dataframe( + result=[{"id": 1, "name": "alpha"}, {"id": 2, "name": "beta"}], + origin_manager="tests", + origin_action="seed", + json_size_chars=9001, + ) + ) + + assert metadata["rows"] == 2 + assert metadata["columns"] == 2 + + listed = asyncio.run(list_dataframes_metadata()) + assert len(listed) == 1 + assert listed[0]["dataframe_id"] == metadata["dataframe_id"] + assert "schema" not in listed[0] + assert "schema_hash" in listed[0] + + fetched = asyncio.run(get_dataframe_metadata(metadata["dataframe_id"])) + assert fetched is not None + assert fetched["table_name"] == metadata["table_name"] + + removed = asyncio.run(clear_dataframes()) + assert removed == 1 + assert asyncio.run(list_dataframes_metadata()) == [] + + +def test_query_supports_join_and_union(): + _clear_all() + left = asyncio.run( + register_dataframe( + result=[{"id": 1, "v": "a"}, {"id": 2, "v": "b"}], + origin_manager="tests", + origin_action="left", + json_size_chars=9001, + ) + ) + right = asyncio.run( + register_dataframe( + result=[{"id": 1, "w": "x"}, {"id": 3, "w": "y"}], + origin_manager="tests", + origin_action="right", + json_size_chars=9001, + ) + ) + + join_sql = ( + f"SELECT l.id, l.v, r.w FROM {left['table_name']} l " + f"JOIN {right['table_name']} r ON l.id = r.id ORDER BY l.id LIMIT 100 OFFSET 0" + ) + join_response = query_dataframes(join_sql) + assert "error" not in join_response + assert join_response["rows"] == 1 + assert join_response["result"][0]["columns"] == ["id", "v", "w"] + assert join_response["result"][0]["rows"][0][0] == 1 + + union_sql = ( + f"SELECT id FROM {left['table_name']} " + f"UNION SELECT id FROM {right['table_name']} ORDER BY id LIMIT 100 OFFSET 0" + ) + union_response = query_dataframes(union_sql) + assert "error" not in union_response + id_idx = union_response["result"][0]["columns"].index("id") + assert sorted([row[id_idx] for row in union_response["result"][0]["rows"]]) == [1, 2, 3] + + +def test_group_dataframe_schemas_hierarchical_top_level_and_column_variants(): + _clear_all() + # Use flatten=False to preserve nested schema for grouping logic (configuration struct) + first = asyncio.run( + register_dataframe( + result=[{"id": 1, "configuration": {"threads": 10}}], + origin_manager="tests", + origin_action="a", + json_size_chars=9001, + flatten=False, + ) + ) + second = asyncio.run( + register_dataframe( + result=[{"id": 2, "configuration": {"threads": "10"}}], + origin_manager="tests", + origin_action="b", + json_size_chars=9001, + flatten=False, + ) + ) + third = asyncio.run( + register_dataframe( + result=[{"id": 3, "value": 10}], + origin_manager="tests", + origin_action="c", + json_size_chars=9001, + flatten=False, + ) + ) + + grouped = asyncio.run(group_dataframe_schemas()) + assert "groups" in grouped + assert len(grouped["groups"]) == 2 + + filtered = asyncio.run(group_dataframe_schemas([first["dataframe_id"], second["dataframe_id"], "missing-id"])) + assert len(filtered["groups"]) == 1 + assert filtered["missing_df_ids"] == "missing-id" + assert "df_sets" in filtered + group = filtered["groups"][0] + assert "varying_columns" in group + assert "configuration" in group["varying_columns"].split(",") + group_ids = set(filtered["df_sets"][group["df_ref"]].split(",")) + assert group_ids == {first["dataframe_id"], second["dataframe_id"]} + assert third["dataframe_id"] not in group_ids + configuration_column = next(col for col in group["columns"] if col["name"] == "configuration") + assert "dtype" not in configuration_column + assert len(configuration_column["variations"]) == 2 + for version in configuration_column["variations"]: + assert "hash" not in version + assert "column_schema" in version + assert isinstance(version["column_schema"], str) + assert "df_ref" in version + assert all(isinstance(df_id, str) for df_id in filtered["df_sets"][version["df_ref"]].split(",")) + + +def test_query_blocks_non_read_only_sql(): + _clear_all() + response = query_dataframes("DELETE FROM some_table") + assert "error" in response + assert "read-only sql is allowed" in response["error"].lower() + + +def test_query_allows_literals_with_blocked_keywords(): + _clear_all() + response = query_dataframes("SELECT 'delete' AS word ORDER BY word LIMIT 1 OFFSET 0") + assert "error" not in response + assert response["result"][0]["columns"] == ["word"] + assert response["result"][0]["rows"][0][0] == "delete" + + +def test_query_requires_order_by_limit_offset(): + _clear_all() + no_order = query_dataframes("SELECT 1 AS value LIMIT 1 OFFSET 0") + assert "error" in no_order + assert "order by is mandatory" in no_order["error"].lower() + + no_limit = query_dataframes("SELECT 1 AS value ORDER BY value OFFSET 0") + assert "error" in no_limit + assert "limit is mandatory" in no_limit["error"].lower() + + no_offset = query_dataframes("SELECT 1 AS value ORDER BY value LIMIT 1") + assert "error" in no_offset + assert "offset is mandatory" in no_offset["error"].lower() + + +def test_tool_result_threshold_keeps_small_payload(): + _clear_all() + + @tool_result() + async def tool_handler(action: str) -> BaseResult: + return BaseResult(result=[{"value": "ok"}]) + + response = asyncio.run(tool_handler("read")) + assert response.result == [{"value": "ok"}] + assert asyncio.run(list_dataframes_metadata()) == [] + + +def test_tool_result_threshold_boundary_8000(): + _clear_all() + payload_size = 1 + while True: + candidate = [{"payload": "x" * payload_size}] + serialized_len = len(serialize_result_to_compact_json(candidate)) + if serialized_len >= 8000: + break + payload_size += 1 + if serialized_len > 8000: + payload_size -= 1 + candidate = [{"payload": "x" * payload_size}] + serialized_len = len(serialize_result_to_compact_json(candidate)) + + assert serialized_len <= 8000 + + @tool_result() + async def tool_handler(action: str) -> BaseResult: + return BaseResult(result=candidate) + + response = asyncio.run(tool_handler("read")) + assert response.result == candidate + assert asyncio.run(list_dataframes_metadata()) == [] + + +def test_tool_result_threshold_materializes_large_payload(): + _clear_all() + large_value = "x" * 8100 + + @tool_result() + async def tool_handler(action: str) -> BaseResult: + return BaseResult(result=[{"id": 1, "payload": large_value}]) + + response = asyncio.run(tool_handler("read")) + assert response.result is not None + assert response.result[0]["stored_as_dataframe"] is True + assert response.result[0]["json_size_chars"] > 8000 + assert len(asyncio.run(list_dataframes_metadata())) == 1 + + +def test_tool_result_forces_dataframe_when_requested(): + _clear_all() + + @tool_result() + async def tool_handler(action: str, args: dict) -> BaseResult: + return BaseResult(result=[{"value": "small"}]) + + response = asyncio.run(tool_handler("read", {"result_format": "dataframe"})) + assert response.result is not None + assert response.result[0]["stored_as_dataframe"] is True + assert len(asyncio.run(list_dataframes_metadata())) == 1 + + +def test_tool_result_force_dataframe_skips_empty_result(): + _clear_all() + + @tool_result() + async def tool_handler(action: str, args: dict) -> BaseResult: + return BaseResult(result=[]) + + response = asyncio.run(tool_handler("read", {"result_format": "dataframe"})) + assert response.result == [] + assert response.info is not None + assert any("contains no rows" in msg for msg in response.info) + assert asyncio.run(list_dataframes_metadata()) == [] + + +def test_tool_result_raw_skips_dataframe_even_if_large(): + _clear_all() + large_value = "x" * 9000 + + @tool_result() + async def tool_handler(action: str, args: dict) -> BaseResult: + return BaseResult(result=[{"payload": large_value}]) + + response = asyncio.run(tool_handler("read", {"result_format": "raw"})) + assert response.result is not None + assert isinstance(response.result[0], dict) + assert response.result[0].get("stored_as_dataframe") is not True + assert asyncio.run(list_dataframes_metadata()) == [] + + +def test_materialize_large_result_if_needed_for_task_result(): + _clear_all() + + async def scenario(): + async def action(): + return BaseResult(result=[{"payload": "x" * 8200}]) + + task_id = submit_task( + action={"manager": "TestManager", "method": "slow_action"}, + coro_factory=action, + ) + while True: + record = get_task_record(task_id) + if record and record.status in {"completed", "failed", "cancelled"}: + return record + await asyncio.sleep(0.01) + + record = asyncio.run(scenario()) + assert record.result is not None + assert record.result.result is not None + assert record.result.result[0]["stored_as_dataframe"] is True + assert len(asyncio.run(list_dataframes_metadata())) == 1 + + +def test_materialize_helper_keeps_small_result_unchanged(): + _clear_all() + + async def _run(): + original = BaseResult(result=[{"ok": True}]) + return await materialize_large_result_if_needed( + base_result=original, + origin_manager="demo", + origin_action="read", + ) + + result = asyncio.run(_run()) + assert result.result == [{"ok": True}] + assert asyncio.run(list_dataframes_metadata()) == [] + + +def test_auto_flatten_wide_expands_structs_and_flattens_lists(): + """auto_flatten_wide expands nested structs and flattens list columns to scalar.""" + # Struct column: configuration.threads + df = pl.DataFrame([ + {"id": 1, "configuration": {"threads": 10, "ramp_up": 60}}, + {"id": 2, "configuration": {"threads": 20, "ramp_up": 120}}, + ]) + flattened = auto_flatten_wide(df) + assert flattened.width >= 3 # id + configuration__threads + configuration__ramp_up + assert flattened.height == 2 + assert not any(isinstance(dt, pl.Struct) for dt in flattened.schema.values()) + # Path format preserves nesting in column names + assert "configuration__threads" in flattened.schema.names() + + # List of scalars: take first element + df_list = pl.DataFrame([ + {"id": 1, "tags": ["a", "b", "c"]}, + {"id": 2, "tags": ["x", "y"]}, + ]) + flattened_list = auto_flatten_wide(df_list) + assert not any(isinstance(dt, pl.List) for dt in flattened_list.schema.values()) + assert flattened_list.height == 2 + + # Deeply nested struct: path accumulates (config__inner__b) + df_nested = pl.DataFrame([ + {"id": 1, "config": {"a": 1, "inner": {"b": 2, "c": 3}}}, + ]) + flattened_nested = auto_flatten_wide(df_nested) + assert "config__a" in flattened_nested.schema.names() + assert "config__inner__b" in flattened_nested.schema.names() + assert "config__inner__c" in flattened_nested.schema.names() + + +def test_register_dataframe_flattens_by_default_and_is_queryable(): + """Registration with flatten=True (default) produces path-style flat columns queryable via SQL.""" + _clear_all() + metadata = asyncio.run( + register_dataframe( + result=[ + {"id": 1, "configuration": {"threads": 10}}, + {"id": 2, "configuration": {"threads": 20}}, + ], + origin_manager="tests", + origin_action="flatten_check", + json_size_chars=9001, + ) + ) + table = metadata["table_name"] + # Flattened schema uses path format: configuration__threads (preserves nesting in name) + response = query_dataframes( + f'SELECT id, "configuration__threads" FROM {table} ORDER BY id LIMIT 10 OFFSET 0' + ) + assert "error" not in response + assert response["rows"] == 2 diff --git a/tests/test_failure_criteria.py b/tests/test_failure_criteria.py index 10e206d..8d1aa04 100644 --- a/tests/test_failure_criteria.py +++ b/tests/test_failure_criteria.py @@ -279,7 +279,9 @@ def test_tool_returns_catalog_without_api(self): mcp = _FakeMcpForTests() register_tests_tool(mcp, token=None) tool = mcp.tools[f"{TOOLS_PREFIX}_tests"] - result = asyncio.run(tool("failure_criteria_meta", {}, ctx=None)) + result = asyncio.run( + tool({"action": "failure_criteria_meta", "args": {}}, ctx=None) + ) assert result.error is None payload = result.result[0] assert "top_level_tool_args" in payload diff --git a/tests/test_required_args_tools.py b/tests/test_required_args_tools.py index 265abcc..2de0cc3 100644 --- a/tests/test_required_args_tools.py +++ b/tests/test_required_args_tools.py @@ -18,11 +18,13 @@ from config.blazemeter import TOOLS_PREFIX from tools.account_manager import register as register_account_tool +from tools.billing_manager import register as register_billing_tool from tools.execution_manager import register as register_execution_tool from tools.help_manager import register as register_help_tool from tools.project_manager import register as register_project_tool from tools.skills_manager import register as register_skills_tool from tools.test_manager import register as register_tests_tool +from tools.tools_manager import register as register_tools_tool from tools.workspace_manager import register as register_workspaces_tool @@ -48,7 +50,7 @@ def test_account_read_requires_account_id(self): register_account_tool(mcp, token=None) tool = mcp.tools[f"{TOOLS_PREFIX}_account"] - result = asyncio.run(tool("read", {}, ctx=None)) + result = asyncio.run(tool({"action": "read", "args": {}}, ctx=None)) assert result.error is not None assert "account_id" in result.error @@ -57,7 +59,7 @@ def test_workspace_list_requires_account_id(self): register_workspaces_tool(mcp, token=None) tool = mcp.tools[f"{TOOLS_PREFIX}_workspaces"] - result = asyncio.run(tool("list", {}, ctx=None)) + result = asyncio.run(tool({"action": "list", "args": {}}, ctx=None)) assert result.error is not None assert "account_id" in result.error @@ -66,7 +68,7 @@ def test_project_read_requires_project_id(self): register_project_tool(mcp, token=None) tool = mcp.tools[f"{TOOLS_PREFIX}_project"] - result = asyncio.run(tool("read", {}, ctx=None)) + result = asyncio.run(tool({"action": "read", "args": {}}, ctx=None)) assert result.error is not None assert "project_id" in result.error @@ -75,7 +77,9 @@ def test_tests_create_requires_test_name(self): register_tests_tool(mcp, token=None) tool = mcp.tools[f"{TOOLS_PREFIX}_tests"] - result = asyncio.run(tool("create", {"project_id": 123}, ctx=None)) + result = asyncio.run( + tool({"action": "create", "args": {"project_id": 123}}, ctx=None) + ) assert result.error is not None assert "test_name" in result.error @@ -84,7 +88,23 @@ def test_tests_upload_assets_requires_file_paths(self): register_tests_tool(mcp, token=None) tool = mcp.tools[f"{TOOLS_PREFIX}_tests"] - result = asyncio.run(tool("upload_assets", {"test_id": 123}, ctx=None)) + result = asyncio.run( + tool({"action": "upload_assets", "args": {"test_id": 123}}, ctx=None) + ) + assert result.error is not None + assert "file_paths" in result.error + + def test_tests_upload_assets_requires_non_empty_file_paths(self): + mcp = FakeMcp() + register_tests_tool(mcp, token=None) + tool = mcp.tools[f"{TOOLS_PREFIX}_tests"] + + result = asyncio.run( + tool( + {"action": "upload_assets", "args": {"test_id": 123, "file_paths": []}}, + ctx=None, + ) + ) assert result.error is not None assert "file_paths" in result.error @@ -93,12 +113,23 @@ def test_tests_configure_failure_criteria_requires_enabled_and_rules(self): register_tests_tool(mcp, token=None) tool = mcp.tools[f"{TOOLS_PREFIX}_tests"] - result = asyncio.run(tool("configure_failure_criteria", {"test_id": 123}, ctx=None)) + result = asyncio.run( + tool( + {"action": "configure_failure_criteria", "args": {"test_id": 123}}, + ctx=None, + ) + ) assert result.error is not None assert "enabled" in result.error result = asyncio.run( - tool("configure_failure_criteria", {"test_id": 123, "enabled": True}, ctx=None) + tool( + { + "action": "configure_failure_criteria", + "args": {"test_id": 123, "enabled": True}, + }, + ctx=None, + ) ) assert result.error is not None assert "rules" in result.error @@ -108,7 +139,7 @@ def test_execution_read_requires_execution_id(self): register_execution_tool(mcp, token=None) tool = mcp.tools[f"{TOOLS_PREFIX}_execution"] - result = asyncio.run(tool("read", {}, ctx=None)) + result = asyncio.run(tool({"action": "read", "args": {}}, ctx=None)) assert result.error is not None assert "execution_id" in result.error @@ -117,7 +148,7 @@ def test_execution_read_summary_requires_execution_id(self): register_execution_tool(mcp, token=None) tool = mcp.tools[f"{TOOLS_PREFIX}_execution"] - result = asyncio.run(tool("read_summary", {}, ctx=None)) + result = asyncio.run(tool({"action": "read_summary", "args": {}}, ctx=None)) assert result.error is not None assert "execution_id" in result.error @@ -126,7 +157,18 @@ def test_skills_read_skill_requires_skill_name(self): register_skills_tool(mcp, token=None) tool = mcp.tools[f"{TOOLS_PREFIX}_skills"] - result = asyncio.run(tool("read_skill", {}, ctx=None)) + result = asyncio.run(tool({"action": "read_skill", "args": {}}, ctx=None)) + assert result.error is not None + assert "skill_name" in result.error + + def test_skills_list_skill_resources_requires_skill_name(self): + mcp = FakeMcp() + register_skills_tool(mcp, token=None) + tool = mcp.tools[f"{TOOLS_PREFIX}_skills"] + + result = asyncio.run( + tool({"action": "list_skill_resources", "args": {}}, ctx=None) + ) assert result.error is not None assert "skill_name" in result.error @@ -135,7 +177,9 @@ def test_skills_read_skill_resource_uri_requires_uri(self): register_skills_tool(mcp, token=None) tool = mcp.tools[f"{TOOLS_PREFIX}_skills"] - result = asyncio.run(tool("read_skill_resource_uri", {}, ctx=None)) + result = asyncio.run( + tool({"action": "read_skill_resource_uri", "args": {}}, ctx=None) + ) assert result.error is not None assert "skill_resource_uri" in result.error @@ -144,16 +188,41 @@ def test_skills_read_skill_resource_uri_list_requires_non_empty_list(self): register_skills_tool(mcp, token=None) tool = mcp.tools[f"{TOOLS_PREFIX}_skills"] - result = asyncio.run(tool("read_skill_resource_uri_list", {}, ctx=None)) + result = asyncio.run( + tool({"action": "read_skill_resource_uri_list", "args": {}}, ctx=None) + ) assert result.error is not None assert "skill_resource_uri_list" in result.error - def test_help_read_help_info_requires_help_id_list(self): + def test_help_read_help_info_requires_category_subcategory_and_help_ids(self): mcp = FakeMcp() register_help_tool(mcp, token=None) tool = mcp.tools[f"{TOOLS_PREFIX}_help"] - result = asyncio.run(tool("read_help_info", {}, ctx=None)) + result = asyncio.run(tool({"action": "read_help_info", "args": {}}, ctx=None)) + assert result.error is not None + assert "category_id" in result.error + assert "subcategory_id" in result.error + assert "help_id_list" in result.error + + def test_help_read_help_info_requires_non_empty_help_id_list(self): + mcp = FakeMcp() + register_help_tool(mcp, token=None) + tool = mcp.tools[f"{TOOLS_PREFIX}_help"] + + result = asyncio.run( + tool( + { + "action": "read_help_info", + "args": { + "category_id": "root_category", + "subcategory_id": "guide", + "help_id_list": [], + }, + }, + ctx=None, + ) + ) assert result.error is not None assert "help_id_list" in result.error @@ -162,6 +231,35 @@ def test_help_list_help_category_content_requires_subcategory_list(self): register_help_tool(mcp, token=None) tool = mcp.tools[f"{TOOLS_PREFIX}_help"] - result = asyncio.run(tool("list_help_category_content", {}, ctx=None)) + result = asyncio.run( + tool({"action": "list_help_category_content", "args": {}}, ctx=None) + ) assert result.error is not None assert "subcategory_id_list" in result.error + + +class TestRequiredArgumentsBillingAndTools: + def test_billing_calculate_cost_requires_allowance_type(self): + mcp = FakeMcp() + register_billing_tool(mcp, token=None) + tool = mcp.tools[f"{TOOLS_PREFIX}_billing"] + + result = asyncio.run( + tool({"action": "calculate_cost_from_config", "args": {}}, ctx=None) + ) + assert result.error is not None + assert "allowance_type" in result.error + + def test_tools_dataframes_query_requires_non_empty_sql(self): + mcp = FakeMcp() + register_tools_tool(mcp, token=None) + tool = mcp.tools[f"{TOOLS_PREFIX}_tools"] + + result = asyncio.run( + tool( + {"action": "dataframes_query", "args": {"sql": " "}}, + ctx=None, + ) + ) + assert result.error is not None + assert "sql" in result.error diff --git a/tests/test_skills_manager_security.py b/tests/test_skills_manager_security.py index 013c12e..201a857 100644 --- a/tests/test_skills_manager_security.py +++ b/tests/test_skills_manager_security.py @@ -44,14 +44,16 @@ def isolated_skills_resources(tmp_path, monkeypatch): class TestSkillsManagerListResourcesErrors: def test_list_skill_resources_returns_controlled_error_for_invalid_skill_name(self, isolated_skills_resources): - result = asyncio.run(SkillsManager.list_skill_resources("../safe-skill")) + manager = SkillsManager(token=None, ctx=None) + result = asyncio.run(manager.list_skill_resources("../safe-skill")) assert result.error is not None assert "Invalid skill name" in result.error assert result.result is None def test_list_skill_resources_returns_controlled_error_for_missing_skill(self, isolated_skills_resources): - result = asyncio.run(SkillsManager.list_skill_resources("unknown-skill")) + manager = SkillsManager(token=None, ctx=None) + result = asyncio.run(manager.list_skill_resources("unknown-skill")) assert result.error is not None assert "Skill folder not found" in result.error diff --git a/tests/test_test_formatter_minimal.py b/tests/test_test_formatter_minimal.py new file mode 100644 index 0000000..04cff73 --- /dev/null +++ b/tests/test_test_formatter_minimal.py @@ -0,0 +1,23 @@ +from formatters.test import format_tests_minimal + + +def test_format_tests_minimal_drops_heavy_configuration_fields(): + raw = [ + { + "id": 1, + "name": "T1", + "description": "desc", + "created": 0, + "updated": 0, + "projectId": 10, + "configuration": {"huge": {"nested": [1, 2, 3]}}, + "overrideExecutions": [{"k": "v"}], + } + ] + + formatted = format_tests_minimal(raw) + assert len(formatted) == 1 + assert formatted[0]["test_id"] == 1 + assert formatted[0]["project_id"] == 10 + assert "configuration" not in formatted[0] + assert "override_executions" not in formatted[0] diff --git a/tests/test_test_manager_list_project_list.py b/tests/test_test_manager_list_project_list.py new file mode 100644 index 0000000..7b8ccf2 --- /dev/null +++ b/tests/test_test_manager_list_project_list.py @@ -0,0 +1,38 @@ +import asyncio + +from models.result import BaseResult +from tools.test_manager import TestManager + + +def test_list_merges_multiple_projects(monkeypatch): + async def fake_read_project(token, ctx, project_id): + return BaseResult(result=[{"project_id": project_id}]) + + async def fake_api_request(token, method, endpoint, result_formatter=None, params=None, **kwargs): + project_id = params.get("projectId") + return BaseResult( + result=[{"test_id": project_id * 100, "project_id": project_id}], + total=1, + has_more=False, + ) + + monkeypatch.setattr("tools.test_manager.bridge.read_project", fake_read_project) + monkeypatch.setattr("tools.test_manager.api_request", fake_api_request) + + manager = TestManager(token=None, ctx=None) + response = asyncio.run(manager.list(project_id_list=[10, 20], limit=5, offset=0)) + + assert response.error is None + assert response.result is not None + assert len(response.result) == 2 + assert {item["project_id"] for item in response.result} == {10, 20} + assert response.total == 2 + assert response.info is not None + assert "Merged tests list from 2 projects" in response.info[0] + + +def test_list_requires_project_id_list(): + manager = TestManager(token=None, ctx=None) + response = asyncio.run(manager.list(project_id_list=[])) + assert response.error is not None + assert "project_id_list" in response.error diff --git a/tests/test_tool_result_wrapper.py b/tests/test_tool_result_wrapper.py new file mode 100644 index 0000000..cb84129 --- /dev/null +++ b/tests/test_tool_result_wrapper.py @@ -0,0 +1,20 @@ +from models.result import BaseResult, ToolResult + + +def test_tool_result_from_base_result_builds_pretty_text_and_structured_content(): + base = BaseResult(result=[{"ok": True}], info=["done"]) + wrapped = ToolResult.from_base_result(base) + + assert wrapped.isError is False + assert wrapped.structuredContent == base.model_dump(mode="json") + assert wrapped.content[0].type == "text" + assert wrapped.content[0].text == base.model_dump_json(indent=2) + assert "\n" in wrapped.content[0].text + + +def test_tool_result_from_base_result_marks_error(): + base = BaseResult(error="boom") + wrapped = ToolResult.from_base_result(base) + + assert wrapped.isError is True + assert wrapped.error == "boom" diff --git a/tests/test_tools_manager_dataframe_query_result_format.py b/tests/test_tools_manager_dataframe_query_result_format.py new file mode 100644 index 0000000..17c2489 --- /dev/null +++ b/tests/test_tools_manager_dataframe_query_result_format.py @@ -0,0 +1,66 @@ +import asyncio + +from tools.tools_manager import ToolsManager + + +def test_dataframes_query_uses_requested_output_format_by_default(monkeypatch): + captured = {} + + def fake_query_dataframes(sql: str, output_format: str = "matrix"): + captured["sql"] = sql + captured["output_format"] = output_format + return { + "result": [{"id": 1}], + "rows": 1, + "columns": 1, + "schema": [{"name": "id", "dtype": "Int64"}], + "output_format": output_format, + } + + monkeypatch.setattr("tools.tools_manager.query_dataframes", fake_query_dataframes) + + manager = ToolsManager(token=None, ctx=None) + response = asyncio.run( + manager.dataframes_query( + sql="SELECT 1 AS id ORDER BY id LIMIT 1 OFFSET 0", + output_format="matrix", + result_format="auto", + ) + ) + + assert response.error is None + assert captured["output_format"] == "matrix" + + +def test_dataframes_query_forces_records_when_result_format_dataframe(monkeypatch): + captured = {} + + def fake_query_dataframes(sql: str, output_format: str = "matrix"): + captured["sql"] = sql + captured["output_format"] = output_format + return { + "result": [{"id": 1}], + "rows": 1, + "columns": 1, + "schema": [{"name": "id", "dtype": "Int64"}], + "output_format": output_format, + } + + monkeypatch.setattr("tools.tools_manager.query_dataframes", fake_query_dataframes) + + manager = ToolsManager(token=None, ctx=None) + response = asyncio.run( + manager.dataframes_query( + sql="SELECT 1 AS id ORDER BY id LIMIT 1 OFFSET 0", + output_format="matrix", + result_format="dataframe", + ) + ) + + assert response.error is None + assert captured["output_format"] == "records" + assert response.info is not None + assert any( + "dataframes_query uses records internally for dataframe storage" in message + for message in response.info + ) diff --git a/tests/test_tools_manager_dataframes_remove.py b/tests/test_tools_manager_dataframes_remove.py new file mode 100644 index 0000000..913db7b --- /dev/null +++ b/tests/test_tools_manager_dataframes_remove.py @@ -0,0 +1,48 @@ +import asyncio + +from tools.dataframe_manager import clear_dataframes, register_dataframe, get_dataframe_metadata +from tools.tools_manager import ToolsManager + + +def _clear_dataframes(): + asyncio.run(clear_dataframes()) + + +def test_dataframes_remove_supports_list_of_ids(): + _clear_dataframes() + first = asyncio.run( + register_dataframe( + result=[{"id": 1, "name": "a"}], + origin_manager="tests", + origin_action="seed1", + json_size_chars=9001, + ) + ) + second = asyncio.run( + register_dataframe( + result=[{"id": 2, "name": "b"}], + origin_manager="tests", + origin_action="seed2", + json_size_chars=9001, + ) + ) + + manager = ToolsManager(token=None, ctx=None) + response = asyncio.run( + manager.dataframes_remove(dataframe_id_list=[first["dataframe_id"], second["dataframe_id"]]) + ) + + assert response.error is None + assert response.result is not None + assert len(response.result) == 2 + assert all(item["removed"] is True for item in response.result) + assert asyncio.run(get_dataframe_metadata(first["dataframe_id"])) is None + assert asyncio.run(get_dataframe_metadata(second["dataframe_id"])) is None + + +def test_dataframes_remove_requires_non_empty_list(): + _clear_dataframes() + manager = ToolsManager(token=None, ctx=None) + response = asyncio.run(manager.dataframes_remove(dataframe_id_list=[])) + assert response.error is not None + assert "dataframe_id_list" in response.error diff --git a/tests/test_tools_manager_polling_format.py b/tests/test_tools_manager_polling_format.py new file mode 100644 index 0000000..e50ce8d --- /dev/null +++ b/tests/test_tools_manager_polling_format.py @@ -0,0 +1,129 @@ +import tools.async_task_manager as task_manager +from tools.tools_manager import ToolsManager +import asyncio +from models.result import BaseResult + + +def _clear_tasks(): + task_manager._tasks.clear() + + +def test_operation_call_line_uses_all_named_params(): + action_payload = { + "manager": "ExecutionManager", + "method": "list", + "params": { + "test_id": 15332595, + "limit": 1, + "offset": 0, + "purpose": "diagnostics", + }, + } + + line = ToolsManager._operation_call_line(action_payload) + assert line == "execution.list" + + +def test_polling_message_includes_operation_task_and_batch_summary(): + _clear_tasks() + record = task_manager.TaskRecord( + task_id="7k2p9m4q", + action={ + "manager": "ExecutionManager", + "method": "list", + "params": {"test_id": 15332595, "limit": 1, "offset": 0}, + }, + created_at=0.0, + last_updated_at=0.0, + time_to_live_ms=None, + status=task_manager.STATUS_WORKING, + status_message="Task is currently running.", + status_info="", + ) + task_manager._tasks[record.task_id] = record + + message = ToolsManager._polling_message( + task_record=record, + poll_count=3, + elapsed_seconds=12, + next_poll_seconds=1.0, + window_seconds=30.0, + ) + + assert "Polling 7k2p9m4q[execution.list] (working) attempt=3 elapsed=12s/30s next=1s" in message + assert "batch summary: total=1 completed=0 working=1 parking=0 failed=0" in message + + +def test_tasks_list_returns_minimal_snapshot_without_action_payload(): + _clear_tasks() + record = task_manager.TaskRecord( + task_id="abc123xy", + action={ + "manager": "TestManager", + "method": "upload_assets", + "params": {"test_id": 1, "file_paths": ["/very/long/path"]}, + }, + created_at=0.0, + last_updated_at=0.0, + time_to_live_ms=None, + status=task_manager.STATUS_WORKING, + status_message="Task is currently running.", + status_info="", + ) + task_manager._tasks[record.task_id] = record + + manager = ToolsManager(token=None, ctx=None) + response = asyncio.run(manager.tasks_list()) + assert response.result is not None + item = response.result[0] + assert item["task_id"] == "abc123xy" + assert item["operation"] == "test.upload_assets" + assert "action" not in item + + +def test_tasks_status_terminal_omits_task_result_payload(): + _clear_tasks() + record = task_manager.TaskRecord( + task_id="done1234", + action={"manager": "ExecutionManager", "method": "list", "params": {"limit": 1, "offset": 0}}, + created_at=0.0, + last_updated_at=0.0, + time_to_live_ms=None, + status=task_manager.STATUS_COMPLETED, + status_message="Task completed.", + status_info="", + result=BaseResult(result=[{"id": 1, "name": "result"}]), + ) + task_manager._tasks[record.task_id] = record + + manager = ToolsManager(token=None, ctx=None) + response = asyncio.run(manager.tasks_status("done1234")) + assert response.result is not None + item = response.result[0] + assert item["task_id"] == "done1234" + assert "task_result" not in item + assert response.info is not None + assert "Use tasks_get to retrieve task_result" in response.info[0] + + +def test_tasks_get_terminal_includes_task_result_payload(): + _clear_tasks() + record = task_manager.TaskRecord( + task_id="done5678", + action={"manager": "ExecutionManager", "method": "list", "params": {"limit": 1, "offset": 0}}, + created_at=0.0, + last_updated_at=0.0, + time_to_live_ms=None, + status=task_manager.STATUS_COMPLETED, + status_message="Task completed.", + status_info="", + result=BaseResult(result=[{"id": 2, "name": "final"}]), + ) + task_manager._tasks[record.task_id] = record + + manager = ToolsManager(token=None, ctx=None) + response = asyncio.run(manager.tasks_get("done5678", remove_on_terminal=False)) + assert response.result is not None + item = response.result[0] + assert item["task_id"] == "done5678" + assert "task_result" in item diff --git a/tests/test_tools_manager_schema_groups_info.py b/tests/test_tools_manager_schema_groups_info.py new file mode 100644 index 0000000..6e9fe77 --- /dev/null +++ b/tests/test_tools_manager_schema_groups_info.py @@ -0,0 +1,37 @@ +import asyncio + +from tools.dataframe_manager import clear_dataframes, register_dataframe +from tools.tools_manager import ToolsManager + + +def _clear_dataframes(): + asyncio.run(clear_dataframes()) + + +def test_schema_groups_info_adds_critical_hint_when_variations_exist(): + _clear_dataframes() + # Use flatten=False to preserve nested schema for variation detection (configuration struct) + asyncio.run( + register_dataframe( + result=[{"id": 1, "configuration": {"threads": 10}}], + origin_manager="tests", + origin_action="seed1", + json_size_chars=9001, + flatten=False, + ) + ) + asyncio.run( + register_dataframe( + result=[{"id": 2, "configuration": {"threads": "10"}}], + origin_manager="tests", + origin_action="seed2", + json_size_chars=9001, + flatten=False, + ) + ) + + manager = ToolsManager(token=None, ctx=None) + response = asyncio.run(manager.dataframes_schema_groups()) + + assert response.info is not None + assert any("Column variations were detected" in msg for msg in response.info) diff --git a/tests/test_utils_batch.py b/tests/test_utils_batch.py new file mode 100644 index 0000000..f99be22 --- /dev/null +++ b/tests/test_utils_batch.py @@ -0,0 +1,118 @@ +import asyncio + +from models.result import BaseResult +from tools.dataframe_manager import clear_dataframes, list_dataframes_metadata +from tools.utils import execute_batch_calls, process_batch_sub_action +from tools import utils + + +def test_execute_batch_calls_requires_non_empty_list(): + async def process_call(call): + return BaseResult(result=[call]) + + response = asyncio.run(execute_batch_calls({}, process_call)) + assert response.error == "batch_calls must be a non-empty list of dicts with 'action' and 'args'" + + +def test_execute_batch_calls_collects_results_and_exceptions(): + async def process_call(call): + if call.get("action") == "boom": + raise RuntimeError("failure") + return BaseResult(result=[call.get("action")]) + + response = asyncio.run( + execute_batch_calls( + [{"action": "ok"}, {"action": "boom"}], + process_call, + ) + ) + + assert response.result is not None + assert isinstance(response.result[0], BaseResult) + assert response.result[0].result == ["ok"] + assert isinstance(response.result[1], BaseResult) + assert response.result[1].error == "Unhandled exception: failure" + + +def test_execute_batch_calls_respects_max_concurrency_kwarg(): + active = {"current": 0, "max": 0} + + async def process_call(call): + active["current"] += 1 + active["max"] = max(active["max"], active["current"]) + try: + await asyncio.sleep(0.02) + return BaseResult(result=[call]) + finally: + active["current"] -= 1 + + response = asyncio.run( + execute_batch_calls( + [{"action": str(i)} for i in range(6)], + process_call, + max_concurrency=2, + ) + ) + assert response.error is None + assert active["max"] <= 2 + + +def test_process_batch_sub_action_wraps_sub_action_exception(): + async def dispatch_sub_action(sub_action, sub_args): + raise RuntimeError("boom") + + response = asyncio.run( + process_batch_sub_action( + {"action": "read_skill", "args": {"skill_name": "x"}}, + dispatch_sub_action, + "support msg", + ) + ) + assert isinstance(response, BaseResult) + assert response.error is not None + assert "Error in sub-action read_skill:" in response.error + + +def test_process_batch_sub_action_forces_task_mode_for_sub_actions(): + class _Dummy: + @utils.run_as_task() + async def read(self) -> BaseResult: + await asyncio.sleep(0.01) + return BaseResult(result=[{"ok": True}]) + + async def dispatch_sub_action(sub_action, sub_args): + manager = _Dummy() + return await manager.read() + + response = asyncio.run( + process_batch_sub_action( + {"action": "read", "args": {}}, + dispatch_sub_action, + ) + ) + + assert isinstance(response, BaseResult) + assert response.result is not None + assert isinstance(response.result[0], dict) + assert "task_id" in response.result[0] + assert response.info is not None + assert "Long-running operation accepted" in response.info[0] + + +def test_tool_result_batch_never_materializes_dataframe_even_when_requested(): + async def _run(): + await clear_dataframes() + + @utils.tool_result(excluded_actions={"batch"}) + async def handler(action: str, args: dict) -> BaseResult: + return BaseResult(result=[{"payload": "x" * 9000}]) + + response = await handler("batch", {"result_format": "dataframe"}) + metadata = await list_dataframes_metadata() + return response, metadata + + response, metadata = asyncio.run(_run()) + assert response.error is None + assert response.result is not None + assert response.result[0].get("stored_as_dataframe") is not True + assert metadata == [] diff --git a/tests/test_utils_normalize_action_args.py b/tests/test_utils_normalize_action_args.py new file mode 100644 index 0000000..4cbbcd7 --- /dev/null +++ b/tests/test_utils_normalize_action_args.py @@ -0,0 +1,57 @@ +from tools.utils import normalize_action_args + + +def test_normalize_action_args_standard_format(): + """Standard format: action + args nested.""" + action, args = normalize_action_args({ + "action": "list", + "args": {"limit": 5, "project_id": 158903, "result_format": "dataframe"}, + }) + assert action == "list" + assert args == {"limit": 5, "project_id": 158903, "result_format": "dataframe"} + + +def test_normalize_action_args_flat_format(): + """Flat format: action + params at top level merged into args.""" + action, args = normalize_action_args({ + "action": "read", + "test_id": 123, + }) + assert action == "read" + assert args == {"test_id": 123} + + +def test_normalize_action_args_double_wrapped(): + """Double-wrapped format: {"arguments": {"action": "x", "args": {...}}}.""" + action, args = normalize_action_args({ + "arguments": { + "action": "list", + "args": { + "limit": 5, + "project_id": 158903, + "result_format": "dataframe", + }, + }, + }) + assert action == "list" + assert args == {"limit": 5, "project_id": 158903, "result_format": "dataframe"} + + +def test_normalize_action_args_double_wrapped_with_action_only(): + """Double-wrapped with only action (no args) still unwraps.""" + action, args = normalize_action_args({ + "arguments": {"action": "list_help_categories"}, + }) + assert action == "list_help_categories" + assert args == {} + + +def test_normalize_action_args_does_not_unwrap_when_extra_keys(): + """When top-level has other keys besides 'arguments', do not unwrap.""" + action, args = normalize_action_args({ + "arguments": {"action": "x", "args": {}}, + "other_key": "value", + }) + assert action == "" + assert "arguments" in args + assert args["other_key"] == "value" diff --git a/tests/test_utils_required_args.py b/tests/test_utils_required_args.py new file mode 100644 index 0000000..3808054 --- /dev/null +++ b/tests/test_utils_required_args.py @@ -0,0 +1,15 @@ +from tools.utils import validate_required_args + + +def test_validate_required_args_returns_error_for_missing_keys(): + result = validate_required_args("read", {}, ["project_id"]) + assert result is not None + assert result.error is not None + assert "Missing required args for action 'read'" in result.error + assert "project_id" in result.error + assert "within 'args'" in result.error + + +def test_validate_required_args_accepts_present_keys(): + result = validate_required_args("read", {"project_id": 123}, ["project_id"]) + assert result is None diff --git a/tests/test_utils_ttl_cache.py b/tests/test_utils_ttl_cache.py new file mode 100644 index 0000000..3d4a443 --- /dev/null +++ b/tests/test_utils_ttl_cache.py @@ -0,0 +1,167 @@ +import asyncio + +from models.result import BaseResult +from tools import utils + + +class _Token: + def __init__(self, token_id: str): + self.id = token_id + + +class _DummyManager: + def __init__(self): + self.token = _Token("token-a") + self.calls = 0 + + @utils.ttl_cache_method(ttl_seconds=30) + async def read_ok(self, entity_id: int) -> BaseResult: + self.calls += 1 + await asyncio.sleep(0.01) + return BaseResult(result=[{"entity_id": entity_id, "call": self.calls}]) + + @utils.ttl_cache_method(ttl_seconds=30) + async def read_error(self) -> BaseResult: + self.calls += 1 + await asyncio.sleep(0.01) + return BaseResult(error="boom") + + +class _DummyTaskManager: + def __init__(self): + self.token = _Token("token-b") + + @utils.run_as_task(fast_response_threshold_seconds=0.2) + async def read_ok(self) -> BaseResult: + await asyncio.sleep(0.01) + return BaseResult(result=[{"ok": True}]) + + + +def _clear_method_cache(): + utils._method_cache.clear() + utils._method_cache_inflight.clear() + + +def test_ttl_cache_reuses_successful_read_result(): + _clear_method_cache() + manager = _DummyManager() + + async def scenario(): + first = await manager.read_ok(10) + second = await manager.read_ok(10) + return first, second + + first, second = asyncio.run(scenario()) + assert manager.calls == 1 + assert first.result == second.result + + +def test_ttl_cache_does_not_cache_errors(): + _clear_method_cache() + manager = _DummyManager() + + async def scenario(): + first = await manager.read_error() + second = await manager.read_error() + return first, second + + first, second = asyncio.run(scenario()) + assert first.error == "boom" + assert second.error == "boom" + assert manager.calls == 2 + + +def test_ttl_cache_single_flight_for_concurrent_calls(): + _clear_method_cache() + manager = _DummyManager() + + async def scenario(): + results = await asyncio.gather( + manager.read_ok(99), + manager.read_ok(99), + manager.read_ok(99), + ) + return results + + results = asyncio.run(scenario()) + assert manager.calls == 1 + assert all(result.result == results[0].result for result in results) + + +def test_tool_result_adds_tool_call_timing_fields(): + @utils.tool_result() + async def tool_handler(action: str) -> BaseResult: + await asyncio.sleep(0.01) + return BaseResult(result=[{"ok": True}]) + + previous = utils.is_result_debug_enabled() + utils.set_result_debug_enabled(True) + try: + response = asyncio.run(tool_handler("read")) + assert response.tool_call_started_at is not None + assert response.tool_call_finished_at is not None + assert isinstance(response.tool_call_duration_ms, int) + assert response.tool_call_duration_ms >= 0 + assert response.debug is not None + assert response.debug.get("network", {}).get("http_calls") == 0 + assert response.debug.get("network", {}).get("http_total_ms") == 0 + finally: + utils.set_result_debug_enabled(previous) + + +def test_tool_result_adds_network_debug_metrics(): + @utils.tool_result() + async def tool_handler(action: str) -> BaseResult: + utils._accumulate_network_debug(120) + utils._accumulate_network_debug(80) + return BaseResult(result=[{"ok": True}]) + + previous = utils.is_result_debug_enabled() + utils.set_result_debug_enabled(True) + try: + response = asyncio.run(tool_handler("read")) + assert response.debug is not None + network = response.debug.get("network", {}) + assert network.get("http_calls") == 2 + assert network.get("http_total_ms") == 200 + finally: + utils.set_result_debug_enabled(previous) + + +def test_tool_result_debug_disabled_by_default(): + @utils.tool_result() + async def tool_handler(action: str) -> BaseResult: + return BaseResult(result=[{"ok": True}]) + + previous = utils.is_result_debug_enabled() + utils.set_result_debug_enabled(False) + try: + response = asyncio.run(tool_handler("read")) + assert response.tool_call_started_at is not None + assert response.tool_call_finished_at is not None + assert isinstance(response.tool_call_duration_ms, int) + assert response.tool_call_duration_ms >= 0 + assert response.debug is None + finally: + utils.set_result_debug_enabled(previous) + + +def test_tool_result_timing_always_present_with_run_as_task_when_debug_disabled(): + @utils.tool_result() + async def tool_handler(action: str, args: dict) -> BaseResult: + manager = _DummyTaskManager() + return await manager.read_ok() + + previous = utils.is_result_debug_enabled() + utils.set_result_debug_enabled(False) + try: + response = asyncio.run(tool_handler("read", {"result_format": "raw"})) + assert response.error is None + assert response.tool_call_started_at is not None + assert response.tool_call_finished_at is not None + assert isinstance(response.tool_call_duration_ms, int) + assert response.tool_call_duration_ms >= 0 + assert response.debug is None + finally: + utils.set_result_debug_enabled(previous) diff --git a/tools/account_manager.py b/tools/account_manager.py index 4671978..482b9a9 100644 --- a/tools/account_manager.py +++ b/tools/account_manager.py @@ -14,6 +14,7 @@ limitations under the License. """ from typing import Optional, Dict, Any + import httpx from mcp.server.fastmcp import Context @@ -22,7 +23,14 @@ from formatters.account import format_accounts from models.manager import Manager from models.result import BaseResult -from tools.utils import api_request, format_sanitized_traceback +from tools.utils import ( + api_request, + normalize_action_args, + run_as_task, + tool_result, + ttl_cache_method, + validate_required_args, format_sanitized_traceback, +) class AccountManager(Manager): @@ -34,10 +42,9 @@ class AccountManager(Manager): def __init__(self, token: Optional[BzmToken], ctx: Context): super().__init__(token, ctx) - async def read(self, account_id: Optional[int]) -> BaseResult: - if not isinstance(account_id, int) or account_id < 1: - return BaseResult(error="Missing or invalid required argument 'account_id'. Expected integer.") - + @ttl_cache_method(ttl_seconds=30) + @run_as_task() + async def read(self, account_id: int) -> BaseResult: account_result = await api_request( self.token, "GET", @@ -55,9 +62,8 @@ async def read(self, account_id: Optional[int]) -> BaseResult: else: return account_result + @run_as_task() async def list(self, limit: int = 50, offset: int = 0) -> BaseResult: - if not isinstance(limit, int) or not isinstance(offset, int): - return BaseResult(error="Invalid arguments 'limit'/'offset'. Expected integers.") # Note: Not it's needed to control AI consent at this level @@ -75,6 +81,7 @@ async def list(self, limit: int = 50, offset: int = 0) -> BaseResult: params=parameters ) + def register(mcp, token: Optional[BzmToken]) -> None: @mcp.tool( name=f"{TOOLS_PREFIX}_account", @@ -83,23 +90,30 @@ def register(mcp, token: Optional[BzmToken]) -> None: Use this when a user needs to select a account. Actions: - read: Read a Account. Get the information of a account. - args(dict): Dictionary with the following required parameters: - account_id (int): The id of the account to get information. + args(dict): Dictionary with the following parameters: + account_id (int, required): The id of the account to get information. - list: List all accounts. - args(dict): Dictionary with the following required parameters: - limit (int, default=10, valid=[1 to 50]): The number of tests to list. - offset (int, default=0): Number of tests to skip. + args(dict): Dictionary with optional pagination (all other keys ignored for this action): + limit (int, optional, default=50, valid=[1 to 50 when result_format=auto/raw, 1000 when result_format=dataframe]): Max number of accounts to return. + offset (int, optional, default=0): Number of accounts to skip. Hints: - If you need to get the default account, use the project id to get the workspace and with that the account. - Use the read operation if AI consent information is needed. The AI Consent it's located at account level. +- Optional result formatting in args: `result_format` = `auto` (default), `dataframe` (force dataframe), `raw` (disable dataframe materialization). - **CRITICAL**: Always follow the action schema exactly. If args are required, include args with exact names/types. """ ) - async def account(action: str, args: Dict[str, Any], ctx: Context) -> BaseResult: + @tool_result() + async def account(arguments: Dict[str, Any] = None, ctx: Context = None) -> BaseResult: + action, args = normalize_action_args(arguments) + if not action: + return BaseResult(error="Missing required argument 'action' within tool arguments.") account_manager = AccountManager(token, ctx) try: match action: case "read": + if validation_error := validate_required_args(action, args, ["account_id"]): + return validation_error return await account_manager.read(args.get("account_id")) case "list": return await account_manager.list(args.get("limit", 50), args.get("offset", 0)) diff --git a/tools/async_task_manager.py b/tools/async_task_manager.py new file mode 100644 index 0000000..c89c4a2 --- /dev/null +++ b/tools/async_task_manager.py @@ -0,0 +1,254 @@ +""" +Copyright 2025 Perforce Software, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import asyncio +import logging +import secrets +import time +from dataclasses import dataclass +from datetime import datetime +from typing import Any, Awaitable, Callable, Dict, Optional + +from models.result import BaseResult +from tools.dataframe_manager import materialize_large_result_if_needed + +STATUS_WORKING = "working" +STATUS_PARKING = "parking" +STATUS_INPUT_REQUIRED = "input_required" +STATUS_COMPLETED = "completed" +STATUS_FAILED = "failed" +STATUS_CANCELLED = "cancelled" + +TERMINAL_STATES = {STATUS_COMPLETED, STATUS_FAILED, STATUS_CANCELLED} +ACTIVE_STATES = {STATUS_PARKING, STATUS_WORKING, STATUS_INPUT_REQUIRED} +MAX_PARALLEL_TASKS = 10 +TASK_ID_ALPHABET = "0123456789abcdefghjkmnpqrstvwxyz" +TASK_ID_LENGTH = 8 +TASK_ID_MAX_ATTEMPTS = 10 + +logger = logging.getLogger(__name__) + +STATUS_INFO = { + STATUS_WORKING: ( + "The request is currently being processed." + ), + STATUS_PARKING: ( + "The request is queued and waiting for an execution slot." + ), + STATUS_INPUT_REQUIRED: ( + "The receiver needs input from the requestor. " + "Use tasks_status for lightweight tracking and tasks_get to receive input requests." + ), + STATUS_COMPLETED: ( + "The request completed successfully and results are available." + ), + STATUS_FAILED: ( + "The associated request did not complete successfully." + ), + STATUS_CANCELLED: ( + "The request was cancelled before completion." + ), +} + +_semaphore = asyncio.Semaphore(MAX_PARALLEL_TASKS) + + +@dataclass +class TaskRecord: + task_id: str + action: Dict[str, Any] + created_at: float + last_updated_at: float + time_to_live_ms: Optional[int] + status: str + status_message: str + status_info: str + result: Optional[BaseResult] = None + asyncio_task: Optional[asyncio.Task] = None + started_running_at: Optional[float] = None + finished_at: Optional[float] = None + + def set_status(self, status: str, status_message: str): + self.status = status + self.status_message = status_message + self.status_info = STATUS_INFO.get(status, "") + self.last_updated_at = time.time() + if status == STATUS_WORKING and self.started_running_at is None: + self.started_running_at = self.last_updated_at + if status in TERMINAL_STATES: + self.finished_at = self.last_updated_at + + +_tasks: Dict[str, TaskRecord] = {} + + +def _to_iso(timestamp: float) -> str: + return datetime.fromtimestamp(timestamp).isoformat() + + +def _normalize_result(result: Any) -> BaseResult: + if isinstance(result, BaseResult): + return result + return BaseResult(result=[result]) + + +def _normalize_task_id(task_id: str) -> str: + return str(task_id).strip().lower() + + +def _generate_task_id() -> str: + return "".join(secrets.choice(TASK_ID_ALPHABET) for _ in range(TASK_ID_LENGTH)) + + +def _allocate_task_id() -> str: + for _ in range(TASK_ID_MAX_ATTEMPTS): + candidate = _generate_task_id() + if candidate not in _tasks: + return candidate + + logger.error( + "Unable to allocate unique 8-char task id after 10 attempts. " + "attempts=%s id_length=%s alphabet=crockford32 active_pool_size=%s", + TASK_ID_MAX_ATTEMPTS, + TASK_ID_LENGTH, + len(_tasks), + ) + raise RuntimeError("Unable to allocate unique 8-char task id after 10 attempts.") + + +async def _task_runner(task_record: TaskRecord, coro_factory: Callable[[], Awaitable[Any]]): + task_record.set_status(STATUS_PARKING, "Task is waiting for an available execution slot.") + try: + async with _semaphore: + task_record.set_status(STATUS_WORKING, "Task is currently running.") + if task_record.time_to_live_ms is None: + action_result = await coro_factory() + else: + action_result = await asyncio.wait_for(coro_factory(), timeout=task_record.time_to_live_ms / 1000) + normalized = _normalize_result(action_result) + result_format = str(task_record.action.get("result_format", "auto")).strip().lower() + if result_format != "raw": + normalized = await materialize_large_result_if_needed( + base_result=normalized, + origin_manager=task_record.action.get("manager", "unknown"), + origin_action=task_record.action.get("method", "unknown"), + force=(result_format == "dataframe"), + ) + task_record.result = normalized + if normalized.error: + task_record.set_status(STATUS_FAILED, f"Task finished with error: {normalized.error}") + else: + task_record.set_status(STATUS_COMPLETED, "Task finished successfully.") + except asyncio.TimeoutError: + timeout_message = ( + f"Task timed out after {task_record.time_to_live_ms} ms." + if task_record.time_to_live_ms is not None + else "Task timed out." + ) + task_record.result = BaseResult(error=timeout_message) + task_record.set_status(STATUS_CANCELLED, timeout_message) + except asyncio.CancelledError: + cancel_message = "Task was cancelled." + task_record.result = BaseResult(error=cancel_message) + task_record.set_status(STATUS_CANCELLED, cancel_message) + except Exception as exc: + error_message = f"Task failed with exception: {str(exc)}" + task_record.result = BaseResult(error=error_message) + task_record.set_status(STATUS_FAILED, error_message) + + +def submit_task( + action: Dict[str, Any], + coro_factory: Callable[[], Awaitable[Any]], + time_to_live_ms: Optional[int] = None +) -> str: + now = time.time() + task_id = _allocate_task_id() + task_record = TaskRecord( + task_id=task_id, + action=action, + created_at=now, + last_updated_at=now, + time_to_live_ms=time_to_live_ms, + status=STATUS_PARKING, + status_message="Task accepted and pending scheduling.", + status_info=STATUS_INFO[STATUS_PARKING], + ) + async_task = asyncio.create_task(_task_runner(task_record, coro_factory)) + task_record.asyncio_task = async_task + _tasks[task_id] = task_record + return task_id + + +def get_task_record(task_id: str) -> Optional[TaskRecord]: + normalized_task_id = _normalize_task_id(task_id) + return _tasks.get(normalized_task_id) + + +def remove_task(task_id: str) -> bool: + normalized_task_id = _normalize_task_id(task_id) + return _tasks.pop(normalized_task_id, None) is not None + + +def task_snapshot(task_record: TaskRecord, include_result: bool = False) -> Dict[str, Any]: + snapshot = { + "task_id": task_record.task_id, + "action": task_record.action, + "created_at": task_record.created_at, + "created_at_iso": _to_iso(task_record.created_at), + "last_updated_at": task_record.last_updated_at, + "last_updated_at_iso": _to_iso(task_record.last_updated_at), + "time_to_live_ms": task_record.time_to_live_ms, + "status": task_record.status, + "status_message": task_record.status_message, + "status_info": task_record.status_info, + "started_running_at": task_record.started_running_at, + "started_running_at_iso": _to_iso(task_record.started_running_at) if task_record.started_running_at else None, + "finished_at": task_record.finished_at, + "finished_at_iso": _to_iso(task_record.finished_at) if task_record.finished_at else None, + } + if include_result and task_record.result is not None: + snapshot["task_result"] = task_record.result.model_dump() + return snapshot + + +def list_tasks(status_list: Optional[list[str]] = None) -> list[TaskRecord]: + if not status_list: + return list(_tasks.values()) + expected = {status.lower() for status in status_list} + return [task for task in _tasks.values() if task.status.lower() in expected] + + +def is_terminal_status(status: str) -> bool: + return status in TERMINAL_STATES + + +def is_active_status(status: str) -> bool: + return status in ACTIVE_STATES + + +def cancel_task(task_id: str) -> Optional[TaskRecord]: + normalized_task_id = _normalize_task_id(task_id) + task_record = _tasks.get(normalized_task_id) + if not task_record: + return None + if task_record.asyncio_task and not task_record.asyncio_task.done(): + task_record.asyncio_task.cancel() + else: + task_record.set_status(STATUS_CANCELLED, "Task was already finished and marked as cancelled.") + if task_record.result is None: + task_record.result = BaseResult(error="Task was cancelled.") + return task_record diff --git a/tools/billing_manager.py b/tools/billing_manager.py index 57f9916..41d7f63 100644 --- a/tools/billing_manager.py +++ b/tools/billing_manager.py @@ -23,7 +23,9 @@ from models.manager import Manager from models.result import BaseResult from tools.billing_utils import calculate_test_cost -from tools.utils import format_sanitized_traceback +from tools.utils import format_sanitized_traceback, normalize_action_args, tool_result, validate_non_empty_str_arg, \ + validate_required_args, \ + run_as_task class BillingManager(Manager): @@ -31,8 +33,12 @@ class BillingManager(Manager): def __init__(self, token: Optional[BzmToken], ctx: Context): super().__init__(token, ctx) + @run_as_task() async def calculate_cost_from_config(self, args: Dict) -> BaseResult: - result = calculate_test_cost(args) + try: + result = calculate_test_cost(args) + except ValueError as e: + return BaseResult(error=str(e)) return BaseResult(result=[ { "cost": result, @@ -53,8 +59,8 @@ def register(mcp, token: Optional[BzmToken]) -> None: Actions: - calculate_cost_from_config: Calculate the cost of a test based on test configuration and workspace allowance type. args(dict): Dictionary with the following parameters: - allowance_type (str, required, valid=["credits", "virtualUserHours", "actualThreads", "serverHours", "functionalRequests"]): The workspace allowance. - test_type (str, required, default="performance", valid=["performance", "browser_performance", "gui_functional", "api_monitoring", "service_virtualization"]): Type of test. + allowance_type (str, required, non-empty, valid=["credits", "virtualUserHours", "actualThreads", "serverHours", "functionalRequests"]): The workspace allowance. + test_type (str, optional, default="performance", valid=["performance", "browser_performance", "gui_functional", "api_monitoring", "service_virtualization"]): Type of test. concurrency (int, required for performance/browser_performance): Maximum concurrent virtual users/threads. duration_minutes (float, required for performance/browser_performance/gui_functional): Test duration in minutes. iterations (int, optional, alternative to duration_minutes): Number of iterations. @@ -82,14 +88,23 @@ def register(mcp, token: Optional[BzmToken]) -> None: - Test Data usage increases cost by 50% for all test types. - Server hours calculation can use provided number_of_servers or estimate based on concurrency (~1000 users per engine). - All calculations are based on official BlazeMeter documentation (blazemeter-usage-billing skill). +- Optional result formatting in args: `result_format` = `auto` (default), `dataframe` (force dataframe), `raw` (disable dataframe materialization). - **CRITICAL**: Always follow the action schema exactly. If args are required, include args with exact names/types. """ ) - async def billing(action: str, args: Dict[str, Any], ctx: Context) -> BaseResult: + @tool_result() + async def billing(arguments: Dict[str, Any] = None, ctx: Context = None) -> BaseResult: + action, args = normalize_action_args(arguments) + if not action: + return BaseResult(error="Missing required argument 'action' within tool arguments.") billing_manager = BillingManager(token, ctx) try: match action: case "calculate_cost_from_config": + if validation_error := validate_required_args(action, args, ["allowance_type"]): + return validation_error + if err := validate_non_empty_str_arg(action, args, "allowance_type"): + return err return await billing_manager.calculate_cost_from_config(args) case _: return BaseResult( diff --git a/tools/bridge.py b/tools/bridge.py index 39a30ee..905c512 100644 --- a/tools/bridge.py +++ b/tools/bridge.py @@ -28,9 +28,14 @@ async def read_account(token: BzmToken, ctx: Context, account_id: int) -> BaseRe return await AccountManager(token, ctx).read(account_id) -async def read_project(token: BzmToken, ctx: Context, project_id: int) -> BaseResult: +async def read_project( + token: BzmToken, + ctx: Context, + project_id: int, + include_tests_count: bool = False +) -> BaseResult: from tools.project_manager import ProjectManager - return await ProjectManager(token, ctx).read(project_id) + return await ProjectManager(token, ctx).read(project_id, include_tests_count=include_tests_count) async def read_workspace(token: BzmToken, ctx: Context, workspace_id: int) -> BaseResult: @@ -46,7 +51,7 @@ async def read_test(token: BzmToken, ctx: Context, test_id: int) -> BaseResult: async def count_project_tests(token: BzmToken, ctx: Context, project_id: int) -> int: from tools.test_manager import TestManager return ( - await TestManager(token, ctx).list(project_id=project_id, limit=1, offset=0, control_ai_consent=False)).total + await TestManager(token, ctx).list(project_id_list=[project_id], limit=1, offset=0, control_ai_consent=False)).total async def read_execution(token: BzmToken, ctx: Context, execution_id: int) -> BaseResult: diff --git a/tools/dataframe_manager.py b/tools/dataframe_manager.py new file mode 100644 index 0000000..c781abe --- /dev/null +++ b/tools/dataframe_manager.py @@ -0,0 +1,777 @@ +""" +Copyright 2025 Perforce Software, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" +import asyncio +import hashlib +import json +import re +import uuid +from dataclasses import dataclass, asdict +from datetime import datetime +from typing import Any, Dict, List, Optional + +import polars as pl +from models.result import BaseResult + + +DATAFRAME_JSON_SIZE_THRESHOLD = 8000 +_dataframes: Dict[str, "DataFrameRecord"] = {} +_write_lock = asyncio.Lock() +_sql_context = pl.SQLContext() + +_DISALLOWED_SQL_PATTERN = re.compile( + r"\b(insert|update|delete|create|drop|alter|truncate|replace|merge|call|copy|grant|revoke)\b", + re.IGNORECASE, +) +_LEADING_SQL_COMMENTS_PATTERN = re.compile( + r"^(?:\s*(?:--[^\n]*\n|/\*.*?\*/))*\s*", + re.DOTALL, +) +_SQL_LINE_COMMENT_PATTERN = re.compile(r"--[^\n]*") +_SQL_BLOCK_COMMENT_PATTERN = re.compile(r"/\*.*?\*/", re.DOTALL) +_ORDER_BY_PATTERN = re.compile(r"\border\s+by\b", re.IGNORECASE) +_LIMIT_PATTERN = re.compile(r"\blimit\b", re.IGNORECASE) +_OFFSET_PATTERN = re.compile(r"\boffset\b", re.IGNORECASE) + + +@dataclass +class DataFrameRecord: + dataframe_id: str + table_name: str + created_at: str + origin_manager: str + origin_action: str + rows: int + columns: int + schema: List[Dict[str, str]] + schema_hash: str + json_size_chars: int + dataframe: pl.DataFrame + + def to_metadata(self, include_schema: bool = True) -> Dict[str, Any]: + metadata = asdict(self) + metadata.pop("dataframe", None) + if not include_schema: + metadata.pop("schema", None) + return metadata + + +def _json_default_serializer(value: Any) -> Any: + if hasattr(value, "model_dump"): + return value.model_dump(mode="json") + if hasattr(value, "isoformat"): + return value.isoformat() + return str(value) + + +def serialize_result_to_compact_json(result: List[Any]) -> str: + return json.dumps(result, separators=(",", ":"), ensure_ascii=False, default=_json_default_serializer) + + +def build_dataframe_from_result(result: List[Any]) -> pl.DataFrame: + normalized = json.loads(serialize_result_to_compact_json(result)) + + # matrix envelope: [{"columns":[...], "rows":[...]}] + if ( + isinstance(normalized, list) + and len(normalized) == 1 + and isinstance(normalized[0], dict) + and set(normalized[0].keys()) == {"columns", "rows"} + and isinstance(normalized[0]["columns"], list) + and isinstance(normalized[0]["rows"], list) + ): + matrix = normalized[0] + return pl.DataFrame(matrix["rows"], schema=[str(c) for c in matrix["columns"]], orient="row") + + # columnar envelope: [{"colA":[...], "colB":[...]}] + if ( + isinstance(normalized, list) + and len(normalized) == 1 + and isinstance(normalized[0], dict) + and normalized[0] + and all(isinstance(v, list) for v in normalized[0].values()) + ): + col_lengths = {len(v) for v in normalized[0].values()} + if len(col_lengths) == 1: + return pl.DataFrame(normalized[0]) + + if isinstance(normalized, list): + if not normalized: + return pl.DataFrame() + if all(isinstance(item, dict) for item in normalized): + return pl.DataFrame(normalized) + return pl.DataFrame({"value": normalized}) + if isinstance(normalized, dict): + return pl.DataFrame([normalized]) + return pl.DataFrame({"value": [normalized]}) + + +def auto_flatten_wide( + df: pl.DataFrame, + max_passes: int = 30, + sep: str = "__", +) -> pl.DataFrame: + """ + Flatten nested structures in a DataFrame for SQL queryability. + + - Nested structs: expanded into flat columns with path-style names + (e.g. configuration__threads, config__inner__b). Only flattens down to leaf scalars. + - List columns: flattened to scalar (first element). List of structs becomes + the struct fields of the first element with path prefix; list of scalars + becomes the first scalar. + - Preserves original row count. + - Safe for schemas with configuration, override_executions, and similar nested structures. + """ + for _ in range(max_passes): + struct_cols = [c for c, dt in df.schema.items() if isinstance(dt, pl.Struct)] + list_cols = [c for c, dt in df.schema.items() if isinstance(dt, pl.List)] + + if not struct_cols and not list_cols: + break + + # Flatten list columns: take first element, then unnest if struct + for col in list_cols: + inner = getattr(df.schema[col], "inner", None) + is_struct_inner = inner is not None and isinstance(inner, pl.Struct) + + temp = f"{col}{sep}temp" + expr = pl.col(col).fill_null([]).list.first() + df = df.with_columns(expr.alias(temp)) + + if is_struct_inner: + df = df.unnest(temp).drop(col) + # Rename to path format: col__field_name + fields = getattr(inner, "fields", []) + rename_map = {f.name: f"{col}{sep}{f.name}" for f in fields} + df = df.rename(rename_map) + else: + df = df.drop(col).rename({temp: col}) + + # Unnest struct columns one at a time, renaming to path format + struct_cols = [c for c, dt in df.schema.items() if isinstance(dt, pl.Struct)] + for col in struct_cols: + struct_dtype = df.schema[col] + fields = getattr(struct_dtype, "fields", []) + rename_map = {f.name: f"{col}{sep}{f.name}" for f in fields} + df = df.unnest(col) + df = df.rename(rename_map) + + return df + + +def _to_schema_rows(dataframe: pl.DataFrame) -> List[Dict[str, str]]: + schema = dataframe.schema + return [{"name": name, "dtype": str(dtype)} for name, dtype in schema.items()] + + +def _schema_hash(schema_rows: List[Dict[str, str]]) -> str: + payload = json.dumps(schema_rows, separators=(",", ":"), ensure_ascii=False) + return hashlib.md5(payload.encode("utf-8")).hexdigest() + + +def _stable_hash(payload: str) -> str: + return hashlib.md5(payload.encode("utf-8")).hexdigest() + + +def _normalize_root_dtype(dtype: str) -> str: + dtype_str = str(dtype or "").strip() + if dtype_str.startswith("Struct("): + return "Struct" + if dtype_str.startswith("List("): + inner = dtype_str[5:-1].strip() if dtype_str.endswith(")") else "" + if inner.startswith("Struct("): + return "List(Struct)" + return "List" + if dtype_str.startswith("Array("): + inner = dtype_str[6:-1].strip() if dtype_str.endswith(")") else "" + if inner.startswith("Struct("): + return "Array(Struct)" + return "Array" + return dtype_str + + +def _canonicalize_top_schema(schema_rows: List[Dict[str, str]]) -> List[Dict[str, str]]: + top_level = [ + {"name": str(row.get("name", "")), "dtype": _normalize_root_dtype(str(row.get("dtype", "")))} + for row in schema_rows + ] + return sorted(top_level, key=lambda col: col["name"]) + + +async def register_dataframe( + result: List[Any], + origin_manager: str, + origin_action: str, + json_size_chars: int, + flatten: bool = True, +) -> Dict[str, Any]: + dataframe = build_dataframe_from_result(result) + return await _register_dataframe_instance( + dataframe, origin_manager, origin_action, json_size_chars, flatten=flatten + ) + + +async def _register_dataframe_instance( + dataframe: pl.DataFrame, + origin_manager: str, + origin_action: str, + json_size_chars: int, + flatten: bool = True, +) -> Dict[str, Any]: + if flatten: + try: + dataframe = auto_flatten_wide(dataframe) + except Exception: + pass # Keep original dataframe if flattening fails + dataframe_id = str(uuid.uuid4()) + table_name = f"df_{dataframe_id.replace('-', '_')}" + record = DataFrameRecord( + dataframe_id=dataframe_id, + table_name=table_name, + created_at=datetime.utcnow().isoformat(), + origin_manager=origin_manager, + origin_action=origin_action, + rows=dataframe.height, + columns=dataframe.width, + schema=(schema_rows := _to_schema_rows(dataframe)), + schema_hash=_schema_hash(schema_rows), + json_size_chars=json_size_chars, + dataframe=dataframe, + ) + async with _write_lock: + _sql_context.register(table_name, dataframe) + _dataframes[dataframe_id] = record + return record.to_metadata() + + +async def materialize_large_result_if_needed( + base_result: BaseResult, + origin_manager: str, + origin_action: str, + force: bool = False +) -> BaseResult: + if not isinstance(base_result, BaseResult) or base_result.error or base_result.result is None: + return base_result + if ( + isinstance(base_result.result, list) + and len(base_result.result) == 1 + and isinstance(base_result.result[0], dict) + and base_result.result[0].get("stored_as_dataframe") is True + and base_result.result[0].get("dataframe_id") + ): + # Avoid rematerializing a payload that is already a dataframe reference. + return base_result + try: + compact_json = serialize_result_to_compact_json(base_result.result) + json_size_chars = len(compact_json) + except Exception as exc: + base_result.append_warnings( + [f"Result size check failed, skipping dataframe materialization: {exc}"] + ) + return base_result + + if not force and json_size_chars <= DATAFRAME_JSON_SIZE_THRESHOLD: + return base_result + + try: + dataframe_preview = build_dataframe_from_result(base_result.result) + except Exception as exc: + base_result.append_warnings( + [f"Result dataframe preview failed, skipping dataframe materialization: {exc}"] + ) + return base_result + + if dataframe_preview.height == 0: + base_result.append_info([ + "Result contains no rows; dataframe was not created." + ]) + return base_result + + metadata = await _register_dataframe_instance( + dataframe=dataframe_preview, + origin_manager=origin_manager, + origin_action=origin_action, + json_size_chars=json_size_chars, + ) + base_result.result = [{ + "stored_as_dataframe": True, + "dataframe_id": metadata["dataframe_id"], + "table_name": metadata["table_name"], + "rows": metadata["rows"], + "columns": metadata["columns"], + "schema_hash": metadata["schema_hash"], + "json_size_chars": metadata["json_size_chars"], + }] + base_result.append_info([ + "Large result was stored as an in-memory dataframe. Use blazemeter_tools with action " + "'dataframes_list'/'dataframes_get' to inspect metadata and 'dataframes_query' to read data with SQL.", + "ORDER BY + LIMIT + OFFSET are mandatory in every dataframe query.", + "Use a prudent default page size of up to 100 rows (for example, LIMIT 100 OFFSET 0), then continue paging as needed.", + "When the dataframe is no longer needed, free resources with 'dataframes_remove' or 'dataframes_clear'.", + ]) + return base_result + + +async def list_dataframes_metadata(include_schema: bool = False) -> List[Dict[str, Any]]: + async with _write_lock: + return [record.to_metadata(include_schema=include_schema) for record in _dataframes.values()] + + +async def get_dataframe_metadata(dataframe_id: str, include_schema: bool = True) -> Optional[Dict[str, Any]]: + async with _write_lock: + record = _dataframes.get(dataframe_id) + if not record: + return None + return record.to_metadata(include_schema=include_schema) + + +async def group_dataframe_schemas(dataframe_id_list: Optional[List[str]] = None) -> Dict[str, Any]: + async with _write_lock: + if dataframe_id_list: + requested = [str(df_id) for df_id in dataframe_id_list] + selected = [record for df_id in requested if (record := _dataframes.get(df_id))] + missing = [df_id for df_id in requested if df_id not in _dataframes] + else: + selected = list(_dataframes.values()) + missing = [] + top_groups: Dict[str, Dict[str, Any]] = {} + for record in selected: + top_schema = _canonicalize_top_schema(record.schema) + top_signature = json.dumps(top_schema, separators=(",", ":"), ensure_ascii=False) + top_hash = _stable_hash(top_signature) + group = top_groups.setdefault( + top_hash, + { + "dataframes": [], + "_columns": {}, + }, + ) + group["dataframes"].append( + record.dataframe_id + ) + + schema_by_name = {str(col.get("name", "")): str(col.get("dtype", "")) for col in record.schema} + for top_col in top_schema: + column_name = top_col["name"] + full_dtype = schema_by_name.get(column_name, "__MISSING__") + schema_preview = full_dtype + version_signature = json.dumps({"dtype": full_dtype}, separators=(",", ":"), ensure_ascii=False) + column_hash = _stable_hash(version_signature) + + column_group = group["_columns"].setdefault( + column_name, + { + "name": column_name, + "_versions": {}, + }, + ) + version_group = column_group["_versions"].setdefault( + column_hash, + { + "hash": column_hash, + "column_schema": schema_preview, + "dataframes": [], + }, + ) + version_group["dataframes"].append( + record.dataframe_id + ) + + top_level_groups = [] + df_sets: Dict[str, str] = {} + dataframe_set_index: Dict[tuple[str, ...], str] = {} + next_df_set_id = 1 + + def _register_dataframe_set(ids: List[str]) -> str: + nonlocal next_df_set_id + normalized = tuple(sorted(set(ids))) + if normalized in dataframe_set_index: + return dataframe_set_index[normalized] + set_id = str(next_df_set_id) + next_df_set_id += 1 + dataframe_set_index[normalized] = set_id + df_sets[set_id] = ",".join(normalized) + return set_id + + for top_hash in sorted(top_groups.keys()): + group = top_groups[top_hash] + columns = [] + varying_columns: List[str] = [] + for column_name in sorted(group["_columns"].keys()): + column_group = group["_columns"][column_name] + variations = [] + for column_hash in sorted(column_group["_versions"].keys()): + version = dict(column_group["_versions"][column_hash]) + version["df_ref"] = _register_dataframe_set(version.pop("dataframes")) + version.pop("hash", None) + variations.append(version) + if len(variations) > 1: + varying_columns.append(column_group["name"]) + columns.append( + { + "name": column_group["name"], + "variations": variations, + } + ) + top_level_groups.append( + { + "df_ref": _register_dataframe_set(group["dataframes"]), + "varying_columns": ",".join(varying_columns), + "columns": columns, + } + ) + + return { + "groups": top_level_groups, + "df_sets": df_sets, + "missing_df_ids": ",".join(missing), + } + + +def _sanitize_sql_for_keyword_scan(sql: str) -> str: + without_comments = _SQL_BLOCK_COMMENT_PATTERN.sub(" ", _SQL_LINE_COMMENT_PATTERN.sub(" ", sql)) + # Remove string and quoted identifier contents to avoid false positives, e.g. SELECT 'delete' + sanitized = re.sub(r"'(?:''|[^'])*'", "''", without_comments) + sanitized = re.sub(r'"(?:""|[^"])*"', '""', sanitized) + sanitized = re.sub(r"`(?:``|[^`])*`", "``", sanitized) + return sanitized + + +def _validate_sql_read_only(sql: str) -> Optional[str]: + without_comments = _LEADING_SQL_COMMENTS_PATTERN.sub("", sql or "") + lowered = without_comments.strip().lower() + if not lowered: + return "SQL query is empty. Provide a SELECT query." + if not (lowered.startswith("select") or lowered.startswith("with")): + return ( + "Only read-only SQL is allowed. Start queries with SELECT or WITH (CTE + SELECT). " + "ORDER BY + LIMIT + OFFSET are mandatory for all dataframe queries." + ) + sanitized = _sanitize_sql_for_keyword_scan(lowered) + if not _ORDER_BY_PATTERN.search(sanitized): + return ( + "Deterministic pagination required. ORDER BY is mandatory for dataframe queries. " + "Use ORDER BY + LIMIT + OFFSET in every query, for example: " + "SELECT * FROM df_x ORDER BY created_at DESC LIMIT 100 OFFSET 0." + ) + if not _LIMIT_PATTERN.search(sanitized): + return ( + "Deterministic pagination required. LIMIT is mandatory for dataframe queries. " + "Use ORDER BY + LIMIT + OFFSET in every query, for example: " + "SELECT * FROM df_x ORDER BY created_at DESC LIMIT 100 OFFSET 0." + ) + if not _OFFSET_PATTERN.search(sanitized): + return ( + "Deterministic pagination required. OFFSET is mandatory for dataframe queries. " + "Use ORDER BY + LIMIT + OFFSET in every query, for example: " + "SELECT * FROM df_x ORDER BY created_at DESC LIMIT 100 OFFSET 0." + ) + disallowed = _DISALLOWED_SQL_PATTERN.search(sanitized) + if disallowed: + return ( + f"SQL statement '{disallowed.group(1).upper()}' is not allowed in dataframe queries. " + "Allowed entry points are SELECT and WITH." + ) + return None + + +def query_dataframes(sql: str, output_format: str = "matrix") -> Dict[str, Any]: + sql_error = _validate_sql_read_only(sql) + if sql_error: + return {"error": sql_error} + normalized_output_format = str(output_format or "matrix").strip().lower() + if normalized_output_format not in {"matrix", "columnar", "records"}: + return {"error": "Invalid output_format. Allowed values: matrix, columnar, records."} + try: + query_result = _sql_context.execute(sql) + if hasattr(query_result, "collect"): + query_result = query_result.collect() + if not isinstance(query_result, pl.DataFrame): + query_result = pl.DataFrame(query_result) + if normalized_output_format == "columnar": + result_payload = [query_result.to_dict(as_series=False)] + elif normalized_output_format == "records": + result_payload = query_result.to_dicts() + else: + result_payload = [{ + "columns": query_result.columns, + "rows": [list(row) for row in query_result.rows()], + }] + return { + "result": result_payload, + "rows": query_result.height, + "columns": query_result.width, + "schema": _to_schema_rows(query_result), + "output_format": normalized_output_format, + } + except Exception as exc: + error_text = str(exc) + lowered = error_text.lower() + if "not found" in lowered and ("table" in lowered or "relation" in lowered): + guidance = ( + "Table not found in SQL context. Use dataframes_list to discover available table_name values, " + "then retry your query." + ) + elif "column" in lowered and "not found" in lowered: + guidance = ( + "Column not found. Use dataframes_get to inspect schema and exact column names before querying." + ) + elif "syntax" in lowered or "parse" in lowered: + guidance = ( + "SQL syntax error. Use dataframes_sql_help for allowed SQL operations and examples." + ) + else: + guidance = ( + "Use dataframes_sql_help first for supported SQL semantics. " + "For multi-dataframe queries, run dataframes_schema_groups before broad schema inspection. " + "Use dataframes_get selectively for outliers or ambiguous fields." + ) + return { + "error": ( + f"SQL query failed: {exc}. {guidance}" + ) + } + + +async def remove_dataframe(dataframe_id: str) -> bool: + async with _write_lock: + record = _dataframes.pop(dataframe_id, None) + if not record: + return False + try: + _sql_context.unregister(record.table_name) + except Exception: + pass + return True + + +async def clear_dataframes() -> int: + async with _write_lock: + ids = list(_dataframes.keys()) + for dataframe_id in ids: + record = _dataframes.pop(dataframe_id, None) + if not record: + continue + try: + _sql_context.unregister(record.table_name) + except Exception: + pass + return len(ids) + + +def get_sql_capabilities() -> Dict[str, Any]: + return { + "engine_scope": { + "query_entrypoints": ["SELECT", "WITH"], + "mode": "read-only", + "description": "SQL support is available for SELECT/WITH queries with BlazeMeter MCP query constraints.", + }, + "allowed_entrypoints": ["SELECT", "WITH"], + "disallowed_statements": [ + "INSERT", + "UPDATE", + "DELETE", + "CREATE", + "DROP", + "ALTER", + "TRUNCATE", + "REPLACE", + "MERGE", + "CALL", + "COPY", + "GRANT", + "REVOKE", + ], + "allowed_features": [ + "JOIN", + "UNION", + "UNION ALL", + "CTE (WITH)", + "GROUP BY", + "HAVING", + "ORDER BY", + "LIMIT", + "OFFSET", + "aggregations", + "UNNEST", + ], + "supported_functions": [ + "ABS", "ACOS", "ACOSD", "ARRAY_CONTAINS", "ARRAY_GET", "ARRAY_LENGTH", "ARRAY_LOWER", "ARRAY_MEAN", + "ARRAY_REVERSE", "ARRAY_SUM", "ARRAY_TO_STRING", "ARRAY_UNIQUE", "ARRAY_UPPER", "ASIN", "ASIND", + "ATAN", "ATAN2", "ATAN2D", "ATAND", "AVG", "BIT_LENGTH", "CBRT", "CEIL", "COALESCE", "CONCAT", + "CONCAT_WS", "COS", "COSD", "COT", "COTD", "COUNT", "DATE", "DATE_PART", "DEGREES", "ENDS_WITH", + "EXP", "EXTRACT", "FIRST", "FLOOR", "GREATEST", "IF", "IFNULL", "INITCAP", "LAST", "LEAST", "LEFT", + "LENGTH", "LN", "LOG", "LOG1P", "LOG10", "LOG2", "LOWER", "LTRIM", "MAX", "MEDIAN", "MIN", "MOD", + "NULLIF", "OCTET_LENGTH", "PI", "POW", "RADIANS", "REGEXP_LIKE", "REPLACE", "REVERSE", "RIGHT", + "ROUND", "RTRIM", "SIGN", "SIN", "SIND", "SQRT", "STARTS_WITH", "STDDEV", "STRPOS", "SUBSTRING", + "SUM", "TAN", "TAND", "UNNEST", "UPPER", "VARIANCE" + ], + "unsupported_functions": [ + {"name": "STRUCT_EXTRACT", "reason": "Not recognized in this SQL context."}, + {"name": "TO_JSON", "reason": "Not recognized in this SQL context."}, + ], + "unsupported_or_unstable_patterns": [ + "Complex chained nested access with mixed subscript and dot notation in a single expression", + "Casting LIST/STRUCT directly to STRING for inspection", + "Nested extraction without staged CTE when list expansion is required", + "Unqualified join keys that create ambiguous column references", + ], + "ai_common_mistakes": [ + "Assuming generic warehouse helper functions are available", + "Building one very large query instead of staged CTEs", + "Skipping aliases in JOIN/CTE steps", + "Trying direct list aggregations (for example list_max on nested overrides) instead of UNNEST + staged CTE", + "Trying unsupported helper functions before checking supported_functions", + "Assuming nested/scalar fields are homogeneous across dataframes without checking schemas first", + "Inspecting every dataframe with dataframes_get before checking grouped schema differences", + "Using direct nested extraction in the first multi-dataframe query after schema groups reports column variations", + "Assuming single dataframe justifies bypassing the robust UNNEST/CTE pattern for nested/list fields", + "Trying the 'fast' direct nested access first when the query touches nested/list fields", + "Try-fast: attempting the simplest path first and retrying on failure instead of reasoning through the design before executing", + "Not considering all values in a nested list when searching for max/min, which can miss important extreme values", + "Using only the first element of a nested list instead of aggregating over all its values", + ], + "query_rules": [ + "CRITICAL: Before writing queries that combine 2 or more dataframes, run dataframes_schema_groups first to validate schema compatibility across all involved dataframes.", + "CRITICAL: Use dataframes_get only for targeted drill-down on dataframes flagged by schema groups as different or ambiguous for required fields.", + "CRITICAL: Hard gate: if schema groups reports column variations, direct nested extraction is forbidden in the first query.", + "CRITICAL: If the query touches nested/list fields, direct nested access is forbidden. Always use the robust pattern: UNNEST -> aggregate -> join-back in CTEs. No exception for single dataframe.", + "IMPORTANT: Validate schema compatibility before using nested fields.", + "ORDER BY + LIMIT + OFFSET are mandatory in every dataframe query.", + "Use deterministic pagination: ORDER BY + LIMIT + OFFSET.", + "Recommended default page size: LIMIT 100 OFFSET 0, then continue paging.", + "If loading data with result_format=dataframe, prefer one initial fetch with the maximum allowed tool limit, then paginate/filter in dataframes_query.", + "CRITICAL: When a query includes UNNEST + CTE + JOIN, always enforce explicit join-key renaming and qualification. Rename the base key in the UNNEST CTE (e.g. test_id AS base_test_id) and use only that renamed key downstream.", + "For CTE-heavy joins, rename join keys in intermediate CTEs (for example test_id AS t_id or base_test_id).", + "Single dataframe query flow (scalar-only): dataframes_sql_help -> dataframes_get -> dataframes_query.", + "Single dataframe query flow (nested/list fields): dataframes_sql_help -> dataframes_get -> staged CTE (UNNEST -> aggregate -> join-back) -> dataframes_query. Same robust pattern as multi-dataframe.", + "Multi-dataframe nested flow: dataframes_sql_help -> dataframes_schema_groups -> targeted dataframes_get -> staged CTE (UNNEST -> aggregate -> join-back) -> final query.", + "If schema groups returns a CRITICAL variation warning, call dataframes_sql_help again immediately before writing the final query.", + "Direct nested access is allowed only when each required nested column has exactly one variation across all relevant dataframes in schema groups.", + ], + "nested_unnest_intro": ( + "To query and aggregate data from a list of structs (e.g., override_executions), use UNNEST in a CTE to flatten the list, " + "then aggregate and compare with scalar fields using GREATEST/LEAST. See query_examples.good for the compact pattern." + ), + "nested_list_pre_sql_checklist": [ + "Step 1: Identify if the query touches nested/list (List, Struct, Array in schema). Step 2: If yes, confirm robust pattern. Step 3: Design the CTE structure. Step 4: Execute. Do not skip to Step 4.", + "Before launching SQL that touches nested/list fields, explicitly confirm: 'There are nested/list fields; I use the robust UNNEST -> aggregate -> join-back pattern.'", + "Do not attempt the 'fast' direct nested extraction first. Start with the robust CTE pattern.", + "Single dataframe is NOT an exception: use the same robust pattern when querying nested/list columns.", + "Anti-ambiguity checklist (UNNEST+CTE+JOIN): No unqualified key columns in SELECT, JOIN, GROUP BY, or ORDER BY; UNNEST CTE key is renamed (base_* or src_*); join-back uses different left/right key names; final projection is scalar-only; query ends with ORDER BY ... LIMIT ... OFFSET.", + ], + "pre_execution_reasoning": [ + "Before dataframes_query: reason step-by-step. (1) Schema check: what columns and types? (2) Nested/list? If List, Struct, Array → robust pattern. (3) Pattern selection: scalar-only vs UNNEST/CTE. (4) Design the query structure. (5) Confirm, then execute.", + "Do not try-fast. Design before do.", + ], + "recommended_patterns": [ + "Prefer one final aggregation query over multiple partial queries when feasible", + "Build queries incrementally: base SELECT -> UNNEST CTE -> aggregate -> join -> final sort/page", + "Use a dedicated CTE for UNNEST operations on nested arrays/lists", + "If a nested field fails, use the robust pattern: UNNEST -> aggregate -> join-back.", + "First nested-field query must use the robust UNNEST/CTE pattern. Never try direct nested access first. No exception for single dataframe.", + "Alias every table and CTE explicitly", + "Rename join keys in CTEs (for example: t_id, base_test_id, src_test_id) before joins to avoid ambiguous references", + "Join-key hygiene for UNNEST+CTE+JOIN: (1) In UNNEST CTE rename base key immediately (test_id AS base_test_id). (2) In downstream CTEs use only the renamed key (GROUP BY base_test_id). (3) In join-back use fully-qualified names (ON b.test_id = a.base_test_id). (4) In final SELECT prefix columns with table alias. (5) Never reuse generic key names across CTE boundaries.", + "Use COALESCE/CASE for fallback values after joins", + "UNION ALL only scalar projections; avoid UNION over nested struct/list columns", + "Validate each CTE with a small LIMIT before composing final query", + "For multi-dataframe analysis, use schema groups first, then perform targeted per-dataframe inspection only when needed.", + "To get the maximum value between a scalar field and all values in a nested list per record, use UNNEST on the list, then GROUP BY and GREATEST(MAX(list.field), MAX(scalar)).", + ], + "known_engine_pitfalls": [ + "CTE + JOIN resolution may treat same-name keys as ambiguous even when aliases are present; rename join keys in the UNNEST stage (base_*/src_*) to guarantee deterministic resolution", + "Alias/join-key resolution may fail in some CTE + JOIN combinations", + "Ambiguous join keys are common if columns are not fully qualified", + "Nested schema drift across tables can break field resolution", + "UNION over nested struct/list columns is fragile; normalize to scalar output first", + "Direct list aggregation over nested overrides is brittle; UNNEST + MAX + join-back is more reliable", + ], + "nested_query_recipe": [ + "Base table CTE", + "UNNEST CTE", + "Aggregate CTE (for example MAX over nested field)", + "Join aggregate back to base", + "Apply null-safe metric expression (for example GREATEST(COALESCE(default1,0), COALESCE(default2,0), COALESCE(override_max,0)))", + "Emit scalar projection only", + "UNION ALL scalar projections only", + "Final ORDER BY + LIMIT + OFFSET", + ], + "debug_ladder": [ + "A) Run schema groups for all candidate dataframes and identify only the outliers to inspect with dataframes_get.", + "B) base SELECT LIMIT 10", + "C) UNNEST stage LIMIT 10", + "D) aggregate stage LIMIT 10", + "E) join result LIMIT 10", + "F) add ranking and pagination", + "G) add next table to UNION ALL and repeat", + ], + "query_examples": { + "good": [ + "SELECT * FROM df_tests ORDER BY test_id LIMIT 100 OFFSET 0", + "WITH expanded AS (SELECT t.test_id, UNNEST(t.override_executions) AS ov FROM df_tests t), " + "agg AS (SELECT e.test_id, MAX(e.ov.concurrency) AS max_concurrency FROM expanded e GROUP BY e.test_id) " + "SELECT t.test_id, t.test_name, " + "GREATEST(COALESCE(t.configuration.threads, 0), COALESCE(a.max_concurrency, 0)) AS max_concurrency_used " + "FROM df_tests t LEFT JOIN agg a ON t.test_id = a.test_id " + "ORDER BY max_concurrency_used DESC, t.test_id ASC LIMIT 10 OFFSET 0", + "WITH a AS (SELECT test_id, test_name FROM df_a), b AS (SELECT test_id, test_name FROM df_b) " + "SELECT * FROM a UNION ALL SELECT * FROM b ORDER BY test_id LIMIT 100 OFFSET 0", + "WITH s1_exp AS (SELECT t.test_id AS t_id, UNNEST(t.override_executions) AS ov FROM df_a t), " + "s1_agg AS (SELECT e.t_id, MAX(e.ov.concurrency) AS ov_max FROM s1_exp e GROUP BY e.t_id), " + "s1 AS (SELECT t.test_id, t.test_name, GREATEST(COALESCE(t.configuration.threads,0), COALESCE(a.ov_max,0)) " + "AS max_concurrency_used FROM df_a t LEFT JOIN s1_agg a ON t.test_id = a.t_id), " + "s2_exp AS (SELECT t.test_id AS t_id, UNNEST(t.override_executions) AS ov FROM df_b t), " + "s2_agg AS (SELECT e.t_id, MAX(e.ov.concurrency) AS ov_max FROM s2_exp e GROUP BY e.t_id), " + "s2 AS (SELECT t.test_id, t.test_name, GREATEST(COALESCE(t.configuration.threads,0), COALESCE(a.ov_max,0)) " + "AS max_concurrency_used FROM df_b t LEFT JOIN s2_agg a ON t.test_id = a.t_id), " + "all_rows AS (SELECT test_id, test_name, max_concurrency_used FROM s1 UNION ALL " + "SELECT test_id, test_name, max_concurrency_used FROM s2) " + "SELECT test_name, test_id, max_concurrency_used FROM all_rows " + "ORDER BY max_concurrency_used DESC, test_id ASC LIMIT 10 OFFSET 0", + "WITH expanded AS (SELECT t.test_id, t.test_name, t.configuration.threads AS threads, UNNEST(t.override_executions) AS ov FROM df_tests t), " + "agg AS (SELECT test_id, test_name, GREATEST(COALESCE(MAX(ov.concurrency), 0), COALESCE(MAX(threads), 0)) AS max_concurrency FROM expanded GROUP BY test_id, test_name) " + "SELECT test_id, test_name, max_concurrency FROM agg ORDER BY max_concurrency DESC, test_id ASC LIMIT 10 OFFSET 0", + ], + "bad": [ + "SELECT * FROM df_tests", + "SELECT * FROM df_tests WHERE status = 'ERROR'", + "SELECT test_id, MAX(UNNEST(override_executions).concurrency) FROM df_tests GROUP BY test_id", + "SELECT * FROM df_a JOIN df_b ON test_id = test_id ORDER BY test_id LIMIT 100 OFFSET 0", + "SELECT * FROM df_a UNION ALL SELECT * FROM df_b ORDER BY test_id LIMIT 100 OFFSET 0", + "SELECT TO_JSON(configuration) FROM df_tests ORDER BY test_id LIMIT 10 OFFSET 0", + ], + }, + "troubleshooting_hints": [ + "If you get ambiguous column errors, alias every table/CTE and qualify join keys.", + "If nested field access fails, move expansion into a dedicated CTE using UNNEST.", + "If a query is too complex, split it into 2-4 CTE stages and validate each stage independently.", + "If an inferred function fails, verify against supported_functions and unsupported_functions.", + "If aggregation results over nested lists do not reflect expected values, check that you are using UNNEST and aggregation (MAX, MIN, etc.) correctly, and that you compare against the scalar field with GREATEST/LEAST.", + ], + "notes": [ + "All in-memory dataframe tables are queryable in the same SQL context.", + "Function name typo seen in some sources: STRPOST; use STRPOS.", + "This help defines practical usage constraints for BlazeMeter MCP SQL queries.", + ], + "references": [ + "https://docs.pola.rs/py-polars/html/reference/sql/index.html", + "https://docs.pola.rs/py-polars/html/reference/sql/functions/index.html", + "https://docs.pola.rs/py-polars/html/reference/sql/clauses.html", + "https://docs.pola.rs/py-polars/html/reference/sql/table_operations.html", + "https://docs.pola.rs/py-polars/html/reference/sql/set_operations.html" + ], + } + diff --git a/tools/execution_manager.py b/tools/execution_manager.py index d48cbbe..33c15bd 100644 --- a/tools/execution_manager.py +++ b/tools/execution_manager.py @@ -25,7 +25,8 @@ from models.result import BaseResult from tools import bridge, search_utils from tools.report_manager import ReportManager -from tools.utils import api_request, timeout, user_agent, format_sanitized_traceback, require_confirmation, Operations +from tools.utils import api_request, timeout, user_agent, format_sanitized_traceback, require_confirmation, Operations, \ + run_as_task, normalize_action_args, validate_required_args, tool_result class ExecutionManager(Manager): @@ -97,12 +98,9 @@ def _handle_analyzer_http_error(self, e: httpx.HTTPStatusError, execution_id: in except: return BaseResult(error=f"HTTP {status_code}: {e.response.text[:200]}") - @require_confirmation(operation=Operations.CREATE) - async def start(self, test_id: Optional[int], delayed_start_ready: bool = True, + @run_as_task() + async def start(self, test_id: int, delayed_start_ready: bool = True, is_debug_run: bool = False) -> BaseResult: - if not isinstance(test_id, int) or test_id < 1: - return BaseResult(error="Missing or invalid required argument 'test_id'. Expected integer.") - # Check if it's valid or allowed test_result = await bridge.read_test(self.token, self.ctx, test_id) if test_result.error: @@ -119,10 +117,8 @@ async def start(self, test_id: Optional[int], delayed_start_ready: bool = True, json=start_body ) - async def read(self, execution_id: Optional[int]) -> BaseResult: - if not isinstance(execution_id, int) or execution_id < 1: - return BaseResult(error="Missing or invalid required argument 'execution_id'. Expected integer.") - + @run_as_task() + async def read(self, execution_id: int) -> BaseResult: execution_response = await api_request( self.token, "GET", @@ -183,12 +179,8 @@ def _get_execution_status_context() -> str: "When it is archived, it is not possible to read the detailed execution information.\n" ) - async def list(self, test_id: Optional[int], limit: int = 50, offset: int = 0) -> BaseResult: - if not isinstance(test_id, int) or test_id < 1: - return BaseResult(error="Missing or invalid required argument 'test_id'. Expected integer.") - if not isinstance(limit, int) or not isinstance(offset, int): - return BaseResult(error="Invalid arguments 'limit'/'offset'. Expected integers.") - + @run_as_task() + async def list(self, test_id: int, limit: int = 50, offset: int = 0) -> BaseResult: test_result = await bridge.read_test(self.token, self.ctx, test_id) if test_result.error: return test_result @@ -229,10 +221,8 @@ async def search_filter_values(self, account_id: int, filter_names: List[str]) - return await search_utils.test_execution_search_filter_values("master", account_id, self.token, filter_names) - async def ai_analysis(self, execution_id: Optional[int]) -> BaseResult: - if not isinstance(execution_id, int) or execution_id < 1: - return BaseResult(error="Missing or invalid required argument 'execution_id'. Expected integer.") - + @run_as_task() + async def ai_analysis(self, execution_id: int) -> BaseResult: execution_response = await self.read(execution_id) if execution_response.error: return execution_response @@ -442,13 +432,13 @@ def register(mcp, token: Optional[BzmToken]): args(dict): Dictionary with the following required parameters: execution_id (int): The execution ID to get the information. - list: List all executions for a test ID. - args(dict): Dictionary with the following required parameters: - test_id (int): The id of the test to list the execution from - limit (int, default=10, valid=[1 to 50]): The number of test executions to list. - offset (int, default=0): Number of test executions to skip. -- search: Search all executions - args(dict): Dictionary with the following optional filter parameters: - account_id (int, mandatory): The id of the account to use. + args(dict): Dictionary with the following parameters: + test_id (int, required): The id of the test whose executions to list. + limit (int, optional, default=50, valid=[1 to 50 when result_format=auto/raw, 1000 when result_format=dataframe]): Max executions to return. + offset (int, optional, default=0): Number of executions to skip. +- search: Search executions (master reports). Requires an account scope; all other filters are optional. + args(dict): Dictionary with the following parameters: + account_id (int, required): Account id for the search scope. execution_name (str): A case and diacritic insensitive search (ilike) for execution name (also known as report name). workspace_id_list (list[int], values= use first search_filter_values tool with 'workspace_id_list'): The workspace IDs to filter the execution results. time_frame (str, default='latest', values['latest','last24','lastWeek','lastMonth', 'custom']): @@ -465,8 +455,8 @@ def register(mcp, token: Optional[BzmToken]): virtual_users_list (list[dict[str,str]], values = use first search_filter_values tool with 'virtual_users_list'): The number of virtual users filter, operator as key and value as value. Example: [{">=", 5}]. page_index (int, default=1), The current page number. If the result mention has_next_page in true, asks the user if they want to see the next page. - search_filter_values: List the values needed for search filters - args(dict): Dictionary with the following required filter parameters: - account_id (int, mandatory): The id of the account to use. + args(dict): Dictionary with the following required parameters: + account_id (int, required): The id of the account to use. filter_names (list[str], values=['workspace_id_list', 'cloud_provider_name_list', 'created_by_id_list', 'locations_id_list', 'project_id_list', 'duration_list', 'number_of_engines_list', 'virtual_users_list']): The filter name list. - read_summary: get the summary report for a given execution ID. args(dict): Dictionary with the following required parameters: @@ -491,41 +481,87 @@ def register(mcp, token: Optional[BzmToken]): The action will check if the execution is running or finished, then either retrieve existing analysis status or create a new analysis entry. It provides dynamic responses indicating whether the analysis is ready or still processing. Hints: +- Optional result formatting in args: `result_format` = `auto` (default), `dataframe` (force dataframe), `raw` (disable dataframe materialization). - **CRITICAL**: Always follow the action schema exactly. If args are required, include args with exact names/types. """ ) - async def execution(action: str, args: Dict[str, Any], ctx: Context) -> BaseResult: + @tool_result() + async def execution(arguments: Dict[str, Any] = None, ctx: Context = None) -> BaseResult: + action, args = normalize_action_args(arguments) + if not action: + return BaseResult(error="Missing required argument 'action' within tool arguments.") execution_manager = ExecutionManager(token, ctx) report_manager = ReportManager(token, ctx) try: match action: case "start": + if validation_error := validate_required_args(action, args, ["test_id"]): + return validation_error return await execution_manager.start(args.get("test_id")) case "read": + if validation_error := validate_required_args(action, args, ["execution_id"]): + return validation_error return await execution_manager.read(args.get("execution_id")) case "list": + if validation_error := validate_required_args(action, args, ["test_id"]): + return validation_error return await execution_manager.list( args.get("test_id"), args.get("limit", 50), args.get("offset", 0), ) case "search": + if validation_error := validate_required_args(action, args, ["account_id"]): + return validation_error return await execution_manager.search(args) case "search_filter_values": - return await execution_manager.search_filter_values(args.get("account_id"), - args.get("filter_names", [])) + if validation_error := validate_required_args(action, args, ["account_id", "filter_names"]): + return validation_error + filter_names = args.get("filter_names") + if not isinstance(filter_names, list) or len(filter_names) == 0: + return BaseResult( + error=( + f"Missing required args for action '{action}': filter_names must be a non-empty list " + f"within 'args'. Required args: filter_names (list[str], non-empty)." + ) + ) + return await execution_manager.search_filter_values( + args.get("account_id"), filter_names + ) case "read_summary": + if validation_error := validate_required_args(action, args, ["execution_id"]): + return validation_error return await report_manager.read_summary(args.get("execution_id")) case "read_errors": + if validation_error := validate_required_args(action, args, ["execution_id"]): + return validation_error return await report_manager.read_error(args.get("execution_id")) case "read_request_stats": + if validation_error := validate_required_args(action, args, ["execution_id"]): + return validation_error return await report_manager.read_request_stats(args.get("execution_id")) case "read_all_reports": - return await execution_manager.read_all_reports(args.get("execution_id")) + if validation_error := validate_required_args(action, args, ["execution_id"]): + return validation_error + execution_id = args.get("execution_id") + summary_result = await report_manager.read_summary(execution_id) + error_result = await report_manager.read_error(execution_id) + stats_result = await report_manager.read_request_stats(execution_id) + return BaseResult( + result=[{ + "summary": summary_result.result or None, + "error": error_result.result or None, + "request_stats": stats_result.result or None + }] + ) case "read_anomalies_stats": + if validation_error := validate_required_args(action, args, ["execution_id"]): + return validation_error return await report_manager.read_anomalies_stats(args.get("execution_id")) case "ai_analysis": + if validation_error := validate_required_args(action, args, ["execution_id"]): + return validation_error return await execution_manager.ai_analysis(args.get("execution_id")) case _: return BaseResult( diff --git a/tools/help_manager.py b/tools/help_manager.py index 6de307d..9016968 100644 --- a/tools/help_manager.py +++ b/tools/help_manager.py @@ -29,7 +29,8 @@ from models.manager import Manager from models.result import BaseResult from tools.help_utils import convert_js_to_py_dict -from tools.utils import http_request, format_sanitized_traceback +from tools.utils import http_request, format_sanitized_traceback, run_as_task, tool_result, normalize_action_args, \ + execute_batch_calls, validate_required_args, validate_non_empty_str_arg class HelpManager(Manager): @@ -37,7 +38,6 @@ class HelpManager(Manager): help_items_index = {} help_index_nodes = {} help_content_cache = {} - MAX_BATCH_CONCURRENCY = 100 CONTENT_TRUST = "trusted" CONTENT_TRUST_NOTE = ( "Help content is sourced from curated BlazeMeter documentation domains and is trusted by design." @@ -136,6 +136,7 @@ async def fetch_chunk(chunk_url: str): help_tree['root_category'] = help_tree.pop('') # Assign a name to the root category HelpManager.help_tree = help_tree + @run_as_task() async def list_help_categories(self) -> BaseResult: if HelpManager.help_tree is None: await self._load_help_tree() @@ -151,11 +152,8 @@ async def list_help_categories(self) -> BaseResult: info=["A list of subcategories is provided for each category"] ) + @run_as_task() async def list_help_category_content(self, category_id: str, subcategory_id_list: List[str]) -> BaseResult: - if not isinstance(subcategory_id_list, list) or not subcategory_id_list: - return BaseResult( - error="Missing required argument 'subcategory_id_list'. Please provide a non-empty list." - ) if HelpManager.help_tree is None: await self._load_help_tree() results = [] @@ -229,6 +227,7 @@ async def get_help_object(category_id: str, subcategory_id: str, help_id: str) - return help_object + @run_as_task() async def read_help_info(self, category_id: str, subcategory_id: str, help_id_list: List[str]) -> BaseResult: if not isinstance(help_id_list, list) or not help_id_list: return BaseResult( @@ -255,6 +254,25 @@ async def read_help_info(self, category_id: str, subcategory_id: str, help_id_li def register(mcp, token: Optional[BzmToken]): + async def _dispatch_batch_help(batch_calls: Any, ctx: Context): + async def _process_help_batch_call(call: Dict[str, Any]) -> BaseResult: + if not isinstance(call, dict): + return BaseResult(error="Each batch call must be a dict with 'action' and optional 'args'.") + sub_action = call.get("action", "") + raw_sub_args = call.get("args", {}) + sub_args = dict(raw_sub_args) if isinstance(raw_sub_args, dict) else {} + try: + return await help_main({"action": sub_action, "args": sub_args}, ctx) + except httpx.HTTPStatusError: + return BaseResult(error=f"HTTP error in sub-action {sub_action}: {format_sanitized_traceback()}") + except Exception: + return BaseResult(error=f"Error in sub-action {sub_action}: {format_sanitized_traceback()}") + + return await execute_batch_calls( + batch_calls, + _process_help_batch_call, + ) + @mcp.tool( name=f"{TOOLS_PREFIX}_help", description=""" @@ -264,83 +282,84 @@ def register(mcp, token: Optional[BzmToken]): Actions: - list_help_categories: List all category_ids and for each of them list their subcategory_ids. - list_help_category_content: List all help_id list related with a category_id and subcategory_id. + args(dict): Dictionary with the following parameters: + category_id (str, optional, default when omitted: `home`): The category id. + subcategory_id_list (List[str], required): The subcategory id list. +- read_help_info: Read the content of a help_id providing category_id, subcategory_id and help_id_list. args(dict): Dictionary with the following required parameters: - category_id (str): The category id. - subcategory_id_list (List[str]): The subcategory id list. -- read_help_info: Read the content of a help_id providing category_id, subcategory_id and help_id - args(dict): Dictionary with the following required parameters: - category_id (str): The category id. - subcategory_id (str): The sub-category id. - help_id_list (List[str]): The help id list to read. + category_id (str): The category id (non-empty after trim). + subcategory_id (str): The sub-category id (required; may be empty string to resolve as `self` in the help tree). + help_id_list (List[str]): Non-empty list of help ids to read. - batch: Execute multiple actions in one call. args(dict): Dictionary with the following required parameters: batch_calls (List[Dict]): List of Actions dictionaries (excluding the action batch), each with 'action' (str) and 'args' (Dict). Hints: - Always generates the url attributes as a link in markdown format (like command_url). - **CRITICAL**: For multiple actions, always use the 'batch' action. +- **IMPORTANT**: `batch` sub-actions execute directly (no forced task mode); responses are returned inline in this same call. +- Optional result formatting in args: `result_format` = `auto` (default), `dataframe` (force dataframe), `raw` (disable dataframe materialization). - **CRITICAL**: Always follow the action schema exactly. If args are required, include args with exact names/types. """ ) + @tool_result(excluded_actions={"batch"}) async def help_main( - action: str = Field(description="The action id to execute"), - args: Dict[str, Any] = Field(description="Dictionary with parameters", default=None), + arguments: Dict[str, Any] = Field( + description="Tool arguments: action, args, and any action-specific params", default=None), ctx: Context = Field(description="Context object providing access to MCP capabilities") ) -> BaseResult: - if args is None: - args = {} - + action, args = normalize_action_args(arguments) + if not action: + return BaseResult(error="Missing required argument 'action' within tool arguments.") help_manager = HelpManager(token, ctx) try: match action: case "list_help_categories": return await help_manager.list_help_categories() case "list_help_category_content": + if validation_error := validate_required_args(action, args, ["subcategory_id_list"]): + return validation_error + category_raw = args.get("category_id", "home") + category_id = ( + category_raw.strip() + if isinstance(category_raw, str) and category_raw.strip() + else "home" + ) return await help_manager.list_help_category_content( - args.get("category_id", "home"), - args.get("subcategory_id_list") + category_id, + args.get("subcategory_id_list", []), ) case "read_help_info": + if validation_error := validate_required_args( + action, args, ["category_id", "subcategory_id", "help_id_list"] + ): + return validation_error + if err := validate_non_empty_str_arg(action, args, "category_id"): + return err + subcategory_id = args.get("subcategory_id") + help_id_list = args.get("help_id_list") + if not isinstance(subcategory_id, str): + return BaseResult( + error=( + f"Missing required args for action '{action}': subcategory_id must be a string " + f"within 'args'. Required args: subcategory_id." + ) + ) + if not isinstance(help_id_list, list) or len(help_id_list) == 0: + return BaseResult( + error=( + f"Missing required args for action '{action}': help_id_list must be a non-empty list " + f"within 'args'. Required args: help_id_list (list[str], non-empty)." + ) + ) return await help_manager.read_help_info( - args.get("category_id", "home"), - args.get("subcategory_id", ""), - args.get("help_id_list") + str(args.get("category_id")).strip(), subcategory_id, help_id_list ) case "batch": # Make sure this initialization doesn't run in parallel if HelpManager.help_tree is None: await help_manager._load_help_tree() - batch_calls = args.get("batch_calls", []) - if not isinstance(batch_calls, list) or not batch_calls: - return BaseResult( - error="batch_calls must be a non-empty list of dicts with 'action' and 'args'") - - semaphore = asyncio.Semaphore(HelpManager.MAX_BATCH_CONCURRENCY) - - async def process_call(call: Dict[str, Any]) -> BaseResult | List[BaseResult]: - sub_action = call.get("action", "") - sub_args = call.get("args", {}) - async with semaphore: - try: - # Recursively call the help function itself - return await help_main(sub_action, sub_args, ctx) - except httpx.HTTPStatusError: - return BaseResult( - error=f"HTTP error in sub-action {sub_action}: {format_sanitized_traceback()}" - ) - except Exception: - return BaseResult( - error=f"Error in sub-action {sub_action}: {format_sanitized_traceback()}\n{SUPPORT_MESSAGE}") - - # Parallel execution with asyncio.gather - results = await asyncio.gather(*[process_call(call) for call in batch_calls], - return_exceptions=True) - # Handle any exceptions returned - processed_results = [ - r if not isinstance(r, Exception) else BaseResult(error=f"Unhandled exception: {str(r)}") - for r in results - ] - return BaseResult(result=processed_results) + return await _dispatch_batch_help(args.get("batch_calls", []), ctx) case _: return BaseResult( error=f"Action {action} not found in help manager tool" diff --git a/tools/project_manager.py b/tools/project_manager.py index fc0d1af..09fce2c 100644 --- a/tools/project_manager.py +++ b/tools/project_manager.py @@ -24,7 +24,8 @@ from models.manager import Manager from models.result import BaseResult from tools import bridge -from tools.utils import api_request, format_sanitized_traceback +from tools.utils import api_request, format_sanitized_traceback, ttl_cache_method, run_as_task, normalize_action_args, \ + tool_result, validate_required_args class ProjectManager(Manager): @@ -32,10 +33,9 @@ class ProjectManager(Manager): def __init__(self, token: Optional[BzmToken], ctx: Context): super().__init__(token, ctx) - async def read(self, project_id: Optional[int]) -> BaseResult: - if not isinstance(project_id, int) or project_id < 1: - return BaseResult(error="Missing or invalid required argument 'project_id'. Expected integer.") - + @ttl_cache_method(ttl_seconds=30) + @run_as_task() + async def read(self, project_id: int, include_tests_count: bool = True) -> BaseResult: project_result = await api_request( self.token, "GET", @@ -52,15 +52,13 @@ async def read(self, project_id: Optional[int]) -> BaseResult: if workspace_result.error: return workspace_result - # Get the amount of test - project_element.tests_count = await bridge.count_project_tests(self.token, self.ctx, project_id) + if include_tests_count: + # Optional enrichment; can be disabled for fast hierarchy validation paths. + project_element.tests_count = await bridge.count_project_tests(self.token, self.ctx, project_id) return project_result - async def list(self, workspace_id: Optional[int], limit: int = 50, offset: int = 0) -> BaseResult: - if not isinstance(workspace_id, int) or workspace_id < 1: - return BaseResult(error="Missing or invalid required argument 'workspace_id'. Expected integer.") - if not isinstance(limit, int) or not isinstance(offset, int): - return BaseResult(error="Invalid arguments 'limit'/'offset'. Expected integers.") + @run_as_task() + async def list(self, workspace_id: int, limit: int = 50, offset: int = 0) -> BaseResult: # Check if it's valid or allowed workspace_result = await bridge.read_workspace(self.token, self.ctx, workspace_id) @@ -82,6 +80,7 @@ async def list(self, workspace_id: Optional[int], limit: int = 50, offset: int = params=parameters ) + def register(mcp, token: Optional[BzmToken]): @mcp.tool( name=f"{TOOLS_PREFIX}_project", @@ -90,27 +89,38 @@ def register(mcp, token: Optional[BzmToken]): Use this when a user needs to select a project for test allocation. Actions: - read: Read a Project. Obtain information about a particular project. - args(dict): Dictionary with the following required parameters: - project_id (int): The id of the project to get information. + args(dict): Dictionary with the following parameters: + project_id (int, required): The id of the project to get information. - list: List all projects. - args(dict): Dictionary with the following required parameters: - workspace_id (int): The id of the workspace to list projects from. - limit (int, default=10, valid=[1 to 50]): The number of projects to list. - offset (int, default=0): Number of projects to skip. + args(dict): Dictionary with the following parameters: + workspace_id (int, required): The id of the workspace to list projects from. + limit (int, optional, default=50, valid=[1 to 50 when result_format=auto/raw, 1000 when result_format=dataframe]): Max projects to return. + offset (int, optional, default=0): Number of projects to skip. Hints: - For a particular project, go directly to the read action (you don't need account or workspace information). - Reading also allows you to obtain the number of tests the project has without having to use a list to count. +- Optional result formatting in args: `result_format` = `auto` (default), `dataframe` (force dataframe), `raw` (disable dataframe materialization). - **CRITICAL**: Always follow the action schema exactly. If args are required, include args with exact names/types. """ ) - async def project(action: str, args: Dict[str, Any], ctx: Context) -> BaseResult: + @tool_result() + async def project(arguments: Dict[str, Any] = None, ctx: Context = None) -> BaseResult: + action, args = normalize_action_args(arguments) + if not action: + return BaseResult(error="Missing required argument 'action' within tool arguments.") project_manager = ProjectManager(token, ctx) try: match action: case "read": + if validation_error := validate_required_args(action, args, ["project_id"]): + return validation_error return await project_manager.read(args.get("project_id")) case "list": - return await project_manager.list(args.get("workspace_id"), args.get("limit", 10), args.get("offset", 0)) + if validation_error := validate_required_args(action, args, ["workspace_id"]): + return validation_error + limit = args.get("limit", 50) + offset = args.get("offset", 0) + return await project_manager.list(args.get("workspace_id"), limit, offset) case _: return BaseResult( error=f"Action {action} not found in project manager tool" diff --git a/tools/report_manager.py b/tools/report_manager.py index 8ea5161..c9a1c5e 100644 --- a/tools/report_manager.py +++ b/tools/report_manager.py @@ -28,7 +28,7 @@ from models.manager import Manager from models.result import BaseResult from tools import bridge -from tools.utils import api_request +from tools.utils import api_request, run_as_task EXECUTION_ARCHIVED_MSG = ("Execution report is archived. It is not possible to read execution " "information from an archived execution.") @@ -52,6 +52,7 @@ def _evaluate_archived(execution_result: BaseResult) -> bool: return (execution_result.result and len(execution_result.result) > 0 and execution_result.result[0].get("result").archived) + @run_as_task() async def read_summary(self, master_id: int): execution_result = await bridge.read_execution(self.token, self.ctx, master_id) if execution_result.error: @@ -76,14 +77,12 @@ async def read_summary(self, master_id: int): } ) - async def read_error(self, master_id: Optional[int]): + @run_as_task() + async def read_error(self, master_id: int): """ Get error report for a given master_id with formatted, AI-friendly structure. Includes execution metadata and explanatory context about error metrics. """ - if not isinstance(master_id, int) or master_id < 1: - return BaseResult(error="Missing or invalid required argument 'execution_id'. Expected integer.") - # Check if it's valid or allowed execution_result = await bridge.read_execution(self.token, self.ctx, master_id) if execution_result.error: @@ -106,14 +105,12 @@ async def read_error(self, master_id: Optional[int]): } ) - async def read_request_stats(self, master_id: Optional[int]): + @run_as_task() + async def read_request_stats(self, master_id: int): """ Get request statistics report for a given master_id with formatted, AI-friendly structure. Includes execution metadata and explanatory context about metrics per endpoint. """ - if not isinstance(master_id, int) or master_id < 1: - return BaseResult(error="Missing or invalid required argument 'execution_id'. Expected integer.") - # Check if it's valid or allowed execution_result = await bridge.read_execution(self.token, self.ctx, master_id) if execution_result.error: @@ -137,16 +134,14 @@ async def read_request_stats(self, master_id: Optional[int]): } ) - async def read_anomalies_stats(self, master_id: Optional[int]): + @run_as_task() + async def read_anomalies_stats(self, master_id: int): """ Get anomaly statistics for a given master_id (test execution). Returns a structured report: no anomalies, full per-anomaly details when permitted, or statistics_unavailable when the API returns no stats (e.g. account without anomaly access). """ - if not isinstance(master_id, int) or master_id < 1: - return BaseResult(error="Missing or invalid required argument 'execution_id'. Expected integer.") - execution_result = await bridge.read_execution(self.token, self.ctx, master_id) if execution_result.error: return execution_result diff --git a/tools/skills_manager.py b/tools/skills_manager.py index 0b222d1..7f24f33 100644 --- a/tools/skills_manager.py +++ b/tools/skills_manager.py @@ -25,7 +25,15 @@ from config.token import BzmToken from models.manager import Manager from models.result import BaseResult -from tools.utils import format_sanitized_traceback +from tools.utils import ( + format_sanitized_traceback, + run_as_task, + execute_batch_calls, + tool_result, + normalize_action_args, + validate_non_empty_str_arg, + validate_required_args, +) from tools.skills_utils import list_skills, read_skill_definition, read_skill_file, parse_skill_uri, \ is_skill_uri, list_skill_resources_uri @@ -35,7 +43,6 @@ class SkillsManager(Manager): skills = None # Static to share between different instance of SkillsManager - MAX_BATCH_CONCURRENCY = 100 CONTENT_TRUST = "trusted" CONTENT_TRUST_NOTE = ( "Skills content is sourced from curated repository resources and is trusted by design." @@ -44,8 +51,8 @@ class SkillsManager(Manager): def __init__(self, token: Optional[BzmToken], ctx: Context): super().__init__(token, ctx) - @staticmethod - async def list_skills() -> BaseResult: + @run_as_task() + async def list_skills(self) -> BaseResult: errors = [] if SkillsManager.skills is None: skills, errors = list_skills() @@ -58,13 +65,8 @@ async def list_skills() -> BaseResult: error=errors[0] if errors and len(errors) > 0 else None # Only the first error ) - @staticmethod - async def read_skill(skill_name: Optional[str]) -> BaseResult: - if not isinstance(skill_name, str) or not skill_name.strip(): - return BaseResult( - error="Missing required argument 'skill_name'. Please specify a non-empty skill name." - ) - skill_name = skill_name.strip() + @run_as_task() + async def read_skill(self, skill_name: str) -> BaseResult: skill_content, error = read_skill_definition(skill_name) # Trust policy note for future audits: # Skills and their resources are curated project artifacts and considered trusted by design. @@ -79,8 +81,7 @@ async def read_skill(skill_name: Optional[str]) -> BaseResult: error=error ) - @staticmethod - async def read_skill_file_path(skill_name: str, file_path: str) -> BaseResult: + async def read_skill_file_path(self, skill_name: str, file_path: str) -> BaseResult: skill_content, error = read_skill_file(skill_name, file_path) return BaseResult( result=[{ @@ -93,18 +94,9 @@ async def read_skill_file_path(skill_name: str, file_path: str) -> BaseResult: error=error ) - @staticmethod - async def list_skill_resources(skill_name: Optional[str]) -> BaseResult: - if not isinstance(skill_name, str) or not skill_name.strip(): - return BaseResult( - error="Missing required argument 'skill_name'. Please specify a non-empty skill name." - ) - skill_name = skill_name.strip() - try: - skill_resources = list_skill_resources_uri(skill_name) - except ValueError as e: - return BaseResult(error=str(e)) - + @run_as_task() + async def list_skill_resources(self, skill_name: str) -> BaseResult: + skill_resources = list_skill_resources_uri(skill_name) return BaseResult( result=[{ "skill_name": skill_name, @@ -116,13 +108,8 @@ async def list_skill_resources(skill_name: Optional[str]) -> BaseResult: has_more=False, ) - @staticmethod - async def read_skill_resource_uri(skill_uri: Optional[str]) -> BaseResult: - if not isinstance(skill_uri, str) or not skill_uri.strip(): - return BaseResult( - error="Missing required argument 'skill_resource_uri'. Please specify a non-empty skill URI." - ) - skill_uri = skill_uri.strip() + @run_as_task() + async def read_skill_resource_uri(self, skill_uri: str) -> BaseResult: if is_skill_uri(skill_uri): skill_name, file_path = parse_skill_uri(skill_uri) skill_content, error = read_skill_file(skill_name, file_path) @@ -141,14 +128,10 @@ async def read_skill_resource_uri(skill_uri: Optional[str]) -> BaseResult: error=f"Invalid Skill URI: {skill_uri}" ) - @staticmethod - async def read_skill_resource_uri_list(skill_uri_list: Optional[List[str]]) -> BaseResult: - if not isinstance(skill_uri_list, list) or not skill_uri_list: - return BaseResult( - error="Missing required argument 'skill_resource_uri_list'. Please provide a non-empty list of skill URIs." - ) + @run_as_task() + async def read_skill_resource_uri_list(self, skill_uri_list: List[str]) -> BaseResult: results = await asyncio.gather( - *(SkillsManager.read_skill_resource_uri(skill_uri) for skill_uri in skill_uri_list) + *(self.read_skill_resource_uri(skill_uri) for skill_uri in skill_uri_list) ) return BaseResult( result=results, @@ -165,6 +148,25 @@ def universal_skills_handler(skill_name: str, path: str) -> str: return error return content + async def _dispatch_batch_skills(batch_calls: Any, ctx: Context): + async def _process_skills_batch_call(call: Dict[str, Any]) -> BaseResult: + if not isinstance(call, dict): + return BaseResult(error="Each batch call must be a dict with 'action' and optional 'args'.") + sub_action = call.get("action", "") + raw_sub_args = call.get("args", {}) + sub_args = dict(raw_sub_args) if isinstance(raw_sub_args, dict) else {} + try: + return await skills({"action": sub_action, "args": sub_args}, ctx) + except httpx.HTTPStatusError: + return BaseResult(error=f"HTTP error in sub-action {sub_action}: {format_sanitized_traceback()}") + except Exception: + return BaseResult(error=f"Error in sub-action {sub_action}: {format_sanitized_traceback()}") + + return await execute_batch_calls( + batch_calls, + _process_skills_batch_call, + ) + @mcp.tool( name=f"{TOOLS_PREFIX}_skills", description=""" @@ -175,78 +177,78 @@ def universal_skills_handler(skill_name: str, path: str) -> str: - list_skills: List all the Skills available to learn. - read_skill: Read detailed information about a specific skill_name. args(dict): Dictionary with the following required parameters: - skill_name (str): The skill name. + skill_name (str, required, non-empty): The skill name. - list_skill_resources: List all the Skills Resources available to learn. args(dict): Dictionary with the following required parameters: - skill_name (str): The skill name. + skill_name (str, required, non-empty): The skill name. - read_skill_resource_uri: Read file content based on a Skill Resource URI (blazemeter-skill-{skill_name}://{resource_path}). args(dict): Dictionary with the following required parameters: - skill_resource_uri (str): The skill URI. + skill_resource_uri (str, required, non-empty): The skill URI. - read_skill_resource_uri_list: Read file content based on a Skill Resource URI list (['blazemeter-skill-{skill_name}://{resource_path}', ...]). args(dict): Dictionary with the following required parameters: - skill_resource_uri_list (List[str]): The skill URI list. + skill_resource_uri_list (List[str], required, non-empty): The skill URI list. - batch: Execute multiple actions in one call. args(dict): Dictionary with the following required parameters: batch_calls (List[Dict]): List of Actions dictionaries (excluding the action batch), each with 'action' (str) and 'args' (Dict). Hints: - Always generates the url attributes as a link in markdown format (like command_url). - **CRITICAL**: For multiple actions, always use the 'batch' action. +- **IMPORTANT**: `batch` sub-actions execute directly (no forced task mode); responses are returned inline in this same call. +- Optional result formatting in args: `result_format` = `auto` (default), `dataframe` (force dataframe), `raw` (disable dataframe materialization). - **CRITICAL**: Always follow the action schema exactly. If args are required, include args with exact names/types. """ ) + @tool_result(excluded_actions={"batch"}) async def skills( - action: str = Field(description="The action id to execute"), - args: Dict[str, Any] = Field(description="Dictionary with parameters", default=None), + arguments: Dict[str, Any] = Field(description="Tool arguments: action, args, and any action-specific params", default=None), ctx: Context = Field(description="Context object providing access to MCP capabilities") ) -> BaseResult: - if args is None: - args = {} - + action, args = normalize_action_args(arguments) + if not action: + return BaseResult(error="Missing required argument 'action' within tool arguments.") skills_manager = SkillsManager(token, ctx) try: match action: case "list_skills": return await skills_manager.list_skills() case "read_skill": - return await skills_manager.read_skill(args.get("skill_name")) + if validation_error := validate_required_args(action, args, ["skill_name"]): + return validation_error + if err := validate_non_empty_str_arg(action, args, "skill_name"): + return err + return await skills_manager.read_skill(str(args.get("skill_name")).strip()) case "list_skill_resources": - return await skills_manager.list_skill_resources(args.get("skill_name")) + if validation_error := validate_required_args(action, args, ["skill_name"]): + return validation_error + if err := validate_non_empty_str_arg(action, args, "skill_name"): + return err + return await skills_manager.list_skill_resources(str(args.get("skill_name")).strip()) case "read_skill_resource_uri": - return await skills_manager.read_skill_resource_uri(args.get("skill_resource_uri")) + if validation_error := validate_required_args(action, args, ["skill_resource_uri"]): + return validation_error + if err := validate_non_empty_str_arg(action, args, "skill_resource_uri"): + return err + return await skills_manager.read_skill_resource_uri( + str(args.get("skill_resource_uri")).strip() + ) case "read_skill_resource_uri_list": - return await skills_manager.read_skill_resource_uri_list(args.get("skill_resource_uri_list")) - case "batch": - batch_calls = args.get("batch_calls", []) - if not isinstance(batch_calls, list) or not batch_calls: + if validation_error := validate_required_args(action, args, ["skill_resource_uri_list"]): + return validation_error + skill_resource_uri_list = args.get("skill_resource_uri_list", []) + if ( + not isinstance(skill_resource_uri_list, list) + or len(skill_resource_uri_list) == 0 + ): return BaseResult( - error="batch_calls must be a non-empty list of dicts with 'action' and 'args'") - - semaphore = asyncio.Semaphore(SkillsManager.MAX_BATCH_CONCURRENCY) - - async def process_call(call: Dict[str, Any]) -> BaseResult | List[BaseResult]: - sub_action = call.get("action", "") - sub_args = call.get("args", {}) - async with semaphore: - try: - # Recursively call the skills function itself - return await skills(sub_action, sub_args, ctx) - except httpx.HTTPStatusError: - return BaseResult( - error=f"HTTP error in sub-action {sub_action}: {format_sanitized_traceback()}" - ) - except Exception: - return BaseResult( - error=f"Error in sub-action {sub_action}: {format_sanitized_traceback()}\n{SUPPORT_MESSAGE}") - - # Parallel execution with asyncio.gather - results = await asyncio.gather(*[process_call(call) for call in batch_calls], - return_exceptions=True) - # Handle any exceptions returned - processed_results = [ - r if not isinstance(r, Exception) else BaseResult(error=f"Unhandled exception: {str(r)}") - for r in results - ] - return BaseResult(result=processed_results) + error=( + f"Missing required args for action '{action}': skill_resource_uri_list must be a " + f"non-empty list within 'args'. Required args: skill_resource_uri_list (list[str], " + f"non-empty)." + ) + ) + return await skills_manager.read_skill_resource_uri_list(skill_resource_uri_list) + case "batch": + return await _dispatch_batch_skills(args.get("batch_calls", []), ctx) case _: return BaseResult( error=f"Action {action} not found in skills manager tool" diff --git a/tools/test_manager.py b/tools/test_manager.py index 1719d16..c2a1f19 100644 --- a/tools/test_manager.py +++ b/tools/test_manager.py @@ -37,7 +37,8 @@ from models.performance_test import PerformanceTestObject from models.result import BaseResult from tools import bridge, search_utils -from tools.utils import api_request, require_confirmation, Operations, format_sanitized_traceback +from tools.utils import api_request, require_confirmation, Operations, format_sanitized_traceback, run_as_task, \ + normalize_action_args, validate_non_empty_str_arg, validate_required_args, tool_result, ttl_cache_method logger = logging.getLogger(__name__) @@ -49,9 +50,9 @@ def __init__(self, token: Optional[BzmToken], ctx: Context): super().__init__(token, ctx) self.path_mapper = PathMapperFactory.create_strategy() - async def read(self, test_id: Optional[int]) -> BaseResult: - if not isinstance(test_id, int) or test_id < 1: - return BaseResult(error="Missing or invalid required argument 'test_id'. Expected integer.") + @ttl_cache_method(ttl_seconds=30) + @run_as_task() + async def read(self, test_id: int) -> BaseResult: test_result = await api_request( self.token, @@ -70,11 +71,8 @@ async def read(self, test_id: Optional[int]) -> BaseResult: return test_result @require_confirmation(operation=Operations.CREATE) - async def create(self, test_name: Optional[str], project_id: Optional[int]) -> BaseResult: - if not isinstance(test_name, str) or not test_name.strip(): - return BaseResult(error="Missing or invalid required argument 'test_name'. Expected non-empty string.") - if not isinstance(project_id, int) or project_id < 1: - return BaseResult(error="Missing or invalid required argument 'project_id'. Expected integer.") + @run_as_task() + async def create(self, test_name: str, project_id: int) -> BaseResult: # Check if it's valid or allowed project_result = await bridge.read_project(self.token, self.ctx, project_id) @@ -100,10 +98,8 @@ async def create(self, test_name: Optional[str], project_id: Optional[int]) -> B ) @require_confirmation(operation=Operations.DELETE) - async def delete(self, test_id: Optional[int]) -> BaseResult: - if not isinstance(test_id, int) or test_id < 1: - return BaseResult(error="Missing or invalid required argument 'test_id'. Expected integer.") - + @run_as_task() + async def delete(self, test_id: int) -> BaseResult: test_result = await self.read(test_id) if test_result.error: return test_result @@ -170,13 +166,9 @@ def _process_upload_results(upload_results: List[Dict[str, Any]], valid_files: L }) @require_confirmation(operation=Operations.CREATE) - async def upload_assets(self, test_id: Optional[int], file_paths: Optional[List[str]], - main_script: Optional[str] = None) -> Dict[ + @run_as_task() + async def upload_assets(self, test_id: int, file_paths: List[str], main_script: Optional[str] = None) -> Dict[ str, Any]: - if not isinstance(test_id, int) or test_id < 1: - return {"error": "Missing or invalid required argument 'test_id'. Expected integer."} - if not isinstance(file_paths, list) or not file_paths: - return {"error": "Missing or invalid required argument 'file_paths'. Expected non-empty list."} # Check if it's valid or allowed test_data = await self.read(test_id) @@ -325,32 +317,67 @@ def _get_script_type(file_name: str) -> str: return script_types.get(extension, 'unknown') - async def list(self, project_id: Optional[int], limit: int = 50, - offset: int = 0, control_ai_consent: bool = True) -> BaseResult: - if not isinstance(project_id, int) or project_id < 1: - return BaseResult(error="Missing or invalid required argument 'project_id'. Expected integer.") - if not isinstance(limit, int) or not isinstance(offset, int): - return BaseResult(error="Invalid arguments 'limit'/'offset'. Expected integers.") + @run_as_task() + async def list( + self, + project_id_list: List[int], + limit: int = 50, + offset: int = 0, + control_ai_consent: bool = True + ) -> BaseResult: + if not project_id_list or not isinstance(project_id_list, list): + return BaseResult( + error="Missing required args for action 'list': project_id_list must be a non-empty list of project IDs.") + project_ids = [int(pid) for pid in project_id_list if pid is not None] + if not project_ids: + return BaseResult( + error="Missing required args for action 'list': project_id_list must be a non-empty list of project IDs.") + + # Preserve order while deduplicating. + project_ids = list(dict.fromkeys(project_ids)) if control_ai_consent: - # Check if it's valid or allowed - project_result = await bridge.read_project(self.token, self.ctx, project_id) - if project_result.error: - return project_result - - parameters = { - "projectId": project_id, - "limit": limit, - "skip": offset, - "sort[]": "-updated" - } + for pid in project_ids: + project_result = await bridge.read_project(self.token, self.ctx, pid) + if project_result.error: + return project_result + + async def _list_project(pid: int) -> BaseResult: + parameters = { + "projectId": pid, + "limit": limit, + "skip": offset, + "sort[]": "-updated" + } + return await api_request( + self.token, + "GET", + f"{TESTS_ENDPOINT}", + result_formatter=format_tests, + params=parameters + ) - return await api_request( - self.token, - "GET", - f"{TESTS_ENDPOINT}", - result_formatter=format_tests, - params=parameters + if len(project_ids) == 1: + return await _list_project(project_ids[0]) + + responses = await asyncio.gather(*[_list_project(pid) for pid in project_ids]) + for response in responses: + if response.error: + return response + + merged_result = [] + has_more = False + for response in responses: + merged_result.extend(response.result or []) + has_more = has_more or bool(response.has_more) + + return BaseResult( + result=merged_result, + total=len(merged_result), + has_more=has_more, + info=[ + f"Merged tests list from {len(project_ids)} projects into one unified result." + ] ) async def search(self, args: dict[str, Any]) -> BaseResult: @@ -418,6 +445,7 @@ def _normalize_configuration_override(configuration: dict, test_data_override: d return test_data_override @require_confirmation(operation=Operations.UPDATE) + @run_as_task() async def configure(self, performance_test: PerformanceTestObject) -> BaseResult: if not performance_test.is_valid(): raise ValueError("PerformanceTestObject must have a valid test_id") @@ -494,20 +522,21 @@ def register(mcp, token: Optional[BzmToken]): When presenting failure_criteria to the user, use meta.general_labels, meta.rule_field_labels, meta.kpi_labels, and meta.condition_labels for readable text; avoid leading with raw kpi ids or op codes. - create: Create a new test. Do not create a test if the user has not confirmed the location for validation of workspace, project and account. args(dict): Dictionary with the following required parameters: - test_name (str): The required name of the test to create. - project_id (int): The id of the project to list tests from. + test_name (str, required, non-empty): The required name of the test to create. + project_id (int): The id of the project where the new test will be created. - delete: Delete a test. args(dict): Dictionary with the following required parameters: test_id (int): The only required parameter. The id of the test to be deleted. -- list: List all tests. - args(dict): Dictionary with the following required parameters: - project_id (int): The id of the project to list tests from. - limit (int, default=10, valid=[1 to 50]): The number of tests to list. - offset (int, default=0): Number of tests to skip. +- list: List all tests from one or more projects. Results are merged into a single unified payload. + args(dict): Dictionary with required and optional parameters: + project_id_list (list[int], required): List of project IDs to list and merge tests from. Must be non-empty. + project_id (int, optional): Accepted as alias for project_id_list when listing from a single project. Use project_id_list for multiple projects. + limit (int, optional, default=50, valid=[1 to 50 when result_format=auto/raw, 1000 when result_format=dataframe]): The number of tests to list. + offset (int, optional, default=0): Number of tests to skip. Each listed test may include failure_criteria; when describing it to the user, use meta labels like read (see read action). -- search: Search all executions - args(dict): Dictionary with the following optional filter parameters: - account_id (int, mandatory): The id of the account to use. +- search: Search tests (reports / executions metadata as exposed by the tests API). Requires an account scope; all other filters are optional. + args(dict): Dictionary with the following parameters: + account_id (int, required): Account id for the search scope. test_name (str): A case and diacritic insensitive search (ilike) for test name (also known as report name). workspace_id_list (list[int], values= use first search_filter_values tool with 'workspace_id_list'): The workspace IDs to filter the execution results. time_frame (str, default='latest', values['latest','last24','lastWeek','lastMonth', 'custom']): @@ -524,13 +553,13 @@ def register(mcp, token: Optional[BzmToken]): virtual_users_list (list[dict[str,str]], values = use first search_filter_values tool with 'virtual_users_list'): The number of virtual users filter, operator as key and value as value. Example: [{">=", 5}]. page_index (int, default=1), The current page number. If the result mention has_next_page in true, asks the user if they want to see the next page. - search_filter_values: List the values needed for search filters - args(dict): Dictionary with the following required filter parameters: - account_id (int, mandatory): The id of the account to use. + args(dict): Dictionary with the following required parameters: + account_id (int, required): The id of the account to use. filter_names (list[str], values=['workspace_id_list', 'cloud_provider_name_list', 'created_by_id_list', 'locations_id_list', 'project_id_list', 'duration_list', 'number_of_engines_list', 'virtual_users_list']): The filter name list. - configure_load: Configure the load of a test for the given test id. The test id is the only required parameter. The test will be configured based on the following parameters only if user confirms the configuration: args(dict): Dictionary with the following parameters: - test_id (int): The only required parameter. The id of the test to configure. + test_id (int, required): The id of the test to configure. iterations (int, default=1, infinite=-1): The number of iterations to run the test with. Don't use if hold-for is provided. hold-for (str, default=1m): The length of time the test will run at the peak concurrency. Values can be provided in m (minutes) only. Don't use if iterations is provided. concurrency (int, default=20, disable=0, max=500000): The number of concurrent virtual users simulated to run. For example, 20 will set the test to run with 20 concurrent users. To disable it set to 0. @@ -540,17 +569,17 @@ def register(mcp, token: Optional[BzmToken]): - configure_locations: Configure the distribution of a test for given test id. The test id is the only required parameter. The test will be configured based on the following parameters only if user confirms the configuration: args(dict): Dictionary with the following parameters: - test_id (int): The only required parameter. The id of the test to configure. + test_id (int, required): The id of the test to configure. locations (list[str]): List of all locations with their percentage distribution of user load in a key value format "location_id=percent_value". Example: ["us-east4-a=25", "us-east1-b=25", "us-west1-a=25", "us-central1-a=25"] - upload_assets: Upload main script test as well as multiple related assets to a test. Supports .zip, .csv, .jmx, .yaml and other file types. args(dict): Dictionary with the following required parameters: test_id (int): The id of the test to upload assets to. - file_paths (list): List of full file paths to upload. + file_paths (list[str], required, non-empty): List of full file paths to upload. main_script (str, optional): Path to the main script file. If provided, will update test configuration to use this script. - failure_criteria_meta: Read-only catalog: overview (layers), top_level_tool_args, rule_fields, general, general_labels, rule_field_labels, kpis, conditions. Field names align with reading and configuring tests. No BlazeMeter API call. args(dict): Optional; may be empty {}. Unknown keys are ignored. - configure_failure_criteria: Set failure criteria (BlazeMeter configuration.enableFailureCriteria and configuration.plugins.thresholds). Replaces the full rules list for the test. - args(dict): Dictionary with the following parameters: + args(dict): Dictionary with the following parameters (test_id, enabled, and rules are required keys; rules may be an empty list to clear): test_id (int): Required. The test id. enabled (bool): Required. Master switch for the Failure Criteria section (API enableFailureCriteria). rules (list): Required. List of rule objects; use an empty list to clear all rules. Each object may include: @@ -578,43 +607,105 @@ def register(mcp, token: Optional[BzmToken]): criteria_overridden_in_interface (bool, optional): Threshold-block metadata (maps to plugins.thresholds when merging); omit to preserve existing. Reading a test and configuring failure criteria use the same field names; BlazeMeter’s REST JSON is only used in HTTP calls inside the server. Hints: +- Optional result formatting in args: `result_format` = `auto` (default), `dataframe` (force dataframe), `raw` (disable dataframe materialization). - **CRITICAL**: Always follow the action schema exactly. If args are required, include args with exact names/types. - Before configure_failure_criteria, prefer failure_criteria_meta for kpi/condition codes and labels, then read if you must merge with existing rules. - For configure_failure_criteria, call read first and merge client-side if you must keep existing rules; providing rules replaces all criteria rows for that test. """ ) - async def tests(action: str, args: Dict[str, Any], ctx: Context) -> BaseResult: + @tool_result() + async def tests(arguments: Dict[str, Any] = None, ctx: Context = None) -> BaseResult: + action, args = normalize_action_args(arguments) + if not action: + return BaseResult(error="Missing required argument 'action' within tool arguments.") test_manager = TestManager(token, ctx) try: match action: case "read": + if validation_error := validate_required_args(action, args, ["test_id"]): + return validation_error return await test_manager.read(args.get("test_id")) case "create": - return await test_manager.create(args.get("test_name"), args.get("project_id")) + if validation_error := validate_required_args(action, args, ["test_name", "project_id"]): + return validation_error + if err := validate_non_empty_str_arg(action, args, "test_name"): + return err + return await test_manager.create( + str(args.get("test_name")).strip(), args.get("project_id") + ) case "delete": + if validation_error := validate_required_args(action, args, ["test_id"]): + return validation_error return await test_manager.delete(args.get("test_id")) case "list": - return await test_manager.list(args.get("project_id"), args.get("limit", 50), args.get("offset", 0)) + # Accept project_id (singular) as alias for project_id_list when listing from one project + if "project_id_list" not in args and args.get("project_id") is not None: + args = dict(args) + args["project_id_list"] = [args["project_id"]] + if validation_error := validate_required_args(action, args, ["project_id_list"]): + return validation_error + project_id_list = args.get("project_id_list") + if not isinstance(project_id_list, list) or not project_id_list: + return BaseResult( + error="Missing required args for action 'list': project_id_list must be a non-empty list of project IDs within 'args'. " + "Required args: project_id_list (list[int], non-empty)." + ) + return await test_manager.list( + project_id_list, + args.get("limit", 50), + args.get("offset", 0) + ) case "search": + if validation_error := validate_required_args(action, args, ["account_id"]): + return validation_error return await test_manager.search(args) case "search_filter_values": - return await test_manager.search_filter_values(args.get("account_id"), args.get("filter_names", [])) + if validation_error := validate_required_args(action, args, ["account_id", "filter_names"]): + return validation_error + filter_names = args.get("filter_names") + if not isinstance(filter_names, list) or len(filter_names) == 0: + return BaseResult( + error=( + f"Missing required args for action '{action}': filter_names must be a non-empty list " + f"within 'args'. Required args: filter_names (list[str], non-empty)." + ) + ) + return await test_manager.search_filter_values( + args.get("account_id"), filter_names + ) case "configure_load": + if validation_error := validate_required_args(action, args, ["test_id"]): + return validation_error performance_test = PerformanceTestObject.from_args(args) return await test_manager.configure(performance_test) case "configure_locations": + if validation_error := validate_required_args(action, args, ["test_id"]): + return validation_error performance_test = PerformanceTestObject.from_args(args) return await test_manager.configure(performance_test) case "upload_assets": + if validation_error := validate_required_args(action, args, ["test_id", "file_paths"]): + return validation_error + file_paths = args.get("file_paths") + if not isinstance(file_paths, list) or len(file_paths) == 0: + return BaseResult( + error=( + f"Missing required args for action '{action}': file_paths must be a non-empty list " + f"within 'args'. Required args: file_paths (list[str], non-empty)." + ) + ) upload_result = await test_manager.upload_assets( args.get("test_id"), args.get("file_paths"), - args.get("main_script"), - ) + args.get("main_script")) if isinstance(upload_result, dict) and upload_result.get("error"): return BaseResult(error=upload_result["error"]) return BaseResult(result=[upload_result]) case "configure_failure_criteria": + if validation_error := validate_required_args( + action, args, ["test_id", "enabled", "rules"] + ): + return validation_error return await test_manager.configure_failure_criteria(args) case "failure_criteria_meta": return await test_manager.failure_criteria_meta(args) diff --git a/tools/tools_manager.py b/tools/tools_manager.py new file mode 100644 index 0000000..7089ca5 --- /dev/null +++ b/tools/tools_manager.py @@ -0,0 +1,736 @@ +""" +Copyright 2025 Perforce Software, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" +import traceback +import asyncio +import time +import re +from typing import Any, Dict, Optional + +import httpx +from mcp.server.fastmcp import Context + +from config.blazemeter import TOOLS_PREFIX, SUPPORT_MESSAGE +from config.token import BzmToken +from models.manager import Manager +from models.result import BaseResult +from tools.async_task_manager import ( + cancel_task, + get_task_record, + is_active_status, + is_terminal_status, + list_tasks, + remove_task, + task_snapshot, +) +from tools.utils import ( + normalize_action_args, + tool_result, + validate_non_empty_str_arg, + validate_required_args, +) +from tools.dataframe_manager import ( + clear_dataframes, + get_dataframe_metadata, + get_sql_capabilities, + group_dataframe_schemas, + list_dataframes_metadata, + query_dataframes, + remove_dataframe, +) + + +class ToolsManager(Manager): + def __init__(self, token: Optional[BzmToken], ctx: Context): + super().__init__(token, ctx) + + @staticmethod + def _should_continue_polling(status: str) -> bool: + return status in {"parking", "working", "input_required"} + + @staticmethod + def _to_snake_case(value: str) -> str: + s1 = re.sub(r"(.)([A-Z][a-z]+)", r"\1_\2", value) + return re.sub(r"([a-z0-9])([A-Z])", r"\1_\2", s1).lower() + + @classmethod + def _operation_name(cls, action_payload: Dict[str, Any]) -> str: + manager = str(action_payload.get("manager", "tool")) + method = str(action_payload.get("method", "action")) + tool_name = manager[:-7] if manager.endswith("Manager") else manager + tool_name = cls._to_snake_case(tool_name) + return f"{tool_name}.{method}" + + @staticmethod + def _format_operation_value(value: Any) -> str: + if isinstance(value, str): + return repr(value) + return repr(value) + + @classmethod + def _operation_call_line(cls, action_payload: Dict[str, Any]) -> str: + # Kept for backwards compatibility with tests/helpers that may still use this. + return cls._operation_name(action_payload) + + @staticmethod + def _batch_summary_line() -> str: + records = list_tasks() + counts = { + "completed": 0, + "working": 0, + "parking": 0, + "failed": 0, + "cancelled": 0, + "input_required": 0, + } + for record in records: + status = str(record.status).lower() + if status in counts: + counts[status] += 1 + summary = ( + f"batch summary: total={len(records)} completed={counts['completed']} " + f"working={counts['working']} parking={counts['parking']} failed={counts['failed']}" + ) + if counts["cancelled"] > 0: + summary += f" cancelled={counts['cancelled']}" + if counts["input_required"] > 0: + summary += f" input_required={counts['input_required']}" + return summary + + @classmethod + def _task_status_line( + cls, + task_record, + poll_count: Optional[int], + elapsed_seconds: int, + next_poll_seconds: Optional[float], + window_seconds: Optional[float] = None, + include_polling_prefix: bool = True, + ) -> str: + operation = cls._operation_name(task_record.action) + prefix = "Polling " if include_polling_prefix else "" + line = f"{prefix}{task_record.task_id}[{operation}] ({task_record.status})" + if poll_count is not None: + line += f" attempt={poll_count}" + line += f" elapsed={elapsed_seconds}s" + if window_seconds is not None: + line += f"/{int(window_seconds)}s" + if next_poll_seconds is not None and cls._should_continue_polling(task_record.status): + line += f" next={int(next_poll_seconds)}s" + if task_record.status == "parking" and task_record.status_message: + line += f" note={repr(task_record.status_message)}" + return line + + @classmethod + def _polling_message( + cls, + task_record, + poll_count: int, + elapsed_seconds: int, + next_poll_seconds: float, + window_seconds: float, + ) -> str: + line = cls._task_status_line( + task_record=task_record, + poll_count=poll_count, + elapsed_seconds=elapsed_seconds, + next_poll_seconds=next_poll_seconds, + window_seconds=window_seconds, + include_polling_prefix=True, + ) + return f"{line} | {cls._batch_summary_line()}" + + @classmethod + def _polling_finished_message(cls, task_record, elapsed_seconds: int) -> str: + line = cls._task_status_line( + task_record=task_record, + poll_count=None, + elapsed_seconds=elapsed_seconds, + next_poll_seconds=None, + window_seconds=None, + include_polling_prefix=True, + ) + return f"{line} | {cls._batch_summary_line()}" + + async def tasks_get( + self, + task_id: str, + remove_on_terminal: bool = True, + wait_for_terminal_ms: int = 0, + poll_interval_ms: int = 1000 + ) -> BaseResult: + if wait_for_terminal_ms < 0: + return BaseResult(error="wait_for_terminal_ms must be greater than or equal to 0.") + if poll_interval_ms <= 0: + return BaseResult(error="poll_interval_ms must be greater than 0.") + + task_record = get_task_record(task_id) + if not task_record: + return BaseResult(error=f"Task ID {task_id} was not found.") + + polling_exhausted = False + if wait_for_terminal_ms > 0 and not is_terminal_status(task_record.status): + start_time = time.monotonic() + wait_seconds = wait_for_terminal_ms / 1000.0 + poll_seconds = poll_interval_ms / 1000.0 + attempt = 0 + + while True: + task_record = get_task_record(task_id) + if not task_record: + return BaseResult(error=f"Task ID {task_id} was not found.") + if is_terminal_status(task_record.status): + break + + elapsed = time.monotonic() - start_time + if elapsed >= wait_seconds: + polling_exhausted = True + break + + attempt += 1 + progress = min(100.0, (elapsed / wait_seconds) * 100.0) if wait_seconds > 0 else 100.0 + try: + await self.ctx.report_progress( + progress=progress, + total=100.0, + message=self._polling_message( + task_record=task_record, + poll_count=attempt, + elapsed_seconds=int(elapsed), + next_poll_seconds=poll_seconds, + window_seconds=wait_seconds, + ) + ) + except Exception: + pass + + remaining = wait_seconds - elapsed + await asyncio.sleep(min(poll_seconds, remaining)) + + try: + final_elapsed = min(wait_seconds, time.monotonic() - start_time) + await self.ctx.report_progress( + progress=100.0, + total=100.0, + message=self._polling_finished_message( + task_record=task_record, + elapsed_seconds=int(final_elapsed), + ) + ) + except Exception: + pass + + terminal = is_terminal_status(task_record.status) + snapshot = task_snapshot(task_record, include_result=terminal) + snapshot["should_continue_polling"] = self._should_continue_polling(task_record.status) + snapshot["next_poll_after_ms"] = poll_interval_ms if snapshot["should_continue_polling"] else 0 + + if terminal and remove_on_terminal: + remove_task(task_id) + return BaseResult( + result=[snapshot], + info=[ + "Task result retrieved successfully and removed automatically from the task registry. " + "It will not be available in subsequent queries." + ] + ) + + if terminal: + return BaseResult( + result=[snapshot], + info=[ + "Task result retrieved successfully and kept in the task registry. " + "Use tasks_remove to delete it when no longer needed." + ] + ) + + return BaseResult( + result=[snapshot], + info=[ + ( + "Task is still in progress after the polling window. Query tasks_status again in a few moments." + if polling_exhausted + else "Task is still in progress. Query tasks_status again in a few moments to check updated state." + ) + ] + ) + + async def tasks_status( + self, + task_id: str, + wait_for_terminal_ms: int = 0, + poll_interval_ms: int = 1000 + ) -> BaseResult: + if wait_for_terminal_ms < 0: + return BaseResult(error="wait_for_terminal_ms must be greater than or equal to 0.") + if poll_interval_ms <= 0: + return BaseResult(error="poll_interval_ms must be greater than 0.") + + task_record = get_task_record(task_id) + if not task_record: + return BaseResult(error=f"Task ID {task_id} was not found.") + + polling_exhausted = False + if wait_for_terminal_ms > 0 and not is_terminal_status(task_record.status): + start_time = time.monotonic() + wait_seconds = wait_for_terminal_ms / 1000.0 + poll_seconds = poll_interval_ms / 1000.0 + attempt = 0 + + while True: + task_record = get_task_record(task_id) + if not task_record: + return BaseResult(error=f"Task ID {task_id} was not found.") + if is_terminal_status(task_record.status): + break + + elapsed = time.monotonic() - start_time + if elapsed >= wait_seconds: + polling_exhausted = True + break + + attempt += 1 + progress = min(100.0, (elapsed / wait_seconds) * 100.0) if wait_seconds > 0 else 100.0 + try: + await self.ctx.report_progress( + progress=progress, + total=100.0, + message=self._polling_message( + task_record=task_record, + poll_count=attempt, + elapsed_seconds=int(elapsed), + next_poll_seconds=poll_seconds, + window_seconds=wait_seconds, + ) + ) + except Exception: + pass + + remaining = wait_seconds - elapsed + await asyncio.sleep(min(poll_seconds, remaining)) + + try: + final_elapsed = min(wait_seconds, time.monotonic() - start_time) + await self.ctx.report_progress( + progress=100.0, + total=100.0, + message=self._polling_finished_message( + task_record=task_record, + elapsed_seconds=int(final_elapsed), + ) + ) + except Exception: + pass + + snapshot = task_snapshot(task_record, include_result=False) + snapshot["should_continue_polling"] = self._should_continue_polling(task_record.status) + snapshot["next_poll_after_ms"] = poll_interval_ms if snapshot["should_continue_polling"] else 0 + info_message = ( + "Task is terminal. Use tasks_get to retrieve task_result when needed." + if is_terminal_status(task_record.status) + else ( + "Task is still in progress after the polling window. Query tasks_status again in a few moments." + if polling_exhausted + else "Task is still in progress. Query tasks_status again in a few moments to check updated state." + ) + ) + return BaseResult(result=[snapshot], info=[info_message]) + + async def tasks_list(self, status: Optional[str] = None, status_list: Optional[list[str]] = None) -> BaseResult: + filters = status_list if status_list else ([status] if status else None) + records = list_tasks(filters) + snapshots = [] + for record in records: + base_snapshot = task_snapshot(record, include_result=False) + snapshots.append( + { + "task_id": record.task_id, + "operation": self._operation_name(record.action), + "status": record.status, + "status_message": record.status_message, + "created_at": record.created_at, + "created_at_iso": base_snapshot["created_at_iso"], + "last_updated_at": record.last_updated_at, + "last_updated_at_iso": base_snapshot["last_updated_at_iso"], + "started_running_at": record.started_running_at, + "started_running_at_iso": base_snapshot["started_running_at_iso"], + "finished_at": record.finished_at, + "finished_at_iso": base_snapshot["finished_at_iso"], + "time_to_live_ms": record.time_to_live_ms, + } + ) + return BaseResult(result=snapshots, total=len(snapshots), has_more=False) + + async def tasks_cancel(self, task_id: str) -> BaseResult: + task_record = cancel_task(task_id) + if not task_record: + return BaseResult(error=f"Task ID {task_id} was not found.") + return BaseResult( + result=[task_snapshot(task_record, include_result=True)], + info=["Task cancellation was requested successfully."] + ) + + async def tasks_remove(self, task_id: str) -> BaseResult: + task_record = get_task_record(task_id) + if not task_record: + return BaseResult(error=f"Task ID {task_id} was not found.") + + cancel_requested = False + if is_active_status(task_record.status): + cancel_requested = True + cancel_task(task_id) + if task_record.asyncio_task: + try: + await asyncio.wait_for(asyncio.shield(task_record.asyncio_task), timeout=0.2) + except asyncio.TimeoutError: + pass + except asyncio.CancelledError: + pass + + snapshot = task_snapshot(task_record, include_result=is_terminal_status(task_record.status)) + removed = remove_task(task_id) + if not removed: + return BaseResult(error=f"Task ID {task_id} could not be removed.") + + info_message = ( + "Task was active. Cancellation was requested before removal." + if cancel_requested + else "Task was removed from task registry." + ) + return BaseResult( + result=[{ + "task_id": task_id, + "removed": True, + "cancel_requested": cancel_requested, + "task_snapshot": snapshot + }], + info=[info_message] + ) + + async def dataframes_list(self) -> BaseResult: + metadata = await list_dataframes_metadata(include_schema=False) + return BaseResult( + result=metadata, + total=len(metadata), + has_more=False, + info=[ + "Schema is omitted in dataframes_list to reduce payload size.", + "Use dataframes_schema_groups to compare shared/different schemas across dataframes.", + "Use dataframes_get for full metadata and schema of a specific dataframe.", + ], + ) + + async def dataframes_get(self, dataframe_id: str) -> BaseResult: + metadata = await get_dataframe_metadata(dataframe_id) + if not metadata: + return BaseResult( + error=f"Dataframe ID {dataframe_id} was not found. Use dataframes_list to discover available dataframes." + ) + return BaseResult(result=[metadata]) + + async def dataframes_schema_groups(self, dataframe_id_list: Optional[list[str]] = None) -> BaseResult: + grouped = await group_dataframe_schemas(dataframe_id_list) + mandatory_review_groups = [ + grp for grp in grouped.get("groups", []) + if isinstance(grp, dict) and str(grp.get("varying_columns", "")).strip() + ] + info_messages = [ + "Grouped dataframe schemas by top-level contract and per-column schema variations.", + "Dataframe ID lists are deduplicated in 'df_sets'; groups and variations reference them via 'df_ref'.", + "If dataframe_id_list is omitted, all current dataframes are included.", + ] + if mandatory_review_groups: + info_messages.append( + "CRITICAL: Column variations were detected. Perform mandatory detailed schema review for varying columns before the final query." + ) + info_messages.append( + "CRITICAL: Do not try-fast. Reason step-by-step: schema check → pattern selection → design → execute." + ) + info_messages.append( + "IMPORTANT: Before planning and executing the final dataframe query, call dataframes_sql_help synchronously in a separate call." + ) + return BaseResult( + result=[grouped], + info=info_messages, + ) + + async def dataframes_query( + self, + sql: str, + output_format: str = "matrix", + result_format: str = "auto", + ) -> BaseResult: + normalized_result_format = str(result_format or "auto").strip().lower() + effective_output_format = output_format + info_messages = [ + "Query executed successfully against the in-memory SQL context.", + "ORDER BY + LIMIT + OFFSET are mandatory in every dataframe query.", + "Use a prudent default page size of up to 100 rows (for example, LIMIT 100 OFFSET 0), then continue paging as needed.", + "ORDER BY + LIMIT + OFFSET provides deterministic pagination.", + ] + if normalized_result_format == "dataframe": + # For dataframe materialization, we must preserve raw row records. + effective_output_format = "records" + info_messages.append( + "When result_format=dataframe, dataframes_query uses records internally for dataframe storage and ignores output_format only for storage." + ) + query_response = query_dataframes(sql, output_format=effective_output_format) + if query_response.get("error"): + return BaseResult( + error=query_response["error"] + ) + return BaseResult( + result=query_response["result"], + total=query_response["rows"], + has_more=False, + info=info_messages + ) + + async def dataframes_remove(self, dataframe_id_list: list[str]) -> BaseResult: + if not dataframe_id_list or not isinstance(dataframe_id_list, list): + return BaseResult( + error="Missing required args for action 'dataframes_remove': dataframe_id_list must be a non-empty list of dataframe IDs." + ) + ids = [str(df_id).strip() for df_id in dataframe_id_list if str(df_id).strip()] + if not ids: + return BaseResult( + error="Missing required args for action 'dataframes_remove': dataframe_id_list must be a non-empty list of dataframe IDs." + ) + + # Preserve order while deduplicating IDs. + unique_ids = list(dict.fromkeys(ids)) + removed_count = 0 + removed_results = [] + missing_ids = [] + for df_id in unique_ids: + removed = await remove_dataframe(df_id) + removed_results.append({ + "dataframe_id": df_id, + "removed": removed, + }) + if removed: + removed_count += 1 + else: + missing_ids.append(df_id) + + if len(unique_ids) == 1 and removed_count == 0: + only_id = unique_ids[0] + return BaseResult( + error=f"Dataframe ID {only_id} was not found. Use dataframes_list to discover available dataframes." + ) + + info_messages = [ + f"Requested removal for {len(unique_ids)} dataframe(s). Removed: {removed_count}. Missing: {len(missing_ids)}." + ] + if removed_count > 0: + info_messages.append("Removed dataframes were unregistered from SQL context.") + if missing_ids: + info_messages.append( + "Some dataframe IDs were not found: " + ", ".join(missing_ids) + "." + ) + return BaseResult( + result=removed_results, + total=len(removed_results), + has_more=False, + info=info_messages + ) + + async def dataframes_clear(self) -> BaseResult: + removed_count = await clear_dataframes() + return BaseResult( + result=[{ + "removed_count": removed_count, + "remaining": 0 + }], + info=[ + "All in-memory dataframes were removed and unregistered from SQL context." + ] + ) + + async def dataframes_sql_help(self) -> BaseResult: + return BaseResult( + result=[get_sql_capabilities()], + info=[ + "Only read-only SQL is allowed in dataframe queries." + ] + ) + + +def register(mcp, token: Optional[BzmToken]): + @mcp.tool( + name=f"{TOOLS_PREFIX}_tools", + description=""" +Operations for asynchronous task lifecycle management. +Actions: +- tasks_get: Get task metadata by task ID and return task_result when task is terminal. + args(dict): Dictionary with required/optional parameters: + task_id (str, required, non-empty): The task id to query. + remove_on_terminal (bool, default=True): Removes task automatically if status is terminal. + wait_for_terminal_ms (int, default=0): Internal polling window in milliseconds. + poll_interval_ms (int, default=1000): Delay between polling attempts in milliseconds. +- tasks_status: Get lightweight task status by task ID (without task_result payload). + args(dict): Dictionary with required/optional parameters: + task_id (str, required, non-empty): The task id to query. + wait_for_terminal_ms (int, default=0): Internal polling window in milliseconds. + poll_interval_ms (int, default=1000): Delay between polling attempts in milliseconds. +- tasks_list: List tasks currently stored in the task registry. + args(dict): Dictionary with optional parameters: + status (str): Optional single status filter. + status_list (list[str]): Optional list of statuses to filter. +- tasks_cancel: Cancel a running/queued task. + args(dict): Dictionary with required parameters: + task_id (str, required, non-empty): The task id to cancel. +- tasks_remove: Remove a task from registry. + args(dict): Dictionary with required parameters: + task_id (str, required, non-empty): The task id to remove. +- dataframes_list: List all in-memory dataframes and their metadata. +- dataframes_get: Get dataframe metadata and schema by dataframe ID. Use this for detailed inspection of a specific dataframe after schema groups indicates differences or ambiguity. + args(dict): Dictionary with required parameters: + dataframe_id (str, required, non-empty): The dataframe id to query. +- dataframes_schema_groups: Group dataframe schemas hierarchically by top-level contract and per-column schema variants. This is the default first step for schema validation when a query involves 2 or more dataframes. + args(dict): Dictionary with optional parameters: + dataframe_id_list (list[str], optional): Specific dataframe IDs to analyze. If omitted, all dataframes are included. +- dataframes_query: Execute read-only SQL against all in-memory dataframe tables. + args(dict): Dictionary with required parameters: + sql (str, required, non-empty): SQL query string. Supports SELECT and WITH queries. + output_format (str, default="matrix", valid=["matrix", "columnar", "records"]): Query result shape. + result_format (str, optional): If set to "dataframe", query data is stored internally from records format; output_format is ignored only for storage. + precondition: Before planning any dataframe SQL query, call dataframes_sql_help first. + requirement: ORDER BY + LIMIT + OFFSET are mandatory in every query. + recommendation: Use a prudent default page size of up to 100 rows (for example, LIMIT 100 OFFSET 0), then continue paging as needed. +- dataframes_remove: Remove one or more dataframes from memory and SQL context. + args(dict): Dictionary with required parameters: + dataframe_id_list (list[str], required): List of dataframe IDs to remove. Must be non-empty. +- dataframes_clear: Remove all dataframes from memory and SQL context. +- dataframes_sql_help: Describe supported SQL usage and blocked operations. +Hints: +- **CRITICAL**: Always follow the action schema exactly. If args are required, include args with exact names/types. +- **CRITICAL**: Before writing any dataframe SQL query, call `dataframes_sql_help` first. +- **CRITICAL**: If the query involves 2 or more dataframes, call `dataframes_schema_groups` before any broad `dataframes_get` usage. +- **CRITICAL**: Use `dataframes_get` selectively for outliers or ambiguous fields detected by schema groups; do not scan all dataframes by default. +- **CRITICAL**: If the query touches nested/list fields, use the robust UNNEST -> aggregate -> join-back pattern in CTEs. No exception for single dataframe. Do not try direct nested access first. +- **CRITICAL**: Before SQL that touches nested/list: explicitly confirm "there are nested/list fields; I use the robust pattern." +- **CRITICAL**: For dataframe SQL: reason step-by-step before executing. Design your approach (schema check, nested check, pattern choice), verify against rules, then execute. Do not try-fast. +- Use dataframes_schema_groups to compare schema similarities/differences across multiple dataframes without repeating schema payload. +- Single dataframe (scalar-only): dataframes_sql_help -> dataframes_get -> dataframes_query. +- Single dataframe (nested/list): dataframes_sql_help -> dataframes_get -> staged CTE (UNNEST -> aggregate -> join-back) -> dataframes_query. Same robust pattern as multi-dataframe. +- Multi-dataframe query flow: dataframes_sql_help -> dataframes_schema_groups -> targeted dataframes_get -> dataframes_query. +- Multi-dataframe nested flow: dataframes_sql_help -> dataframes_schema_groups -> targeted dataframes_get -> staged CTE (UNNEST -> aggregate -> join-back) -> final query. +- Large tool results may be automatically materialized as in-memory dataframes and returned as references. +- Optional result formatting in args: `result_format` = `auto` (default), `dataframe` (force dataframe), `raw` (disable dataframe materialization). +- If you plan joins, filtering, sorting, grouping, or multi-step analysis, prefer `result_format=dataframe` and use `dataframes_query` instead of merging large inline results in AI context. +- ORDER BY + LIMIT + OFFSET are mandatory in every dataframe query. +- Use a prudent default page size of up to 100 rows (for example, LIMIT 100 OFFSET 0), then continue paging as needed. +- Use ORDER BY + LIMIT + OFFSET for deterministic pagination when reading dataframes. +- All registered dataframe tables are available in the same SQL context, including JOIN and UNION scenarios. +- When a dataframe is no longer needed, release memory using dataframes_remove or dataframes_clear. +- When a task is no longer needed (especially terminal tasks), release it from registry using tasks_remove. +""" + ) + @tool_result(excluded_actions={ + "tasks_list", + "tasks_status", + "dataframes_list", + "dataframes_get", + "dataframes_schema_groups", + "dataframes_query", + "dataframes_remove", + "dataframes_clear", + "dataframes_sql_help", + }) + async def tools(arguments: Dict[str, Any] = None, ctx: Context = None) -> BaseResult: + action, args = normalize_action_args(arguments) + if not action: + return BaseResult(error="Missing required argument 'action' within tool arguments.") + tools_manager = ToolsManager(token, ctx) + try: + match action: + case "tasks_get": + if validation_error := validate_required_args(action, args, ["task_id"]): + return validation_error + if err := validate_non_empty_str_arg(action, args, "task_id"): + return err + return await tools_manager.tasks_get( + args.get("task_id"), + args.get("remove_on_terminal", True), + args.get("wait_for_terminal_ms", 0), + args.get("poll_interval_ms", 1000) + ) + case "tasks_list": + return await tools_manager.tasks_list(args.get("status"), args.get("status_list")) + case "tasks_status": + if validation_error := validate_required_args(action, args, ["task_id"]): + return validation_error + if err := validate_non_empty_str_arg(action, args, "task_id"): + return err + return await tools_manager.tasks_status( + args.get("task_id"), + args.get("wait_for_terminal_ms", 0), + args.get("poll_interval_ms", 1000) + ) + case "tasks_cancel": + if validation_error := validate_required_args(action, args, ["task_id"]): + return validation_error + if err := validate_non_empty_str_arg(action, args, "task_id"): + return err + return await tools_manager.tasks_cancel(args.get("task_id")) + case "tasks_remove": + if validation_error := validate_required_args(action, args, ["task_id"]): + return validation_error + if err := validate_non_empty_str_arg(action, args, "task_id"): + return err + return await tools_manager.tasks_remove(args.get("task_id")) + case "dataframes_list": + return await tools_manager.dataframes_list() + case "dataframes_get": + if validation_error := validate_required_args(action, args, ["dataframe_id"]): + return validation_error + if err := validate_non_empty_str_arg(action, args, "dataframe_id"): + return err + return await tools_manager.dataframes_get(args.get("dataframe_id")) + case "dataframes_schema_groups": + return await tools_manager.dataframes_schema_groups(args.get("dataframe_id_list")) + case "dataframes_query": + if validation_error := validate_required_args(action, args, ["sql"]): + return validation_error + if err := validate_non_empty_str_arg(action, args, "sql"): + return err + return await tools_manager.dataframes_query( + args.get("sql"), + args.get("output_format", "matrix"), + args.get("result_format", "auto"), + ) + case "dataframes_remove": + if validation_error := validate_required_args(action, args, ["dataframe_id_list"]): + return validation_error + dataframe_id_list = args.get("dataframe_id_list") + if not isinstance(dataframe_id_list, list) or not dataframe_id_list: + return BaseResult( + error="Missing required args for action 'dataframes_remove': dataframe_id_list must be a non-empty list of dataframe IDs within 'args'. " + "Required args: dataframe_id_list (list[str], non-empty)." + ) + return await tools_manager.dataframes_remove(dataframe_id_list) + case "dataframes_clear": + return await tools_manager.dataframes_clear() + case "dataframes_sql_help": + return await tools_manager.dataframes_sql_help() + case _: + return BaseResult(error=f"Action {action} not found in tools manager tool") + except httpx.HTTPStatusError: + return BaseResult(error=f"Error: {traceback.format_exc()}") + except Exception: + return BaseResult( + error=f"Error: {traceback.format_exc()}\n{SUPPORT_MESSAGE}" + ) diff --git a/tools/user_manager.py b/tools/user_manager.py index a3859e0..2f86b46 100644 --- a/tools/user_manager.py +++ b/tools/user_manager.py @@ -24,7 +24,8 @@ from formatters.user import format_users from models.manager import Manager from models.result import BaseResult -from tools.utils import api_request, format_sanitized_traceback +from tools.utils import api_request, format_sanitized_traceback, ttl_cache_method, run_as_task, normalize_action_args, \ + tool_result class UserManager(Manager): @@ -32,6 +33,8 @@ class UserManager(Manager): def __init__(self, token: Optional[BzmToken], ctx: Context): super().__init__(token, ctx) + @ttl_cache_method(ttl_seconds=30) + @run_as_task() async def read(self) -> BaseResult: return await api_request( self.token, @@ -50,15 +53,18 @@ def register(mcp, token: Optional[BzmToken]): - read: Read a current user information from BlazeMeter. Hints: - For default account, workspace and project, use the 'read' action. +- Optional result formatting in args: `result_format` = `auto` (default), `dataframe` (force dataframe), `raw` (disable dataframe materialization). - **CRITICAL**: Always follow the action schema exactly. If args are required, include args with exact names/types. """ ) + @tool_result() async def user( - action: str = Field(description="The action id to execute"), - args: Dict[str, Any] = Field(description="Dictionary with parameters"), + arguments: Dict[str, Any] = Field(description="Tool arguments: action, args, and any action-specific params", default=None), ctx: Context = Field(description="Context object providing access to MCP capabilities") ) -> BaseResult: - + action, args = normalize_action_args(arguments) + if not action: + return BaseResult(error="Missing required argument 'action' within tool arguments.") user_manager = UserManager(token, ctx) try: match action: diff --git a/tools/utils.py b/tools/utils.py index 37d8599..a391679 100644 --- a/tools/utils.py +++ b/tools/utils.py @@ -16,26 +16,36 @@ """ Simple utilities for BlazeMeter MCP tools. """ +import asyncio +import contextvars +import copy import functools +import inspect import os import platform import re import sys +import time import traceback from datetime import datetime, timezone from enum import Enum -from typing import Optional, Callable, Awaitable from importlib import resources from pathlib import Path +from typing import Optional, Callable, Awaitable, Any, Dict import httpx +from mcp.types import CallToolResult from pydantic import BaseModel from config.blazemeter import BZM_API_BASE_URL from config.security import validate_http_request_endpoint from config.token import BzmToken from config.version import __version__ -from models.result import BaseResult, HttpBaseResult +from models.result import BaseResult, HttpBaseResult, ToolResult +from tools.async_task_manager import submit_task, get_task_record, remove_task, task_snapshot +from tools.dataframe_manager import ( + materialize_large_result_if_needed, +) so = platform.system() # "Windows", "Linux", "Darwin" version = platform.version() # kernel / build version @@ -59,10 +69,10 @@ ) unix_abs_path_pattern = re.compile( r"/(?:" - r"Users|home|root" # User home directories (macOS, Linux) - r"|var|tmp|etc|opt|srv" # Standard Linux directories - r"|mnt|run|media" # Mount points and runtime (Linux) - r"|app|data" # Common Docker container directories + r"Users|home|root" # User home directories (macOS, Linux) + r"|var|tmp|etc|opt|srv" # Standard Linux directories + r"|mnt|run|media" # Mount points and runtime (Linux) + r"|app|data" # Common Docker container directories r"|System|Library|Applications|private|Volumes" # macOS directories r")/[^\n\r\t\"']+" ) @@ -130,6 +140,19 @@ class ConfirmMode(Enum): _confirm_mode = ConfirmMode.DELETE +_task_management_enabled = contextvars.ContextVar("task_management_enabled", default=False) +_method_cache: Dict[str, tuple[float, Any]] = {} +_method_cache_inflight: Dict[str, asyncio.Future] = {} +_method_cache_lock = asyncio.Lock() +_method_cache_max_entries = 2048 +_http_clients_lock = asyncio.Lock() +_bzm_http_client: Optional[httpx.AsyncClient] = None +_generic_http_client: Optional[httpx.AsyncClient] = None +_network_debug_context = contextvars.ContextVar("network_debug_context", default=None) +_cache_debug_context = contextvars.ContextVar("cache_debug_context", default=None) +_result_debug_enabled = False +_result_format_context = contextvars.ContextVar("result_format_context", default="auto") +_force_task_response_context = contextvars.ContextVar("force_task_response_context", default=False) class Operations(Enum): @@ -139,6 +162,99 @@ class Operations(Enum): DELETE = "D" # Delete +def set_result_debug_enabled(enabled: bool): + global _result_debug_enabled + _result_debug_enabled = bool(enabled) + + +def is_result_debug_enabled() -> bool: + return _result_debug_enabled + + +async def _get_bzm_http_client() -> httpx.AsyncClient: + global _bzm_http_client + if _bzm_http_client is not None: + return _bzm_http_client + async with _http_clients_lock: + if _bzm_http_client is None: + _bzm_http_client = httpx.AsyncClient(base_url=BZM_API_BASE_URL, http2=True, timeout=timeout) + return _bzm_http_client + + +async def _get_generic_http_client() -> httpx.AsyncClient: + global _generic_http_client + if _generic_http_client is not None: + return _generic_http_client + async with _http_clients_lock: + if _generic_http_client is None: + _generic_http_client = httpx.AsyncClient(base_url="", http2=True, timeout=timeout) + return _generic_http_client + + +def _start_network_debug_scope() -> contextvars.Token: + if not _result_debug_enabled: + return _network_debug_context.set(None) + return _network_debug_context.set({"http_calls": 0, "http_total_ms": 0}) + + +def _get_network_debug_snapshot() -> Dict[str, int]: + current = _network_debug_context.get() + if not isinstance(current, dict): + return {"http_calls": 0, "http_total_ms": 0} + return { + "http_calls": int(current.get("http_calls", 0)), + "http_total_ms": int(current.get("http_total_ms", 0)), + } + + +def _accumulate_network_debug(elapsed_ms: int): + current = _network_debug_context.get() + if not isinstance(current, dict): + return + current["http_calls"] = int(current.get("http_calls", 0)) + 1 + current["http_total_ms"] = int(current.get("http_total_ms", 0)) + max(0, int(elapsed_ms)) + + +def _start_cache_debug_scope() -> contextvars.Token: + if not _result_debug_enabled: + return _cache_debug_context.set(None) + return _cache_debug_context.set( + { + "hits": 0, + "misses": 0, + "shared_wait_ms": 0, + "lock_wait_ms": 0, + "deepcopy_ms": 0, + } + ) + + +def _accumulate_cache_debug(metric: str, value: int = 1): + current = _cache_debug_context.get() + if not isinstance(current, dict): + return + current[metric] = int(current.get(metric, 0)) + int(value) + + +def _get_cache_debug_snapshot() -> Dict[str, int]: + current = _cache_debug_context.get() + if not isinstance(current, dict): + return { + "hits": 0, + "misses": 0, + "shared_wait_ms": 0, + "lock_wait_ms": 0, + "deepcopy_ms": 0, + } + return { + "hits": int(current.get("hits", 0)), + "misses": int(current.get("misses", 0)), + "shared_wait_ms": int(current.get("shared_wait_ms", 0)), + "lock_wait_ms": int(current.get("lock_wait_ms", 0)), + "deepcopy_ms": int(current.get("deepcopy_ms", 0)), + } + + async def api_request(token: Optional[BzmToken], method: str, endpoint: str, result_formatter: Callable = None, result_formatter_params: Optional[dict] = None, @@ -156,60 +272,38 @@ async def api_request(token: Optional[BzmToken], method: str, endpoint: str, headers["Authorization"] = token.as_basic_auth() headers["User-Agent"] = user_agent - async with (httpx.AsyncClient(base_url=BZM_API_BASE_URL, http2=True, timeout=timeout) as client): - try: - resp = await client.request(method, endpoint, headers=headers, **kwargs) - resp.raise_for_status() - content_type = resp.headers.get("content-type", "") - if "application/json" in content_type.lower(): - response_dict = resp.json() - result = response_dict.get("result", []) - else: - response_dict = {} - result = resp.text - default_total = 0 - if not isinstance(result, list): # Generalize result always as a list - result = [result] - default_total = 1 - final_result = result_formatter(result, result_formatter_params) if result_formatter else result + client = await _get_bzm_http_client() + request_started = time.monotonic() + try: + resp = await client.request(method, endpoint, headers=headers, **kwargs) + resp.raise_for_status() + content_type = resp.headers.get("content-type", "") + if "application/json" in content_type.lower(): + response_dict = resp.json() + result = response_dict.get("result", []) + else: + response_dict = {} + result = resp.text + default_total = 0 + if not isinstance(result, list): # Generalize result always as a list + result = [result] + default_total = 1 + final_result = result_formatter(result, result_formatter_params) if result_formatter else result + return BaseResult( + result=final_result, + error=response_dict.get("error", None), + total=response_dict.get("total", default_total), + has_more=response_dict.get("total", 0) - ( + response_dict.get("skip", 0) + response_dict.get("limit", 0)) > 0 + ) + except httpx.HTTPStatusError as e: + if e.response.status_code in [401, 403]: return BaseResult( - result=final_result, - error=response_dict.get("error", None), - total=response_dict.get("total", default_total), - has_more=response_dict.get("total", 0) - ( - response_dict.get("skip", 0) + response_dict.get("limit", 0)) > 0 + error="Invalid credentials" ) - except httpx.HTTPStatusError as e: - status_code = e.response.status_code - error_msg = None - if status_code in [401, 403]: - # Try to extract detailed error message from response body - error_msg = "Invalid credentials" - - error_body = e.response.json() - if isinstance(error_body, dict): - api_error = error_body.get("error") - if api_error: - if isinstance(api_error, dict): - error_msg = api_error.get("message", error_msg) - else: - error_msg = str(api_error) - elif "message" in error_body: - error_msg = error_body.get("message", error_msg) - - # Check for data retention related keywords - error_text = str(error_body).lower() - if any(keyword in error_text for keyword in ["retention", "expired", "no longer available"]): - error_msg = "Data retention period expired: Report data is no longer available due to data retention policy" - - elif status_code in [404]: - error_msg = "Not Found. Please ask the user to verify if the request is valid." - - if error_msg: - return BaseResult( - error=error_msg - ) - raise + raise + finally: + _accumulate_network_debug(int((time.monotonic() - request_started) * 1000)) async def http_request(method: str, endpoint: str, @@ -227,23 +321,26 @@ async def http_request(method: str, endpoint: str, headers = kwargs.pop("headers", {}) headers["User-Agent"] = user_agent - async with (httpx.AsyncClient(base_url="", http2=True, timeout=timeout) as client): - try: - resp = await client.request(method, endpoint, headers=headers, **kwargs) - resp.raise_for_status() - result = resp.text - error = None - final_result = result_formatter(result, result_formatter_params) if result_formatter else result + client = await _get_generic_http_client() + request_started = time.monotonic() + try: + resp = await client.request(method, endpoint, headers=headers, **kwargs) + resp.raise_for_status() + result = resp.text + error = None + final_result = result_formatter(result, result_formatter_params) if result_formatter else result + return HttpBaseResult( + result=final_result, + error=error, + ) + except httpx.HTTPStatusError as e: + if e.response.status_code in [401, 403]: return HttpBaseResult( - result=final_result, - error=error, + error="Invalid credentials" ) - except httpx.HTTPStatusError as e: - if e.response.status_code in [401, 403]: - return HttpBaseResult( - error="Invalid credentials" - ) - raise + raise + finally: + _accumulate_network_debug(int((time.monotonic() - request_started) * 1000)) def get_date_time_iso(timestamp: int) -> Optional[str]: @@ -253,6 +350,184 @@ def get_date_time_iso(timestamp: int) -> Optional[str]: return datetime.fromtimestamp(timestamp, tz=timezone.utc).isoformat() +def _set_tool_call_timing( + result: BaseResult, + started_monotonic: float, + started_wall_clock: float, + extra_timing: Optional[Dict[str, int]] = None +): + finished_wall_clock = time.time() + duration_ms = int((time.monotonic() - started_monotonic) * 1000) + result.tool_call_started_at = datetime.fromtimestamp(started_wall_clock).isoformat() + result.tool_call_finished_at = datetime.fromtimestamp(finished_wall_clock).isoformat() + result.tool_call_duration_ms = duration_ms + if not _result_debug_enabled: + return + debug = result.debug if isinstance(result.debug, dict) else {} + network = _get_network_debug_snapshot() + debug["network"] = network + debug["cache"] = _get_cache_debug_snapshot() + timing = { + "total_ms": duration_ms, + "network_ms": int(network.get("http_total_ms", 0)), + "non_network_ms": max(0, duration_ms - int(network.get("http_total_ms", 0))), + } + if extra_timing: + timing.update({k: int(v) for k, v in extra_timing.items()}) + debug["timing"] = timing + result.debug = debug + + +def _attach_task_debug(result: BaseResult, task_record: Any): + if not _result_debug_enabled: + return + if not isinstance(result, BaseResult) or task_record is None: + return + debug = result.debug if isinstance(result.debug, dict) else {} + task_debug: Dict[str, int] = {} + if task_record.started_running_at is not None: + task_debug["queue_wait_ms"] = int((task_record.started_running_at - task_record.created_at) * 1000) + end_ts = task_record.finished_at if task_record.finished_at is not None else task_record.last_updated_at + task_debug["run_ms"] = int((end_ts - task_record.started_running_at) * 1000) + task_debug["lifecycle_ms"] = int((task_record.last_updated_at - task_record.created_at) * 1000) + debug["task"] = task_debug + result.debug = debug + + +def _extract_result_format(action: Any, args_dict: Any) -> str: + # result_format is only supported in tool entrypoints with dict args. + if not isinstance(args_dict, dict): + return "auto" + result_format = str(args_dict.get("result_format", "auto")).strip().lower() + if result_format not in {"auto", "dataframe", "raw"}: + return "invalid" + return result_format + + +def normalize_action_args(arguments: Optional[Dict[str, Any]] = None) -> tuple[str, Dict[str, Any]]: + """ + Normalize tool arguments to (action, args) format. + Supports: + - {"action": "x", "args": {"key": "value"}} + - {"action": "x", "key": "value"} (params at top level, merged into args) + - {"arguments": {"action": "x", "args": {...}}} (double-wrapped by client) + Top-level keys other than 'action' and 'args' are merged into args. + Use a single 'arguments' param so the full MCP tool call payload is received + (avoids Pydantic dropping extra fields when using action/args separately). + """ + arguments = arguments or {} + # Unwrap double-nested format: {"arguments": {"action": "x", "args": {...}}} + inner = arguments.get("arguments") + if ( + isinstance(inner, dict) + and len(arguments) == 1 + and ("action" in inner or "args" in inner) + ): + arguments = inner + action = str(arguments.get("action") or "").strip() or "" + args = dict(arguments.get("args") or {}) + for key, value in arguments.items(): + if key not in ("action", "args"): + args[key] = value + return action, args + + +def validate_required_args(action: str, args: Optional[Dict[str, Any]], required: list[str]) -> Optional[BaseResult]: + args = args or {} + missing = [key for key in required if key not in args or args[key] is None] + if not missing: + return None + missing_str = ", ".join(missing) + required_str = ", ".join(required) + return BaseResult( + error=( + f"Missing required args for action '{action}': {missing_str} not found within 'args'. " + f"Required args: {required_str}. Ensure parameters are passed inside the 'args' argument." + ) + ) + + +def validate_non_empty_str_arg( + action: str, args: Optional[Dict[str, Any]], key: str +) -> Optional[BaseResult]: + """Return BaseResult error if args[key] is missing, not a str, or only whitespace.""" + args = args or {} + value = args.get(key) + if not isinstance(value, str) or not value.strip(): + return BaseResult( + error=( + f"Missing required args for action '{action}': {key} must be a non-empty string " + f"within 'args'. Required args: {key}." + ) + ) + return None + + +# Max concurrent sub-calls when executing tool "batch" (help, skills, etc.); additional calls wait on the semaphore. +MAX_BATCH_CONCURRENCY = 100 + + +async def _execute_batch_item_with_limit( + semaphore: asyncio.Semaphore, + call: Any, + process_call: Callable[[Any], Awaitable[BaseResult | list[BaseResult]]], +) -> BaseResult | list[BaseResult]: + async with semaphore: + return await process_call(call) + + +async def execute_batch_calls( + batch_calls: Any, + process_call: Callable[[Any], Awaitable[BaseResult | list[BaseResult]]], + *, + max_concurrency: Optional[int] = None, +) -> BaseResult: + """Run batch sub-calls with asyncio.gather; at most ``max_concurrency`` run at once (default MAX_BATCH_CONCURRENCY). + + When the number of batch items exceeds the limit, extra work waits until a slot is free. + """ + if not isinstance(batch_calls, list) or not batch_calls: + return BaseResult( + error="batch_calls must be a non-empty list of dicts with 'action' and 'args'" + ) + limit = MAX_BATCH_CONCURRENCY if max_concurrency is None else max_concurrency + if limit < 1: + limit = 1 + semaphore = asyncio.Semaphore(limit) + results = await asyncio.gather( + *( + _execute_batch_item_with_limit(semaphore, call, process_call) + for call in batch_calls + ), + return_exceptions=True, + ) + processed_results = [ + r if not isinstance(r, Exception) else BaseResult(error=f"Unhandled exception: {str(r)}") + for r in results + ] + return BaseResult(result=processed_results) + + +async def process_batch_sub_action( + call: Dict[str, Any], + dispatch_sub_action: Callable[[str, Dict[str, Any]], Awaitable[BaseResult | list[BaseResult]]], + support_message: Optional[str] = None, +) -> BaseResult | list[BaseResult]: + sub_action = call.get("action", "") + raw_sub_args = call.get("args", {}) + sub_args = dict(raw_sub_args) if isinstance(raw_sub_args, dict) else {} + force_task_token = _force_task_response_context.set(True) + try: + return await dispatch_sub_action(sub_action, sub_args) + except httpx.HTTPStatusError: + return BaseResult(error=f"HTTP error in sub-action {sub_action}: {traceback.format_exc()}") + except Exception: + suffix = f"\n{support_message}" if support_message else "" + return BaseResult(error=f"Error in sub-action {sub_action}: {traceback.format_exc()}{suffix}") + finally: + _force_task_response_context.reset(force_task_token) + + def get_resources_path(): try: resources_path = resources.files("resources") @@ -290,6 +565,123 @@ def operation_need_confirmation(operation: Operations) -> bool: return False +def _cache_scope_from_instance(instance: Any) -> str: + token = getattr(instance, "token", None) + token_id = getattr(token, "id", "anonymous") + return f"{instance.__class__.__name__}:{token_id}" + + +def _cache_compact_value(value: Any) -> str: + if isinstance(value, (str, int, float, bool)) or value is None: + return repr(value) + if isinstance(value, dict): + keys = sorted(value.keys(), key=lambda x: str(x)) + return "{" + ",".join(f"{k}:{_cache_compact_value(value[k])}" for k in keys) + "}" + if isinstance(value, (list, tuple, set)): + return "[" + ",".join(_cache_compact_value(v) for v in value) + "]" + return repr(value) + + +def _cleanup_expired_cache_entries(now: Optional[float] = None): + current = now if now is not None else time.monotonic() + expired_keys = [key for key, (expires_at, _) in _method_cache.items() if expires_at <= current] + for key in expired_keys: + _method_cache.pop(key, None) + + +def _trim_cache_size_if_needed(): + if len(_method_cache) <= _method_cache_max_entries: + return + # Keep entries with the longest remaining TTL. + sorted_entries = sorted(_method_cache.items(), key=lambda item: item[1][0], reverse=True) + _method_cache.clear() + for key, value in sorted_entries[:_method_cache_max_entries]: + _method_cache[key] = value + + +def ttl_cache_method(ttl_seconds: int = 30): + """ + Async TTL cache decorator for manager instance methods. + Caches successful results only and prevents duplicate concurrent fetches. + """ + + def decorator(func: Callable[..., Awaitable[Any]]): + @functools.wraps(func) + async def wrapper(self, *args, **kwargs): + scope = _cache_scope_from_instance(self) + cache_key = ( + f"{scope}:{func.__module__}.{func.__qualname__}:" + f"args={_cache_compact_value(args)}:kwargs={_cache_compact_value(kwargs)}" + ) + now = time.monotonic() + cache_hit = False + cached_value = None + shared_future: Optional[asyncio.Future] = None + is_owner = False + + lock_wait_started = time.monotonic() + async with _method_cache_lock: + _accumulate_cache_debug("lock_wait_ms", int((time.monotonic() - lock_wait_started) * 1000)) + _cleanup_expired_cache_entries(now) + cached_entry = _method_cache.get(cache_key) + if cached_entry and cached_entry[0] > now: + cache_hit = True + cached_value = cached_entry[1] + else: + shared_future = _method_cache_inflight.get(cache_key) + if shared_future is None: + shared_future = asyncio.get_running_loop().create_future() + _method_cache_inflight[cache_key] = shared_future + is_owner = True + _accumulate_cache_debug("misses", 1) + else: + is_owner = False + + if cache_hit: + _accumulate_cache_debug("hits", 1) + dc_started = time.monotonic() + copied = copy.deepcopy(cached_value) + _accumulate_cache_debug("deepcopy_ms", int((time.monotonic() - dc_started) * 1000)) + return copied + + if not is_owner and shared_future is not None: + shared_wait_started = time.monotonic() + shared_result = await shared_future + _accumulate_cache_debug("shared_wait_ms", int((time.monotonic() - shared_wait_started) * 1000)) + dc_started = time.monotonic() + copied = copy.deepcopy(shared_result) + _accumulate_cache_debug("deepcopy_ms", int((time.monotonic() - dc_started) * 1000)) + return copied + + try: + result = await func(self, *args, **kwargs) + should_cache = not (isinstance(result, BaseResult) and result.error) + dc_started = time.monotonic() + cached_copy = copy.deepcopy(result) if should_cache else None + shared_copy = copy.deepcopy(result) + _accumulate_cache_debug("deepcopy_ms", int((time.monotonic() - dc_started) * 1000)) + + async with _method_cache_lock: + if should_cache: + _method_cache[cache_key] = (time.monotonic() + ttl_seconds, cached_copy) + _trim_cache_size_if_needed() + + current_future = _method_cache_inflight.pop(cache_key, None) + if current_future is not None and not current_future.done(): + current_future.set_result(shared_copy) + return result + except Exception as exc: + async with _method_cache_lock: + current_future = _method_cache_inflight.pop(cache_key, None) + if current_future is not None and not current_future.done(): + current_future.set_exception(exc) + raise + + return wrapper + + return decorator + + def require_confirmation(operation: Operations = Operations.READ, message="This action requires manual confirmation to continue"): confirmation_schema = Confirmation @@ -317,3 +709,239 @@ async def wrapper(self, *args, **kwargs): return wrapper return decorator + + +async def execute_with_task_management( + action_payload: Dict[str, Any], + coro_factory: Callable[[], Awaitable[Any]], + time_to_live_ms: Optional[int] = None, + fast_response_threshold_seconds: float = 5.0 +) -> BaseResult: + wait_started = time.monotonic() + try: + task_id = submit_task(action_payload, coro_factory, time_to_live_ms=time_to_live_ms) + except RuntimeError as exc: + return BaseResult(error=str(exc)) + task_record = get_task_record(task_id) + if not task_record or not task_record.asyncio_task: + return BaseResult(error="Task could not be scheduled.") + + try: + await asyncio.wait_for(asyncio.shield(task_record.asyncio_task), timeout=fast_response_threshold_seconds) + latest_record = get_task_record(task_id) + if not latest_record or latest_record.result is None: + remove_task(task_id) + return BaseResult(error="Task finished without result.") + final_result = latest_record.result + _attach_task_debug(final_result, latest_record) + if isinstance(final_result.debug, dict): + final_result.debug.setdefault("task", {}) + final_result.debug["task"]["sync_wait_ms"] = int((time.monotonic() - wait_started) * 1000) + remove_task(task_id) + return final_result + except asyncio.TimeoutError: + latest_record = get_task_record(task_id) + if not latest_record: + return BaseResult(error="Task was not found after scheduling.") + snapshot = task_snapshot(latest_record, include_result=False) + timeout_result = BaseResult( + result=[snapshot], + info=[ + "Long-running operation accepted. Use blazemeter_tools with action 'tasks_status' to monitor status." + ] + ) + _attach_task_debug(timeout_result, latest_record) + if isinstance(timeout_result.debug, dict): + timeout_result.debug.setdefault("task", {}) + timeout_result.debug["task"]["sync_wait_ms"] = int((time.monotonic() - wait_started) * 1000) + return timeout_result + + +def _serialize_action_value(value: Any) -> Any: + if isinstance(value, (str, int, float, bool)) or value is None: + return value + if isinstance(value, dict): + return {str(k): _serialize_action_value(v) for k, v in value.items()} + if isinstance(value, (list, tuple, set)): + return [_serialize_action_value(v) for v in value] + return repr(value) + + +def run_as_task( + time_to_live_ms: Optional[int] = None, + fast_response_threshold_seconds: float = 5.0 +): + def decorator(func: Callable[..., Awaitable[Any]]): + @functools.wraps(func) + async def wrapper(self, *args, **kwargs): + if _task_management_enabled.get(): + return await func(self, *args, **kwargs) + + try: + signature = inspect.signature(func) + bound = signature.bind(self, *args, **kwargs) + bound.apply_defaults() + # Keep all user-provided parameters with names for richer task context. + named_params = { + key: _serialize_action_value(value) + for key, value in bound.arguments.items() + if key != "self" + } + except Exception: + named_params = {} + + action_payload = { + "manager": self.__class__.__name__, + "method": func.__name__, + "args": _serialize_action_value(args), + "kwargs": _serialize_action_value(kwargs), + "params": named_params, + "result_format": _result_format_context.get(), + } + + token = _task_management_enabled.set(True) + try: + effective_fast_response_threshold = ( + 0.0 if _force_task_response_context.get() else fast_response_threshold_seconds + ) + coro_factory = lambda: func(self, *args, **kwargs) + return await execute_with_task_management( + action_payload=action_payload, + coro_factory=coro_factory, + time_to_live_ms=time_to_live_ms, + fast_response_threshold_seconds=effective_fast_response_threshold + ) + finally: + _task_management_enabled.reset(token) + + return wrapper + + return decorator + + +def tool_result(excluded_actions: Optional[set[str]] = None): + excluded = excluded_actions or set() + + def decorator(func: Callable[..., Awaitable[BaseResult]]): + @functools.wraps(func) + async def wrapper(*args, **kwargs) -> ToolResult | CallToolResult: + def _to_tool_result(value: Any) -> ToolResult | CallToolResult: + if isinstance(value, ToolResult): + return value + if isinstance(value, CallToolResult): + return value + if isinstance(value, BaseResult): + return ToolResult.from_base_result(value) + return ToolResult.from_base_result(BaseResult(result=[value])) + + started_monotonic = time.monotonic() + started_wall_clock = time.time() + network_token = _start_network_debug_scope() + cache_token = _start_cache_debug_scope() + result_format_token = _result_format_context.set("auto") + try: + result = await func(*args, **kwargs) + after_func_monotonic = time.monotonic() + if not isinstance(result, BaseResult) or result.error or result.result is None: + if isinstance(result, BaseResult): + _set_tool_call_timing( + result, + started_monotonic, + started_wall_clock, + extra_timing={ + "manager_logic_ms": int((after_func_monotonic - started_monotonic) * 1000), + "postprocess_ms": 0, + }, + ) + return _to_tool_result(result) + + action = kwargs.get("action") + if action is None and len(args) > 0: + action = args[0] + tool_args = kwargs.get("args") + if tool_args is None and len(args) > 1: + tool_args = args[1] + + result_format = _extract_result_format(action, tool_args) + if result_format == "invalid": + invalid = BaseResult( + error="Invalid result_format value. Allowed values: auto, dataframe, raw." + ) + _set_tool_call_timing( + invalid, + started_monotonic, + started_wall_clock, + extra_timing={ + "manager_logic_ms": int((after_func_monotonic - started_monotonic) * 1000), + "postprocess_ms": 0, + }, + ) + return _to_tool_result(invalid) + if isinstance(action, str) and action == "batch": + # Batch envelopes must remain inline results; do not materialize as dataframe. + result_format = "raw" + _result_format_context.reset(result_format_token) + result_format_token = _result_format_context.set(result_format) + + if result_format == "auto" and isinstance(action, str) and action in excluded: + _set_tool_call_timing( + result, + started_monotonic, + started_wall_clock, + extra_timing={ + "manager_logic_ms": int((after_func_monotonic - started_monotonic) * 1000), + "postprocess_ms": 0, + }, + ) + return _to_tool_result(result) + + try: + postprocess_started = time.monotonic() + if result_format == "raw": + final_result = result + else: + final_result = await materialize_large_result_if_needed( + base_result=result, + origin_manager=func.__name__, + origin_action=str(action) if action is not None else "unknown", + force=(result_format == "dataframe"), + ) + _set_tool_call_timing( + final_result, + started_monotonic, + started_wall_clock, + extra_timing={ + "manager_logic_ms": int((after_func_monotonic - started_monotonic) * 1000), + "postprocess_ms": int((time.monotonic() - postprocess_started) * 1000), + }, + ) + return _to_tool_result(final_result) + except Exception as exc: + failure_result = BaseResult( + error=( + f"Large result materialization failed: {exc}. " + "Try reducing the scope or filters and retry." + ) + ) + _set_tool_call_timing( + failure_result, + started_monotonic, + started_wall_clock, + extra_timing={ + "manager_logic_ms": int((after_func_monotonic - started_monotonic) * 1000), + "postprocess_ms": 0, + }, + ) + return _to_tool_result(failure_result) + finally: + _network_debug_context.reset(network_token) + _cache_debug_context.reset(cache_token) + _result_format_context.reset(result_format_token) + + return wrapper + + return decorator + + +# Backward-compatible alias. Prefer using tool_result in new code. +dataframe_result = tool_result diff --git a/tools/workspace_manager.py b/tools/workspace_manager.py index 8c6c62c..7815e5e 100644 --- a/tools/workspace_manager.py +++ b/tools/workspace_manager.py @@ -25,7 +25,8 @@ from models.manager import Manager from models.result import BaseResult from tools import bridge -from tools.utils import api_request, format_sanitized_traceback +from tools.utils import api_request, format_sanitized_traceback, run_as_task, normalize_action_args, tool_result, \ + validate_required_args, ttl_cache_method class WorkspaceManager(Manager): @@ -37,9 +38,9 @@ class WorkspaceManager(Manager): def __init__(self, token: Optional[BzmToken], ctx: Context): super().__init__(token, ctx) - async def read(self, workspace_id: Optional[int]) -> BaseResult: - if not isinstance(workspace_id, int) or workspace_id < 1: - return BaseResult(error="Missing or invalid required argument 'workspace_id'. Expected integer.") + @ttl_cache_method(ttl_seconds=30) + @run_as_task() + async def read(self, workspace_id: int) -> BaseResult: workspace_result = await api_request( self.token, @@ -58,11 +59,8 @@ async def read(self, workspace_id: Optional[int]) -> BaseResult: else: return workspace_result - async def list(self, account_id: Optional[int], limit: int = 50, offset: int = 0) -> BaseResult: - if not isinstance(account_id, int) or account_id < 1: - return BaseResult(error="Missing or invalid required argument 'account_id'. Expected integer.") - if not isinstance(limit, int) or not isinstance(offset, int): - return BaseResult(error="Invalid arguments 'limit'/'offset'. Expected integers.") + @run_as_task() + async def list(self, account_id: int, limit: int = 50, offset: int = 0) -> BaseResult: # Check if it's valid or allowed account_data = await bridge.read_account(self.token, self.ctx, account_id) @@ -84,11 +82,8 @@ async def list(self, account_id: Optional[int], limit: int = 50, offset: int = 0 params=parameters ) - async def read_locations(self, workspace_id: Optional[int], purpose: str = "load") -> BaseResult: - if not isinstance(workspace_id, int) or workspace_id < 1: - return BaseResult(error="Missing or invalid required argument 'workspace_id'. Expected integer.") - if not isinstance(purpose, str) or not purpose.strip(): - return BaseResult(error="Invalid argument 'purpose'. Expected non-empty string.") + @run_as_task() + async def read_locations(self, workspace_id: int, purpose: str = "load") -> BaseResult: locations_result = await api_request( self.token, @@ -108,6 +103,7 @@ async def read_locations(self, workspace_id: Optional[int], purpose: str = "load else: return locations_result + def register(mcp, token: Optional[BzmToken]): @mcp.tool( name=f"{TOOLS_PREFIX}_workspaces", @@ -115,39 +111,54 @@ def register(mcp, token: Optional[BzmToken]): Operations on workspaces. Actions: - read: Read a workspace. Get the detailed information of a workspace. - args(dict): Dictionary with the following required parameters: - workspace_id (int): The id of the workspace. + args(dict): Dictionary with the following parameters: + workspace_id (int, required): The id of the workspace. - list: List all workspaces. - args(dict): Dictionary with the following required parameters: - account_id (int): The id of the account to list the workspaces from - limit (int, default=10, valid=[1 to 50]): The number of workspaces to list. - offset (int, default=0): Number of workspaces to skip. + args(dict): Dictionary with the following parameters: + account_id (int, required): The id of the account to list workspaces from. + limit (int, optional, default=50, valid=[1 to 50 when result_format=auto/raw, 1000 when result_format=dataframe]): Max workspaces to return. + offset (int, optional, default=0): Number of workspaces to skip. - read_locations: get the location list for a given workspace ID. - args(dict): Dictionary with the following required parameters: - workspace_id (int): The id of the workspace. - purpose (str, default="load", valid=["load", "functional", "grid", "mock"]): The purpose filter. + args(dict): Dictionary with the following parameters: + workspace_id (int, required): The id of the workspace. + purpose (str, optional, default="load", valid=["load", "functional", "grid", "mock"]): The purpose filter. Hints: - For available locations and available billing usage use the 'read' action for a particular workspace. +- Optional result formatting in args: `result_format` = `auto` (default), `dataframe` (force dataframe), `raw` (disable dataframe materialization). - **CRITICAL**: Always follow the action schema exactly. If args are required, include args with exact names/types. """ ) + @tool_result() async def workspace( - action: str = Field(description="The action id to execute"), - args: Dict[str, Any] = Field(description="Dictionary with parameters"), + arguments: Dict[str, Any] = Field( + description="Tool arguments: action, args, and any action-specific params", default=None), ctx: Context = Field(description="Context object providing access to MCP capabilities") ) -> BaseResult: - + action, args = normalize_action_args(arguments) + if not action: + return BaseResult(error="Missing required argument 'action' within tool arguments.") workspace_manager = WorkspaceManager(token, ctx) try: match action: case "read": + if validation_error := validate_required_args(action, args, ["workspace_id"]): + return validation_error return await workspace_manager.read(args.get("workspace_id")) case "list": - return await workspace_manager.list( - args.get("account_id"), args.get("limit", 50), args.get("offset", 0) - ) + if validation_error := validate_required_args(action, args, ["account_id"]): + return validation_error + return await workspace_manager.list(args.get("account_id"), args.get("limit", 50), + args.get("offset", 0)) case "read_locations": - return await workspace_manager.read_locations(args.get("workspace_id"), args.get("purpose", "load")) + if validation_error := validate_required_args(action, args, ["workspace_id"]): + return validation_error + purpose_raw = args.get("purpose", "load") + purpose = ( + purpose_raw.strip() + if isinstance(purpose_raw, str) and purpose_raw.strip() + else "load" + ) + return await workspace_manager.read_locations(args.get("workspace_id"), purpose) case _: return BaseResult( error=f"Action {action} not found in workspace manager tool" diff --git a/uv.lock b/uv.lock index 6b6024b..7f790cb 100644 --- a/uv.lock +++ b/uv.lock @@ -59,6 +59,7 @@ dependencies = [ { name = "httpx", extra = ["http2"] }, { name = "lxml" }, { name = "mcp", extra = ["cli"] }, + { name = "polars" }, { name = "pydantic" }, { name = "pydantic-core" }, { name = "pydantic-settings" }, @@ -77,6 +78,7 @@ requires-dist = [ { name = "httpx", extras = ["http2"], specifier = ">=0.28.1" }, { name = "lxml", specifier = ">=5.3.0" }, { name = "mcp", extras = ["cli"], specifier = ">=1.27.0" }, + { name = "polars", specifier = ">=1.40.1" }, { name = "pydantic", specifier = ">=2.11.7" }, { name = "pydantic-core", specifier = ">=2.33.2" }, { name = "pydantic-settings", specifier = ">=2.10.1" }, @@ -567,6 +569,34 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/54/20/4d324d65cc6d9205fabedc306948156824eb9f0ee1633355a8f7ec5c66bf/pluggy-1.6.0-py3-none-any.whl", hash = "sha256:e920276dd6813095e9377c0bc5566d94c932c33b27a3e3945d8389c374dd4746", size = 20538, upload-time = "2025-05-15T12:30:06.134Z" }, ] +[[package]] +name = "polars" +version = "1.40.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "polars-runtime-32" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/b3/8c/bc9bc948058348ed43117cecc3007cd608f395915dae8a00974579a5dab1/polars-1.40.1.tar.gz", hash = "sha256:ab2694134b137596b5a59bfd7b4c54ebbc9b59f9403127f18e32d363777552e8", size = 733574, upload-time = "2026-04-22T19:15:55.507Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/ea/91/74fc60d94488685a92ac9d49d7ec55f3e91fe9b77942a6235a5fa7f249c3/polars-1.40.1-py3-none-any.whl", hash = "sha256:c0f861219d1319cdea45c4ce4d30355a47176b8f98dcedf95ea8269f131b8abd", size = 828723, upload-time = "2026-04-22T19:14:25.452Z" }, +] + +[[package]] +name = "polars-runtime-32" +version = "1.40.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/54/ba/26d40f039be9f552b5fd7365a621bdfc0f8e912ef77094ae4693491b0bae/polars_runtime_32-1.40.1.tar.gz", hash = "sha256:37f3065615d1bf90d03b5326222df4c5c1f8a5d33e50470aa588e3465e6eb814", size = 2935843, upload-time = "2026-04-22T19:15:57.26Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/7d/46/22c8af5eed68ac2eeb556e0fa3ca8a7b798e984ceff4450888f3b5ac61fd/polars_runtime_32-1.40.1-cp310-abi3-macosx_10_12_x86_64.whl", hash = "sha256:b748ef652270cc49e9e69f99a035e0eb4d5f856d42bcd6ac4d9d80a40142aa1e", size = 52098755, upload-time = "2026-04-22T19:14:28.555Z" }, + { url = "https://files.pythonhosted.org/packages/c6/3e/48599a38009ca60ff82a6f38c8a621ce3c0286aa7397c7d79e741bd9060e/polars_runtime_32-1.40.1-cp310-abi3-macosx_11_0_arm64.whl", hash = "sha256:d249b3743e05986060cec0a7aaa542d020df6c6b876e556023a310efd581f9be", size = 46367542, upload-time = "2026-04-22T19:14:32.433Z" }, + { url = "https://files.pythonhosted.org/packages/43/e9/384bc069367a1a36ee31c13782c178dbd039b2b873b772d4a0fc23a2373d/polars_runtime_32-1.40.1-cp310-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5987b30e7aa1059d069498496e8dda35afd592b0ac3d46ed87e3ff8df1ad652c", size = 50252104, upload-time = "2026-04-22T19:14:35.945Z" }, + { url = "https://files.pythonhosted.org/packages/15/ef/7d57ceb0651af74194e97ed6583e148d352f03d696090221b8059cdfc90b/polars_runtime_32-1.40.1-cp310-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8d7f42a8b3f16fc66002cc0f6516f7dd7653396886ae0ed362ab95c0b3408b59", size = 56250788, upload-time = "2026-04-22T19:14:39.743Z" }, + { url = "https://files.pythonhosted.org/packages/10/0f/e4b3ffc748827a14a474ec9c42e45c066050e440fec57e914091d9adda75/polars_runtime_32-1.40.1-cp310-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:e5f7becc237a7ec9d9a10878dc8e54b73bbf4e2d94a2991c37d7a0b38590d8f9", size = 50432590, upload-time = "2026-04-22T19:14:43.388Z" }, + { url = "https://files.pythonhosted.org/packages/d9/0b/b8d95fbed869fa4caabe9c400e4210374913b376e925e96fdcfa9be6416b/polars_runtime_32-1.40.1-cp310-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:992d14cf191dde043d36fbdbc98a65e43fbc7e9a5024cecd45f838ac4988c1ee", size = 54155564, upload-time = "2026-04-22T19:14:47.239Z" }, + { url = "https://files.pythonhosted.org/packages/06/d9/d091d8fb5cbed5e9536adfed955c4c89987a4cc3b8e73ae4532402b91c74/polars_runtime_32-1.40.1-cp310-abi3-win_amd64.whl", hash = "sha256:f78bb2abd00101cbb23cc0cb068f7e36e081057a15d2ec2dde3dda280709f030", size = 51829755, upload-time = "2026-04-22T19:14:50.85Z" }, + { url = "https://files.pythonhosted.org/packages/65/ad/b33c3022a394f3eb55c3310597cec615412a8a33880055eee191d154a628/polars_runtime_32-1.40.1-cp310-abi3-win_arm64.whl", hash = "sha256:b5cbfaf6b085b420b4bfcbe24e8f665076d1cccfdb80c0484c02a023ce205537", size = 45822104, upload-time = "2026-04-22T19:14:54.192Z" }, +] + [[package]] name = "pycparser" version = "3.0" From 289207fb409322ab112a4137ec2c8939ec768b3c Mon Sep 17 00:00:00 2001 From: David <3dgiordano@gmail.com> Date: Tue, 28 Apr 2026 21:28:58 -0300 Subject: [PATCH 2/5] Use simple id for dataframes --- tools/async_task_manager.py | 38 +++++++-------- tools/dataframe_manager.py | 97 ++++++++++++++++++++++--------------- tools/utils.py | 19 ++++++-- 3 files changed, 91 insertions(+), 63 deletions(-) diff --git a/tools/async_task_manager.py b/tools/async_task_manager.py index c89c4a2..d164266 100644 --- a/tools/async_task_manager.py +++ b/tools/async_task_manager.py @@ -16,7 +16,6 @@ import asyncio import logging -import secrets import time from dataclasses import dataclass from datetime import datetime @@ -24,6 +23,10 @@ from models.result import BaseResult from tools.dataframe_manager import materialize_large_result_if_needed +from tools.utils import generate_simple_id, SIMPLE_ID_ALPHABET, SIMPLE_ID_LENGTH, normalize_simple_id + +# Crockford-like base32 alphabet used by generate_simple_id / task ids (tests assert against this). +TASK_ID_ALPHABET = SIMPLE_ID_ALPHABET STATUS_WORKING = "working" STATUS_PARKING = "parking" @@ -35,8 +38,7 @@ TERMINAL_STATES = {STATUS_COMPLETED, STATUS_FAILED, STATUS_CANCELLED} ACTIVE_STATES = {STATUS_PARKING, STATUS_WORKING, STATUS_INPUT_REQUIRED} MAX_PARALLEL_TASKS = 10 -TASK_ID_ALPHABET = "0123456789abcdefghjkmnpqrstvwxyz" -TASK_ID_LENGTH = 8 + TASK_ID_MAX_ATTEMPTS = 10 logger = logging.getLogger(__name__) @@ -94,7 +96,6 @@ def set_status(self, status: str, status_message: str): _tasks: Dict[str, TaskRecord] = {} - def _to_iso(timestamp: float) -> str: return datetime.fromtimestamp(timestamp).isoformat() @@ -105,12 +106,8 @@ def _normalize_result(result: Any) -> BaseResult: return BaseResult(result=[result]) -def _normalize_task_id(task_id: str) -> str: - return str(task_id).strip().lower() - - def _generate_task_id() -> str: - return "".join(secrets.choice(TASK_ID_ALPHABET) for _ in range(TASK_ID_LENGTH)) + return generate_simple_id() def _allocate_task_id() -> str: @@ -120,13 +117,14 @@ def _allocate_task_id() -> str: return candidate logger.error( - "Unable to allocate unique 8-char task id after 10 attempts. " - "attempts=%s id_length=%s alphabet=crockford32 active_pool_size=%s", + "Unable to allocate task id. attempts=%s id_length=%s alphabet=crockford32 active_pool_size=%s", TASK_ID_MAX_ATTEMPTS, - TASK_ID_LENGTH, + SIMPLE_ID_LENGTH, len(_tasks), ) - raise RuntimeError("Unable to allocate unique 8-char task id after 10 attempts.") + raise RuntimeError( + f"Unable to allocate unique {SIMPLE_ID_LENGTH}-char task id after {TASK_ID_MAX_ATTEMPTS} attempts." + ) async def _task_runner(task_record: TaskRecord, coro_factory: Callable[[], Awaitable[Any]]): @@ -171,9 +169,9 @@ async def _task_runner(task_record: TaskRecord, coro_factory: Callable[[], Await def submit_task( - action: Dict[str, Any], - coro_factory: Callable[[], Awaitable[Any]], - time_to_live_ms: Optional[int] = None + action: Dict[str, Any], + coro_factory: Callable[[], Awaitable[Any]], + time_to_live_ms: Optional[int] = None ) -> str: now = time.time() task_id = _allocate_task_id() @@ -194,13 +192,11 @@ def submit_task( def get_task_record(task_id: str) -> Optional[TaskRecord]: - normalized_task_id = _normalize_task_id(task_id) - return _tasks.get(normalized_task_id) + return _tasks.get(normalize_simple_id(task_id)) def remove_task(task_id: str) -> bool: - normalized_task_id = _normalize_task_id(task_id) - return _tasks.pop(normalized_task_id, None) is not None + return _tasks.pop(normalize_simple_id(task_id), None) is not None def task_snapshot(task_record: TaskRecord, include_result: bool = False) -> Dict[str, Any]: @@ -241,7 +237,7 @@ def is_active_status(status: str) -> bool: def cancel_task(task_id: str) -> Optional[TaskRecord]: - normalized_task_id = _normalize_task_id(task_id) + normalized_task_id = normalize_simple_id(task_id) task_record = _tasks.get(normalized_task_id) if not task_record: return None diff --git a/tools/dataframe_manager.py b/tools/dataframe_manager.py index c781abe..4d23106 100644 --- a/tools/dataframe_manager.py +++ b/tools/dataframe_manager.py @@ -16,17 +16,24 @@ import asyncio import hashlib import json +import logging import re import uuid from dataclasses import dataclass, asdict -from datetime import datetime +from datetime import datetime, UTC from typing import Any, Dict, List, Optional import polars as pl + from models.result import BaseResult +from tools.utils import generate_simple_id, SIMPLE_ID_LENGTH +logger = logging.getLogger(__name__) DATAFRAME_JSON_SIZE_THRESHOLD = 8000 + +DATAFRAME_ID_MAX_ATTEMPTS = 10 + _dataframes: Dict[str, "DataFrameRecord"] = {} _write_lock = asyncio.Lock() _sql_context = pl.SQLContext() @@ -85,23 +92,23 @@ def build_dataframe_from_result(result: List[Any]) -> pl.DataFrame: # matrix envelope: [{"columns":[...], "rows":[...]}] if ( - isinstance(normalized, list) - and len(normalized) == 1 - and isinstance(normalized[0], dict) - and set(normalized[0].keys()) == {"columns", "rows"} - and isinstance(normalized[0]["columns"], list) - and isinstance(normalized[0]["rows"], list) + isinstance(normalized, list) + and len(normalized) == 1 + and isinstance(normalized[0], dict) + and set(normalized[0].keys()) == {"columns", "rows"} + and isinstance(normalized[0]["columns"], list) + and isinstance(normalized[0]["rows"], list) ): matrix = normalized[0] return pl.DataFrame(matrix["rows"], schema=[str(c) for c in matrix["columns"]], orient="row") # columnar envelope: [{"colA":[...], "colB":[...]}] if ( - isinstance(normalized, list) - and len(normalized) == 1 - and isinstance(normalized[0], dict) - and normalized[0] - and all(isinstance(v, list) for v in normalized[0].values()) + isinstance(normalized, list) + and len(normalized) == 1 + and isinstance(normalized[0], dict) + and normalized[0] + and all(isinstance(v, list) for v in normalized[0].values()) ): col_lengths = {len(v) for v in normalized[0].values()} if len(col_lengths) == 1: @@ -119,9 +126,9 @@ def build_dataframe_from_result(result: List[Any]) -> pl.DataFrame: def auto_flatten_wide( - df: pl.DataFrame, - max_passes: int = 30, - sep: str = "__", + df: pl.DataFrame, + max_passes: int = 30, + sep: str = "__", ) -> pl.DataFrame: """ Flatten nested structures in a DataFrame for SQL queryability. @@ -211,11 +218,11 @@ def _canonicalize_top_schema(schema_rows: List[Dict[str, str]]) -> List[Dict[str async def register_dataframe( - result: List[Any], - origin_manager: str, - origin_action: str, - json_size_chars: int, - flatten: bool = True, + result: List[Any], + origin_manager: str, + origin_action: str, + json_size_chars: int, + flatten: bool = True, ) -> Dict[str, Any]: dataframe = build_dataframe_from_result(result) return await _register_dataframe_instance( @@ -223,24 +230,39 @@ async def register_dataframe( ) +def _allocate_dataframe_id() -> str: + for _ in range(DATAFRAME_ID_MAX_ATTEMPTS): + candidate = generate_simple_id() + if candidate not in _dataframes: + return candidate + + logger.error( + "Unable to allocate dataframe id. attempts=%s id_length=%s active_pool_size=%s", + DATAFRAME_ID_MAX_ATTEMPTS, + SIMPLE_ID_LENGTH, + len(_dataframes), + ) + raise RuntimeError(f"Unable to allocate dataframe id after {DATAFRAME_ID_MAX_ATTEMPTS} attempts.") + + async def _register_dataframe_instance( - dataframe: pl.DataFrame, - origin_manager: str, - origin_action: str, - json_size_chars: int, - flatten: bool = True, + dataframe: pl.DataFrame, + origin_manager: str, + origin_action: str, + json_size_chars: int, + flatten: bool = True, ) -> Dict[str, Any]: if flatten: try: dataframe = auto_flatten_wide(dataframe) except Exception: pass # Keep original dataframe if flattening fails - dataframe_id = str(uuid.uuid4()) - table_name = f"df_{dataframe_id.replace('-', '_')}" + dataframe_id = _allocate_dataframe_id() + table_name = f"df_{dataframe_id}" record = DataFrameRecord( dataframe_id=dataframe_id, table_name=table_name, - created_at=datetime.utcnow().isoformat(), + created_at=datetime.now(UTC).isoformat(), origin_manager=origin_manager, origin_action=origin_action, rows=dataframe.height, @@ -257,19 +279,19 @@ async def _register_dataframe_instance( async def materialize_large_result_if_needed( - base_result: BaseResult, - origin_manager: str, - origin_action: str, - force: bool = False + base_result: BaseResult, + origin_manager: str, + origin_action: str, + force: bool = False ) -> BaseResult: if not isinstance(base_result, BaseResult) or base_result.error or base_result.result is None: return base_result if ( - isinstance(base_result.result, list) - and len(base_result.result) == 1 - and isinstance(base_result.result[0], dict) - and base_result.result[0].get("stored_as_dataframe") is True - and base_result.result[0].get("dataframe_id") + isinstance(base_result.result, list) + and len(base_result.result) == 1 + and isinstance(base_result.result[0], dict) + and base_result.result[0].get("stored_as_dataframe") is True + and base_result.result[0].get("dataframe_id") ): # Avoid rematerializing a payload that is already a dataframe reference. return base_result @@ -774,4 +796,3 @@ def get_sql_capabilities() -> Dict[str, Any]: "https://docs.pola.rs/py-polars/html/reference/sql/set_operations.html" ], } - diff --git a/tools/utils.py b/tools/utils.py index a391679..a312eb3 100644 --- a/tools/utils.py +++ b/tools/utils.py @@ -13,6 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. """ +import secrets + """ Simple utilities for BlazeMeter MCP tools. """ @@ -42,10 +44,6 @@ from config.token import BzmToken from config.version import __version__ from models.result import BaseResult, HttpBaseResult, ToolResult -from tools.async_task_manager import submit_task, get_task_record, remove_task, task_snapshot -from tools.dataframe_manager import ( - materialize_large_result_if_needed, -) so = platform.system() # "Windows", "Linux", "Darwin" version = platform.version() # kernel / build version @@ -77,6 +75,14 @@ r")/[^\n\r\t\"']+" ) +SIMPLE_ID_ALPHABET = "0123456789abcdefghjkmnpqrstvwxyz" +SIMPLE_ID_LENGTH = 8 + +def generate_simple_id() -> str: + return "".join(secrets.choice(SIMPLE_ID_ALPHABET) for _ in range(SIMPLE_ID_LENGTH)) + +def normalize_simple_id(simple_id: str) -> str: + return str(simple_id).strip().lower() def sanitize_path(path_value: str) -> str: if not path_value: @@ -717,6 +723,9 @@ async def execute_with_task_management( time_to_live_ms: Optional[int] = None, fast_response_threshold_seconds: float = 5.0 ) -> BaseResult: + # Deferred import avoids circular dependency: utils → async_task_manager → dataframe_manager → utils. + from tools.async_task_manager import submit_task, get_task_record, remove_task, task_snapshot + wait_started = time.monotonic() try: task_id = submit_task(action_payload, coro_factory, time_to_live_ms=time_to_live_ms) @@ -900,6 +909,8 @@ def _to_tool_result(value: Any) -> ToolResult | CallToolResult: if result_format == "raw": final_result = result else: + from tools.dataframe_manager import materialize_large_result_if_needed + final_result = await materialize_large_result_if_needed( base_result=result, origin_manager=func.__name__, From 32e2f4d8fc3028d2e4a35492afc83f37d5f25baa Mon Sep 17 00:00:00 2001 From: David <3dgiordano@gmail.com> Date: Wed, 29 Apr 2026 11:24:50 -0300 Subject: [PATCH 3/5] Improve dataframe sql help with more hints --- tools/dataframe_manager.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tools/dataframe_manager.py b/tools/dataframe_manager.py index 4d23106..8823ee9 100644 --- a/tools/dataframe_manager.py +++ b/tools/dataframe_manager.py @@ -643,6 +643,7 @@ def get_sql_capabilities() -> Dict[str, Any]: "unsupported_functions": [ {"name": "STRUCT_EXTRACT", "reason": "Not recognized in this SQL context."}, {"name": "TO_JSON", "reason": "Not recognized in this SQL context."}, + {"name": "TYPEOF", "reason": "Not recognized in this SQL context."}, ], "unsupported_or_unstable_patterns": [ "Complex chained nested access with mixed subscript and dot notation in a single expression", @@ -712,6 +713,7 @@ def get_sql_capabilities() -> Dict[str, Any]: "Validate each CTE with a small LIMIT before composing final query", "For multi-dataframe analysis, use schema groups first, then perform targeted per-dataframe inspection only when needed.", "To get the maximum value between a scalar field and all values in a nested list per record, use UNNEST on the list, then GROUP BY and GREATEST(MAX(list.field), MAX(scalar)).", + "Before UNION ALL, normalize each branch to the same concrete type your next step expects (e.g. INTEGER year, not “string then parse after union”).", ], "known_engine_pitfalls": [ "CTE + JOIN resolution may treat same-name keys as ambiguous even when aliases are present; rename join keys in the UNNEST stage (base_*/src_*) to guarantee deterministic resolution", From cb60529849145a52edf92689c4d8eb1db6b94f79 Mon Sep 17 00:00:00 2001 From: David <3dgiordano@gmail.com> Date: Wed, 29 Apr 2026 16:19:56 -0300 Subject: [PATCH 4/5] Improve dataframe sql help with more hints --- tools/dataframe_manager.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/tools/dataframe_manager.py b/tools/dataframe_manager.py index 8823ee9..92b2848 100644 --- a/tools/dataframe_manager.py +++ b/tools/dataframe_manager.py @@ -644,6 +644,7 @@ def get_sql_capabilities() -> Dict[str, Any]: {"name": "STRUCT_EXTRACT", "reason": "Not recognized in this SQL context."}, {"name": "TO_JSON", "reason": "Not recognized in this SQL context."}, {"name": "TYPEOF", "reason": "Not recognized in this SQL context."}, + {"name": "generate_series", "reason": "Not recognized in this SQL context."}, ], "unsupported_or_unstable_patterns": [ "Complex chained nested access with mixed subscript and dot notation in a single expression", @@ -665,6 +666,7 @@ def get_sql_capabilities() -> Dict[str, Any]: "Try-fast: attempting the simplest path first and retrying on failure instead of reasoning through the design before executing", "Not considering all values in a nested list when searching for max/min, which can miss important extreme values", "Using only the first element of a nested list instead of aggregating over all its values", + "Using ANSI date literals (DATE '2026-03-30') inside VALUES clause", ], "query_rules": [ "CRITICAL: Before writing queries that combine 2 or more dataframes, run dataframes_schema_groups first to validate schema compatibility across all involved dataframes.", @@ -683,6 +685,9 @@ def get_sql_capabilities() -> Dict[str, Any]: "Multi-dataframe nested flow: dataframes_sql_help -> dataframes_schema_groups -> targeted dataframes_get -> staged CTE (UNNEST -> aggregate -> join-back) -> final query.", "If schema groups returns a CRITICAL variation warning, call dataframes_sql_help again immediately before writing the final query.", "Direct nested access is allowed only when each required nested column has exactly one variation across all relevant dataframes in schema groups.", + "For date literals in VALUES → always use: CAST('YYYY-MM-DD' AS DATE)", + "DATE('YYYY-MM-DD') is also supported and often cleaner", + "Never use: DATE '2026-03-30' inside VALUES", ], "nested_unnest_intro": ( "To query and aggregate data from a list of structs (e.g., override_executions), use UNNEST in a CTE to flatten the list, " @@ -722,6 +727,9 @@ def get_sql_capabilities() -> Dict[str, Any]: "Nested schema drift across tables can break field resolution", "UNION over nested struct/list columns is fragile; normalize to scalar output first", "Direct list aggregation over nested overrides is brittle; UNNEST + MAX + join-back is more reliable", + "VALUES clause is very strict: does not accept DATE '2026-03-30' literal. Must use CAST('2026-03-30' AS DATE) or DATE('2026-03-30')", + "Temporal literals inside VALUES frequently cause 'expects literals' errors", + "CAST to DATE/DATETIME is more reliable than the typed literal syntax in Polars SQL" ], "nested_query_recipe": [ "Base table CTE", From b0850ab6067b1d0935594df53bb9def7ad0430c6 Mon Sep 17 00:00:00 2001 From: David <3dgiordano@gmail.com> Date: Wed, 29 Apr 2026 16:53:00 -0300 Subject: [PATCH 5/5] Improve dataframe sql help with more hints --- tools/dataframe_manager.py | 43 +++++++++++++++++++++++++++++++++----- 1 file changed, 38 insertions(+), 5 deletions(-) diff --git a/tools/dataframe_manager.py b/tools/dataframe_manager.py index 92b2848..d07883b 100644 --- a/tools/dataframe_manager.py +++ b/tools/dataframe_manager.py @@ -18,7 +18,6 @@ import json import logging import re -import uuid from dataclasses import dataclass, asdict from datetime import datetime, UTC from typing import Any, Dict, List, Optional @@ -641,10 +640,44 @@ def get_sql_capabilities() -> Dict[str, Any]: "SUM", "TAN", "TAND", "UNNEST", "UPPER", "VARIANCE" ], "unsupported_functions": [ - {"name": "STRUCT_EXTRACT", "reason": "Not recognized in this SQL context."}, - {"name": "TO_JSON", "reason": "Not recognized in this SQL context."}, - {"name": "TYPEOF", "reason": "Not recognized in this SQL context."}, - {"name": "generate_series", "reason": "Not recognized in this SQL context."}, + "GENERATE_SERIES", + "STRING_AGG", + "GROUP_CONCAT", + "LISTAGG", + "PERCENTILE_CONT", + "PERCENTILE_DISC", + "NTILE", + "CUME_DIST", + "PERCENT_RANK", + "WIDTH_BUCKET", + "JSON_EXTRACT", + "JSON_EXTRACT_PATH", + "JSON_EXTRACT_STRING", + "TO_JSON", + "TYPEOF", + "STRUCT_EXTRACT", + "MONTHS_BETWEEN", + "MODE", + "ROLLUP", + "CUBE", + "GROUPING SETS", + ], + "limited_or_unstable_functions": [ + {"name": "LAG", + "reason": "Unstable, especially with complex CTEs or multiple windows. Avoid when possible."}, + {"name": "LEAD", + "reason": "Unstable, especially with complex CTEs or multiple windows. Avoid when possible."}, + {"name": "DATE_TRUNC", + "reason": "Partial and inconsistent support. Better to use DATE_PART combined with CAST or manual date arithmetic."}, + {"name": "RANK", + "reason": "Partial support through window functions. Results may differ from PostgreSQL/BigQuery."}, + {"name": "DENSE_RANK", + "reason": "Partial support through window functions. Results may differ from PostgreSQL/BigQuery."}, + {"name": "ROW_NUMBER", + "reason": "Works in simple cases but can be unstable with complex queries or multiple CTEs."}, + {"name": "FIRST_VALUE", "reason": "Limited support as window function."}, + {"name": "LAST_VALUE", "reason": "Limited support as window function."}, + {"name": "NTH_VALUE", "reason": "Very limited and unstable support."} ], "unsupported_or_unstable_patterns": [ "Complex chained nested access with mixed subscript and dot notation in a single expression",