Skip to content
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
117 changes: 113 additions & 4 deletions astrbot/core/astr_agent_tool_exec.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
import asyncio
import inspect
import json
import os
import traceback
import typing as T
import uuid
from collections.abc import Sequence
from collections.abc import Set as AbstractSet

import mcp

Expand All @@ -26,6 +29,7 @@
SEND_MESSAGE_TO_USER_TOOL,
)
from astrbot.core.cron.events import CronMessageEvent
from astrbot.core.message.components import Image
from astrbot.core.message.message_event_result import (
CommandResult,
MessageChain,
Expand All @@ -35,9 +39,110 @@
from astrbot.core.provider.entites import ProviderRequest
from astrbot.core.provider.register import llm_tools
from astrbot.core.utils.history_saver import persist_agent_history
from astrbot.core.utils.string_utils import normalize_and_dedupe_strings


class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]):
Comment thread
sourcery-ai[bot] marked this conversation as resolved.
_ALLOWED_IMAGE_EXTENSIONS = {
".png",
".jpg",
".jpeg",
".gif",
".webp",
".bmp",
".tif",
".tiff",
".svg",
".heic",
}

@classmethod
def _is_supported_image_ref(cls, image_ref: str) -> bool:
if not image_ref:
return False
lowered = image_ref.lower()
if lowered.startswith(("http://", "https://", "base64://")):
return True
file_path = image_ref[8:] if lowered.startswith("file:///") else image_ref
Comment thread
sourcery-ai[bot] marked this conversation as resolved.
Outdated
ext = os.path.splitext(file_path)[1].lower()
if ext in cls._ALLOWED_IMAGE_EXTENSIONS:
return True
# Keep support for extension-less temp files returned by image converters.
return ext == "" and os.path.exists(file_path)

@classmethod
def _collect_image_urls_from_args(cls, image_urls_raw: T.Any) -> list[str]:
candidates: list[str] = []
if image_urls_raw is None:
pass
elif isinstance(image_urls_raw, str):
candidates.append(image_urls_raw)
elif isinstance(image_urls_raw, (Sequence, AbstractSet)) and not isinstance(
image_urls_raw, (str, bytes, bytearray)
):
non_string_count = 0
for item in image_urls_raw:
if isinstance(item, str):
candidates.append(item)
else:
non_string_count += 1
if non_string_count > 0:
logger.warning(
"Dropped %d non-string image_urls entries in handoff tool args.",
non_string_count,
)
else:
logger.warning(
"Unsupported image_urls type in handoff tool args: %s",
type(image_urls_raw).__name__,
)
return candidates

@classmethod
async def _collect_image_urls_from_message(
Comment thread
sourcery-ai[bot] marked this conversation as resolved.
cls, run_context: ContextWrapper[AstrAgentContext]
) -> list[str]:
urls: list[str] = []
event = getattr(run_context.context, "event", None)
message_obj = getattr(event, "message_obj", None)
message = getattr(message_obj, "message", None)
if message:
for idx, component in enumerate(message):
if not isinstance(component, Image):
continue
try:
path = await component.convert_to_file_path()
if path:
urls.append(path)
except Exception as e:
logger.error(
"Failed to convert handoff image component at index %d: %s",
idx,
e,
exc_info=True,
)
return urls

@classmethod
async def _collect_handoff_image_urls(
cls,
run_context: ContextWrapper[AstrAgentContext],
image_urls_raw: T.Any,
) -> list[str]:
candidates: list[str] = []
candidates.extend(cls._collect_image_urls_from_args(image_urls_raw))
candidates.extend(await cls._collect_image_urls_from_message(run_context))

normalized = normalize_and_dedupe_strings(candidates)
sanitized = [item for item in normalized if cls._is_supported_image_ref(item)]
dropped_count = len(normalized) - len(sanitized)
if dropped_count > 0:
logger.warning(
"Dropped %d invalid image_urls entries in handoff tool args.",
Comment thread
sourcery-ai[bot] marked this conversation as resolved.
Outdated
dropped_count,
)
return sanitized

