Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 31 additions & 7 deletions python/packages/kagent-adk/src/kagent/adk/_remote_a2a_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,27 +56,51 @@
logger = logging.getLogger("kagent_adk." + __name__)

_USER_ID_CONTEXT_KEY = "x-user-id"
_HEADERS_CONTEXT_KEY = "headers"
_SOURCE_HEADER = "x-kagent-source"
_SOURCE_SUBAGENT = "agent"


class _SubagentInterceptor(ClientCallInterceptor):
"""
Injects the authenticated user's ID as an ``x-user-id`` HTTP header and
marks the request as originating from an agent call via
``x-kagent-source: agent`` on every outgoing A2A request.
Injects the authenticated user's ID as an ``x-user-id`` HTTP header,
forwards the parent ``Authorization`` header when available, and marks
the request as originating from an agent call via ``x-kagent-source:
agent`` on every outgoing A2A request.

Only ``Authorization`` is promoted from parent session headers; all
other session headers remain context-only.
"""

async def intercept(self, method_name, request_payload, http_kwargs, agent_card, context):
headers = dict(http_kwargs.get("headers", {}))
# Always mark requests from a parent agent tool as subagent-originated
headers[_SOURCE_HEADER] = _SOURCE_SUBAGENT
if context and _USER_ID_CONTEXT_KEY in context.state:
headers["x-user-id"] = context.state[_USER_ID_CONTEXT_KEY]
if context:
if _USER_ID_CONTEXT_KEY in context.state:
headers["x-user-id"] = context.state[_USER_ID_CONTEXT_KEY]

request_headers = context.state.get(_HEADERS_CONTEXT_KEY, {})
if isinstance(request_headers, dict):
for key, value in request_headers.items():
if key.lower() == "authorization":
headers = {k: v for k, v in headers.items() if k.lower() != "authorization"}
headers[key] = value
break
Comment thread
towsif-rahman marked this conversation as resolved.
http_kwargs["headers"] = headers
return request_payload, http_kwargs


def _build_subagent_call_context(tool_context: ToolContext) -> ClientCallContext:
"""Build A2A call context for requests delegated to sub-agents."""
ctx_state = {_USER_ID_CONTEXT_KEY: tool_context.session.user_id}
session_state = getattr(tool_context.session, "state", {}) or {}
session_headers = session_state.get(_HEADERS_CONTEXT_KEY, {}) if isinstance(session_state, dict) else {}
if session_headers:
ctx_state[_HEADERS_CONTEXT_KEY] = session_headers
return ClientCallContext(state=ctx_state)


def _extract_text_from_task(task: Task) -> str:
"""Extract text content from a completed task's artifacts or status message."""
# Prefer artifacts (the canonical result)
Expand Down Expand Up @@ -239,7 +263,7 @@ async def _handle_first_call(self, args: dict[str, Any], tool_context: ToolConte

# Forward the authenticated user ID so the subagent session is scoped
# to the same user as the parent agent session.
call_context = ClientCallContext(state={_USER_ID_CONTEXT_KEY: tool_context.session.user_id})
call_context = _build_subagent_call_context(tool_context)

task: Optional[Task] = None
try:
Expand Down Expand Up @@ -381,7 +405,7 @@ async def _handle_resume(self, tool_context: ToolContext) -> Any:
)

client = await self._ensure_client()
call_context = ClientCallContext(state={_USER_ID_CONTEXT_KEY: tool_context.session.user_id})
call_context = _build_subagent_call_context(tool_context)
task: Optional[Task] = None
try:
async for response in client.send_message(request=decision_message, context=call_context):
Expand Down
128 changes: 126 additions & 2 deletions python/packages/kagent-adk/tests/unittests/test_remote_a2a_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from unittest.mock import AsyncMock, MagicMock, patch

import httpx
from a2a.client.middleware import ClientCallContext
from a2a.types import (
DataPart,
Role,
Expand All @@ -26,6 +27,7 @@
KAgentRemoteA2ATool,
KAgentRemoteA2AToolset,
SubagentSessionProvider,
_SubagentInterceptor,
)

# ---------------------------------------------------------------------------
Expand All @@ -38,8 +40,9 @@
class _MockSession:
"""Minimal session mock providing user_id."""

def __init__(self, user_id: str = _DEFAULT_USER_ID):
def __init__(self, user_id: str = _DEFAULT_USER_ID, state: dict[str, Any] | None = None):
self.user_id = user_id
self.state = state or {}


