Skip to content
Draft
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions config/system.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,18 @@ llm_pool:
# Add, remove or override model specific parameters
temperature: null # Removes temperature from default
max_completion_tokens: 2048 # Overrides default
# Vertex AI example with per-model region:
# judge_vertex_gemini:
# provider: vertex_ai
# model: gemini-2.0-flash
# parameters:
# vertex_location: us-central1 # Region for this model
# # vertex_project: my-gcp-project # Optional: override GCP project
# judge_vertex_llama:
# provider: vertex_ai
# model: meta/llama-3.3-70b-instruct-maas
# parameters:
# vertex_location: europe-west1 # Different region for this model

# Judge Panel: multiple judges from the pool
# Combine their scores. First judge in judges is the fallback when the full panel is not used for a metric.
Expand Down
79 changes: 60 additions & 19 deletions src/lightspeed_evaluation/core/llm/litellm_patch.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""LiteLLM configuration for token tracking and Ragas 0.4 compatibility.
"""LiteLLM configuration for token tracking, Ragas 0.4 compatibility, and Vertex AI support.

This module configures litellm for two purposes:
This module configures litellm for three purposes:

1. TOKEN TRACKING: Wraps litellm.completion, litellm.acompletion, litellm.embedding,
and litellm.aembedding to track token usage for all LLM and embedding calls.
Expand All @@ -14,14 +14,21 @@

We replace the LoggingWorker with a no-op implementation to avoid this.
This is safe because we don't use litellm's built-in observability features.

3. VERTEX AI PER-MODEL REGION SUPPORT: litellm.drop_params=True (set by
DeepEval) silently strips vertex_project and vertex_location from
completion kwargs. The completion wrappers intercept these params and
temporarily set them as litellm module-level attributes, which litellm
checks as a fallback in its vertex_ai handler.
"""

import logging
import os
import threading
import warnings
from contextlib import contextmanager
from functools import wraps
from typing import Any
from typing import Any, Generator

import litellm

Expand Down Expand Up @@ -89,6 +96,50 @@ def clear_queue(self) -> None:
litellm.suppress_debug_info = True


# =============================================================================
# GLOBAL STATE LOCK
# =============================================================================
# Single lock for ALL litellm global state mutations (cache, ssl_verify,
# vertex_project, vertex_location). Import this lock in any module that
# reads/writes litellm global state to prevent race conditions between
# concurrent pipelines.
litellm_state_lock = threading.Lock()


# =============================================================================
# VERTEX AI PER-MODEL REGION SUPPORT
# =============================================================================
# litellm.drop_params=True (set by DeepEval) silently strips vertex_project
# and vertex_location from completion kwargs. We intercept these params and
# temporarily set them as litellm module-level attributes, which litellm
# checks as a fallback in its vertex_ai handler.


@contextmanager
def _vertex_override(kwargs: dict[str, Any]) -> Generator[None, None, None]:
"""Pop vertex_project/vertex_location from kwargs and set as litellm module attrs.

When neither key is present the context manager is a no-op (no lock acquired).
"""
vp = kwargs.pop("vertex_project", None)
vl = kwargs.pop("vertex_location", None)
if vp is None and vl is None:
yield
return
with litellm_state_lock:
old_vp = getattr(litellm, "vertex_project", None)
old_vl = getattr(litellm, "vertex_location", None)
try:
if vp is not None:
litellm.vertex_project = vp
if vl is not None:
litellm.vertex_location = vl
yield
finally:
litellm.vertex_project = old_vp
litellm.vertex_location = old_vl
Comment thread
coderabbitai[bot] marked this conversation as resolved.


# =============================================================================
# TOKEN TRACKING: Wrap completion and embedding functions
# =============================================================================
Expand All @@ -101,11 +152,11 @@ def clear_queue(self) -> None:
_original_aembedding = litellm.aembedding


# Patch litellm's completion functions to include token tracking
@wraps(_original_completion)
def _completion_with_token_tracking(*args: Any, **kwargs: Any) -> Any:
"""Wrapper around litellm.completion that tracks tokens."""
response = _original_completion(*args, **kwargs)
"""Wrapper around litellm.completion that tracks tokens and handles Vertex params."""
with _vertex_override(kwargs):
response = _original_completion(*args, **kwargs)
try:
track_judge_tokens(response)
except Exception as e: # pylint: disable=broad-exception-caught
Expand All @@ -115,16 +166,16 @@ def _completion_with_token_tracking(*args: Any, **kwargs: Any) -> Any:

