Skip to content
Draft
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions config/system.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,18 @@ llm_pool:
# Add, remove or override model specific parameters
temperature: null # Removes temperature from default
max_completion_tokens: 2048 # Overrides default
# Vertex AI example with per-model region:
# judge_vertex_gemini:
# provider: vertex_ai
# model: gemini-2.0-flash
# parameters:
# vertex_location: us-central1 # Region for this model
# # vertex_project: my-gcp-project # Optional: override GCP project
# judge_vertex_llama:
# provider: vertex_ai
# model: meta/llama-3.3-70b-instruct-maas
# parameters:
# vertex_location: europe-west1 # Different region for this model

# Judge Panel: multiple judges from the pool
# Combine their scores. First judge in judges is the fallback when the full panel is not used for a metric.
Expand Down
115 changes: 96 additions & 19 deletions src/lightspeed_evaluation/core/llm/litellm_patch.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""LiteLLM configuration for token tracking and Ragas 0.4 compatibility.
"""LiteLLM configuration for token tracking, Ragas 0.4 compatibility, and Vertex AI support.

This module configures litellm for two purposes:
This module configures litellm for three purposes:

1. TOKEN TRACKING: Wraps litellm.completion, litellm.acompletion, litellm.embedding,
and litellm.aembedding to track token usage for all LLM and embedding calls.
Expand All @@ -14,14 +14,22 @@

We replace the LoggingWorker with a no-op implementation to avoid this.
This is safe because we don't use litellm's built-in observability features.

3. VERTEX AI PER-MODEL REGION SUPPORT: litellm.drop_params=True (set by
DeepEval) silently strips vertex_project and vertex_location from
completion kwargs. The completion wrappers intercept these params and
temporarily set them as litellm module-level attributes, which litellm
checks as a fallback in its vertex_ai handler.
"""

import asyncio
import logging
import os
import threading
import warnings
from contextlib import asynccontextmanager, contextmanager
from functools import wraps
from typing import Any
from typing import Any, AsyncGenerator, Generator

import litellm

Expand Down Expand Up @@ -89,6 +97,85 @@ def clear_queue(self) -> None:
litellm.suppress_debug_info = True


# =============================================================================
# GLOBAL STATE LOCK
# =============================================================================
# Single lock for ALL litellm global state mutations (cache, ssl_verify,
# vertex_project, vertex_location). Import this lock in any module that
# reads/writes litellm global state to prevent race conditions between
# concurrent pipelines. Both sync and async code paths share this lock;
# async callers use asyncio.to_thread so the event loop is never blocked.
litellm_state_lock = threading.Lock()


# =============================================================================
# VERTEX AI PER-MODEL REGION SUPPORT
# =============================================================================
# litellm.drop_params=True (set by DeepEval) silently strips vertex_project
# and vertex_location from completion kwargs. We intercept these params and
# temporarily set them as litellm module-level attributes, which litellm
# checks as a fallback in its vertex_ai handler.


@contextmanager
def _vertex_override(kwargs: dict[str, Any]) -> Generator[None, None, None]:
"""Pop vertex_project/vertex_location from kwargs and set as litellm module attrs.

Always acquires litellm_state_lock to prevent concurrent reads of partially
updated globals, even when no vertex params are present in kwargs.
"""
with litellm_state_lock:
vp = kwargs.pop("vertex_project", None)
vl = kwargs.pop("vertex_location", None)
if vp is None and vl is None:
yield
return
old_vp = getattr(litellm, "vertex_project", None)
old_vl = getattr(litellm, "vertex_location", None)
try:
if vp is not None:
litellm.vertex_project = vp
if vl is not None:
litellm.vertex_location = vl
yield
finally:
litellm.vertex_project = old_vp
litellm.vertex_location = old_vl
Comment thread
coderabbitai[bot] marked this conversation as resolved.


@asynccontextmanager
async def _vertex_override_async(
kwargs: dict[str, Any],
) -> AsyncGenerator[None, None]:
"""Async version of _vertex_override using asyncio.to_thread.

Acquires litellm_state_lock before mutating globals and holds it across the
yield so no concurrent caller can see partially-updated state. Lock
acquire/release use asyncio.to_thread to avoid blocking the event loop.
Uses the same lock as the synchronous path to prevent races between sync
and async callers.
"""
vp = kwargs.pop("vertex_project", None)
vl = kwargs.pop("vertex_location", None)
if vp is None and vl is None:
yield
return
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Async no-op path skips lock, creating race condition with sync callers.

The sync _vertex_override holds litellm_state_lock for the entire duration including the no-op case (lines 127-132), but the async version checks vp is None and vl is None before acquiring the lock and yields without it. This allows an async no-op caller to proceed while a sync caller is modifying litellm.vertex_project/litellm.vertex_location, causing the async caller's underlying litellm call to observe partially-updated or temporary override values.

Move the lock acquisition before the None check to match the sync behavior:

