Skip to content
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 103 additions & 5 deletions astrbot/core/astr_agent_tool_exec.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
import asyncio
import inspect
import json
import os
import traceback
import typing as T
import uuid
from collections.abc import Sequence
from collections.abc import Set as AbstractSet

import mcp

Expand All @@ -26,6 +29,7 @@
SEND_MESSAGE_TO_USER_TOOL,
)
from astrbot.core.cron.events import CronMessageEvent
from astrbot.core.message.components import Image
from astrbot.core.message.message_event_result import (
CommandResult,
MessageChain,
Expand All @@ -35,9 +39,98 @@
from astrbot.core.provider.entites import ProviderRequest
from astrbot.core.provider.register import llm_tools
from astrbot.core.utils.history_saver import persist_agent_history
from astrbot.core.utils.string_utils import normalize_and_dedupe_strings


class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]):
Comment thread
sourcery-ai[bot] marked this conversation as resolved.
_ALLOWED_IMAGE_EXTENSIONS = {
".png",
".jpg",
".jpeg",
".gif",
".webp",
".bmp",
".tif",
".tiff",
".svg",
".heic",
}

@classmethod
def _is_supported_image_ref(cls, image_ref: str) -> bool:
if not image_ref:
return False
lowered = image_ref.lower()
if lowered.startswith(("http://", "https://", "base64://")):
return True
file_path = image_ref[8:] if lowered.startswith("file:///") else image_ref
Comment thread
sourcery-ai[bot] marked this conversation as resolved.
Outdated
ext = os.path.splitext(file_path)[1].lower()
return ext in cls._ALLOWED_IMAGE_EXTENSIONS

@classmethod
def _coerce_image_urls(cls, image_urls: T.Any) -> list[T.Any]:
if image_urls is None:
return []
if isinstance(image_urls, str):
return [image_urls]
if isinstance(image_urls, (Sequence, AbstractSet)) and not isinstance(
image_urls, (str, bytes, bytearray)
):
return list(image_urls)
logger.warning(
"Unsupported image_urls type in handoff tool args: %s",
type(image_urls).__name__,
)
return []

@classmethod
def _filter_supported_image_urls(cls, candidates: list[T.Any]) -> list[str]:
normalized = normalize_and_dedupe_strings(candidates)
sanitized = [item for item in normalized if cls._is_supported_image_ref(item)]
dropped_count = len(normalized) - len(sanitized)
if dropped_count > 0:
logger.warning(
"Dropped %d invalid image_urls entries in handoff tool args.",
dropped_count,
)
return sanitized

@classmethod
async def _iter_event_image_paths(
cls, run_context: ContextWrapper[AstrAgentContext]
) -> list[str]:
paths: list[str] = []
event = getattr(run_context.context, "event", None)
message_obj = getattr(event, "message_obj", None)
message = getattr(message_obj, "message", None)
if message:
for idx, component in enumerate(message):
if not isinstance(component, Image):
continue
try:
path = await component.convert_to_file_path()
if path and cls._is_supported_image_ref(path):
paths.append(path)
except Exception as e:
logger.error(
"Failed to convert handoff image component at index %d: %s",
idx,
e,
exc_info=True,
)
return paths

@classmethod
async def _prepare_handoff_image_urls(
cls,
run_context: ContextWrapper[AstrAgentContext],
image_urls: T.Any,
) -> list[str]:
candidates = cls._coerce_image_urls(image_urls)
event_paths = await cls._iter_event_image_paths(run_context)
candidates.extend(event_paths)
return cls._filter_supported_image_urls(candidates)

