diff --git a/assets/lab/environments/AGENTS.md b/assets/lab/environments/AGENTS.md index 131eb5205..136e7f340 100644 --- a/assets/lab/environments/AGENTS.md +++ b/assets/lab/environments/AGENTS.md @@ -786,7 +786,25 @@ combined = vf.EnvGroup( ) ``` -The group concatenates all sub-environment datasets, tagging each row with a `task` column that routes rollouts to the appropriate environment for generation and scoring. Metrics from all environments are tracked together. +The group concatenates all sub-environment datasets, tagging each row with a `task` column that routes rollouts to the appropriate environment for generation and scoring. Metrics from all environments are tracked together. + +Passing an unknown task name to `get_env_for_task` raises a `ValueError` listing the available task names, making misconfiguration immediately visible rather than silently misrouting to the first sub-environment. + +An `EnvGroup` can itself be used as a sub-environment inside another `EnvGroup`. The outer group automatically inherits the inner group's task names and routes through both levels, so each inner task name maps correctly to its environment: + +```python +inner = vf.EnvGroup( + envs=[math_env, code_env], + env_names=["math", "code"], +) + +# prime-rl wraps user envs this way; task names are preserved through both levels +outer = vf.EnvGroup(envs=[inner], env_names=["my_env"]) +# outer.get_env_for_task("math") -> inner -> math_env +# outer.get_env_for_task("code") -> inner -> code_env +``` + +Task names must be unique across all levels of nesting; a collision raises `ValueError` at construction time. ## Performance diff --git a/docs/environments.md b/docs/environments.md index 70e709b0e..1765d2df2 100644 --- a/docs/environments.md +++ b/docs/environments.md @@ -780,7 +780,25 @@ combined = vf.EnvGroup( ) ``` -The group concatenates all sub-environment datasets, tagging each row with a `task` column that routes rollouts to the appropriate environment for generation and scoring. Metrics from all environments are tracked together. +The group concatenates all sub-environment datasets, tagging each row with a `task` column that routes rollouts to the appropriate environment for generation and scoring. Metrics from all environments are tracked together. + +Passing an unknown task name to `get_env_for_task` raises a `ValueError` listing the available task names, making misconfiguration immediately visible rather than silently misrouting to the first sub-environment. + +An `EnvGroup` can itself be used as a sub-environment inside another `EnvGroup`. The outer group automatically inherits the inner group's task names and routes through both levels, so each inner task name maps correctly to its environment: + +```python +inner = vf.EnvGroup( + envs=[math_env, code_env], + env_names=["math", "code"], +) + +# prime-rl wraps user envs this way; task names are preserved through both levels +outer = vf.EnvGroup(envs=[inner], env_names=["my_env"]) +# outer.get_env_for_task("math") -> inner -> math_env +# outer.get_env_for_task("code") -> inner -> code_env +``` + +Task names must be unique across all levels of nesting; a collision raises `ValueError` at construction time. ## Performance diff --git a/docs/reference.md b/docs/reference.md index d20012cac..b66321437 100644 --- a/docs/reference.md +++ b/docs/reference.md @@ -471,11 +471,15 @@ OpenEnv integration that runs OpenEnv projects in Prime Sandboxes using a prebui ```python env_group = vf.EnvGroup( envs=[env1, env2, env3], - names=["math", "code", "qa"] # optional + env_names=["math", "code", "qa"] # optional ) ``` -Combines multiple environments for mixed-task training. +Combines multiple environments for mixed-task training. Each row in the concatenated dataset is tagged with a `task` column that routes rollouts and scoring to the correct sub-environment. + +`get_env_for_task(task)` raises `ValueError` for unknown task names (listing available tasks) rather than silently falling back to `envs[0]`. + +An `EnvGroup` can be nested inside another `EnvGroup`. The outer group automatically inherits the inner group's task names so routing works through both levels. Task names must be unique across all nesting levels; a collision raises `ValueError` at construction time. --- diff --git a/environments/AGENTS.md b/environments/AGENTS.md index b8fd5bdde..633c508f5 100644 --- a/environments/AGENTS.md +++ b/environments/AGENTS.md @@ -786,7 +786,25 @@ combined = vf.EnvGroup( ) ``` -The group concatenates all sub-environment datasets, tagging each row with a `task` column that routes rollouts to the appropriate environment for generation and scoring. Metrics from all environments are tracked together. +The group concatenates all sub-environment datasets, tagging each row with a `task` column that routes rollouts to the appropriate environment for generation and scoring. Metrics from all environments are tracked together. + +Passing an unknown task name to `get_env_for_task` raises a `ValueError` listing the available task names, making misconfiguration immediately visible rather than silently misrouting to the first sub-environment. + +An `EnvGroup` can itself be used as a sub-environment inside another `EnvGroup`. The outer group automatically inherits the inner group's task names and routes through both levels, so each inner task name maps correctly to its environment: + +```python +inner = vf.EnvGroup( + envs=[math_env, code_env], + env_names=["math", "code"], +) + +# prime-rl wraps user envs this way; task names are preserved through both levels +outer = vf.EnvGroup(envs=[inner], env_names=["my_env"]) +# outer.get_env_for_task("math") -> inner -> math_env +# outer.get_env_for_task("code") -> inner -> code_env +``` + +Task names must be unique across all levels of nesting; a collision raises `ValueError` at construction time. ## Performance diff --git a/tests/test_env_group.py b/tests/test_env_group.py index ff7363a16..78722a792 100644 --- a/tests/test_env_group.py +++ b/tests/test_env_group.py @@ -353,8 +353,62 @@ def test_get_env_for_task(self, mock_client): assert env_group.get_env_for_task("math") == env1 assert env_group.get_env_for_task("code") == env2 - # Unknown task returns first environment as fallback - assert env_group.get_env_for_task("unknown") == env1 + # Unknown task should raise rather than silently misroute + with pytest.raises(ValueError, match="No environment found for task"): + env_group.get_env_for_task("unknown") + + def test_nested_env_group_preserves_inner_tasks(self, mock_client): + """Wrapping an EnvGroup in another EnvGroup must preserve inner task names.""" + env1 = SingleTurnEnv( + client=mock_client, + model="test-model", + dataset=Dataset.from_dict({"question": ["q1"], "answer": ["a1"]}), + rubric=Rubric(), + ) + env2 = SingleTurnEnv( + client=mock_client, + model="test-model", + dataset=Dataset.from_dict({"question": ["q2"], "answer": ["a2"]}), + rubric=Rubric(), + ) + + inner_group = EnvGroup(envs=[env1, env2], env_names=["math", "code"]) + outer_group = EnvGroup(envs=[inner_group], env_names=["my_env"]) + + # Inner task names should be present in the outer env_map + assert outer_group.get_env_for_task("math") is inner_group + assert outer_group.get_env_for_task("code") is inner_group + + # Stale outer name should be removed so it does not cause misleading lookups + with pytest.raises(ValueError, match="No environment found for task"): + outer_group.get_env_for_task("my_env") + + # Dataset should retain the inner task labels + dataset = outer_group.get_dataset() + tasks = set(dataset["task"]) + assert "math" in tasks + assert "code" in tasks + assert "my_env" not in tasks + + def test_nested_env_group_name_collision_raises(self, mock_client): + """A nested EnvGroup whose inner task name collides with a sibling env raises.""" + env1 = SingleTurnEnv( + client=mock_client, + model="test-model", + dataset=Dataset.from_dict({"question": ["q1"], "answer": ["a1"]}), + rubric=Rubric(), + ) + env2 = SingleTurnEnv( + client=mock_client, + model="test-model", + dataset=Dataset.from_dict({"question": ["q2"], "answer": ["a2"]}), + rubric=Rubric(), + ) + + inner_group = EnvGroup(envs=[env1], env_names=["math"]) + # "math" is both an inner task name and the name of a sibling env + with pytest.raises(ValueError, match="conflicts with an existing task name"): + EnvGroup(envs=[inner_group, env2], env_names=["group", "math"]) @pytest.mark.asyncio async def test_env_group_generate(self, mock_client, make_input): diff --git a/verifiers/envs/env_group.py b/verifiers/envs/env_group.py index 76b6a3ce1..f8bd28e53 100644 --- a/verifiers/envs/env_group.py +++ b/verifiers/envs/env_group.py @@ -173,24 +173,54 @@ def add_task(example): return add_task + def _register_nested_env_map(inner_env: "EnvGroup", outer_name: str) -> None: + """Expand a nested EnvGroup's task names into the outer env_map. + + Registers each inner task name pointing to inner_env, then removes the + outer_name entry unless that name is also an inner task name (which would + mean the pop would undo the registration we just did). + Raises ValueError if an inner task name conflicts with a pre-existing + sibling env entry. + """ + for inner_name in inner_env.env_map: + if ( + inner_name in self.env_map + and self.env_map[inner_name] is not inner_env + ): + raise ValueError( + f"Inner task name '{inner_name}' from nested EnvGroup " + f"'{outer_name}' conflicts with an existing task name in the " + f"outer EnvGroup. Use unique task names across all levels." + ) + self.env_map[inner_name] = inner_env + if outer_name not in inner_env.env_map: + self.env_map.pop(outer_name, None) + for env, name in zip(self.envs, self.env_names): add_task = make_add_task_fn(name) # Build dataset if using DatasetBuilder, returns None if not available env_dataset = env.build_dataset() if env_dataset is not None: - # override task column to use env_name for routing - if "task" in env_dataset.column_names: - env_dataset = env_dataset.remove_columns(["task"]) - env_dataset = env_dataset.map(add_task, **map_kwargs) + if isinstance(env, EnvGroup): + # Preserve inner task names so routing works through both levels. + _register_nested_env_map(env, name) + else: + # override task column to use env_name for routing + if "task" in env_dataset.column_names: + env_dataset = env_dataset.remove_columns(["task"]) + env_dataset = env_dataset.map(add_task, **map_kwargs) datasets.append(env_dataset) # Build eval_dataset if using DatasetBuilder, returns None if not available env_eval_dataset = env.build_eval_dataset() if env_eval_dataset is not None: - # override task column to use env_name for routing - if "task" in env_eval_dataset.column_names: - env_eval_dataset = env_eval_dataset.remove_columns(["task"]) - env_eval_dataset = env_eval_dataset.map(add_task, **map_kwargs) + if isinstance(env, EnvGroup): + _register_nested_env_map(env, name) + else: + # override task column to use env_name for routing + if "task" in env_eval_dataset.column_names: + env_eval_dataset = env_eval_dataset.remove_columns(["task"]) + env_eval_dataset = env_eval_dataset.map(add_task, **map_kwargs) eval_datasets.append(env_eval_dataset) dataset = concatenate_datasets(datasets) if datasets else None eval_dataset = concatenate_datasets(eval_datasets) if eval_datasets else None @@ -320,7 +350,13 @@ async def rollout( return await env.rollout(input, client, model, sampling_args) def get_env_for_task(self, task: str) -> vf.Environment: - return self.env_map.get(task, self.envs[0]) + env = self.env_map.get(task) + if env is None: + available = list(self.env_map.keys()) + raise ValueError( + f"No environment found for task '{task}'. Available tasks: {available}" + ) + return env def set_max_seq_len(self, max_seq_len: int | None) -> None: """Set the max_seq_len value for this environment group and all sub-environments."""