🔧 Proposed fix
 async def _vertex_override_async(
     kwargs: dict[str, Any],
 ) -> AsyncGenerator[None, None]:
     ...
-    vp = kwargs.pop("vertex_project", None)
-    vl = kwargs.pop("vertex_location", None)
-    if vp is None and vl is None:
-        yield
-        return
-
     await asyncio.to_thread(litellm_state_lock.acquire)
-    old_vp = getattr(litellm, "vertex_project", None)
-    old_vl = getattr(litellm, "vertex_location", None)
     try:
+        vp = kwargs.pop("vertex_project", None)
+        vl = kwargs.pop("vertex_location", None)
+        if vp is None and vl is None:
+            yield
+            return
+        old_vp = getattr(litellm, "vertex_project", None)
+        old_vl = getattr(litellm, "vertex_location", None)
         if vp is not None:
             litellm.vertex_project = vp
         if vl is not None:
             litellm.vertex_location = vl
-        yield
-    finally:
-        litellm.vertex_project = old_vp
-        litellm.vertex_location = old_vl
+        try:
+            yield
+        finally:
+            litellm.vertex_project = old_vp
+            litellm.vertex_location = old_vl
+    finally:
         await asyncio.to_thread(litellm_state_lock.release)
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@src/lightspeed_evaluation/core/llm/litellm_patch.py` around lines 158 - 162,
The async _vertex_override context manager currently checks vp/vl before
acquiring litellm_state_lock, causing a race with sync callers; change the async
implementation to acquire litellm_state_lock first (using the same async
lock/context as the sync path), then pop vp/vl and perform the None check while
holding the lock, and only yield after the lock-protected override setup so that
litellm.vertex_project and litellm.vertex_location are never observed in a
partially-updated state; ensure the lock is released correctly on exit just like
the sync _vertex_override.


await asyncio.to_thread(litellm_state_lock.acquire)
old_vp = getattr(litellm, "vertex_project", None)
old_vl = getattr(litellm, "vertex_location", None)
try:
if vp is not None:
litellm.vertex_project = vp
if vl is not None:
litellm.vertex_location = vl
yield
finally:
litellm.vertex_project = old_vp
litellm.vertex_location = old_vl
await asyncio.to_thread(litellm_state_lock.release)


# =============================================================================
# TOKEN TRACKING: Wrap completion and embedding functions
# =============================================================================
Expand All @@ -101,11 +188,11 @@ def clear_queue(self) -> None:
_original_aembedding = litellm.aembedding


# Patch litellm's completion functions to include token tracking
@wraps(_original_completion)
def _completion_with_token_tracking(*args: Any, **kwargs: Any) -> Any:
"""Wrapper around litellm.completion that tracks tokens."""
response = _original_completion(*args, **kwargs)
"""Wrapper around litellm.completion that tracks tokens and handles Vertex params."""
with _vertex_override(kwargs):
response = _original_completion(*args, **kwargs)
try:
track_judge_tokens(response)
except Exception as e: # pylint: disable=broad-exception-caught
Expand All @@ -115,16 +202,16 @@ def _completion_with_token_tracking(*args: Any, **kwargs: Any) -> Any:

@wraps(_original_acompletion)
async def _acompletion_with_token_tracking(*args: Any, **kwargs: Any) -> Any:
"""Wrapper around litellm.acompletion that tracks tokens."""
response = await _original_acompletion(*args, **kwargs)
"""Wrapper around litellm.acompletion that tracks tokens and handles Vertex params."""
async with _vertex_override_async(kwargs):
response = await _original_acompletion(*args, **kwargs)
try:
track_judge_tokens(response)
except Exception as e: # pylint: disable=broad-exception-caught
logger.exception("Failed to track tokens for acompletion: %s", e)
return response


# Patch litellm's embedding functions to include token tracking
@wraps(_original_embedding)
def _embedding_with_token_tracking(*args: Any, **kwargs: Any) -> Any:
"""Wrapper around litellm.embedding that tracks tokens."""
Expand All @@ -147,22 +234,12 @@ async def _aembedding_with_token_tracking(*args: Any, **kwargs: Any) -> Any:
return response


# Patch litellm's completion and embedding functions to include token tracking
litellm.completion = _completion_with_token_tracking
litellm.acompletion = _acompletion_with_token_tracking
litellm.embedding = _embedding_with_token_tracking
litellm.aembedding = _aembedding_with_token_tracking


# =============================================================================
# GLOBAL STATE LOCK
# =============================================================================
# Single lock for ALL litellm global state mutations (cache, ssl_verify).
# Import this lock in any module that reads/writes litellm.cache or
# litellm.ssl_verify to prevent race conditions between concurrent pipelines.
litellm_state_lock = threading.Lock()


# =============================================================================
# SSL CONFIGURATION UTILITY
# =============================================================================
Expand Down
Loading
Loading