-
Notifications
You must be signed in to change notification settings - Fork 35
feat: add per-model Vertex AI region support for judge panel #240
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 6 commits
e46d5e1
2369a0b
0a37727
b59d0a3
d9addea
8d456a4
5bb0acd
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,6 +1,6 @@ | ||
| """LiteLLM configuration for token tracking and Ragas 0.4 compatibility. | ||
| """LiteLLM configuration for token tracking, Ragas 0.4 compatibility, and Vertex AI support. | ||
|
|
||
| This module configures litellm for two purposes: | ||
| This module configures litellm for three purposes: | ||
|
|
||
| 1. TOKEN TRACKING: Wraps litellm.completion, litellm.acompletion, litellm.embedding, | ||
| and litellm.aembedding to track token usage for all LLM and embedding calls. | ||
|
|
@@ -14,14 +14,22 @@ | |
|
|
||
| We replace the LoggingWorker with a no-op implementation to avoid this. | ||
| This is safe because we don't use litellm's built-in observability features. | ||
|
|
||
| 3. VERTEX AI PER-MODEL REGION SUPPORT: litellm.drop_params=True (set by | ||
| DeepEval) silently strips vertex_project and vertex_location from | ||
| completion kwargs. The completion wrappers intercept these params and | ||
| temporarily set them as litellm module-level attributes, which litellm | ||
| checks as a fallback in its vertex_ai handler. | ||
| """ | ||
|
|
||
| import asyncio | ||
| import logging | ||
| import os | ||
| import threading | ||
| import warnings | ||
| from contextlib import asynccontextmanager, contextmanager | ||
| from functools import wraps | ||
| from typing import Any | ||
| from typing import Any, AsyncGenerator, Generator | ||
|
|
||
| import litellm | ||
|
|
||
|
|
@@ -89,6 +97,85 @@ def clear_queue(self) -> None: | |
| litellm.suppress_debug_info = True | ||
|
|
||
|
|
||
| # ============================================================================= | ||
| # GLOBAL STATE LOCK | ||
| # ============================================================================= | ||
| # Single lock for ALL litellm global state mutations (cache, ssl_verify, | ||
| # vertex_project, vertex_location). Import this lock in any module that | ||
| # reads/writes litellm global state to prevent race conditions between | ||
| # concurrent pipelines. Both sync and async code paths share this lock; | ||
| # async callers use asyncio.to_thread so the event loop is never blocked. | ||
| litellm_state_lock = threading.Lock() | ||
|
|
||
|
|
||
| # ============================================================================= | ||
| # VERTEX AI PER-MODEL REGION SUPPORT | ||
| # ============================================================================= | ||
| # litellm.drop_params=True (set by DeepEval) silently strips vertex_project | ||
| # and vertex_location from completion kwargs. We intercept these params and | ||
| # temporarily set them as litellm module-level attributes, which litellm | ||
| # checks as a fallback in its vertex_ai handler. | ||
|
|
||
|
|
||
| @contextmanager | ||
| def _vertex_override(kwargs: dict[str, Any]) -> Generator[None, None, None]: | ||
| """Pop vertex_project/vertex_location from kwargs and set as litellm module attrs. | ||
|
|
||
| Always acquires litellm_state_lock to prevent concurrent reads of partially | ||
| updated globals, even when no vertex params are present in kwargs. | ||
| """ | ||
| with litellm_state_lock: | ||
| vp = kwargs.pop("vertex_project", None) | ||
| vl = kwargs.pop("vertex_location", None) | ||
| if vp is None and vl is None: | ||
| yield | ||
| return | ||
| old_vp = getattr(litellm, "vertex_project", None) | ||
| old_vl = getattr(litellm, "vertex_location", None) | ||
| try: | ||
| if vp is not None: | ||
| litellm.vertex_project = vp | ||
| if vl is not None: | ||
| litellm.vertex_location = vl | ||
| yield | ||
| finally: | ||
| litellm.vertex_project = old_vp | ||
| litellm.vertex_location = old_vl | ||
|
|
||
|
|
||
| @asynccontextmanager | ||
| async def _vertex_override_async( | ||
| kwargs: dict[str, Any], | ||
| ) -> AsyncGenerator[None, None]: | ||
| """Async version of _vertex_override using asyncio.to_thread. | ||
|
|
||
| Acquires litellm_state_lock before mutating globals and holds it across the | ||
| yield so no concurrent caller can see partially-updated state. Lock | ||
| acquire/release use asyncio.to_thread to avoid blocking the event loop. | ||
| Uses the same lock as the synchronous path to prevent races between sync | ||
| and async callers. | ||
| """ | ||
| vp = kwargs.pop("vertex_project", None) | ||
| vl = kwargs.pop("vertex_location", None) | ||
| if vp is None and vl is None: | ||
| yield | ||
| return | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Async no-op path skips lock, creating race condition with sync callers. The sync Move the lock acquisition before the None check to match the sync behavior: 🔧 Proposed fix async def _vertex_override_async(
kwargs: dict[str, Any],
) -> AsyncGenerator[None, None]:
...
- vp = kwargs.pop("vertex_project", None)
- vl = kwargs.pop("vertex_location", None)
- if vp is None and vl is None:
- yield
- return
-
await asyncio.to_thread(litellm_state_lock.acquire)
- old_vp = getattr(litellm, "vertex_project", None)
- old_vl = getattr(litellm, "vertex_location", None)
try:
+ vp = kwargs.pop("vertex_project", None)
+ vl = kwargs.pop("vertex_location", None)
+ if vp is None and vl is None:
+ yield
+ return
+ old_vp = getattr(litellm, "vertex_project", None)
+ old_vl = getattr(litellm, "vertex_location", None)
if vp is not None:
litellm.vertex_project = vp
if vl is not None:
litellm.vertex_location = vl
- yield
- finally:
- litellm.vertex_project = old_vp
- litellm.vertex_location = old_vl
+ try:
+ yield
+ finally:
+ litellm.vertex_project = old_vp
+ litellm.vertex_location = old_vl
+ finally:
await asyncio.to_thread(litellm_state_lock.release)🤖 Prompt for AI Agents |
||
|
|
||
| await asyncio.to_thread(litellm_state_lock.acquire) | ||
| old_vp = getattr(litellm, "vertex_project", None) | ||
| old_vl = getattr(litellm, "vertex_location", None) | ||
| try: | ||
| if vp is not None: | ||
| litellm.vertex_project = vp | ||
| if vl is not None: | ||
| litellm.vertex_location = vl | ||
| yield | ||
| finally: | ||
| litellm.vertex_project = old_vp | ||
| litellm.vertex_location = old_vl | ||
| await asyncio.to_thread(litellm_state_lock.release) | ||
|
|
||
|
|
||
| # ============================================================================= | ||
| # TOKEN TRACKING: Wrap completion and embedding functions | ||
| # ============================================================================= | ||
|
|
@@ -101,11 +188,11 @@ def clear_queue(self) -> None: | |
| _original_aembedding = litellm.aembedding | ||
|
|
||
|
|
||
| # Patch litellm's completion functions to include token tracking | ||
| @wraps(_original_completion) | ||
| def _completion_with_token_tracking(*args: Any, **kwargs: Any) -> Any: | ||
| """Wrapper around litellm.completion that tracks tokens.""" | ||
| response = _original_completion(*args, **kwargs) | ||
| """Wrapper around litellm.completion that tracks tokens and handles Vertex params.""" | ||
| with _vertex_override(kwargs): | ||
| response = _original_completion(*args, **kwargs) | ||
| try: | ||
| track_judge_tokens(response) | ||
| except Exception as e: # pylint: disable=broad-exception-caught | ||
|
|
@@ -115,16 +202,16 @@ def _completion_with_token_tracking(*args: Any, **kwargs: Any) -> Any: | |
|
|
||
| @wraps(_original_acompletion) | ||
| async def _acompletion_with_token_tracking(*args: Any, **kwargs: Any) -> Any: | ||
| """Wrapper around litellm.acompletion that tracks tokens.""" | ||
| response = await _original_acompletion(*args, **kwargs) | ||
| """Wrapper around litellm.acompletion that tracks tokens and handles Vertex params.""" | ||
| async with _vertex_override_async(kwargs): | ||
| response = await _original_acompletion(*args, **kwargs) | ||
| try: | ||
| track_judge_tokens(response) | ||
| except Exception as e: # pylint: disable=broad-exception-caught | ||
| logger.exception("Failed to track tokens for acompletion: %s", e) | ||
| return response | ||
|
|
||
|
|
||
| # Patch litellm's embedding functions to include token tracking | ||
| @wraps(_original_embedding) | ||
| def _embedding_with_token_tracking(*args: Any, **kwargs: Any) -> Any: | ||
| """Wrapper around litellm.embedding that tracks tokens.""" | ||
|
|
@@ -147,22 +234,12 @@ async def _aembedding_with_token_tracking(*args: Any, **kwargs: Any) -> Any: | |
| return response | ||
|
|
||
|
|
||
| # Patch litellm's completion and embedding functions to include token tracking | ||
| litellm.completion = _completion_with_token_tracking | ||
| litellm.acompletion = _acompletion_with_token_tracking | ||
| litellm.embedding = _embedding_with_token_tracking | ||
| litellm.aembedding = _aembedding_with_token_tracking | ||
|
|
||
|
|
||
| # ============================================================================= | ||
| # GLOBAL STATE LOCK | ||
| # ============================================================================= | ||
| # Single lock for ALL litellm global state mutations (cache, ssl_verify). | ||
| # Import this lock in any module that reads/writes litellm.cache or | ||
| # litellm.ssl_verify to prevent race conditions between concurrent pipelines. | ||
| litellm_state_lock = threading.Lock() | ||
|
|
||
|
|
||
| # ============================================================================= | ||
| # SSL CONFIGURATION UTILITY | ||
| # ============================================================================= | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.