Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1029,6 +1029,33 @@ async def create_stream(
thought: str | None = None
# Determine the content and thought based on what was collected
if full_tool_calls:
# Some providers (e.g. Gemini) send a complete JSON object per streaming
# chunk rather than incremental argument deltas. Simple concatenation then
# produces invalid JSON such as ``{}{"stock":"MSFT"}``. Detect and fix by
# keeping only the last valid complete JSON object in the arguments string.
for fc in full_tool_calls.values():
if fc.arguments:
try:
json.loads(fc.arguments)
except json.JSONDecodeError:
last_brace = fc.arguments.rfind("}")
if last_brace != -1:
depth = 0
for i in range(last_brace, -1, -1):
ch = fc.arguments[i]
if ch == "}":
depth += 1
elif ch == "{":
depth -= 1
if depth == 0:
candidate = fc.arguments[i : last_brace + 1]
try:
json.loads(candidate)
fc.arguments = candidate
except json.JSONDecodeError:
pass
break

# This is a tool call response
content = list(full_tool_calls.values())
if content_deltas:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1670,6 +1670,104 @@ async def _mock_create(*args: Any, **kwargs: Any) -> ChatCompletion | AsyncGener
assert chunks[-1].thought == "Hello Another Hello Yet Another Hello"


@pytest.mark.asyncio
async def test_tool_calling_with_stream_gemini_style_arguments(monkeypatch: pytest.MonkeyPatch) -> None:
"""Test fix for Gemini-style streaming where each chunk carries a complete JSON object.

Some providers (e.g. Gemini) emit a full JSON object per streaming chunk rather than
incremental argument deltas. Simple string concatenation produces invalid JSON such as
``{}{"input": "task"}`` which later fails to parse. The client must detect and recover
from this by keeping only the last valid complete JSON object.

Regression test for: https://github.com/microsoft/autogen/issues/6843
"""

async def _mock_create_stream_gemini(*args: Any, **kwargs: Any) -> AsyncGenerator[ChatCompletionChunk, None]:
model = resolve_model(kwargs.get("model", "gpt-4o"))
# Simulate Gemini behavior: first chunk has empty-object arguments,
# subsequent chunk has the actual arguments — both are valid standalone JSON.
chunks = [
# First tool-call chunk: empty arguments placeholder
MockChunkDefinition(
chunk_choice=ChunkChoice(
finish_reason=None,
index=0,
delta=ChoiceDelta(
content=None,
role="assistant",
tool_calls=[
ChoiceDeltaToolCall(
index=0,
id="1",
type="function",
function=ChoiceDeltaToolCallFunction(
name="_pass_function",
arguments="{}",
),
)
],
),
),
usage=None,
),
# Second tool-call chunk: actual arguments (Gemini sends full JSON again)
MockChunkDefinition(
chunk_choice=ChunkChoice(
finish_reason="tool_calls",
index=0,
delta=ChoiceDelta(
content=None,
role="assistant",
tool_calls=[
ChoiceDeltaToolCall(
index=0,
id="",
type="function",
function=ChoiceDeltaToolCallFunction(
name="",
arguments=json.dumps({"input": "task"}),
),
)
],
),
),
usage=None,
),
]
for chunk in chunks:
await asyncio.sleep(0.1)
yield ChatCompletionChunk(
id="id",
choices=[chunk.chunk_choice],
created=0,
model=model,
object="chat.completion.chunk",
usage=chunk.usage,
)

async def _mock_create(*args: Any, **kwargs: Any) -> ChatCompletion | AsyncGenerator[ChatCompletionChunk, None]:
stream = kwargs.get("stream", False)
if not stream:
raise ValueError("Stream is not False")
return _mock_create_stream_gemini(*args, **kwargs)

monkeypatch.setattr(AsyncCompletions, "create", _mock_create)

model_client = OpenAIChatCompletionClient(model="gpt-4o", api_key="")
pass_tool = FunctionTool(_pass_function, description="pass tool.")
stream = model_client.create_stream(messages=[UserMessage(content="Hello", source="user")], tools=[pass_tool])
result_chunks: List[str | CreateResult] = []
async for chunk in stream:
result_chunks.append(chunk)

final = result_chunks[-1]
assert isinstance(final, CreateResult)
# The concatenated raw string would be `{}{"input": "task"}` (invalid JSON).
# After recovery the arguments must be the last valid complete JSON object.
assert final.content == [FunctionCall(id="1", arguments='{"input": "task"}', name="_pass_function")]
assert final.finish_reason == "function_calls"


@pytest.mark.asyncio
async def test_tool_calls_assistant_message_content_field(monkeypatch: pytest.MonkeyPatch) -> None:
"""Test that AssistantMessage with tool calls includes required content field.
Expand Down