diff --git a/python/packages/autogen-ext/src/autogen_ext/models/openai/_openai_client.py b/python/packages/autogen-ext/src/autogen_ext/models/openai/_openai_client.py index a80e912534ab..802fc7dcc5de 100644 --- a/python/packages/autogen-ext/src/autogen_ext/models/openai/_openai_client.py +++ b/python/packages/autogen-ext/src/autogen_ext/models/openai/_openai_client.py @@ -1029,6 +1029,33 @@ async def create_stream( thought: str | None = None # Determine the content and thought based on what was collected if full_tool_calls: + # Some providers (e.g. Gemini) send a complete JSON object per streaming + # chunk rather than incremental argument deltas. Simple concatenation then + # produces invalid JSON such as ``{}{"stock":"MSFT"}``. Detect and fix by + # keeping only the last valid complete JSON object in the arguments string. + for fc in full_tool_calls.values(): + if fc.arguments: + try: + json.loads(fc.arguments) + except json.JSONDecodeError: + last_brace = fc.arguments.rfind("}") + if last_brace != -1: + depth = 0 + for i in range(last_brace, -1, -1): + ch = fc.arguments[i] + if ch == "}": + depth += 1 + elif ch == "{": + depth -= 1 + if depth == 0: + candidate = fc.arguments[i : last_brace + 1] + try: + json.loads(candidate) + fc.arguments = candidate + except json.JSONDecodeError: + pass + break + # This is a tool call response content = list(full_tool_calls.values()) if content_deltas: diff --git a/python/packages/autogen-ext/tests/models/test_openai_model_client.py b/python/packages/autogen-ext/tests/models/test_openai_model_client.py index ba79795d1ed7..d0c9a7b08640 100644 --- a/python/packages/autogen-ext/tests/models/test_openai_model_client.py +++ b/python/packages/autogen-ext/tests/models/test_openai_model_client.py @@ -1670,6 +1670,104 @@ async def _mock_create(*args: Any, **kwargs: Any) -> ChatCompletion | AsyncGener assert chunks[-1].thought == "Hello Another Hello Yet Another Hello" +@pytest.mark.asyncio +async def test_tool_calling_with_stream_gemini_style_arguments(monkeypatch: pytest.MonkeyPatch) -> None: + """Test fix for Gemini-style streaming where each chunk carries a complete JSON object. + + Some providers (e.g. Gemini) emit a full JSON object per streaming chunk rather than + incremental argument deltas. Simple string concatenation produces invalid JSON such as + ``{}{"input": "task"}`` which later fails to parse. The client must detect and recover + from this by keeping only the last valid complete JSON object. + + Regression test for: https://github.com/microsoft/autogen/issues/6843 + """ + + async def _mock_create_stream_gemini(*args: Any, **kwargs: Any) -> AsyncGenerator[ChatCompletionChunk, None]: + model = resolve_model(kwargs.get("model", "gpt-4o")) + # Simulate Gemini behavior: first chunk has empty-object arguments, + # subsequent chunk has the actual arguments — both are valid standalone JSON. + chunks = [ + # First tool-call chunk: empty arguments placeholder + MockChunkDefinition( + chunk_choice=ChunkChoice( + finish_reason=None, + index=0, + delta=ChoiceDelta( + content=None, + role="assistant", + tool_calls=[ + ChoiceDeltaToolCall( + index=0, + id="1", + type="function", + function=ChoiceDeltaToolCallFunction( + name="_pass_function", + arguments="{}", + ), + ) + ], + ), + ), + usage=None, + ), + # Second tool-call chunk: actual arguments (Gemini sends full JSON again) + MockChunkDefinition( + chunk_choice=ChunkChoice( + finish_reason="tool_calls", + index=0, + delta=ChoiceDelta( + content=None, + role="assistant", + tool_calls=[ + ChoiceDeltaToolCall( + index=0, + id="", + type="function", + function=ChoiceDeltaToolCallFunction( + name="", + arguments=json.dumps({"input": "task"}), + ), + ) + ], + ), + ), + usage=None, + ), + ] + for chunk in chunks: + await asyncio.sleep(0.1) + yield ChatCompletionChunk( + id="id", + choices=[chunk.chunk_choice], + created=0, + model=model, + object="chat.completion.chunk", + usage=chunk.usage, + ) + + async def _mock_create(*args: Any, **kwargs: Any) -> ChatCompletion | AsyncGenerator[ChatCompletionChunk, None]: + stream = kwargs.get("stream", False) + if not stream: + raise ValueError("Stream is not False") + return _mock_create_stream_gemini(*args, **kwargs) + + monkeypatch.setattr(AsyncCompletions, "create", _mock_create) + + model_client = OpenAIChatCompletionClient(model="gpt-4o", api_key="") + pass_tool = FunctionTool(_pass_function, description="pass tool.") + stream = model_client.create_stream(messages=[UserMessage(content="Hello", source="user")], tools=[pass_tool]) + result_chunks: List[str | CreateResult] = [] + async for chunk in stream: + result_chunks.append(chunk) + + final = result_chunks[-1] + assert isinstance(final, CreateResult) + # The concatenated raw string would be `{}{"input": "task"}` (invalid JSON). + # After recovery the arguments must be the last valid complete JSON object. + assert final.content == [FunctionCall(id="1", arguments='{"input": "task"}', name="_pass_function")] + assert final.finish_reason == "function_calls" + + @pytest.mark.asyncio async def test_tool_calls_assistant_message_content_field(monkeypatch: pytest.MonkeyPatch) -> None: """Test that AssistantMessage with tool calls includes required content field.