-
Notifications
You must be signed in to change notification settings - Fork 35
feat: add per-model Vertex AI region support for judge panel #240
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Draft
nvola
wants to merge
7
commits into
lightspeed-core:main
Choose a base branch
from
nvola:vertex-per-model-region
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Draft
Changes from 1 commit
Commits
Show all changes
7 commits
Select commit
Hold shift + click to select a range
e46d5e1
feat: add per-model Vertex AI region support for judge panel
nvola 2369a0b
fix: close race condition and event loop blocking in _vertex_override
nvola 0a37727
Consolidate vertex override locks into single threading.Lock
nvola b59d0a3
fix: hold lock across yield in _vertex_override_async to close race w…
nvola d9addea
test: add timeout to asyncio.gather in deadlock detection test
nvola 8d456a4
test: assert lock release in sync vertex override tests
nvola 5bb0acd
fix: acquire lock before popping kwargs in _vertex_override_async to …
nvola File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,199 @@ | ||
| """Unit tests for litellm_patch vertex override support.""" | ||
|
|
||
| from typing import Any, Callable | ||
|
|
||
| import pytest | ||
| from pytest_mock import MockerFixture | ||
| import litellm | ||
|
|
||
| from lightspeed_evaluation.core.llm import litellm_patch | ||
| from lightspeed_evaluation.core.llm.litellm_patch import _vertex_override | ||
|
|
||
|
|
||
| class TestVertexOverrideContextManager: | ||
| """Tests for the _vertex_override context manager.""" | ||
|
|
||
| def test_no_vertex_params_is_noop(self) -> None: | ||
| """Test that _vertex_override is a no-op when no vertex params present.""" | ||
| kwargs: dict[str, Any] = {"model": "gpt-4", "temperature": 0.5} | ||
| original_kwargs = dict(kwargs) | ||
|
|
||
| with _vertex_override(kwargs): | ||
| pass | ||
|
|
||
| assert kwargs == original_kwargs | ||
|
|
||
| def test_vertex_location_set_and_restored(self) -> None: | ||
| """Test that vertex_location is set on litellm module and restored after.""" | ||
| old_value = getattr(litellm, "vertex_location", None) | ||
| kwargs: dict[str, Any] = {"vertex_location": "us-central1"} | ||
|
|
||
| with _vertex_override(kwargs): | ||
| assert litellm.vertex_location == "us-central1" | ||
| assert "vertex_location" not in kwargs | ||
|
|
||
| assert getattr(litellm, "vertex_location", None) == old_value | ||
|
|
||
| def test_vertex_project_set_and_restored(self) -> None: | ||
| """Test that vertex_project is set on litellm module and restored after.""" | ||
| old_value = getattr(litellm, "vertex_project", None) | ||
| kwargs: dict[str, Any] = {"vertex_project": "my-project"} | ||
|
|
||
| with _vertex_override(kwargs): | ||
| assert litellm.vertex_project == "my-project" | ||
| assert "vertex_project" not in kwargs | ||
|
|
||
| assert getattr(litellm, "vertex_project", None) == old_value | ||
|
|
||
| def test_both_params_set_and_restored(self) -> None: | ||
| """Test that both vertex params are set and restored.""" | ||
| old_location = getattr(litellm, "vertex_location", None) | ||
| old_project = getattr(litellm, "vertex_project", None) | ||
| kwargs: dict[str, Any] = { | ||
| "vertex_location": "europe-west1", | ||
| "vertex_project": "my-project", | ||
| "temperature": 0.5, | ||
| } | ||
|
|
||
| with _vertex_override(kwargs): | ||
| assert litellm.vertex_location == "europe-west1" | ||
| assert litellm.vertex_project == "my-project" | ||
| assert "vertex_location" not in kwargs | ||
| assert "vertex_project" not in kwargs | ||
| assert kwargs == {"temperature": 0.5} | ||
|
|
||
| assert getattr(litellm, "vertex_location", None) == old_location | ||
| assert getattr(litellm, "vertex_project", None) == old_project | ||
|
|
||
| def test_params_restored_on_exception(self) -> None: | ||
| """Test that vertex params are restored even when an exception occurs.""" | ||
| old_location = getattr(litellm, "vertex_location", None) | ||
| kwargs: dict[str, Any] = {"vertex_location": "us-east1"} | ||
|
|
||
| with pytest.raises(ValueError, match="test error"): | ||
| with _vertex_override(kwargs): | ||
| assert litellm.vertex_location == "us-east1" | ||
| raise ValueError("test error") | ||
|
|
||
| assert getattr(litellm, "vertex_location", None) == old_location | ||
|
|
||
| def test_no_lock_acquired_without_vertex_params( | ||
| self, mocker: MockerFixture | ||
| ) -> None: | ||
| """Test that the lock is not acquired when no vertex params are present.""" | ||
| mock_lock = mocker.patch.object(litellm_patch, "litellm_state_lock") | ||
| kwargs: dict[str, Any] = {"temperature": 0.5} | ||
|
|
||
| with _vertex_override(kwargs): | ||
| pass | ||
|
|
||
| mock_lock.__enter__.assert_not_called() | ||
|
|
||
| def test_lock_acquired_with_vertex_params(self, mocker: MockerFixture) -> None: | ||
| """Test that the lock is acquired when vertex params are present.""" | ||
| mock_lock = mocker.MagicMock() | ||
| mocker.patch.object(litellm_patch, "litellm_state_lock", mock_lock) | ||
| kwargs: dict[str, Any] = {"vertex_location": "us-central1"} | ||
|
|
||
| with _vertex_override(kwargs): | ||
| pass | ||
|
|
||
| mock_lock.__enter__.assert_called_once() | ||
|
|
||
|
|
||
| class TestCompletionWithVertexOverride: | ||
| """Test litellm.completion integration with vertex override.""" | ||
|
|
||
| def test_completion_with_vertex_location( | ||
| self, mocker: MockerFixture, mock_judge_llm_response: Callable[..., Any] | ||
| ) -> None: | ||
| """Test that vertex_location is handled during completion calls.""" | ||
| mock_completion = mocker.patch(f"{litellm_patch.__name__}._original_completion") | ||
| mock_completion.return_value = mock_judge_llm_response( | ||
| prompt_tokens=10, completion_tokens=5, cache_hit=False, content="ok" | ||
| ) | ||
|
|
||
| old_location = getattr(litellm, "vertex_location", None) | ||
|
|
||
| litellm.completion( | ||
| model="vertex_ai/gemini-pro", | ||
| messages=[{"role": "user", "content": "test"}], | ||
| vertex_location="us-central1", | ||
| ) | ||
|
|
||
| mock_completion.assert_called_once() | ||
| call_kwargs = mock_completion.call_args[1] | ||
| assert "vertex_location" not in call_kwargs | ||
| assert getattr(litellm, "vertex_location", None) == old_location | ||
|
|
||
| def test_completion_without_vertex_params_unchanged( | ||
| self, mocker: MockerFixture, mock_judge_llm_response: Callable[..., Any] | ||
| ) -> None: | ||
| """Test that completion works normally without vertex params.""" | ||
| mock_completion = mocker.patch(f"{litellm_patch.__name__}._original_completion") | ||
| mock_completion.return_value = mock_judge_llm_response( | ||
| prompt_tokens=10, completion_tokens=5, cache_hit=False, content="ok" | ||
| ) | ||
|
|
||
| litellm.completion( | ||
| model="gpt-4", | ||
| messages=[{"role": "user", "content": "test"}], | ||
| temperature=0.5, | ||
| ) | ||
|
|
||
| call_kwargs = mock_completion.call_args[1] | ||
| assert call_kwargs["temperature"] == 0.5 | ||
| assert call_kwargs["model"] == "gpt-4" | ||
|
|
||
| @pytest.mark.asyncio | ||
| async def test_acompletion_with_vertex_location( | ||
| self, mocker: MockerFixture, mock_judge_llm_response: Callable[..., Any] | ||
| ) -> None: | ||
| """Test that vertex_location is handled during async completion calls.""" | ||
| mock_acompletion = mocker.patch( | ||
| f"{litellm_patch.__name__}._original_acompletion" | ||
| ) | ||
| mock_acompletion.return_value = mock_judge_llm_response( | ||
| prompt_tokens=10, completion_tokens=5, cache_hit=False, content="ok" | ||
| ) | ||
|
|
||
| old_location = getattr(litellm, "vertex_location", None) | ||
|
|
||
| await litellm.acompletion( | ||
| model="vertex_ai/gemini-pro", | ||
| messages=[{"role": "user", "content": "test"}], | ||
| vertex_location="europe-west1", | ||
| ) | ||
|
|
||
| mock_acompletion.assert_called_once() | ||
| call_kwargs = mock_acompletion.call_args[1] | ||
| assert "vertex_location" not in call_kwargs | ||
| assert getattr(litellm, "vertex_location", None) == old_location | ||
|
|
||
| @pytest.mark.asyncio | ||
| async def test_acompletion_with_both_vertex_params( | ||
| self, mocker: MockerFixture, mock_judge_llm_response: Callable[..., Any] | ||
| ) -> None: | ||
| """Test that both vertex params are handled during async completion.""" | ||
| mock_acompletion = mocker.patch( | ||
| f"{litellm_patch.__name__}._original_acompletion" | ||
| ) | ||
| mock_acompletion.return_value = mock_judge_llm_response( | ||
| prompt_tokens=10, completion_tokens=5, cache_hit=False, content="ok" | ||
| ) | ||
|
|
||
| old_location = getattr(litellm, "vertex_location", None) | ||
| old_project = getattr(litellm, "vertex_project", None) | ||
|
|
||
| await litellm.acompletion( | ||
| model="vertex_ai/gemini-pro", | ||
| messages=[{"role": "user", "content": "test"}], | ||
| vertex_location="us-central1", | ||
| vertex_project="my-project", | ||
| ) | ||
|
|
||
| call_kwargs = mock_acompletion.call_args[1] | ||
| assert "vertex_location" not in call_kwargs | ||
| assert "vertex_project" not in call_kwargs | ||
| assert getattr(litellm, "vertex_location", None) == old_location | ||
| assert getattr(litellm, "vertex_project", None) == old_project |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.