diff --git a/tests/test_environment_extra.py b/tests/test_environment_extra.py index 196ad8737..3ad859e08 100644 --- a/tests/test_environment_extra.py +++ b/tests/test_environment_extra.py @@ -331,6 +331,7 @@ async def run_group( sampling_args, max_retries, state_columns, + **kwargs, ): assert isinstance(client_config, ClientConfig) self.client_urls_per_group.append(str(client_config.api_base_url)) @@ -426,6 +427,7 @@ async def run_group( sampling_args, max_retries, state_columns, + **kwargs, ): assert isinstance(client_config, ClientConfig) self.client_url = str(client_config.api_base_url) @@ -485,6 +487,7 @@ async def run_rollout( sampling_args, max_retries, state_columns, + **kwargs, ): assert isinstance(client_config, ClientConfig) self.client_url = str(client_config.api_base_url) diff --git a/verifiers/envs/environment.py b/verifiers/envs/environment.py index f2d0a9636..a93843724 100644 --- a/verifiers/envs/environment.py +++ b/verifiers/envs/environment.py @@ -673,6 +673,7 @@ async def run_rollout( max_retries: int = 0, state_columns: list[str] | None = None, env_client: EnvClient | None = None, + training_context: dict | None = None, ) -> RolloutOutput: """Generate and, optionally, score a rollout.""" @@ -693,10 +694,13 @@ async def run_rollout( sampling_args, max_retries, state_columns, + training_context=training_context, ) resolved_client = resolve_client(client) + self.rubric.training_context = training_context + async def run_rollout_attempt() -> State: state = await self.rollout( input, @@ -730,6 +734,7 @@ async def run_group( max_retries: int = 0, state_columns: list[str] | None = None, env_client: EnvClient | None = None, + training_context: dict | None = None, **kwargs, ) -> list[RolloutOutput]: """Generate and, optionally, score one group.""" @@ -751,10 +756,13 @@ async def run_group( sampling_args, max_retries, state_columns, + training_context=training_context, ) resolved_client = resolve_client(client) + self.rubric.training_context = training_context + async def run_group_attempt() -> list[State]: rollout_tasks = [ self.rollout( diff --git a/verifiers/rubrics/rubric.py b/verifiers/rubrics/rubric.py index 273f95765..eda92bf9b 100644 --- a/verifiers/rubrics/rubric.py +++ b/verifiers/rubrics/rubric.py @@ -48,6 +48,10 @@ def __init__( self.parser = parser or vf.Parser() + # Training context set by the orchestrator before scoring. + # Contains metadata like {"step": int, "ckpt_step": int}. + self.training_context: dict | None = None + # class objects for reward functions self.class_objects = {} if self.parser: diff --git a/verifiers/rubrics/rubric_group.py b/verifiers/rubrics/rubric_group.py index 8de1d7f77..39127ea35 100644 --- a/verifiers/rubrics/rubric_group.py +++ b/verifiers/rubrics/rubric_group.py @@ -19,6 +19,16 @@ def __init__(self, rubrics: list[Rubric], **kwargs): self.rubrics = rubrics self.logger.debug(f"Initialized RubricGroup with {len(rubrics)} rubrics") + @property # type: ignore[override] + def training_context(self) -> dict | None: + return self._training_context + + @training_context.setter + def training_context(self, value: dict | None) -> None: + self._training_context = value + for rubric in getattr(self, "rubrics", []): + rubric.training_context = value + def _get_reward_func_names(self) -> list[str]: names = [] for rubric in self.rubrics: diff --git a/verifiers/serve/client/env_client.py b/verifiers/serve/client/env_client.py index 8649fb246..c41b46116 100644 --- a/verifiers/serve/client/env_client.py +++ b/verifiers/serve/client/env_client.py @@ -50,6 +50,7 @@ async def run_rollout( sampling_args: SamplingArgs, max_retries: int = 0, state_columns: list[str] | None = None, + training_context: dict | None = None, ) -> RolloutOutput: resolved_client_config = resolve_client_config(client_config) request = RunRolloutRequest( @@ -59,6 +60,7 @@ async def run_rollout( sampling_args=sampling_args, max_retries=max_retries, state_columns=state_columns, + training_context=training_context, ) response = await self.handle_run_rollout_request(request, timeout=None) assert response.output is not None @@ -72,6 +74,7 @@ async def run_group( sampling_args: SamplingArgs, max_retries: int = 0, state_columns: list[str] | None = None, + training_context: dict | None = None, ) -> list[RolloutOutput]: resolved_client_config = resolve_client_config(client_config) request = RunGroupRequest( @@ -81,6 +84,7 @@ async def run_group( sampling_args=sampling_args, max_retries=max_retries, state_columns=state_columns, + training_context=training_context, ) response = await self.handle_run_group_request(request, timeout=None) assert response.outputs is not None diff --git a/verifiers/serve/server/env_worker.py b/verifiers/serve/server/env_worker.py index 70f77cf30..344325413 100644 --- a/verifiers/serve/server/env_worker.py +++ b/verifiers/serve/server/env_worker.py @@ -138,6 +138,7 @@ async def resolve_client(self, client_config: ClientConfig) -> Client: async def handle_run_rollout( self, request: RunRolloutRequest ) -> RunRolloutResponse: + self.env.rubric.training_context = request.training_context client = await self.resolve_client(request.client_config) output = await self.env.run_rollout( input=request.input, @@ -150,6 +151,7 @@ async def handle_run_rollout( return RunRolloutResponse(output=output) async def handle_run_group(self, request: RunGroupRequest) -> RunGroupResponse: + self.env.rubric.training_context = request.training_context client = await self.resolve_client(request.client_config) outputs = await self.env.run_group( group_inputs=request.group_inputs, diff --git a/verifiers/serve/types.py b/verifiers/serve/types.py index 834ce1f25..7932b9323 100644 --- a/verifiers/serve/types.py +++ b/verifiers/serve/types.py @@ -51,6 +51,7 @@ class RunRolloutRequest(BaseRequest): sampling_args: SamplingArgs max_retries: int state_columns: list[str] | None + training_context: dict | None = None class RunRolloutResponse(BaseResponse): @@ -68,6 +69,7 @@ class RunGroupRequest(BaseRequest): sampling_args: SamplingArgs max_retries: int state_columns: list[str] | None + training_context: dict | None = None class RunGroupResponse(BaseResponse):