Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions tests/test_eval_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,3 +136,27 @@ def test_print_results_includes_usage(capsys, make_metadata, make_output):
assert "Usage:" in captured.out
assert "input_tokens (avg): 8.000" in captured.out
assert "output_tokens (avg): 3.000" in captured.out


def test_print_results_handles_heterogeneous_metrics(
capsys, make_metadata, make_output
):
from verifiers.utils.eval_utils import print_results

outputs = [
make_output(example_id=0, reward=1.0, metrics={"rlm_turns": 3.0}),
make_output(
example_id=1,
reward=0.0,
metrics={"rlm_compactions_count": 1.0, "rlm_turns": 2.0},
),
]
metadata = make_metadata(num_examples=2, rollouts_per_example=1)

results = GenerateOutputs(outputs=outputs, metadata=metadata)
print_results(results)
captured = capsys.readouterr()

assert "rlm_compactions_count: avg - 1.000" in captured.out
assert "r1: [1.0]" in captured.out
assert "rlm_turns: avg - 2.500" in captured.out
23 changes: 16 additions & 7 deletions verifiers/utils/eval_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -524,11 +524,12 @@ def filter_inputs(
return filtered_inputs


def to_col_order(list_of_dicts: list[Mapping[str, float]]) -> dict[str, list[float]]:
"""Convert a list of mappings to a dictionary of lists, ordered by the keys of the first mapping."""
if not list_of_dicts:
return {}
return {k: [m[k] for m in list_of_dicts] for k in list_of_dicts[0].keys()}
def to_col_order(
list_of_dicts: list[Mapping[str, float]],
) -> dict[str, list[float | None]]:
"""Convert a list of mappings to a dictionary of lists."""
keys = sorted({key for mapping in list_of_dicts for key in mapping})
return {key: [mapping.get(key) for mapping in list_of_dicts] for key in keys}


def output_env_id(output: Mapping[str, Any]) -> str:
Expand Down Expand Up @@ -585,9 +586,17 @@ def print_rewards(results: GenerateOutputs):
metrics_col = to_col_order(metrics)
for k in metrics_col.keys():
v = metrics_col[k]
print(f"{k}: avg - {sum(v) / len(v):.3f}, std - {np.std(v):.3f}")
present_values = [value for value in v if value is not None]
print(
f"{k}: avg - {sum(present_values) / len(present_values):.3f}, "
f"std - {np.std(present_values):.3f}"
)
for i in range(r):
trials = [round(v[i + (j * r)], 3) for j in range(n)]
trials = [
round(value, 3)
for j in range(n)
if (value := v[i + (j * r)]) is not None
]
out = f"r{i + 1}: {trials}"
print(out)

Expand Down
Loading