@classmethod
async def execute(cls, tool, run_context, **tool_args):
"""执行函数调用。
Expand All @@ -58,7 +151,7 @@ async def execute(cls, tool, run_context, **tool_args):
):
yield r
return
async for r in cls._execute_handoff(tool, run_context, **tool_args):
async for r in cls._execute_handoff(tool, run_context, tool_args):
yield r
return

Expand Down Expand Up @@ -161,10 +254,14 @@ async def _execute_handoff(
cls,
tool: HandoffTool,
run_context: ContextWrapper[AstrAgentContext],
**tool_args,
tool_args: dict[str, T.Any],
):
input_ = tool_args.get("input")
image_urls = tool_args.get("image_urls")
image_urls = await cls._prepare_handoff_image_urls(
run_context,
tool_args.get("image_urls"),
)
tool_args["image_urls"] = image_urls

# Build handoff toolset from registered tools plus runtime computer tools.
toolset = cls._build_handoff_toolset(run_context, tool.agent.tools)
Expand Down Expand Up @@ -263,8 +360,9 @@ async def _do_handoff_background(
) -> None:
"""Run the subagent handoff and, on completion, wake the main agent."""
result_text = ""
prepared_tool_args = dict(tool_args)
try:
async for r in cls._execute_handoff(tool, run_context, **tool_args):
async for r in cls._execute_handoff(tool, run_context, prepared_tool_args):
if isinstance(r, mcp.types.CallToolResult):
for content in r.content:
if isinstance(content, mcp.types.TextContent):
Expand All @@ -281,7 +379,7 @@ async def _do_handoff_background(
task_id=task_id,
tool_name=tool.name,
result_text=result_text,
tool_args=tool_args,
tool_args=prepared_tool_args,
note=(
event.get_extra("background_note")
or f"Background task for subagent '{tool.agent.name}' finished."
Expand Down
122 changes: 122 additions & 0 deletions tests/unit/test_astr_agent_tool_exec.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
from types import SimpleNamespace

import mcp
import pytest

from astrbot.core.agent.run_context import ContextWrapper
from astrbot.core.astr_agent_tool_exec import FunctionToolExecutor
from astrbot.core.message.components import Image


class _DummyEvent:
def __init__(self, message_components: list[object] | None = None) -> None:
self.unified_msg_origin = "webchat:FriendMessage:webchat!user!session"
self.message_obj = SimpleNamespace(message=message_components or [])

def get_extra(self, _key: str):
return None


class _DummyTool:
def __init__(self) -> None:
self.name = "transfer_to_subagent"
self.agent = SimpleNamespace(name="subagent")


def _build_run_context(message_components: list[object] | None = None):
event = _DummyEvent(message_components=message_components)
ctx = SimpleNamespace(event=event, context=SimpleNamespace())
return ContextWrapper(context=ctx)


@pytest.mark.asyncio
async def test_prepare_handoff_image_urls_normalizes_filters_and_appends_event_image(
monkeypatch: pytest.MonkeyPatch,
):
async def _fake_convert_to_file_path(self):
return "/tmp/event_image.png"

monkeypatch.setattr(Image, "convert_to_file_path", _fake_convert_to_file_path)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion (testing): 建议增加一个当 convert_to_file_path 抛出异常时的错误处理测试用例。

目前的测试只覆盖了转换成功的路径。请新增一个测试,通过 monkeypatch 让 convert_to_file_path 抛出异常(例如 RuntimeError),并验证 _prepare_handoff_image_urls

  • 本身不会抛出异常;并且
  • 会跳过将失败的图片添加到 image_urls 中。

这样可以锁定在图片转换场景下预期的容错行为。

Original comment in English

suggestion (testing): Consider a test case for error handling when convert_to_file_path raises.

Right now the tests only cover the successful conversion path. Please add a test that monkeypatches convert_to_file_path to raise (e.g., RuntimeError) and verifies that _prepare_handoff_image_urls:

  • Does not raise, and
  • Skips adding the failed image to image_urls.

This will lock in the intended failure-tolerant behavior around image conversion.


run_context = _build_run_context([Image(file="file:///tmp/original.png")])
image_urls_input = (
" https://example.com/a.png ",
"/tmp/not_an_image.txt",
"/tmp/local.webp",
123,
)

image_urls = await FunctionToolExecutor._prepare_handoff_image_urls(
run_context,
image_urls_input,
)

assert image_urls == [
"https://example.com/a.png",
"/tmp/local.webp",
"/tmp/event_image.png",
]


@pytest.mark.asyncio
async def test_prepare_handoff_image_urls_skips_failed_event_image_conversion(
monkeypatch: pytest.MonkeyPatch,
):
async def _fake_convert_to_file_path(self):
raise RuntimeError("boom")

monkeypatch.setattr(Image, "convert_to_file_path", _fake_convert_to_file_path)

run_context = _build_run_context([Image(file="file:///tmp/original.png")])
image_urls = await FunctionToolExecutor._prepare_handoff_image_urls(
run_context,
["https://example.com/a.png"],
)

assert image_urls == ["https://example.com/a.png"]


@pytest.mark.asyncio
async def test_do_handoff_background_reports_prepared_image_urls(
monkeypatch: pytest.MonkeyPatch,
):
captured: dict = {}

async def _unexpected_prepare(cls, run_context, image_urls):
raise AssertionError("background path should not pre-prepare image urls")

async def _fake_execute_handoff(cls, tool, run_context, tool_args):
tool_args["image_urls"] = ["prepared://image.png"]
yield mcp.types.CallToolResult(
content=[mcp.types.TextContent(type="text", text="ok")]
)

async def _fake_wake(cls, run_context, **kwargs):
captured.update(kwargs)

monkeypatch.setattr(
FunctionToolExecutor,
"_prepare_handoff_image_urls",
classmethod(_unexpected_prepare),
)
monkeypatch.setattr(
FunctionToolExecutor,
"_execute_handoff",
classmethod(_fake_execute_handoff),
)
monkeypatch.setattr(
FunctionToolExecutor,
"_wake_main_agent_for_background_result",
classmethod(_fake_wake),
)

run_context = _build_run_context()
await FunctionToolExecutor._do_handoff_background(
tool=_DummyTool(),
run_context=run_context,
task_id="task-id",
input="hello",
image_urls="https://example.com/raw.png",
)

assert captured["tool_args"]["image_urls"] == ["prepared://image.png"]