@wraps(_original_acompletion)
async def _acompletion_with_token_tracking(*args: Any, **kwargs: Any) -> Any:
"""Wrapper around litellm.acompletion that tracks tokens."""
response = await _original_acompletion(*args, **kwargs)
"""Wrapper around litellm.acompletion that tracks tokens and handles Vertex params."""
with _vertex_override(kwargs):
response = await _original_acompletion(*args, **kwargs)
try:
track_judge_tokens(response)
except Exception as e: # pylint: disable=broad-exception-caught
logger.exception("Failed to track tokens for acompletion: %s", e)
return response


# Patch litellm's embedding functions to include token tracking
@wraps(_original_embedding)
def _embedding_with_token_tracking(*args: Any, **kwargs: Any) -> Any:
"""Wrapper around litellm.embedding that tracks tokens."""
Expand All @@ -147,22 +198,12 @@ async def _aembedding_with_token_tracking(*args: Any, **kwargs: Any) -> Any:
return response


# Patch litellm's completion and embedding functions to include token tracking
litellm.completion = _completion_with_token_tracking
litellm.acompletion = _acompletion_with_token_tracking
litellm.embedding = _embedding_with_token_tracking
litellm.aembedding = _aembedding_with_token_tracking


# =============================================================================
# GLOBAL STATE LOCK
# =============================================================================
# Single lock for ALL litellm global state mutations (cache, ssl_verify).
# Import this lock in any module that reads/writes litellm.cache or
# litellm.ssl_verify to prevent race conditions between concurrent pipelines.
litellm_state_lock = threading.Lock()


# =============================================================================
# SSL CONFIGURATION UTILITY
# =============================================================================
Expand Down
199 changes: 199 additions & 0 deletions tests/unit/core/llm/test_litellm_patch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,199 @@
"""Unit tests for litellm_patch vertex override support."""

from typing import Any, Callable

import pytest
from pytest_mock import MockerFixture
import litellm

from lightspeed_evaluation.core.llm import litellm_patch
from lightspeed_evaluation.core.llm.litellm_patch import _vertex_override


class TestVertexOverrideContextManager:
"""Tests for the _vertex_override context manager."""

def test_no_vertex_params_is_noop(self) -> None:
"""Test that _vertex_override is a no-op when no vertex params present."""
kwargs: dict[str, Any] = {"model": "gpt-4", "temperature": 0.5}
original_kwargs = dict(kwargs)

with _vertex_override(kwargs):
pass

assert kwargs == original_kwargs

def test_vertex_location_set_and_restored(self) -> None:
"""Test that vertex_location is set on litellm module and restored after."""
old_value = getattr(litellm, "vertex_location", None)
kwargs: dict[str, Any] = {"vertex_location": "us-central1"}

with _vertex_override(kwargs):
assert litellm.vertex_location == "us-central1"
assert "vertex_location" not in kwargs

assert getattr(litellm, "vertex_location", None) == old_value

def test_vertex_project_set_and_restored(self) -> None:
"""Test that vertex_project is set on litellm module and restored after."""
old_value = getattr(litellm, "vertex_project", None)
kwargs: dict[str, Any] = {"vertex_project": "my-project"}

with _vertex_override(kwargs):
assert litellm.vertex_project == "my-project"
assert "vertex_project" not in kwargs

assert getattr(litellm, "vertex_project", None) == old_value

def test_both_params_set_and_restored(self) -> None:
"""Test that both vertex params are set and restored."""
old_location = getattr(litellm, "vertex_location", None)
old_project = getattr(litellm, "vertex_project", None)
kwargs: dict[str, Any] = {
"vertex_location": "europe-west1",
"vertex_project": "my-project",
"temperature": 0.5,
}

with _vertex_override(kwargs):
assert litellm.vertex_location == "europe-west1"
assert litellm.vertex_project == "my-project"
assert "vertex_location" not in kwargs
assert "vertex_project" not in kwargs
assert kwargs == {"temperature": 0.5}

assert getattr(litellm, "vertex_location", None) == old_location
assert getattr(litellm, "vertex_project", None) == old_project

def test_params_restored_on_exception(self) -> None:
"""Test that vertex params are restored even when an exception occurs."""
old_location = getattr(litellm, "vertex_location", None)
kwargs: dict[str, Any] = {"vertex_location": "us-east1"}

with pytest.raises(ValueError, match="test error"):
with _vertex_override(kwargs):
assert litellm.vertex_location == "us-east1"
raise ValueError("test error")

