Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 56 additions & 0 deletions encord/orm/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -600,3 +600,59 @@ class ProjectUserResponse(BaseDTO):

user_email: str
user_role: ProjectUserRole


# --- Bulk classifications endpoint DTOs ---


class BulkClassificationsPayload(BaseDTO):
"""Request payload for ``POST /projects/{uuid}/label-rows/classifications``."""

label_uuids: List[str]
branch_name: str = "main"


class ClassificationAnswerObject(BaseDTO):
"""Nested answer within a :class:`ClassificationAnswerEntry`."""

name: str
value: str
answers: Any # list[AnswersAnswer] | str | float
feature_hash: str
manual_annotation: bool


class ClassificationAnswerEntry(BaseDTO):
"""Single classification answer returned by the bulk classifications endpoint.

Matches the ``ClassificationAnswer`` shape produced by the legacy
``initialise_labels`` SDK path.
"""

classification_hash: str
feature_hash: Optional[str] = None
classifications: List[ClassificationAnswerObject]
range: Optional[List[Tuple[int, int]]] = None
created_at: Optional[str] = None
created_by: Optional[str] = None
last_edited_at: Optional[str] = None
last_edited_by: Optional[str] = None
manual_annotation: Optional[bool] = None


class LabelClassificationsEntry(BaseDTO):
"""One label row's classification data in the bulk response."""

label_uuid: UUID
data_uuid: UUID
data_title: Optional[str] = None
branch_name: str = "main"
classification_answers: Dict[str, ClassificationAnswerEntry] = {}
export_uuid: Optional[UUID] = None
exported_at: Optional[datetime.datetime] = None


class BulkClassificationsResponse(BaseDTO):
"""Response from ``POST /projects/{uuid}/label-rows/classifications``."""

labels: List[LabelClassificationsEntry]
33 changes: 33 additions & 0 deletions encord/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,11 @@
)
from encord.orm.project import (
AddProjectIssueTagsPayload,
BulkClassificationsPayload,
BulkClassificationsResponse,
CopyDatasetOptions,
CopyLabelsOptions,
LabelClassificationsEntry,
ProjectDataset,
ProjectDTO,
ProjectStatus,
Expand Down Expand Up @@ -287,6 +290,36 @@ def list_label_rows_v2(
]
return label_rows

def get_label_classifications(
self,
label_uuids: List[Union[str, UUID]],
branch_name: str = "main",
batch_size: int = 500,
Comment thread
deoracord marked this conversation as resolved.
) -> Iterable[LabelClassificationsEntry]:
"""Fast bulk fetch of classification answers only.

Args:
label_uuids: Label row UUIDs to fetch classifications for.
Obtain these from :meth:`list_label_rows_v2` (``row.label_hash``).
branch_name: Label branch to read from (default ``"main"``).
batch_size: Number of labels per server request (max 1000).

Returns:
An iterable of :class:`~encord.orm.project.LabelClassificationsEntry`,
one per label row. Each entry contains ``classification_answers``
keyed by classification hash.
"""
uuid_strs = [str(u) for u in label_uuids]
for i in range(0, len(uuid_strs), batch_size):
batch = uuid_strs[i : i + batch_size]
response = self._api_client.post(
f"projects/{self.project_hash}/label-rows/classifications",
params=None,
payload=BulkClassificationsPayload(label_uuids=batch, branch_name=branch_name),
result_type=BulkClassificationsResponse,
)
yield from response.labels

def add_users(self, user_emails: List[str], user_role: ProjectUserRole) -> List[ProjectUser]:
"""Add users to the project.

Expand Down
76 changes: 76 additions & 0 deletions tests/test_label_classifications.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
import uuid
from unittest.mock import MagicMock, patch

from encord.http.v2.api_client import ApiClient
from encord.orm.project import (
BulkClassificationsPayload,
BulkClassificationsResponse,
LabelClassificationsEntry,
)
from encord.project import Project


def _make_entry(**overrides) -> LabelClassificationsEntry:
defaults = dict(label_uuid=uuid.uuid4(), data_uuid=uuid.uuid4())
defaults.update(overrides)
return LabelClassificationsEntry(**defaults)


def _make_response(*entries: LabelClassificationsEntry) -> BulkClassificationsResponse:
return BulkClassificationsResponse(labels=list(entries))


@patch.object(ApiClient, "post")
def test_single_batch(api_post: MagicMock, project: Project) -> None:
entry = _make_entry()
api_post.return_value = _make_response(entry)

label_uuids = [uuid.uuid4() for _ in range(3)]
result = list(project.get_label_classifications(label_uuids))

api_post.assert_called_once()
_, kwargs = api_post.call_args
assert kwargs["payload"] == BulkClassificationsPayload(
label_uuids=[str(u) for u in label_uuids],
branch_name="main",
)
assert kwargs["result_type"] is BulkClassificationsResponse
assert f"projects/{project.project_hash}/label-rows/classifications" in api_post.call_args.args[0]
assert result == [entry]


@patch.object(ApiClient, "post")
def test_multiple_batches(api_post: MagicMock, project: Project) -> None:
entry_a = _make_entry()
entry_b = _make_entry()
api_post.side_effect = [_make_response(entry_a), _make_response(entry_b)]

label_uuids = [uuid.uuid4() for _ in range(5)]
result = list(project.get_label_classifications(label_uuids, batch_size=3))

assert api_post.call_count == 2

first_payload = api_post.call_args_list[0][1]["payload"]
second_payload = api_post.call_args_list[1][1]["payload"]
assert first_payload.label_uuids == [str(u) for u in label_uuids[:3]]
assert second_payload.label_uuids == [str(u) for u in label_uuids[3:]]

assert result == [entry_a, entry_b]


@patch.object(ApiClient, "post")
def test_custom_branch_name(api_post: MagicMock, project: Project) -> None:
api_post.return_value = _make_response(_make_entry())

list(project.get_label_classifications([uuid.uuid4()], branch_name="my-branch"))

payload = api_post.call_args[1]["payload"]
assert payload.branch_name == "my-branch"


@patch.object(ApiClient, "post")
def test_empty_input(api_post: MagicMock, project: Project) -> None:
result = list(project.get_label_classifications([]))

api_post.assert_not_called()
assert result == []
Loading