diff --git a/tests/test_judges_view.py b/tests/test_judges_view.py new file mode 100644 index 000000000..1a712da55 --- /dev/null +++ b/tests/test_judges_view.py @@ -0,0 +1,128 @@ +"""Tests for the Judges View capture path. + +JudgeRubric should record each judge call into state["judges"], and +state_to_output should propagate that list into the RolloutOutput so it +reaches the platform without an opt-in state_columns declaration. +""" + +import json + +import pytest + +from verifiers.rubrics.judge_rubric import JudgeRubric +from verifiers.utils.save_utils import ( + make_serializable, + states_to_outputs, +) + + +class _FakeMessage: + def __init__(self, content: str) -> None: + self.content = content + + +class _FakeChoice: + def __init__(self, content: str) -> None: + self.message = _FakeMessage(content) + + +class _FakeResponse: + def __init__(self, content: str) -> None: + self.choices = [_FakeChoice(content)] + + +class _FakeChatCompletions: + def __init__(self, response_text: str) -> None: + self._response_text = response_text + self.calls: list[dict] = [] + + async def create(self, **kwargs): + self.calls.append(kwargs) + return _FakeResponse(self._response_text) + + +class _FakeChat: + def __init__(self, completions: _FakeChatCompletions) -> None: + self.completions = completions + + +class _FakeJudgeClient: + def __init__(self, response_text: str = "yes") -> None: + self.chat = _FakeChat(_FakeChatCompletions(response_text)) + + +@pytest.mark.asyncio +async def test_judge_records_appended_to_state(make_state): + client = _FakeJudgeClient(response_text="yes") + rubric = JudgeRubric(judge_client=client, judge_model="fake-judge") + state = make_state( + prompt=[{"role": "user", "content": "What is 2+2?"}], + completion=[{"role": "assistant", "content": "4"}], + answer="4", + ) + + out = await rubric.judge( + prompt=state["prompt"], + completion=state["completion"], + answer=state["answer"], + state=state, + ) + assert out == "yes" + + judges = state.get("judges") + assert isinstance(judges, list) and len(judges) == 1 + record = judges[0] + assert record["judge_output"] == "yes" + assert record["model"] == "fake-judge" + assert record["rubric"] == "JudgeRubric" + assert isinstance(record["judge_input"], list) + assert record["judge_input"][0]["role"] == "user" + assert "What is 2+2?" in record["judge_input"][0]["content"] + + +@pytest.mark.asyncio +async def test_judge_records_distinguish_named_rubrics(make_state): + correctness = JudgeRubric( + judge_client=_FakeJudgeClient("yes"), + judge_model="judge-a", + name="correctness_judge", + ) + style = JudgeRubric( + judge_client=_FakeJudgeClient("no"), + judge_model="judge-b", + name="style_judge", + ) + state = make_state( + prompt=[{"role": "user", "content": "Q"}], + completion=[{"role": "assistant", "content": "A"}], + answer="A", + ) + + await correctness.judge(state["prompt"], state["completion"], state["answer"], state) + await style.judge(state["prompt"], state["completion"], state["answer"], state) + + judges = state["judges"] + assert [r["rubric"] for r in judges] == ["correctness_judge", "style_judge"] + assert [r["model"] for r in judges] == ["judge-a", "judge-b"] + + +@pytest.mark.asyncio +async def test_state_to_output_propagates_judges(make_state): + rubric = JudgeRubric(judge_client=_FakeJudgeClient("yes"), judge_model="fake-judge") + state = make_state( + prompt=[{"role": "user", "content": "Q"}], + completion=[{"role": "assistant", "content": "A"}], + answer="A", + ) + await rubric.judge(state["prompt"], state["completion"], state["answer"], state) + + output = states_to_outputs([state], state_columns=[])[0] + assert "judges" in output + serialized = json.loads(json.dumps(output, default=make_serializable)) + assert serialized["judges"][0]["judge_output"] == "yes" + + +def test_state_to_output_omits_judges_when_absent(make_state): + state = make_state() + output = states_to_outputs([state], state_columns=[])[0] + assert "judges" not in output diff --git a/verifiers/rubrics/judge_rubric.py b/verifiers/rubrics/judge_rubric.py index 9a8b73706..55911cca8 100644 --- a/verifiers/rubrics/judge_rubric.py +++ b/verifiers/rubrics/judge_rubric.py @@ -1,10 +1,11 @@ +import time from typing import Any from openai import APIError, APITimeoutError, AsyncOpenAI, RateLimitError from verifiers.parsers.parser import Parser from verifiers.rubrics.rubric import Rubric -from verifiers.types import Messages, State +from verifiers.types import JudgeRecord, Messages, State from verifiers.utils.async_utils import maybe_await DEFAULT_JUDGE_PROMPT = """Given a ground truth answer \ @@ -37,12 +38,14 @@ def __init__( judge_model: str = "gpt-4.1-nano", judge_sampling_args: dict[str, Any] | None = None, judge_prompt: str = DEFAULT_JUDGE_PROMPT, + name: str | None = None, ): super().__init__(parser=parser) self.judge_client = judge_client if judge_client is not None else AsyncOpenAI() self.judge_model = judge_model self.judge_prompt = judge_prompt self.judge_sampling_args = judge_sampling_args or {} + self.name = name or self.__class__.__name__ self.class_objects = { "parser": self.parser, "judge": self.judge, @@ -73,7 +76,9 @@ async def judge( ) cached = state.get("judge_response") if state else None if isinstance(cached, dict) and judge_prompt in cached: - return cached[judge_prompt] + cached_response = cached[judge_prompt] + self._record_judge_call(state, judge_prompt, cached_response) + return cached_response # Normalize judge sampling args for chat API judge_args = dict(self.judge_sampling_args or {}) if "max_tokens" in judge_args: @@ -138,4 +143,26 @@ async def judge( cached = {} cached[judge_prompt] = judge_response state["judge_response"] = cached + self._record_judge_call(state, judge_prompt, judge_response) return judge_response + + def _record_judge_call( + self, state: State, judge_prompt: str, judge_response: str + ) -> None: + """Append a JudgeRecord to ``state["judges"]`` so the platform can render + it in the rollout view. Recorded on every call (including cache hits) so + that two rubrics sharing a prompt are still distinguishable downstream. + """ + judges = state.get("judges") + if not isinstance(judges, list): + judges = [] + record: JudgeRecord = { + "judge_input": [{"role": "user", "content": judge_prompt}], + "judge_output": judge_response, + "rubric": self.name, + "model": self.judge_model, + "score": None, + "timestamp": time.time(), + } + judges.append(record) + state["judges"] = judges diff --git a/verifiers/types.py b/verifiers/types.py index 70fbc2b53..24ab2ac46 100644 --- a/verifiers/types.py +++ b/verifiers/types.py @@ -324,6 +324,23 @@ class ErrorInfo(TypedDict): error_chain_str: str +class JudgeRecord(TypedDict, total=False): + """One LLM-as-judge call captured during rollout scoring. + + Recorded by ``JudgeRubric.judge`` so the platform can render the exact + input/output of each judge invocation in the rollout view. Fields are all + optional except ``judge_input`` and ``judge_output`` so older callers can + omit metadata without breaking serialization. + """ + + judge_input: list[dict[str, Any]] | str + judge_output: str + rubric: str + model: str + score: float | None + timestamp: float + + class RolloutOutput(dict): """Serialized output from a rollout (mirrors RolloutInput). @@ -334,7 +351,7 @@ class RolloutOutput(dict): Required fields: example_id, task, prompt, completion, reward, timing, is_completed, is_truncated, metrics Optional fields: answer, info, error, stop_condition, trajectory, tool_defs, - token_usage + token_usage, judges Additional fields: arbitrary serializable state_columns """ @@ -356,6 +373,7 @@ class RolloutOutput(dict): trajectory: list["TrajectoryStep"] tool_defs: list[Tool] token_usage: TokenUsage + judges: list[JudgeRecord] class State(dict): diff --git a/verifiers/utils/save_utils.py b/verifiers/utils/save_utils.py index c3afa1234..0e9b8871f 100644 --- a/verifiers/utils/save_utils.py +++ b/verifiers/utils/save_utils.py @@ -249,6 +249,11 @@ def state_to_output( state_metrics = state.get("metrics") or {} for k, v in state_metrics.items(): output[k] = v + # propagate judge records (LLM-as-judge calls captured by JudgeRubric). + # Auto-emitted so env authors don't need to register a state column. + judges = state.get("judges") + if judges: + output["judges"] = judges # add state columns (must be serializable) for col in state_columns or []: value = state.get(col)