diff --git a/packages/lmi/src/lmi/llms.py b/packages/lmi/src/lmi/llms.py index 5c43d48d..65e534ae 100644 --- a/packages/lmi/src/lmi/llms.py +++ b/packages/lmi/src/lmi/llms.py @@ -1040,6 +1040,7 @@ async def acompletion(self, messages: list[Message], **kwargs) -> list[LLMResult cost=cost, system_fingerprint=completions.system_fingerprint, reasoning_content=reasoning_content, + finish_reason=choice.finish_reason, ) ) return results @@ -1079,6 +1080,7 @@ async def acompletion_iter( role = None reasoning_content = [] used_model = None + choice = None async for completion in stream_completions: if not used_model: used_model = completion.model or self.name @@ -1106,6 +1108,11 @@ async def acompletion_iter( except Exception as e: logger.warning(f"Failed to calculate cost for {used_model}: {e}") + # Extract finish_reason from the last completion chunk + finish_reason = ( + getattr(choice, "finish_reason", None) if choice else None + ) + result = LLMResult( model=used_model, text=text, @@ -1119,6 +1126,7 @@ async def acompletion_iter( cache_read_tokens=cache_read, cache_creation_tokens=cache_creation, cost=cost, + finish_reason=finish_reason, ) if text: diff --git a/packages/lmi/src/lmi/types.py b/packages/lmi/src/lmi/types.py index d01ba526..c3ce7097 100644 --- a/packages/lmi/src/lmi/types.py +++ b/packages/lmi/src/lmi/types.py @@ -123,6 +123,10 @@ class LLMResult(BaseModel): reasoning_content: str | None = Field( default=None, description="Reasoning content from LLMs such as DeepSeek-R1." ) + finish_reason: str | None = Field( + default=None, + description="The reason the model stopped generating tokens (e.g., 'stop', 'length', 'tool_calls', 'refusal').", + ) def __str__(self) -> str: return self.text or "" diff --git a/packages/lmi/tests/test_llms.py b/packages/lmi/tests/test_llms.py index cd2bd9fd..4cd08b10 100644 --- a/packages/lmi/tests/test_llms.py +++ b/packages/lmi/tests/test_llms.py @@ -1201,10 +1201,137 @@ def mock_router_method(_self, _override_config=None): assert results.text == "I'm sorry, but I can't assist with that request." assert results.model == CommonLLMNames.GPT_41.value + assert results.finish_reason == "stop" assert "the llm request was refused" in caplog.text.lower() assert "attempting to fallback" in caplog.text.lower() +@pytest.mark.asyncio +async def test_finish_reason_stored_in_result() -> None: + """Test that finish_reason is properly stored in LLMResult for different scenarios.""" + llm = LiteLLMModel(name=CommonLLMNames.GPT_4O.value) + + messages = [Message(content="Say 'hello'")] + + # Mock the router to simulate different finish reasons + mock_router_obj = Mock() + + # Test with "stop" finish reason (normal completion) + mock_completion = Mock() + mock_message = Mock(content="hello", reasoning_content="") + mock_message.model_dump.return_value = { + "role": "assistant", + "content": "hello", + } + mock_completion.choices = [ + Mock( + finish_reason="stop", + message=mock_message, + ) + ] + mock_completion.usage = Mock(prompt_tokens=5, completion_tokens=1) + mock_completion.model = CommonLLMNames.GPT_4O.value + mock_completion.system_fingerprint = None + + mock_router_obj.acompletion = AsyncMock(return_value=mock_completion) + + def mock_router_method(_self, _override_config=None): + return mock_router_obj + + with patch.object(LiteLLMModel, "get_router", new=mock_router_method): + results = await llm.call_single(messages) + + assert results.finish_reason == "stop" + assert results.text == "hello" + + # Test with "length" finish reason (hit token limit) + mock_message_length = Mock(content="truncated text", reasoning_content="") + mock_message_length.model_dump.return_value = { + "role": "assistant", + "content": "truncated text", + } + mock_completion_length = Mock() + mock_completion_length.choices = [ + Mock( + finish_reason="length", + message=mock_message_length, + ) + ] + mock_completion_length.usage = Mock(prompt_tokens=5, completion_tokens=100) + mock_completion_length.model = CommonLLMNames.GPT_4O.value + mock_completion_length.system_fingerprint = None + + mock_router_obj.acompletion = AsyncMock(return_value=mock_completion_length) + + with patch.object(LiteLLMModel, "get_router", new=mock_router_method): + results_length = await llm.call_single(messages) + + assert results_length.finish_reason == "length" + assert results_length.text == "truncated text" + + +@pytest.mark.asyncio +async def test_finish_reason_in_streaming() -> None: + """Test that finish_reason is properly captured in streaming completions.""" + model = LiteLLMModel(name=CommonLLMNames.OPENAI_TEST.value) + messages = [Message(content="Say hello")] + + def _build_mock_completion( + delta_content: str = "", + delta_role: str = "assistant", + finish_reason: str | None = None, + usage: Any = None, + ) -> Mock: + # Create delta with spec to prevent auto-creation of attributes + mock_delta = Mock(spec=['content', 'role']) + mock_delta.content = delta_content + mock_delta.role = delta_role + + mock_choice = Mock() + mock_choice.finish_reason = finish_reason + mock_choice.logprobs = None + mock_choice.delta = mock_delta + + mock_completion = Mock() + mock_completion.model = "test-model" + mock_completion.choices = [mock_choice] + mock_completion.usage = usage + + return mock_completion + + # Mock the router to simulate streaming with finish_reason + with patch.object(model, "_router") as mock_router: + # Create mock completions - finish_reason typically only in last chunk + mock_chunk1 = _build_mock_completion(delta_content="Hello") + mock_chunk2 = _build_mock_completion(delta_content=" world") + mock_chunk_final = _build_mock_completion( + delta_content="!", + finish_reason="stop", + usage=Mock(prompt_tokens=5, completion_tokens=3), + ) + + # Create async generator + async def mock_stream(): # noqa: RUF029 + async def mock_stream_iter(): # noqa: RUF029 + yield mock_chunk1 + yield mock_chunk2 + yield mock_chunk_final + + return mock_stream_iter() + + mock_router.acompletion.return_value = mock_stream() + + # Test streaming + async_iterable = await model.acompletion_iter(messages) + results = [result async for result in async_iterable] + + # Verify finish_reason is captured + assert len(results) == 1 + result = results[0] + assert result.finish_reason == "stop" + assert result.text == "Hello world!" + + @pytest.mark.asyncio @pytest.mark.parametrize( ("model_name", "expected_tool_role_count"),