Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
128 changes: 128 additions & 0 deletions tests/test_judges_view.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
"""Tests for the Judges View capture path.

JudgeRubric should record each judge call into state["judges"], and
state_to_output should propagate that list into the RolloutOutput so it
reaches the platform without an opt-in state_columns declaration.
"""

import json

import pytest

from verifiers.rubrics.judge_rubric import JudgeRubric
from verifiers.utils.save_utils import (
make_serializable,
states_to_outputs,
)


class _FakeMessage:
def __init__(self, content: str) -> None:
self.content = content


class _FakeChoice:
def __init__(self, content: str) -> None:
self.message = _FakeMessage(content)


class _FakeResponse:
def __init__(self, content: str) -> None:
self.choices = [_FakeChoice(content)]


class _FakeChatCompletions:
def __init__(self, response_text: str) -> None:
self._response_text = response_text
self.calls: list[dict] = []

async def create(self, **kwargs):
self.calls.append(kwargs)
return _FakeResponse(self._response_text)


class _FakeChat:
def __init__(self, completions: _FakeChatCompletions) -> None:
self.completions = completions


class _FakeJudgeClient:
def __init__(self, response_text: str = "yes") -> None:
self.chat = _FakeChat(_FakeChatCompletions(response_text))


@pytest.mark.asyncio
async def test_judge_records_appended_to_state(make_state):
client = _FakeJudgeClient(response_text="yes")
rubric = JudgeRubric(judge_client=client, judge_model="fake-judge")
state = make_state(
prompt=[{"role": "user", "content": "What is 2+2?"}],
completion=[{"role": "assistant", "content": "4"}],
answer="4",
)

out = await rubric.judge(
prompt=state["prompt"],
completion=state["completion"],
answer=state["answer"],
state=state,
)
assert out == "yes"

judges = state.get("judges")
assert isinstance(judges, list) and len(judges) == 1
record = judges[0]
assert record["judge_output"] == "yes"
assert record["model"] == "fake-judge"
assert record["rubric"] == "JudgeRubric"
assert isinstance(record["judge_input"], list)
assert record["judge_input"][0]["role"] == "user"
assert "What is 2+2?" in record["judge_input"][0]["content"]


@pytest.mark.asyncio
async def test_judge_records_distinguish_named_rubrics(make_state):
correctness = JudgeRubric(
judge_client=_FakeJudgeClient("yes"),
judge_model="judge-a",
name="correctness_judge",
)
style = JudgeRubric(
judge_client=_FakeJudgeClient("no"),
judge_model="judge-b",
name="style_judge",
)
state = make_state(
prompt=[{"role": "user", "content": "Q"}],
completion=[{"role": "assistant", "content": "A"}],
answer="A",
)

await correctness.judge(state["prompt"], state["completion"], state["answer"], state)
await style.judge(state["prompt"], state["completion"], state["answer"], state)

judges = state["judges"]
assert [r["rubric"] for r in judges] == ["correctness_judge", "style_judge"]
assert [r["model"] for r in judges] == ["judge-a", "judge-b"]


@pytest.mark.asyncio
async def test_state_to_output_propagates_judges(make_state):
rubric = JudgeRubric(judge_client=_FakeJudgeClient("yes"), judge_model="fake-judge")
state = make_state(
prompt=[{"role": "user", "content": "Q"}],
completion=[{"role": "assistant", "content": "A"}],
answer="A",
)
await rubric.judge(state["prompt"], state["completion"], state["answer"], state)

output = states_to_outputs([state], state_columns=[])[0]
assert "judges" in output
serialized = json.loads(json.dumps(output, default=make_serializable))
assert serialized["judges"][0]["judge_output"] == "yes"


def test_state_to_output_omits_judges_when_absent(make_state):
state = make_state()
output = states_to_outputs([state], state_columns=[])[0]
assert "judges" not in output
31 changes: 29 additions & 2 deletions verifiers/rubrics/judge_rubric.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import time
from typing import Any

from openai import APIError, APITimeoutError, AsyncOpenAI, RateLimitError

from verifiers.parsers.parser import Parser
from verifiers.rubrics.rubric import Rubric
from verifiers.types import Messages, State
from verifiers.types import JudgeRecord, Messages, State
from verifiers.utils.async_utils import maybe_await

