Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 1 addition & 5 deletions cli/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,11 +169,7 @@ def generate_project(
shutil.copytree(
root / ".github",
output_dir / ".github",
ignore=shutil.ignore_patterns(
"auto-tag-release.yaml",
"check-pr-title.yaml",
"cli-integration.yaml",
),
ignore=shutil.ignore_patterns("cli-integration.yaml"),
Comment thread
BradenBug marked this conversation as resolved.
)

# Create empty tests directory
Expand Down
3 changes: 2 additions & 1 deletion src/benchmark_service/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,11 +197,12 @@ def _current_service_version(self) -> str | None:
service_override = self.service.get_service_version()
return service_override or self._service_version

async def _version(self) -> VersionResponse:
async def _version(self, dataset: str | None = None) -> VersionResponse:
return VersionResponse(
framework_version=_framework_version,
service_name=self._service_name,
service_version=self._current_service_version(),
dataset_version=self.service.get_dataset_version(dataset),
)

async def _authorize_websocket(self, websocket: WebSocket) -> str | None:
Expand Down
17 changes: 15 additions & 2 deletions src/benchmark_service/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,16 @@

from abc import ABC, abstractmethod
from collections.abc import AsyncGenerator
from typing import Any, Self
from pathlib import Path
from typing import Any, ClassVar, Self

from benchmark_service.auth import (
LEGACY_TENANT_SENTINEL,
check_benchmark_service_auth,
load_allowlist,
resolve_caller_tenant,
)
from benchmark_service.dataset_versioning import DatasetVersionEntry, load_verified_dataset_versions
from benchmark_service.sandbox import Sandbox
from benchmark_service.schemas import (
EvaluateResponseRequest,
Expand All @@ -36,10 +38,20 @@ class BenchmarkService(ABC):

datasets: dict[str, dict[str, Any]]

# When set, dataset versions are loaded and content-verified at startup and
# served by get_dataset_version(); a checksum mismatch aborts startup.
dataset_versions_file: ClassVar[Path | None] = None
dataset_versions: dict[str, DatasetVersionEntry]

@classmethod
async def create(cls) -> Self:
"""Factory method to create and initialize a benchmark service."""
instance = cls.__new__(cls)
instance.dataset_versions = (
load_verified_dataset_versions(cls.dataset_versions_file)
if cls.dataset_versions_file is not None
else {}
)
instance.datasets = await instance.load_datasets()
return instance

Expand Down Expand Up @@ -104,7 +116,8 @@ def get_service_version(self) -> str | None:

def get_dataset_version(self, dataset: str | None = None) -> str | None:
"""Return the version for `dataset`, if this benchmark tracks one."""
return None
entry = self.dataset_versions.get(dataset or "default")
return entry.version if entry is not None else None

def get_dataset(self, dataset: str | None = None) -> dict[str, Any]:
"""Get a specific dataset by name. Defaults to 'default'."""
Expand Down
88 changes: 88 additions & 0 deletions src/benchmark_service/dataset_versioning.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
"""Hash-guarded dataset version tracking.

A dataset_versions.yaml next to the dataset files maps each dataset name to a
human-assigned semver and a sha256 of its content file. Version semantics:
major = scores not comparable, minor = additive, patch = non-scoring fixes.
"""

import hashlib
import sys
from pathlib import Path
from typing import Any, cast

import yaml
from pydantic import BaseModel


class DatasetVersionError(Exception):
"""Dataset content does not match its declared version entry."""


class DatasetVersionEntry(BaseModel):
file: str
version: str
sha256: str


def compute_checksum(path: Path) -> str:
return hashlib.sha256(path.read_bytes()).hexdigest()


def _load_versions_mapping(versions_file: Path) -> dict[str, Any]:
"""Load the versions file as a mapping; an empty or non-mapping file is a clear error."""
data = yaml.safe_load(versions_file.read_text())
if not isinstance(data, dict):
raise DatasetVersionError(
f"{versions_file} must be a YAML mapping of dataset name to entry, got {type(data).__name__}"
)
return cast("dict[str, Any]", data)


def load_dataset_versions(versions_file: Path) -> dict[str, DatasetVersionEntry]:
return {name: DatasetVersionEntry.model_validate(entry) for name, entry in _load_versions_mapping(versions_file).items()}


def load_verified_dataset_versions(versions_file: Path) -> dict[str, DatasetVersionEntry]:
"""Load entries and verify every dataset file matches its declared checksum.

Raises DatasetVersionError on any mismatch: content that does not match its
declared version must never be served.
"""
entries = load_dataset_versions(versions_file)
data_dir = versions_file.parent
mismatches: list[str] = []
for name, entry in entries.items():
actual = compute_checksum(data_dir / entry.file)
if actual != entry.sha256:
mismatches.append(f"{name} ({entry.file}): declared {entry.sha256}, actual {actual}")
if mismatches:
raise DatasetVersionError(
"dataset content does not match dataset_versions.yaml — bump the version, then run "
"`python -m benchmark_service.dataset_versioning update <file>`:\n " + "\n ".join(mismatches)
)
return entries


def main(argv: list[str]) -> int:
if len(argv) != 2 or argv[0] not in ("check", "update"):
print("usage: python -m benchmark_service.dataset_versioning {check|update} <dataset_versions.yaml>")
return 2
command, versions_file = argv[0], Path(argv[1])
if command == "check":
try:
load_verified_dataset_versions(versions_file)
except DatasetVersionError as exc:
print(exc)
return 1
print("dataset checksums OK")
return 0
raw = _load_versions_mapping(versions_file)
for entry in raw.values():
entry["sha256"] = compute_checksum(versions_file.parent / entry["file"])
versions_file.write_text(yaml.safe_dump(raw, sort_keys=False))
print(f"updated {versions_file}")
return 0


if __name__ == "__main__":
sys.exit(main(sys.argv[1:]))
1 change: 1 addition & 0 deletions src/benchmark_service/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,7 @@ class VersionResponse(BaseModel):
framework_version: str
service_name: str | None = None
service_version: str | None = None
dataset_version: str | None = None


class StreamMessageChunk(BaseModel):
Expand Down
9 changes: 9 additions & 0 deletions templates/README.md.jinja
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,12 @@ make dev # run local server
make test # run tests
make help # list all commands
```

## Releasing

Versions come from git tags (hatch-vcs). To enable the release flow in this repo:

- Add a **`GH_PAT`** repository secret — a personal access token with permission to push tags (`contents: write`). `auto-tag-release` uses it to push the version tag on merge to `main`; the workflow fails without it.
- Every PR title must include **`#patch`**, **`#minor`**, or **`#major`** — `check-pr-title` enforces this and the tag is bumped accordingly on merge.

Until the first tag exists, the package builds at version `0.0.0`.
10 changes: 8 additions & 2 deletions templates/pyproject.toml.jinja
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "{{ benchmark_name }}-benchmark-service"
version = "0.1.0"
dynamic = ["version"]
readme = "README.md"
requires-python = ">=3.12"

Expand All @@ -17,9 +17,15 @@ dev = [
]

[build-system]
requires = ["hatchling"]
requires = ["hatchling", "hatch-vcs"]
build-backend = "hatchling.build"

[tool.hatch.version]
source = "vcs"
# A freshly scaffolded project has no git tags (or no repo yet); fall back so it
# still builds. Real versions come from tags once the repo is set up.
fallback-version = "0.0.0"

[tool.hatch.build.targets.wheel]
packages = ["src/{{ benchmark_package }}"]

Expand Down
1 change: 1 addition & 0 deletions tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ def test_sandbox_config_rejects_unknown_provider(monkeypatch: pytest.MonkeyPatch
"framework_version": "0.7.4",
"service_name": "legal-research-benchmark-service",
"service_version": "1.2.3",
"dataset_version": "3.0.0",
},
),
(
Expand Down
84 changes: 84 additions & 0 deletions tests/test_dataset_versioning.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
"""Tests for hash-guarded dataset version tracking."""

from pathlib import Path

import pytest
import yaml

from benchmark_service.dataset_versioning import (
DatasetVersionError,
compute_checksum,
load_dataset_versions,
load_verified_dataset_versions,
main,
)

from tests.conftest import StubBenchmark


def _write_fixture(tmp_path: Path, content: bytes = b'{"tests": []}') -> Path:
data_file = tmp_path / "validation.json"
data_file.write_bytes(content)
versions_file = tmp_path / "dataset_versions.yaml"
versions_file.write_text(
yaml.safe_dump(
{
"validation": {
"file": "validation.json",
"version": "1.0.0",
"sha256": compute_checksum(data_file),
}
}
)
)
return versions_file


def test_load_verified_returns_entries_when_content_matches(tmp_path: Path) -> None:
versions_file = _write_fixture(tmp_path)
entries = load_verified_dataset_versions(versions_file)
assert entries["validation"].version == "1.0.0"


def test_load_verified_raises_on_content_mismatch(tmp_path: Path) -> None:
versions_file = _write_fixture(tmp_path)
(tmp_path / "validation.json").write_bytes(b'{"tests": [1]}')
with pytest.raises(DatasetVersionError, match="validation"):
load_verified_dataset_versions(versions_file)


def test_empty_versions_file_raises_clear_error(tmp_path: Path) -> None:
versions_file = tmp_path / "dataset_versions.yaml"
versions_file.write_text("# no entries yet\n")
with pytest.raises(DatasetVersionError, match="must be a YAML mapping"):
load_dataset_versions(versions_file)


def test_check_command_fails_on_mismatch_and_update_repairs(tmp_path: Path) -> None:
versions_file = _write_fixture(tmp_path)
(tmp_path / "validation.json").write_bytes(b'{"tests": [1]}')
assert main(["check", str(versions_file)]) == 1
assert main(["update", str(versions_file)]) == 0
assert main(["check", str(versions_file)]) == 0


async def test_service_startup_verifies_and_serves_dataset_versions(tmp_path: Path) -> None:
versions_file = _write_fixture(tmp_path)

class VersionedBenchmark(StubBenchmark):
dataset_versions_file = versions_file

service = await VersionedBenchmark.create()
assert service.get_dataset_version("validation") == "1.0.0"
assert service.get_dataset_version("unknown") is None


async def test_service_startup_fails_on_checksum_mismatch(tmp_path: Path) -> None:
versions_file = _write_fixture(tmp_path)
(tmp_path / "validation.json").write_bytes(b"tampered")

class VersionedBenchmark(StubBenchmark):
dataset_versions_file = versions_file

with pytest.raises(DatasetVersionError):
await VersionedBenchmark.create()
6 changes: 3 additions & 3 deletions tests/test_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,17 +131,17 @@ def test_generates_project_structure() -> None:
assert (output_dir / ".github" / "workflows").is_dir()


def test_generated_project_excludes_framework_versioning_workflows(tmp_path: Path) -> None:
def test_generated_project_includes_release_workflows_but_not_cli_integration(tmp_path: Path) -> None:
output_dir = tmp_path / "swebench-benchmark-service"
generate_project("swebench", output_dir)

workflows_dir = output_dir / ".github" / "workflows"
assert (workflows_dir / "test.yaml").exists()
assert (workflows_dir / "style.yaml").exists()
assert (workflows_dir / "typecheck.yaml").exists()
assert (workflows_dir / "auto-tag-release.yaml").exists()
assert (workflows_dir / "check-pr-title.yaml").exists()
assert not (workflows_dir / "cli-integration.yaml").exists()
assert not (workflows_dir / "auto-tag-release.yaml").exists()
assert not (workflows_dir / "check-pr-title.yaml").exists()


def test_generated_benchmark_service_implements_task_listing(tmp_path: Path) -> None:
Expand Down
26 changes: 26 additions & 0 deletions tests/test_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import re
from collections.abc import AsyncGenerator
from pathlib import Path
from typing import Any
from unittest.mock import patch

Expand All @@ -18,6 +19,8 @@
StreamChunk,
)

from tests.test_dataset_versioning import _write_fixture # pyright: ignore[reportPrivateUsage]


def test_version_is_importable_and_well_formed() -> None:
assert isinstance(benchmark_service.__version__, str)
Expand Down Expand Up @@ -100,3 +103,26 @@ def test_version_endpoint_prefers_service_version_hook() -> None:

assert response.status_code == 200
assert response.json()["service_version"] == "service-hook-1.2.3"


def test_version_endpoint_reports_dataset_version_key() -> None:
app = BenchmarkServiceApp(_FakeService)
with TestClient(app) as client:
response = client.get("/version", params={"dataset": "default"})

assert response.status_code == 200
assert "dataset_version" in response.json()


def test_version_endpoint_reports_tracked_dataset_version(tmp_path: Path) -> None:
versions_file = _write_fixture(tmp_path)

class _DatasetVersionedService(_FakeService):
dataset_versions_file = versions_file

app = BenchmarkServiceApp(_DatasetVersionedService)
with TestClient(app) as client:
response = client.get("/version", params={"dataset": "validation"})

assert response.status_code == 200
assert response.json()["dataset_version"] == "1.0.0"
Loading