class MockToolContext:
Expand All @@ -49,11 +52,12 @@ def __init__(
self,
tool_confirmation: ToolConfirmation | None = None,
user_id: str = _DEFAULT_USER_ID,
session_state: dict[str, Any] | None = None,
):
self.state: dict[str, Any] = {}
self.function_call_id = "outer_fc_1"
self.tool_confirmation = tool_confirmation
self.session = _MockSession(user_id)
self.session = _MockSession(user_id, session_state)
self._confirmations: dict[str, ToolConfirmation] = {}

def request_confirmation(self, *, hint: str = "", payload: dict | None = None) -> None:
Expand Down Expand Up @@ -237,6 +241,48 @@ async def capture(*, request, context=None, **kw):

assert captured_contexts[0].state["x-user-id"] == "alice@example.com"

async def test_session_headers_forwarded_in_call_context(self):
"""The parent session's request headers are forwarded via ClientCallContext."""
tool = _make_tool()
task = _make_task(TaskState.completed, text="ok")
captured_contexts: list = []
session_headers = {
"Authorization": "Bearer user-token",
"x-not-forwarded-by-interceptor": "kept-in-context-only",
}

async def capture(*, request, context=None, **kw):
captured_contexts.append(context)
yield (task, None)

p, _ = _patch_client(tool, capture)
try:
ctx = MockToolContext(session_state={"headers": session_headers})
await tool.run_async(args={"request": "go"}, tool_context=ctx)
finally:
p.stop()

assert captured_contexts[0].state["headers"] == session_headers

async def test_call_context_omits_headers_when_session_has_none(self):
"""Missing session headers do not prevent building a valid ClientCallContext."""
tool = _make_tool()
task = _make_task(TaskState.completed, text="ok")
captured_contexts: list = []

async def capture(*, request, context=None, **kw):
captured_contexts.append(context)
yield (task, None)

p, _ = _patch_client(tool, capture)
try:
ctx = MockToolContext(user_id="alice@example.com")
await tool.run_async(args={"request": "go"}, tool_context=ctx)
finally:
p.stop()

assert captured_contexts[0].state == {"x-user-id": "alice@example.com"}


# ---------------------------------------------------------------------------
# HITL input_required tests
Expand Down Expand Up @@ -404,6 +450,84 @@ async def test_resume_input_required_chains(self):
assert ctx.function_call_id in ctx._confirmations
assert "restart_pod" in ctx._confirmations[ctx.function_call_id].hint

async def test_session_headers_forwarded_in_resume_call_context(self):
"""Resume calls also forward parent session request headers via ClientCallContext."""
tool = _make_tool()
task = _make_task(TaskState.completed, text="ok")
captured_contexts: list = []
session_headers = {"Authorization": "Bearer resumed-user-token"}

async def capture(*, request, context=None, **kw):
captured_contexts.append(context)
yield (task, None)

p, _ = _patch_client(tool, capture)
try:
ctx = _approval_ctx(
confirmed=True,
payload=_RESUME_PAYLOAD,
session_state={"headers": session_headers},
)
await tool.run_async(args={}, tool_context=ctx)
finally:
p.stop()

assert captured_contexts[0].state["headers"] == session_headers


# ---------------------------------------------------------------------------
# Subagent interceptor tests
# ---------------------------------------------------------------------------


class TestSubagentInterceptor:
async def test_forwards_authorization_from_context_headers(self):
interceptor = _SubagentInterceptor()
context = ClientCallContext(
state={
"x-user-id": "alice@example.com",
"headers": {
"Authorization": "Bearer user-token",
"x-secret-header": "should-not-forward",
},
}
)

_, http_kwargs = await interceptor.intercept(
"send_message",
{},
{"headers": {"authorization": "Bearer stale-token", "accept": "application/json"}},
None,
context,
)

assert http_kwargs["headers"]["Authorization"] == "Bearer user-token"
assert http_kwargs["headers"]["x-user-id"] == "alice@example.com"
assert http_kwargs["headers"]["x-kagent-source"] == "agent"
assert http_kwargs["headers"]["accept"] == "application/json"
assert "authorization" not in http_kwargs["headers"]
assert "x-secret-header" not in http_kwargs["headers"]

async def test_ignores_context_headers_without_authorization(self):
interceptor = _SubagentInterceptor()
context = ClientCallContext(
state={
"headers": {
"x-secret-header": "should-not-forward",
},
}
)

_, http_kwargs = await interceptor.intercept(
"send_message",
{},
{"headers": {}},
None,
context,
)

assert http_kwargs["headers"] == {"x-kagent-source": "agent"}


# ---------------------------------------------------------------------------
# Toolset lifecycle tests
Expand Down