Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions tests/test_environment_extra.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,6 +331,7 @@ async def run_group(
sampling_args,
max_retries,
state_columns,
**kwargs,
):
assert isinstance(client_config, ClientConfig)
self.client_urls_per_group.append(str(client_config.api_base_url))
Expand Down Expand Up @@ -426,6 +427,7 @@ async def run_group(
sampling_args,
max_retries,
state_columns,
**kwargs,
):
assert isinstance(client_config, ClientConfig)
self.client_url = str(client_config.api_base_url)
Expand Down Expand Up @@ -485,6 +487,7 @@ async def run_rollout(
sampling_args,
max_retries,
state_columns,
**kwargs,
):
assert isinstance(client_config, ClientConfig)
self.client_url = str(client_config.api_base_url)
Expand Down
8 changes: 8 additions & 0 deletions verifiers/envs/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -673,6 +673,7 @@ async def run_rollout(
max_retries: int = 0,
state_columns: list[str] | None = None,
env_client: EnvClient | None = None,
training_context: dict | None = None,
) -> RolloutOutput:
"""Generate and, optionally, score a rollout."""

Expand All @@ -693,10 +694,13 @@ async def run_rollout(
sampling_args,
max_retries,
state_columns,
training_context=training_context,
)

resolved_client = resolve_client(client)

self.rubric.training_context = training_context

async def run_rollout_attempt() -> State:
state = await self.rollout(
input,
Expand Down Expand Up @@ -730,6 +734,7 @@ async def run_group(
max_retries: int = 0,
state_columns: list[str] | None = None,
env_client: EnvClient | None = None,
training_context: dict | None = None,
**kwargs,
) -> list[RolloutOutput]:
"""Generate and, optionally, score one group."""
Expand All @@ -751,10 +756,13 @@ async def run_group(
sampling_args,
max_retries,
state_columns,
training_context=training_context,
)

resolved_client = resolve_client(client)

self.rubric.training_context = training_context

async def run_group_attempt() -> list[State]:
rollout_tasks = [
self.rollout(
Expand Down
4 changes: 4 additions & 0 deletions verifiers/rubrics/rubric.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,10 @@ def __init__(

self.parser = parser or vf.Parser()

# Training context set by the orchestrator before scoring.
# Contains metadata like {"step": int, "ckpt_step": int}.
self.training_context: dict | None = None
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing documentation for new training context feature

Low Severity

This PR adds the user-facing training_context attribute to Rubric and new training_context parameters to Environment.run_rollout and Environment.run_group, but no corresponding updates were made to docs/reference.md or docs/environments.md, both of which document these classes and methods. Per project rules, PRs modifying core user-facing functionality described in docs/ must update the relevant documentation.

Additional Locations (1)
Fix in Cursor Fix in Web

Triggered by project rule: BugBot Instructions

Reviewed by Cursor Bugbot for commit 475a603. Configure here.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Acknowledged — will add documentation in a follow-up once the API stabilizes through review. The feature is intentionally minimal right now (optional dict, defaults to None) so existing code is unaffected.


# class objects for reward functions
self.class_objects = {}
if self.parser:
Expand Down
10 changes: 10 additions & 0 deletions verifiers/rubrics/rubric_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,16 @@ def __init__(self, rubrics: list[Rubric], **kwargs):
self.rubrics = rubrics
self.logger.debug(f"Initialized RubricGroup with {len(rubrics)} rubrics")

@property # type: ignore[override]
def training_context(self) -> dict | None:
return self._training_context

@training_context.setter
def training_context(self, value: dict | None) -> None:
self._training_context = value
for rubric in getattr(self, "rubrics", []):
rubric.training_context = value

def _get_reward_func_names(self) -> list[str]:
names = []
for rubric in self.rubrics:
Expand Down
4 changes: 4 additions & 0 deletions verifiers/serve/client/env_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ async def run_rollout(
sampling_args: SamplingArgs,
max_retries: int = 0,
state_columns: list[str] | None = None,
training_context: dict | None = None,
) -> RolloutOutput:
resolved_client_config = resolve_client_config(client_config)
request = RunRolloutRequest(
Expand All @@ -59,6 +60,7 @@ async def run_rollout(
sampling_args=sampling_args,
max_retries=max_retries,
state_columns=state_columns,
training_context=training_context,
)
response = await self.handle_run_rollout_request(request, timeout=None)
assert response.output is not None
Expand All @@ -72,6 +74,7 @@ async def run_group(
sampling_args: SamplingArgs,
max_retries: int = 0,
state_columns: list[str] | None = None,
training_context: dict | None = None,
) -> list[RolloutOutput]:
resolved_client_config = resolve_client_config(client_config)
request = RunGroupRequest(
Expand All @@ -81,6 +84,7 @@ async def run_group(
sampling_args=sampling_args,
max_retries=max_retries,
state_columns=state_columns,
training_context=training_context,
)
response = await self.handle_run_group_request(request, timeout=None)
assert response.outputs is not None
Expand Down
2 changes: 2 additions & 0 deletions verifiers/serve/server/env_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ async def resolve_client(self, client_config: ClientConfig) -> Client:
async def handle_run_rollout(
self, request: RunRolloutRequest
) -> RunRolloutResponse:
self.env.rubric.training_context = request.training_context
client = await self.resolve_client(request.client_config)
output = await self.env.run_rollout(
input=request.input,
Expand All @@ -150,6 +151,7 @@ async def handle_run_rollout(
return RunRolloutResponse(output=output)

async def handle_run_group(self, request: RunGroupRequest) -> RunGroupResponse:
self.env.rubric.training_context = request.training_context
client = await self.resolve_client(request.client_config)
outputs = await self.env.run_group(
group_inputs=request.group_inputs,
Expand Down
2 changes: 2 additions & 0 deletions verifiers/serve/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ class RunRolloutRequest(BaseRequest):
sampling_args: SamplingArgs
max_retries: int
state_columns: list[str] | None
training_context: dict | None = None


class RunRolloutResponse(BaseResponse):
Expand All @@ -68,6 +69,7 @@ class RunGroupRequest(BaseRequest):
sampling_args: SamplingArgs
max_retries: int
state_columns: list[str] | None
training_context: dict | None = None


class RunGroupResponse(BaseResponse):
Expand Down