diff --git a/encord/orm/project.py b/encord/orm/project.py index 3707b325a..2f66a1bf9 100644 --- a/encord/orm/project.py +++ b/encord/orm/project.py @@ -600,3 +600,59 @@ class ProjectUserResponse(BaseDTO): user_email: str user_role: ProjectUserRole + + +# --- Bulk classifications endpoint DTOs --- + + +class BulkClassificationsPayload(BaseDTO): + """Request payload for ``POST /projects/{uuid}/label-rows/classifications``.""" + + label_uuids: List[str] + branch_name: str = "main" + + +class ClassificationAnswerObject(BaseDTO): + """Nested answer within a :class:`ClassificationAnswerEntry`.""" + + name: str + value: str + answers: Any # list[AnswersAnswer] | str | float + feature_hash: str + manual_annotation: bool + + +class ClassificationAnswerEntry(BaseDTO): + """Single classification answer returned by the bulk classifications endpoint. + + Matches the ``ClassificationAnswer`` shape produced by the legacy + ``initialise_labels`` SDK path. + """ + + classification_hash: str + feature_hash: Optional[str] = None + classifications: List[ClassificationAnswerObject] + range: Optional[List[Tuple[int, int]]] = None + created_at: Optional[str] = None + created_by: Optional[str] = None + last_edited_at: Optional[str] = None + last_edited_by: Optional[str] = None + manual_annotation: Optional[bool] = None + + +class LabelClassificationsEntry(BaseDTO): + """One label row's classification data in the bulk response.""" + + label_uuid: UUID + data_uuid: UUID + data_title: Optional[str] = None + branch_name: str = "main" + classification_answers: Dict[str, ClassificationAnswerEntry] = {} + export_uuid: Optional[UUID] = None + exported_at: Optional[datetime.datetime] = None + + +class BulkClassificationsResponse(BaseDTO): + """Response from ``POST /projects/{uuid}/label-rows/classifications``.""" + + labels: List[LabelClassificationsEntry] diff --git a/encord/project.py b/encord/project.py index 9c1e74e8a..2bf6e79a1 100644 --- a/encord/project.py +++ b/encord/project.py @@ -50,8 +50,11 @@ ) from encord.orm.project import ( AddProjectIssueTagsPayload, + BulkClassificationsPayload, + BulkClassificationsResponse, CopyDatasetOptions, CopyLabelsOptions, + LabelClassificationsEntry, ProjectDataset, ProjectDTO, ProjectStatus, @@ -287,6 +290,36 @@ def list_label_rows_v2( ] return label_rows + def get_label_classifications( + self, + label_uuids: List[Union[str, UUID]], + branch_name: str = "main", + batch_size: int = 500, + ) -> Iterable[LabelClassificationsEntry]: + """Fast bulk fetch of classification answers only. + + Args: + label_uuids: Label row UUIDs to fetch classifications for. + Obtain these from :meth:`list_label_rows_v2` (``row.label_hash``). + branch_name: Label branch to read from (default ``"main"``). + batch_size: Number of labels per server request (max 1000). + + Returns: + An iterable of :class:`~encord.orm.project.LabelClassificationsEntry`, + one per label row. Each entry contains ``classification_answers`` + keyed by classification hash. + """ + uuid_strs = [str(u) for u in label_uuids] + for i in range(0, len(uuid_strs), batch_size): + batch = uuid_strs[i : i + batch_size] + response = self._api_client.post( + f"projects/{self.project_hash}/label-rows/classifications", + params=None, + payload=BulkClassificationsPayload(label_uuids=batch, branch_name=branch_name), + result_type=BulkClassificationsResponse, + ) + yield from response.labels + def add_users(self, user_emails: List[str], user_role: ProjectUserRole) -> List[ProjectUser]: """Add users to the project. diff --git a/tests/test_label_classifications.py b/tests/test_label_classifications.py new file mode 100644 index 000000000..f0e006f8d --- /dev/null +++ b/tests/test_label_classifications.py @@ -0,0 +1,76 @@ +import uuid +from unittest.mock import MagicMock, patch + +from encord.http.v2.api_client import ApiClient +from encord.orm.project import ( + BulkClassificationsPayload, + BulkClassificationsResponse, + LabelClassificationsEntry, +) +from encord.project import Project + + +def _make_entry(**overrides) -> LabelClassificationsEntry: + defaults = dict(label_uuid=uuid.uuid4(), data_uuid=uuid.uuid4()) + defaults.update(overrides) + return LabelClassificationsEntry(**defaults) + + +def _make_response(*entries: LabelClassificationsEntry) -> BulkClassificationsResponse: + return BulkClassificationsResponse(labels=list(entries)) + + +@patch.object(ApiClient, "post") +def test_single_batch(api_post: MagicMock, project: Project) -> None: + entry = _make_entry() + api_post.return_value = _make_response(entry) + + label_uuids = [uuid.uuid4() for _ in range(3)] + result = list(project.get_label_classifications(label_uuids)) + + api_post.assert_called_once() + _, kwargs = api_post.call_args + assert kwargs["payload"] == BulkClassificationsPayload( + label_uuids=[str(u) for u in label_uuids], + branch_name="main", + ) + assert kwargs["result_type"] is BulkClassificationsResponse + assert f"projects/{project.project_hash}/label-rows/classifications" in api_post.call_args.args[0] + assert result == [entry] + + +@patch.object(ApiClient, "post") +def test_multiple_batches(api_post: MagicMock, project: Project) -> None: + entry_a = _make_entry() + entry_b = _make_entry() + api_post.side_effect = [_make_response(entry_a), _make_response(entry_b)] + + label_uuids = [uuid.uuid4() for _ in range(5)] + result = list(project.get_label_classifications(label_uuids, batch_size=3)) + + assert api_post.call_count == 2 + + first_payload = api_post.call_args_list[0][1]["payload"] + second_payload = api_post.call_args_list[1][1]["payload"] + assert first_payload.label_uuids == [str(u) for u in label_uuids[:3]] + assert second_payload.label_uuids == [str(u) for u in label_uuids[3:]] + + assert result == [entry_a, entry_b] + + +@patch.object(ApiClient, "post") +def test_custom_branch_name(api_post: MagicMock, project: Project) -> None: + api_post.return_value = _make_response(_make_entry()) + + list(project.get_label_classifications([uuid.uuid4()], branch_name="my-branch")) + + payload = api_post.call_args[1]["payload"] + assert payload.branch_name == "my-branch" + + +@patch.object(ApiClient, "post") +def test_empty_input(api_post: MagicMock, project: Project) -> None: + result = list(project.get_label_classifications([])) + + api_post.assert_not_called() + assert result == []