assert getattr(litellm, "vertex_location", None) == old_location

def test_no_lock_acquired_without_vertex_params(
self, mocker: MockerFixture
) -> None:
"""Test that the lock is not acquired when no vertex params are present."""
mock_lock = mocker.patch.object(litellm_patch, "litellm_state_lock")
kwargs: dict[str, Any] = {"temperature": 0.5}

with _vertex_override(kwargs):
pass

mock_lock.__enter__.assert_not_called()

def test_lock_acquired_with_vertex_params(self, mocker: MockerFixture) -> None:
"""Test that the lock is acquired when vertex params are present."""
mock_lock = mocker.MagicMock()
mocker.patch.object(litellm_patch, "litellm_state_lock", mock_lock)
kwargs: dict[str, Any] = {"vertex_location": "us-central1"}

with _vertex_override(kwargs):
pass

mock_lock.__enter__.assert_called_once()


class TestCompletionWithVertexOverride:
"""Test litellm.completion integration with vertex override."""

def test_completion_with_vertex_location(
self, mocker: MockerFixture, mock_judge_llm_response: Callable[..., Any]
) -> None:
"""Test that vertex_location is handled during completion calls."""
mock_completion = mocker.patch(f"{litellm_patch.__name__}._original_completion")
mock_completion.return_value = mock_judge_llm_response(
prompt_tokens=10, completion_tokens=5, cache_hit=False, content="ok"
)

old_location = getattr(litellm, "vertex_location", None)

litellm.completion(
model="vertex_ai/gemini-pro",
messages=[{"role": "user", "content": "test"}],
vertex_location="us-central1",
)

mock_completion.assert_called_once()
call_kwargs = mock_completion.call_args[1]
assert "vertex_location" not in call_kwargs
assert getattr(litellm, "vertex_location", None) == old_location

def test_completion_without_vertex_params_unchanged(
self, mocker: MockerFixture, mock_judge_llm_response: Callable[..., Any]
) -> None:
"""Test that completion works normally without vertex params."""
mock_completion = mocker.patch(f"{litellm_patch.__name__}._original_completion")
mock_completion.return_value = mock_judge_llm_response(
prompt_tokens=10, completion_tokens=5, cache_hit=False, content="ok"
)

litellm.completion(
model="gpt-4",
messages=[{"role": "user", "content": "test"}],
temperature=0.5,
)

call_kwargs = mock_completion.call_args[1]
assert call_kwargs["temperature"] == 0.5
assert call_kwargs["model"] == "gpt-4"

@pytest.mark.asyncio
async def test_acompletion_with_vertex_location(
self, mocker: MockerFixture, mock_judge_llm_response: Callable[..., Any]
) -> None:
"""Test that vertex_location is handled during async completion calls."""
mock_acompletion = mocker.patch(
f"{litellm_patch.__name__}._original_acompletion"
)
mock_acompletion.return_value = mock_judge_llm_response(
prompt_tokens=10, completion_tokens=5, cache_hit=False, content="ok"
)

old_location = getattr(litellm, "vertex_location", None)

await litellm.acompletion(
model="vertex_ai/gemini-pro",
messages=[{"role": "user", "content": "test"}],
vertex_location="europe-west1",
)

mock_acompletion.assert_called_once()
call_kwargs = mock_acompletion.call_args[1]
assert "vertex_location" not in call_kwargs
assert getattr(litellm, "vertex_location", None) == old_location

@pytest.mark.asyncio
async def test_acompletion_with_both_vertex_params(
self, mocker: MockerFixture, mock_judge_llm_response: Callable[..., Any]
) -> None:
"""Test that both vertex params are handled during async completion."""
mock_acompletion = mocker.patch(
f"{litellm_patch.__name__}._original_acompletion"
)
mock_acompletion.return_value = mock_judge_llm_response(
prompt_tokens=10, completion_tokens=5, cache_hit=False, content="ok"
)

old_location = getattr(litellm, "vertex_location", None)
old_project = getattr(litellm, "vertex_project", None)

await litellm.acompletion(
model="vertex_ai/gemini-pro",
messages=[{"role": "user", "content": "test"}],
vertex_location="us-central1",
vertex_project="my-project",
)

call_kwargs = mock_acompletion.call_args[1]
assert "vertex_location" not in call_kwargs
assert "vertex_project" not in call_kwargs
assert getattr(litellm, "vertex_location", None) == old_location
assert getattr(litellm, "vertex_project", None) == old_project
Loading