@classmethod
async def execute(cls, tool, run_context, **tool_args):
"""执行函数调用。
Expand All @@ -58,7 +163,7 @@ async def execute(cls, tool, run_context, **tool_args):
):
yield r
return
async for r in cls._execute_handoff(tool, run_context, **tool_args):
async for r in cls._execute_handoff(tool, run_context, tool_args):
yield r
return

Expand Down Expand Up @@ -161,10 +266,14 @@ async def _execute_handoff(
cls,
tool: HandoffTool,
run_context: ContextWrapper[AstrAgentContext],
**tool_args,
tool_args: dict[str, T.Any],
):
input_ = tool_args.get("input")
image_urls = tool_args.get("image_urls")
image_urls = await cls._collect_handoff_image_urls(
run_context,
tool_args.get("image_urls"),
)
tool_args["image_urls"] = image_urls

# Build handoff toolset from registered tools plus runtime computer tools.
toolset = cls._build_handoff_toolset(run_context, tool.agent.tools)
Expand Down Expand Up @@ -264,7 +373,7 @@ async def _do_handoff_background(
"""Run the subagent handoff and, on completion, wake the main agent."""
result_text = ""
try:
async for r in cls._execute_handoff(tool, run_context, **tool_args):
async for r in cls._execute_handoff(tool, run_context, tool_args):
if isinstance(r, mcp.types.CallToolResult):
for content in r.content:
if isinstance(content, mcp.types.TextContent):
Expand Down
135 changes: 135 additions & 0 deletions tests/unit/test_astr_agent_tool_exec.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
from types import SimpleNamespace

import mcp
import pytest

from astrbot.core.agent.run_context import ContextWrapper
from astrbot.core.astr_agent_tool_exec import FunctionToolExecutor
from astrbot.core.message.components import Image


class _DummyEvent:
def __init__(self, message_components: list[object] | None = None) -> None:
self.unified_msg_origin = "webchat:FriendMessage:webchat!user!session"
self.message_obj = SimpleNamespace(message=message_components or [])

def get_extra(self, _key: str):
return None


class _DummyTool:
def __init__(self) -> None:
self.name = "transfer_to_subagent"
self.agent = SimpleNamespace(name="subagent")


def _build_run_context(message_components: list[object] | None = None):
event = _DummyEvent(message_components=message_components)
ctx = SimpleNamespace(event=event, context=SimpleNamespace())
return ContextWrapper(context=ctx)


@pytest.mark.asyncio
async def test_collect_handoff_image_urls_normalizes_filters_and_appends_event_image(
Comment on lines +32 to +33
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion (testing): 增加对 base64://file:/// 图像引用以及其他受支持扩展名的测试,以覆盖 _is_supported_image_ref 的行为。

当前测试覆盖了 HTTP URL、一个有效的图像扩展名、一个非图像扩展名以及无扩展名的事件文件。由于 _is_supported_image_ref 还支持 base64:// URI、file:///... 路径和其他扩展名(如 .heic.svg),请添加一个参数化测试,将这些引用的混合输入传给 _collect_handoff_image_urls(或 _collect_image_urls_from_args),并断言只返回受支持的引用。这样可以让测试与 _is_supported_image_ref 的行为更加一致,并防止在处理不同图像 scheme 和扩展名时出现回归。

建议实现:

def _build_run_context(message_components: list[object] | None = None):
    event = _DummyEvent(message_components=message_components)
    ctx = SimpleNamespace(event=event, context=SimpleNamespace())
    return ContextWrapper(context=ctx)


@pytest.mark.asyncio
async def test_collect_handoff_image_urls_normalizes_filters_and_appends_event_image(
    monkeypatch: pytest.MonkeyPatch,
):
    async def _fake_convert_to_file_path(self):
        return "/tmp/event_image.png"

    monkeypatch.setattr(Image, "convert_to_file_path", _fake_convert_to_file_path)

    run_context = _build_run_context([Image(file="file:///tmp/original.png")])
    image_urls_input = (
        " https://example.com/a.png ",
        "/tmp/not_an_image.txt",
        "/tmp/local.webp",
    )


@pytest.mark.asyncio
@pytest.mark.parametrize(
    "image_refs, expected_supported_refs",
    [
        pytest.param(
            (
                # supported HTTP(S) URL
                "https://example.com/valid.png",
                # supported base64 URI
                "base64://iVBORw0KGgoAAAANSUhEUgAAAAUA",
                # supported file:// paths with different extensions
                "file:///tmp/photo.heic",
                "file:///tmp/vector.svg",
                # unsupported refs that should be filtered out
                "file:///tmp/not-image.txt",
                "mailto:user@example.com",
                "random-string-without-scheme-or-extension",
            ),
            {
                "https://example.com/valid.png",
                "base64://iVBORw0KGgoAAAANSUhEUgAAAAUA",
                "file:///tmp/photo.heic",
                "file:///tmp/vector.svg",
            },
            id="mixed_supported_and_unsupported_schemes_and_extensions",
        ),
    ],
)
async def test_collect_handoff_image_urls_filters_supported_schemes_and_extensions(
    monkeypatch: pytest.MonkeyPatch,
    image_refs: tuple[str, ...],
    expected_supported_refs: set[str],
):
    # 确保事件图片处理是确定性的,并且不会影响过滤行为
    async def _fake_convert_to_file_path(self):
        return "/tmp/event_image.png"

    monkeypatch.setattr(Image, "convert_to_file_path", _fake_convert_to_file_path)

    # 对于该测试,不传入事件图片:我们只关注函数如何过滤引用
    run_context = _build_run_context([])

    # 调用与 handoff 图像收集测试中相同的 helper,传入混合了受支持和不受支持的图像引用。
    result = await _collect_handoff_image_urls(run_context, *image_refs)

    # 函数应只返回受支持的引用(HTTP(S)、base64://、具有受支持图像扩展名的 file://),并过滤掉其余的。
    assert expected_supported_refs.issubset(set(result))

这里有一些假设:

  1. _collect_handoff_image_urls 是该测试文件中其他地方使用的 helper,并接受 (run_context, *image_refs) 作为参数。
    如果它的签名不同(例如期望 image_urls_input=image_refs 或单个可迭代对象),请相应调整调用方式:

    • 示例:result = await _collect_handoff_image_urls(run_context, image_urls_input=image_refs)
    • 示例:result = await _collect_handoff_image_urls(run_context, image_refs)
  2. 该函数可能会追加其他 URL(例如事件图片),因此断言使用 issubset 而非严格相等。
    如果在空 run_context 的情况下可以保证函数只返回过滤后的 image_refs 而不会追加额外 URL,则可以将断言收紧为:

    • assert set(result) == expected_supported_refs
  3. 如果你的测试套件更倾向于直接使用 _collect_image_urls_from_args,可以将对 _collect_handoff_image_urls 的调用替换为 _collect_image_urls_from_args,同时保持相同的参数和断言。

Original comment in English

suggestion (testing): Add tests for base64:// and file:/// image refs plus other supported extensions to match _is_supported_image_ref behavior.

Current tests cover HTTP URLs, one valid image extension, a non-image extension, and extensionless event files. Since _is_supported_image_ref also supports base64:// URIs, file:///... paths, and other extensions like .heic and .svg, please add a parameterized test that passes a mix of these refs into _collect_handoff_image_urls (or _collect_image_urls_from_args) and asserts that only supported refs are returned. This will better align the tests with _is_supported_image_ref and guard against regressions in handling different image schemes and extensions.

Suggested implementation:

def _build_run_context(message_components: list[object] | None = None):
    event = _DummyEvent(message_components=message_components)
    ctx = SimpleNamespace(event=event, context=SimpleNamespace())
    return ContextWrapper(context=ctx)


@pytest.mark.asyncio
async def test_collect_handoff_image_urls_normalizes_filters_and_appends_event_image(
    monkeypatch: pytest.MonkeyPatch,
):
    async def _fake_convert_to_file_path(self):
        return "/tmp/event_image.png"

    monkeypatch.setattr(Image, "convert_to_file_path", _fake_convert_to_file_path)

    run_context = _build_run_context([Image(file="file:///tmp/original.png")])
    image_urls_input = (
        " https://example.com/a.png ",
        "/tmp/not_an_image.txt",
        "/tmp/local.webp",
    )


@pytest.mark.asyncio
@pytest.mark.parametrize(
    "image_refs, expected_supported_refs",
    [
        pytest.param(
            (
                # supported HTTP(S) URL
                "https://example.com/valid.png",
                # supported base64 URI
                "base64://iVBORw0KGgoAAAANSUhEUgAAAAUA",
                # supported file:// paths with different extensions
                "file:///tmp/photo.heic",
                "file:///tmp/vector.svg",
                # unsupported refs that should be filtered out
                "file:///tmp/not-image.txt",
                "mailto:user@example.com",
                "random-string-without-scheme-or-extension",
            ),
            {
                "https://example.com/valid.png",
                "base64://iVBORw0KGgoAAAANSUhEUgAAAAUA",
                "file:///tmp/photo.heic",
                "file:///tmp/vector.svg",
            },
            id="mixed_supported_and_unsupported_schemes_and_extensions",
        ),
    ],
)
async def test_collect_handoff_image_urls_filters_supported_schemes_and_extensions(
    monkeypatch: pytest.MonkeyPatch,
    image_refs: tuple[str, ...],
    expected_supported_refs: set[str],
):
    # Ensure event image handling is deterministic and does not affect filtering behavior
    async def _fake_convert_to_file_path(self):
        return "/tmp/event_image.png"

    monkeypatch.setattr(Image, "convert_to_file_path", _fake_convert_to_file_path)

    # No event images for this test: we only care about how the function filters the refs
    run_context = _build_run_context([])

    # Call the same helper used by handoff image collection tests,
    # passing in a mix of supported and unsupported image references.
    result = await _collect_handoff_image_urls(run_context, *image_refs)

    # The function should only return the supported refs (HTTP(S), base64://, file://
    # with supported image extensions) and filter out the rest.
    assert expected_supported_refs.issubset(set(result))

This change assumes:

  1. _collect_handoff_image_urls is the helper used elsewhere in this test file and accepts (run_context, *image_refs) as arguments.
    If its signature differs (e.g., it expects image_urls_input=image_refs or a single iterable), adjust the call accordingly:

    • Example: result = await _collect_handoff_image_urls(run_context, image_urls_input=image_refs) or
    • Example: result = await _collect_handoff_image_urls(run_context, image_refs).
  2. The function may append additional URLs (such as event images), so the assertion uses issubset instead of strict equality.
    If the function is guaranteed to only return filtered image_refs without extra URLs for an empty run_context, you can tighten the assertion to:

    • assert set(result) == expected_supported_refs.
  3. If your test suite prefers using _collect_image_urls_from_args directly, replace the call to _collect_handoff_image_urls with _collect_image_urls_from_args while keeping the same arguments and assertions.

monkeypatch: pytest.MonkeyPatch,
):
async def _fake_convert_to_file_path(self):
return "/tmp/event_image.png"

monkeypatch.setattr(Image, "convert_to_file_path", _fake_convert_to_file_path)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion (testing): 建议增加一个当 convert_to_file_path 抛出异常时的错误处理测试用例。

目前的测试只覆盖了转换成功的路径。请新增一个测试,通过 monkeypatch 让 convert_to_file_path 抛出异常(例如 RuntimeError),并验证 _prepare_handoff_image_urls

  • 本身不会抛出异常;并且
  • 会跳过将失败的图片添加到 image_urls 中。

这样可以锁定在图片转换场景下预期的容错行为。

Original comment in English

suggestion (testing): Consider a test case for error handling when convert_to_file_path raises.

Right now the tests only cover the successful conversion path. Please add a test that monkeypatches convert_to_file_path to raise (e.g., RuntimeError) and verifies that _prepare_handoff_image_urls:

  • Does not raise, and
  • Skips adding the failed image to image_urls.

This will lock in the intended failure-tolerant behavior around image conversion.


run_context = _build_run_context([Image(file="file:///tmp/original.png")])
image_urls_input = (
" https://example.com/a.png ",
"/tmp/not_an_image.txt",
"/tmp/local.webp",
123,
)

image_urls = await FunctionToolExecutor._collect_handoff_image_urls(
Comment on lines +41 to +49
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion (testing): 覆盖 image_urlsNone 且只应收集事件图片的场景。

当前测试总是传入非 Noneimage_urls(字符串/元组/列表)。请添加一个测试,将 image_urls_raw 设为 None,并在消息中包含一个或多个 Image 组件,以验证事件派生的图像路径仍然会被收集,并确保这种默认行为不会出现回归。

Original comment in English

suggestion (testing): Cover the case where image_urls is None and only event images should be collected.

Current tests always pass a non-None image_urls value (string/tuple/list). Please add a test where image_urls_raw is None and the message includes one or more Image components, to verify that event-derived image paths are still collected and this default behavior doesn’t regress.

run_context,
image_urls_input,
)

assert image_urls == [
"https://example.com/a.png",
"/tmp/local.webp",
"/tmp/event_image.png",
]


@pytest.mark.asyncio
async def test_collect_handoff_image_urls_skips_failed_event_image_conversion(
monkeypatch: pytest.MonkeyPatch,
):
async def _fake_convert_to_file_path(self):
raise RuntimeError("boom")

monkeypatch.setattr(Image, "convert_to_file_path", _fake_convert_to_file_path)

run_context = _build_run_context([Image(file="file:///tmp/original.png")])
image_urls = await FunctionToolExecutor._collect_handoff_image_urls(
run_context,
["https://example.com/a.png"],
)

assert image_urls == ["https://example.com/a.png"]


@pytest.mark.asyncio
async def test_do_handoff_background_reports_prepared_image_urls(
monkeypatch: pytest.MonkeyPatch,
):
captured: dict = {}

async def _fake_execute_handoff(cls, tool, run_context, tool_args):
tool_args["image_urls"] = ["https://example.com/raw.png"]
yield mcp.types.CallToolResult(
content=[mcp.types.TextContent(type="text", text="ok")]
)

async def _fake_wake(cls, run_context, **kwargs):
captured.update(kwargs)

monkeypatch.setattr(
FunctionToolExecutor,
"_execute_handoff",
classmethod(_fake_execute_handoff),
)
monkeypatch.setattr(
FunctionToolExecutor,
"_wake_main_agent_for_background_result",
classmethod(_fake_wake),
)

run_context = _build_run_context()
await FunctionToolExecutor._do_handoff_background(
tool=_DummyTool(),
run_context=run_context,
task_id="task-id",
input="hello",
image_urls="https://example.com/raw.png",
)

assert captured["tool_args"]["image_urls"] == ["https://example.com/raw.png"]


@pytest.mark.asyncio
async def test_collect_handoff_image_urls_keeps_extensionless_existing_event_file(
monkeypatch: pytest.MonkeyPatch,
):
async def _fake_convert_to_file_path(self):
return "/tmp/astrbot-handoff-image"

monkeypatch.setattr(Image, "convert_to_file_path", _fake_convert_to_file_path)
monkeypatch.setattr(
"astrbot.core.astr_agent_tool_exec.os.path.exists", lambda _: True
)
Comment on lines +228 to +240
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion (testing): 为不存在的无扩展名路径补充一个测试,断言它们会被过滤掉。

为了覆盖 _is_supported_image_ref 的反向分支,可以将 convert_to_file_path 打补丁为返回一个无扩展名路径,并让 os.path.exists 返回 False,然后断言 image_urls 为空(或者至少不包含该路径)。这样可以锁定现有与不存在的无扩展名文件之间在行为上的预期差异。

Original comment in English

suggestion (testing): Add a complementary test for extensionless paths that do not exist to assert they are filtered out.

To exercise the opposite branch of _is_supported_image_ref, patch convert_to_file_path to return an extensionless path and os.path.exists to return False, then assert that image_urls is empty (or at least does not include that path). This will lock in the intended behavior difference between existing and non-existing extensionless files.


run_context = _build_run_context([Image(file="file:///tmp/original.png")])
image_urls = await FunctionToolExecutor._collect_handoff_image_urls(
run_context,
[],
)

assert image_urls == ["/tmp/astrbot-handoff-image"]