DEFAULT_JUDGE_PROMPT = """Given a ground truth answer \
Expand Down Expand Up @@ -37,12 +38,14 @@ def __init__(
judge_model: str = "gpt-4.1-nano",
judge_sampling_args: dict[str, Any] | None = None,
judge_prompt: str = DEFAULT_JUDGE_PROMPT,
name: str | None = None,
):
super().__init__(parser=parser)
self.judge_client = judge_client if judge_client is not None else AsyncOpenAI()
self.judge_model = judge_model
self.judge_prompt = judge_prompt
self.judge_sampling_args = judge_sampling_args or {}
self.name = name or self.__class__.__name__
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing documentation update for new user-facing features

Low Severity

This PR adds the name= constructor parameter to JudgeRubric and a new judges field on RolloutOutput, both user-facing. The existing docs in docs/environments.md (which shows JudgeRubric usage) and docs/reference.md (which lists RolloutOutput fields and mentions JudgeRubric) are not updated to reflect these additions, violating the documentation update rule.

Additional Locations (1)
Fix in Cursor Fix in Web

Triggered by project rule: BugBot Instructions

Reviewed by Cursor Bugbot for commit 8b58330. Configure here.

self.class_objects = {
"parser": self.parser,
"judge": self.judge,
Expand Down Expand Up @@ -73,7 +76,9 @@ async def judge(
)
cached = state.get("judge_response") if state else None
if isinstance(cached, dict) and judge_prompt in cached:
return cached[judge_prompt]
cached_response = cached[judge_prompt]
self._record_judge_call(state, judge_prompt, cached_response)
return cached_response
# Normalize judge sampling args for chat API
judge_args = dict(self.judge_sampling_args or {})
if "max_tokens" in judge_args:
Expand Down Expand Up @@ -138,4 +143,26 @@ async def judge(
cached = {}
cached[judge_prompt] = judge_response
state["judge_response"] = cached
self._record_judge_call(state, judge_prompt, judge_response)
return judge_response

def _record_judge_call(
self, state: State, judge_prompt: str, judge_response: str
) -> None:
"""Append a JudgeRecord to ``state["judges"]`` so the platform can render
it in the rollout view. Recorded on every call (including cache hits) so
that two rubrics sharing a prompt are still distinguishable downstream.
"""
judges = state.get("judges")
if not isinstance(judges, list):
judges = []
record: JudgeRecord = {
"judge_input": [{"role": "user", "content": judge_prompt}],
"judge_output": judge_response,
"rubric": self.name,
"model": self.judge_model,
"score": None,
"timestamp": time.time(),
}
judges.append(record)
state["judges"] = judges
20 changes: 19 additions & 1 deletion verifiers/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,6 +324,23 @@ class ErrorInfo(TypedDict):
error_chain_str: str


class JudgeRecord(TypedDict, total=False):
"""One LLM-as-judge call captured during rollout scoring.

Recorded by ``JudgeRubric.judge`` so the platform can render the exact
input/output of each judge invocation in the rollout view. Fields are all
optional except ``judge_input`` and ``judge_output`` so older callers can
omit metadata without breaking serialization.
"""

judge_input: list[dict[str, Any]] | str
judge_output: str
rubric: str
model: str
score: float | None
timestamp: float
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

JudgeRecord total=False makes all fields optional unexpectedly

Low Severity

JudgeRecord uses total=False which makes every field optional at the type level, but the docstring explicitly states that judge_input and judge_output are required. This mismatch means type checkers won't enforce the presence of these two critical fields, silently allowing incomplete records. Using Required[] from typing on those two fields (or splitting into a base TypedDict with total=True) would correctly express the intended contract.

Fix in Cursor Fix in Web

Reviewed by Cursor Bugbot for commit 8b58330. Configure here.



class RolloutOutput(dict):
"""Serialized output from a rollout (mirrors RolloutInput).

Expand All @@ -334,7 +351,7 @@ class RolloutOutput(dict):
Required fields: example_id, task, prompt, completion, reward, timing,
is_completed, is_truncated, metrics
Optional fields: answer, info, error, stop_condition, trajectory, tool_defs,
token_usage
token_usage, judges
Additional fields: arbitrary serializable state_columns
"""

Expand All @@ -356,6 +373,7 @@ class RolloutOutput(dict):
trajectory: list["TrajectoryStep"]
tool_defs: list[Tool]
token_usage: TokenUsage
judges: list[JudgeRecord]


class State(dict):
Expand Down
5 changes: 5 additions & 0 deletions verifiers/utils/save_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,11 @@ def state_to_output(
state_metrics = state.get("metrics") or {}
for k, v in state_metrics.items():
output[k] = v
# propagate judge records (LLM-as-judge calls captured by JudgeRubric).
# Auto-emitted so env authors don't need to register a state column.
judges = state.get("judges")
if judges:
output["judges"] = judges
# add state columns (must be serializable)
for col in state_columns or []:
value = state.get(col)
Expand Down
Loading