Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 1 addition & 5 deletions cli/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,11 +169,7 @@ def generate_project(
shutil.copytree(
root / ".github",
output_dir / ".github",
ignore=shutil.ignore_patterns(
"auto-tag-release.yaml",
"check-pr-title.yaml",
"cli-integration.yaml",
),
ignore=shutil.ignore_patterns("cli-integration.yaml"),
Comment thread
BradenBug marked this conversation as resolved.
)

# Create empty tests directory
Expand Down
3 changes: 2 additions & 1 deletion src/benchmark_service/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,11 +197,12 @@ def _current_service_version(self) -> str | None:
service_override = self.service.get_service_version()
return service_override or self._service_version

async def _version(self) -> VersionResponse:
async def _version(self, dataset: str | None = None) -> VersionResponse:
return VersionResponse(
framework_version=_framework_version,
service_name=self._service_name,
service_version=self._current_service_version(),
dataset_version=self.service.get_dataset_version(dataset),
)

async def _authorize_websocket(self, websocket: WebSocket) -> str | None:
Expand Down
17 changes: 15 additions & 2 deletions src/benchmark_service/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,16 @@

from abc import ABC, abstractmethod
from collections.abc import AsyncGenerator
from typing import Any, Self
from pathlib import Path
from typing import Any, ClassVar, Self

from benchmark_service.auth import (
LEGACY_TENANT_SENTINEL,
check_benchmark_service_auth,
load_allowlist,
resolve_caller_tenant,
)
from benchmark_service.dataset_versioning import DatasetVersionEntry, load_verified_dataset_versions
from benchmark_service.sandbox import Sandbox
from benchmark_service.schemas import (
EvaluateResponseRequest,
Expand All @@ -36,10 +38,20 @@ class BenchmarkService(ABC):

datasets: dict[str, dict[str, Any]]

# When set, dataset versions are loaded and content-verified at startup and
# served by get_dataset_version(); a checksum mismatch aborts startup.
dataset_versions_file: ClassVar[Path | None] = None
dataset_versions: dict[str, DatasetVersionEntry]

@classmethod
async def create(cls) -> Self:
"""Factory method to create and initialize a benchmark service."""
instance = cls.__new__(cls)
instance.dataset_versions = (
load_verified_dataset_versions(cls.dataset_versions_file)
if cls.dataset_versions_file is not None
else {}
)
instance.datasets = await instance.load_datasets()
return instance

Expand Down Expand Up @@ -104,7 +116,8 @@ def get_service_version(self) -> str | None:

def get_dataset_version(self, dataset: str | None = None) -> str | None:
"""Return the version for `dataset`, if this benchmark tracks one."""
return None
entry = self.dataset_versions.get(dataset or "default")
return entry.version if entry is not None else None

def get_dataset(self, dataset: str | None = None) -> dict[str, Any]:
"""Get a specific dataset by name. Defaults to 'default'."""
Expand Down
78 changes: 78 additions & 0 deletions src/benchmark_service/dataset_versioning.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
"""Hash-guarded dataset version tracking.

A dataset_versions.yaml next to the dataset files maps each dataset name to a
human-assigned semver and a sha256 of its content file. Version semantics:
major = scores not comparable, minor = additive, patch = non-scoring fixes.
"""

import hashlib
import sys
from pathlib import Path

import yaml
from pydantic import BaseModel


class DatasetVersionError(Exception):
"""Dataset content does not match its declared version entry."""


class DatasetVersionEntry(BaseModel):
file: str
version: str
sha256: str


def compute_checksum(path: Path) -> str:
return hashlib.sha256(path.read_bytes()).hexdigest()


def load_dataset_versions(versions_file: Path) -> dict[str, DatasetVersionEntry]:
data = yaml.safe_load(versions_file.read_text())
return {name: DatasetVersionEntry.model_validate(entry) for name, entry in data.items()}
Comment thread
BradenBug marked this conversation as resolved.
Outdated


def load_verified_dataset_versions(versions_file: Path) -> dict[str, DatasetVersionEntry]:
"""Load entries and verify every dataset file matches its declared checksum.

Raises DatasetVersionError on any mismatch: content that does not match its
declared version must never be served.
"""
entries = load_dataset_versions(versions_file)
data_dir = versions_file.parent
mismatches: list[str] = []
for name, entry in entries.items():
actual = compute_checksum(data_dir / entry.file)
if actual != entry.sha256:
mismatches.append(f"{name} ({entry.file}): declared {entry.sha256}, actual {actual}")
if mismatches:
raise DatasetVersionError(
"dataset content does not match dataset_versions.yaml — bump the version, then run "
"`python -m benchmark_service.dataset_versioning update <file>`:\n " + "\n ".join(mismatches)
)
return entries


def main(argv: list[str]) -> int:
if len(argv) != 2 or argv[0] not in ("check", "update"):
print("usage: python -m benchmark_service.dataset_versioning {check|update} <dataset_versions.yaml>")
return 2
command, versions_file = argv[0], Path(argv[1])
if command == "check":
try:
load_verified_dataset_versions(versions_file)
except DatasetVersionError as exc:
print(exc)
return 1
print("dataset checksums OK")
return 0
raw = yaml.safe_load(versions_file.read_text())
for entry in raw.values():
entry["sha256"] = compute_checksum(versions_file.parent / entry["file"])
versions_file.write_text(yaml.safe_dump(raw, sort_keys=False))
print(f"updated {versions_file}")
return 0


if __name__ == "__main__":
sys.exit(main(sys.argv[1:]))
1 change: 1 addition & 0 deletions src/benchmark_service/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,7 @@ class VersionResponse(BaseModel):
framework_version: str
service_name: str | None = None
service_version: str | None = None
dataset_version: str | None = None


class StreamMessageChunk(BaseModel):
Expand Down
10 changes: 8 additions & 2 deletions templates/pyproject.toml.jinja
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "{{ benchmark_name }}-benchmark-service"
version = "0.1.0"
dynamic = ["version"]
readme = "README.md"
requires-python = ">=3.12"

Expand All @@ -17,9 +17,15 @@ dev = [
]

[build-system]
requires = ["hatchling"]
requires = ["hatchling", "hatch-vcs"]
build-backend = "hatchling.build"

[tool.hatch.version]
source = "vcs"
# A freshly scaffolded project has no git tags (or no repo yet); fall back so it
# still builds. Real versions come from tags once the repo is set up.
fallback-version = "0.0.0"

[tool.hatch.build.targets.wheel]
packages = ["src/{{ benchmark_package }}"]

Expand Down
1 change: 1 addition & 0 deletions tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ def test_sandbox_config_rejects_unknown_provider(monkeypatch: pytest.MonkeyPatch
"framework_version": "0.7.4",
"service_name": "legal-research-benchmark-service",
"service_version": "1.2.3",
"dataset_version": "3.0.0",
},
),
(
Expand Down
76 changes: 76 additions & 0 deletions tests/test_dataset_versioning.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
"""Tests for hash-guarded dataset version tracking."""

from pathlib import Path

import pytest
import yaml

from benchmark_service.dataset_versioning import (
DatasetVersionError,
compute_checksum,
load_verified_dataset_versions,
main,
)

from tests.conftest import StubBenchmark


def _write_fixture(tmp_path: Path, content: bytes = b'{"tests": []}') -> Path:
data_file = tmp_path / "validation.json"
data_file.write_bytes(content)
versions_file = tmp_path / "dataset_versions.yaml"
versions_file.write_text(
yaml.safe_dump(
{
"validation": {
"file": "validation.json",
"version": "1.0.0",
"sha256": compute_checksum(data_file),
}
}
)
)
return versions_file


def test_load_verified_returns_entries_when_content_matches(tmp_path: Path) -> None:
versions_file = _write_fixture(tmp_path)
entries = load_verified_dataset_versions(versions_file)
assert entries["validation"].version == "1.0.0"


def test_load_verified_raises_on_content_mismatch(tmp_path: Path) -> None:
versions_file = _write_fixture(tmp_path)
(tmp_path / "validation.json").write_bytes(b'{"tests": [1]}')
with pytest.raises(DatasetVersionError, match="validation"):
load_verified_dataset_versions(versions_file)


def test_check_command_fails_on_mismatch_and_update_repairs(tmp_path: Path) -> None:
versions_file = _write_fixture(tmp_path)
(tmp_path / "validation.json").write_bytes(b'{"tests": [1]}')
assert main(["check", str(versions_file)]) == 1
assert main(["update", str(versions_file)]) == 0
assert main(["check", str(versions_file)]) == 0


async def test_service_startup_verifies_and_serves_dataset_versions(tmp_path: Path) -> None:
versions_file = _write_fixture(tmp_path)

class VersionedBenchmark(StubBenchmark):
dataset_versions_file = versions_file

service = await VersionedBenchmark.create()
assert service.get_dataset_version("validation") == "1.0.0"
assert service.get_dataset_version("unknown") is None


async def test_service_startup_fails_on_checksum_mismatch(tmp_path: Path) -> None:
versions_file = _write_fixture(tmp_path)
(tmp_path / "validation.json").write_bytes(b"tampered")

class VersionedBenchmark(StubBenchmark):
dataset_versions_file = versions_file

with pytest.raises(DatasetVersionError):
await VersionedBenchmark.create()
6 changes: 3 additions & 3 deletions tests/test_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,17 +131,17 @@ def test_generates_project_structure() -> None:
assert (output_dir / ".github" / "workflows").is_dir()


def test_generated_project_excludes_framework_versioning_workflows(tmp_path: Path) -> None:
def test_generated_project_includes_release_workflows_but_not_cli_integration(tmp_path: Path) -> None:
output_dir = tmp_path / "swebench-benchmark-service"
generate_project("swebench", output_dir)

workflows_dir = output_dir / ".github" / "workflows"
assert (workflows_dir / "test.yaml").exists()
assert (workflows_dir / "style.yaml").exists()
assert (workflows_dir / "typecheck.yaml").exists()
assert (workflows_dir / "auto-tag-release.yaml").exists()
assert (workflows_dir / "check-pr-title.yaml").exists()
assert not (workflows_dir / "cli-integration.yaml").exists()
assert not (workflows_dir / "auto-tag-release.yaml").exists()
assert not (workflows_dir / "check-pr-title.yaml").exists()


def test_generated_benchmark_service_implements_task_listing(tmp_path: Path) -> None:
Expand Down
26 changes: 26 additions & 0 deletions tests/test_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import re
from collections.abc import AsyncGenerator
from pathlib import Path
from typing import Any
from unittest.mock import patch

Expand All @@ -18,6 +19,8 @@
StreamChunk,
)

from tests.test_dataset_versioning import _write_fixture # pyright: ignore[reportPrivateUsage]


def test_version_is_importable_and_well_formed() -> None:
assert isinstance(benchmark_service.__version__, str)
Expand Down Expand Up @@ -100,3 +103,26 @@ def test_version_endpoint_prefers_service_version_hook() -> None:

assert response.status_code == 200
assert response.json()["service_version"] == "service-hook-1.2.3"


def test_version_endpoint_reports_dataset_version_key() -> None:
app = BenchmarkServiceApp(_FakeService)
with TestClient(app) as client:
response = client.get("/version", params={"dataset": "default"})

assert response.status_code == 200
assert "dataset_version" in response.json()


def test_version_endpoint_reports_tracked_dataset_version(tmp_path: Path) -> None:
versions_file = _write_fixture(tmp_path)

class _DatasetVersionedService(_FakeService):
dataset_versions_file = versions_file

app = BenchmarkServiceApp(_DatasetVersionedService)
with TestClient(app) as client:
response = client.get("/version", params={"dataset": "validation"})

assert response.status_code == 200
assert response.json()["dataset_version"] == "1.0.0"
Loading