diff --git a/docs/features/entity_counting.md b/docs/features/entity_counting.md index 0de1fa3..0afe240 100644 --- a/docs/features/entity_counting.md +++ b/docs/features/entity_counting.md @@ -40,7 +40,25 @@ results = await orchestrator.synthesize( ) ``` -**2. Dedicated LLM Client (`counting_llm_client`)** +**2. Sharded / Parallel Counting** + +To handle complex counting scenarios (e.g. searching for distinct variations of an entity with disjoint rules), you can pass a `list[str]` of multiple specialized prompts to `custom_counting_context`. + +The system automatically executes these prompts concurrently (in "shards"), achieves consensus on each shard individually, and then merges and deduplicates the final results. This parallelization prevents massive monolithic prompts from degrading LLM performance without increasing wall-clock time. + +```python +results = await orchestrator.synthesize( + ..., + count_entities=True, + custom_counting_context=[ + "Count only Domestic invoices and describe their destinations.", + "Count only International invoices and describe their destinations.", + "Count any catch-all invoices that don't fit the above rules." + ] +) +``` + +**3. Dedicated LLM Client (`counting_llm_client`)** You can use a different model for counting (e.g., a faster/cheaper one) than for the main extraction. diff --git a/docs/workflow_orchestrator.rst b/docs/workflow_orchestrator.rst index c39b027..686fb68 100644 --- a/docs/workflow_orchestrator.rst +++ b/docs/workflow_orchestrator.rst @@ -171,7 +171,7 @@ Once the orchestrator is configured, you can start processing documents using on * ``input_strings`` (``List[str]``): A list of strings, where each string is a document to be processed. * ``db_session_for_hydration`` (``Optional[Session]``): An optional SQLAlchemy session. If provided, the hydrator will use it to resolve relationships. If not, a temporary in-memory session is created. * ``count_entities`` (``bool``, default ``False``): If True, performs an initial pass to count entities before extraction. - * ``custom_counting_context`` (``str``, optional): Custom instructions or context specifically for the entity counting phase. + * ``custom_counting_context`` (``Union[str, List[str]]``, optional): Custom instructions or context specifically for the entity counting phase. If a list of strings is provided, sharded parallel counting is enabled, where each string acts as a separate shard that is executed concurrently and the results are merged. * ``extraction_example_json`` (``str``, optional): A JSON string that provides a few-shot example to the LLM, guiding it to produce a better-structured output. If not provided, the orchestrator will attempt to auto-generate one. * ``extraction_example_object`` (``Optional[Union[SQLModel, List[SQLModel]]]``, optional): An existing SQLModel object or a list of them to be used as the few-shot example. This is an alternative to providing the example as a raw JSON string. * ``custom_extraction_process`` (``str``, optional): Custom, step-by-step instructions for the LLM on how to perform the extraction. diff --git a/src/extrai/core/batch/batch_processor.py b/src/extrai/core/batch/batch_processor.py index 7a8f1ad..7413a1c 100644 --- a/src/extrai/core/batch/batch_processor.py +++ b/src/extrai/core/batch/batch_processor.py @@ -76,6 +76,10 @@ async def process_batch( ) -> BatchProcessResult: status = await self.status_checker.get_status(root_batch_id, db_session) context = db_session.get(BatchJobContext, root_batch_id) + if context is None: + return BatchProcessResult( + status=BatchJobStatus.FAILED, message="Batch context not found" + ) # 1. Already Completed if status == BatchJobStatus.COMPLETED and context.results: @@ -144,6 +148,12 @@ async def _process_counting_completion( if not lines: raise ValueError("Empty results content") + import re + + revisions = [] + shard_buckets = {} # {shard_idx: [revisions]} + is_sharded = False + # Parse each line as a revision for raw_content in lines: try: @@ -152,7 +162,15 @@ async def _process_counting_completion( else: wrapper = raw_content + custom_id = wrapper.get("custom_id", "") + shard_match = re.search(r"_shard_(\d+)$", custom_id) + shard_idx = None + if shard_match: + is_sharded = True + shard_idx = int(shard_match.group(1)) + # Check if it's wrapped in OpenAI batch response format + parsed_json = None if "response" in wrapper and "body" in wrapper.get("response", {}): body = wrapper["response"]["body"] @@ -168,32 +186,37 @@ async def _process_counting_completion( if "choices" in body and body["choices"]: content = body["choices"][0]["message"]["content"] parsed_json = json.loads(content) - revisions.append(parsed_json) else: # Maybe it's directly the JSON string or dict if isinstance(wrapper, str): - revisions.append(json.loads(wrapper)) + parsed_json = json.loads(wrapper) elif isinstance(wrapper, dict): - revisions.append(wrapper) + parsed_json = wrapper + + if parsed_json: + if is_sharded and shard_idx is not None: + if shard_idx not in shard_buckets: + shard_buckets[shard_idx] = [] + shard_buckets[shard_idx].append(parsed_json) + else: + revisions.append(parsed_json) except json.JSONDecodeError as e: self.logger.error(f"Failed to parse counting result as JSON: {e}") continue - self.logger.debug(f"Parsed {len(revisions)} counting revisions") + self.logger.debug( + f"Parsed {len(revisions)} non-sharded counting revisions and {len(shard_buckets)} shards" + ) # Recreate original prompts to use for consensus fallback if needed from extrai.core.prompts.counting import ( generate_entity_counting_system_prompt, generate_entity_counting_user_prompt, ) + from extrai.utils.serialization_utils import resolve_step_param schema_json = self.model_registry.get_schema_for_models(target_model_names) - system_prompt = generate_entity_counting_system_prompt( - target_model_names, - schema_json, - context.config.custom_counting_context, - ) user_prompt = generate_entity_counting_user_prompt(context.input_strings) target_json_schema = ( self.entity_counter.get_counting_model( @@ -203,15 +226,79 @@ async def _process_counting_completion( else None ) - # Achieve consensus - consensus_result = ( - await self.entity_counter.counting_consensus.achieve_consensus( - revisions=revisions, - system_prompt=system_prompt, - user_prompt=user_prompt, - target_json_schema=target_json_schema, + consensus_result = [] + + if is_sharded: + resolved_context = resolve_step_param( + context.config.custom_counting_context, + context.config.current_model_index + if context.config.hierarchical + else 0, + len(self.model_registry.models) + if context.config.hierarchical + else 1, + ) + if not isinstance(resolved_context, list): + resolved_context = [resolved_context] # fallback + + for shard_idx, shard_revisions in shard_buckets.items(): + shard_ctx = ( + resolved_context[shard_idx] + if shard_idx < len(resolved_context) + else resolved_context[-1] + ) + system_prompt = generate_entity_counting_system_prompt( + target_model_names, + schema_json, + shard_ctx, + ) + shard_consensus = ( + await self.entity_counter.counting_consensus.achieve_consensus( + revisions=shard_revisions, + system_prompt=system_prompt, + user_prompt=user_prompt, + target_json_schema=target_json_schema, + ) + ) + consensus_result.extend(shard_consensus) + + # Deduplicate + seen = set() + deduped = [] + for e in consensus_result: + e_str = json.dumps(e, sort_keys=True) + if e_str not in seen: + seen.add(e_str) + deduped.append(e) + consensus_result = deduped + + else: + resolved_context = resolve_step_param( + context.config.custom_counting_context, + context.config.current_model_index + if context.config.hierarchical + else 0, + len(self.model_registry.models) + if context.config.hierarchical + else 1, + ) + if isinstance(resolved_context, list): + resolved_context = resolved_context[ + 0 + ] # Should not happen if it wasn't sharded + system_prompt = generate_entity_counting_system_prompt( + target_model_names, + schema_json, + resolved_context, + ) + consensus_result = ( + await self.entity_counter.counting_consensus.achieve_consensus( + revisions=revisions, + system_prompt=system_prompt, + user_prompt=user_prompt, + target_json_schema=target_json_schema, + ) ) - ) # Filter out any hallucinated models not in target_model_names entity_descriptions = [ @@ -372,7 +459,6 @@ async def _handle_batch_retry( return BatchProcessResult( status=BatchJobStatus.FAILED, message="All revisions failed validation, cannot retry.", - errors=validation_errors, ) # Store valid partial results @@ -395,7 +481,6 @@ async def _handle_batch_retry( return BatchProcessResult( status=BatchJobStatus.SUBMITTED, message=f"Partial success. Retrying {num_to_retry} failed revisions.", - errors=validation_errors, ) async def _finalize_completion( diff --git a/src/extrai/core/batch/batch_submitter.py b/src/extrai/core/batch/batch_submitter.py index 7449764..508ab00 100644 --- a/src/extrai/core/batch/batch_submitter.py +++ b/src/extrai/core/batch/batch_submitter.py @@ -269,21 +269,48 @@ async def _submit_counting_phase( total_steps, ) - system_prompt, user_prompt = self.entity_counter.prepare_counting_prompts( - input_strings, - model_names, - resolved_context, - previous_entities=previous_entities if previous_entities else None, - examples=examples, - ) - client = self.entity_counter.llm_client - requests = self._create_batch_requests( - system_prompt, - user_prompt, - num_revisions=self.config.num_counting_revisions, - override_client=client, - ) + requests = [] + + if isinstance(resolved_context, list): + self.logger.info( + f"Preparing sharded counting batch with {len(resolved_context)} shards" + ) + for shard_idx, shard_context in enumerate(resolved_context): + system_prompt, user_prompt = ( + self.entity_counter.prepare_counting_prompts( + input_strings, + model_names, + shard_context, + previous_entities=previous_entities + if previous_entities + else None, + examples=examples, + ) + ) + shard_requests = self._create_batch_requests( + system_prompt, + user_prompt, + num_revisions=self.config.num_counting_revisions, + override_client=client, + ) + for req in shard_requests: + req["custom_id"] = f"{req['custom_id']}_shard_{shard_idx}" + requests.extend(shard_requests) + else: + system_prompt, user_prompt = self.entity_counter.prepare_counting_prompts( + input_strings, + model_names, + resolved_context, + previous_entities=previous_entities if previous_entities else None, + examples=examples, + ) + requests = self._create_batch_requests( + system_prompt, + user_prompt, + num_revisions=self.config.num_counting_revisions, + override_client=client, + ) response_model = None if self.config.use_structured_output: @@ -346,27 +373,41 @@ async def _submit_extraction_phase( len(self.model_registry.models) if context.config.hierarchical else 1 ) + from typing import cast + request = self.request_factory.prepare_request( input_strings=context.input_strings, config=self.config, extraction_example_json=context.config.extraction_example_json, - custom_extraction_process=resolve_step_param( - context.config.custom_extraction_process, - step_index, - total_steps, + custom_extraction_process=cast( + str, + resolve_step_param( + context.config.custom_extraction_process, + step_index, + total_steps, + ), ), - custom_extraction_guidelines=resolve_step_param( - context.config.custom_extraction_guidelines, - step_index, - total_steps, + custom_extraction_guidelines=cast( + str, + resolve_step_param( + context.config.custom_extraction_guidelines, + step_index, + total_steps, + ), ), - custom_final_checklist=resolve_step_param( - context.config.custom_final_checklist, - step_index, - total_steps, + custom_final_checklist=cast( + str, + resolve_step_param( + context.config.custom_final_checklist, + step_index, + total_steps, + ), ), - custom_context=resolve_step_param( - context.config.custom_context, step_index, total_steps + custom_context=cast( + str, + resolve_step_param( + context.config.custom_context, step_index, total_steps + ), ), expected_entity_descriptions=context.config.expected_entity_descriptions, previous_entities=previous_entities if previous_entities else None, @@ -410,7 +451,7 @@ def _create_batch_requests( self, system_prompt: str, user_prompt: str, - json_schema: str | None = None, + json_schema: dict[str, Any] | str | None = None, num_revisions: int | None = None, override_client: BaseLLMClient | None = None, ) -> list[dict]: diff --git a/src/extrai/core/counting_consensus.py b/src/extrai/core/counting_consensus.py index 54a87c9..2e65682 100644 --- a/src/extrai/core/counting_consensus.py +++ b/src/extrai/core/counting_consensus.py @@ -2,8 +2,8 @@ import logging from typing import Any -from .extraction_config import ExtractionConfig from ..utils.alignment_utils import align_entity_arrays, calculate_similarity +from .extraction_config import ExtractionConfig class CountingConsensus: diff --git a/src/extrai/core/entity_counter.py b/src/extrai/core/entity_counter.py index 041413f..94f89bd 100644 --- a/src/extrai/core/entity_counter.py +++ b/src/extrai/core/entity_counter.py @@ -88,21 +88,13 @@ async def count_entities( self, input_strings: list[str], model_names: list[str], - custom_counting_context: str = "", + custom_counting_context: str | list[str] = "", previous_entities: list[dict[str, Any]] | None = None, examples: str = "", ) -> list[dict[str, Any]]: """Performs entity counting for specified models using consensus.""" self.logger.info(f"Counting entities for: {model_names}") - system_prompt, user_prompt = self.prepare_counting_prompts( - input_strings, - model_names, - custom_counting_context, - previous_entities, - examples, - ) - target_json_schema = ( EntityCountResult.model_json_schema() if self.config.use_structured_output @@ -113,8 +105,14 @@ async def count_entities( if isinstance(client, list): client = client[0] - try: - # Execute multiple revisions natively + async def _count_shard(shard_context: str) -> list[dict[str, Any]]: + system_prompt, user_prompt = self.prepare_counting_prompts( + input_strings, + model_names, + shard_context, + previous_entities, + examples, + ) revisions = await client.generate_and_validate_raw_json_output( system_prompt=system_prompt, user_prompt=user_prompt, @@ -144,6 +142,33 @@ async def count_entities( return consensus_result + import asyncio + import json + + def _deduplicate(entities: list[dict[str, Any]]) -> list[dict[str, Any]]: + seen = set() + deduped = [] + for e in entities: + e_str = json.dumps(e, sort_keys=True) + if e_str not in seen: + seen.add(e_str) + deduped.append(e) + return deduped + + try: + if isinstance(custom_counting_context, list): + self.logger.info( + f"Running sharded counting with {len(custom_counting_context)} shards" + ) + tasks = [_count_shard(ctx) for ctx in custom_counting_context] + shard_results = await asyncio.gather(*tasks) + merged_results = [] + for res in shard_results: + merged_results.extend(res) + return _deduplicate(merged_results) + else: + return await _count_shard(custom_counting_context) + except Exception as e: self.logger.error(f"Entity counting failed: {e}") return [] diff --git a/src/extrai/core/schema_inspector.py b/src/extrai/core/schema_inspector.py index bc3c8c1..17ec1a3 100644 --- a/src/extrai/core/schema_inspector.py +++ b/src/extrai/core/schema_inspector.py @@ -547,6 +547,7 @@ def discover_sqlmodels_from_root( all_discovered_models=all_discovered_models, # type: ignore[arg-type] recursion_guard=set(), ) + all_discovered_models = self._topological_sort_models(all_discovered_models) except Exception as e: self.logger.error( f"Error during SQLModel discovery starting from {root_sqlmodel_class.__name__}: {e}" @@ -554,3 +555,61 @@ def discover_sqlmodels_from_root( return [] return all_discovered_models + + def _topological_sort_models( + self, models: list[type[SQLModel]] + ) -> list[type[SQLModel]]: + """ + Sorts the discovered models topologically based on their foreign key dependencies. + Parents must appear before children. + """ + graph = {m: set() for m in models} + model_by_table = {} + for m in models: + try: + table_name = m.__tablename__ + except AttributeError: + try: + table_name = inspect(m).selectable.name + except Exception: + table_name = m.__name__.lower() + model_by_table[table_name] = m + + for m in models: + try: + insp = inspect(m) + for col_attr in insp.column_attrs: + col = col_attr.expression + for fk in col.foreign_keys: + target_table = fk.column.table.name + if ( + target_table in model_by_table + and model_by_table[target_table] != m + ): + graph[m].add(model_by_table[target_table]) + except Exception: + pass + + sorted_models = [] + visited = set() + temp_visited = set() + + def visit(m): + if m in temp_visited: + return + if m not in visited: + temp_visited.add(m) + for dep in graph[m]: + visit(dep) + temp_visited.remove(m) + visited.add(m) + sorted_models.append(m) + + for m in models: + visit(m) + + for m in models: + if m not in sorted_models: + sorted_models.append(m) + + return sorted_models diff --git a/src/extrai/core/workflow_orchestrator.py b/src/extrai/core/workflow_orchestrator.py index e64d593..863f637 100644 --- a/src/extrai/core/workflow_orchestrator.py +++ b/src/extrai/core/workflow_orchestrator.py @@ -189,6 +189,7 @@ async def synthesize_batch( ... wait_for_completion: If True, waits for the batch job (and any hierarchical steps) to complete. poll_interval: Interval in seconds to poll for status if wait_for_completion is True. + custom_counting_context (str | list[str], optional): Custom context for entity counting. Passing a list enables sharded parallel counting. Defaults to "". Returns: root_batch_id (str) if wait_for_completion is False. diff --git a/src/extrai/llm_providers/base_google_client.py b/src/extrai/llm_providers/base_google_client.py index cb06016..77f2bff 100644 --- a/src/extrai/llm_providers/base_google_client.py +++ b/src/extrai/llm_providers/base_google_client.py @@ -4,6 +4,7 @@ from typing import Any from extrai.core.base_llm_client import ProviderBatchStatus + from .generic_openai_client import GenericOpenAIClient try: diff --git a/src/extrai/llm_providers/vertex_ai_client.py b/src/extrai/llm_providers/vertex_ai_client.py index 1831669..c5fb9d2 100644 --- a/src/extrai/llm_providers/vertex_ai_client.py +++ b/src/extrai/llm_providers/vertex_ai_client.py @@ -140,6 +140,7 @@ def create_inline_batch_job( import json import tempfile import uuid + from google.cloud import storage job_id = str(uuid.uuid4()) @@ -231,9 +232,10 @@ async def retrieve_batch_results(self, batch_id: str) -> str: job = await self.retrieve_batch_job(batch_id) if hasattr(job, "dest") and hasattr(job.dest, "gcs_uri") and job.dest.gcs_uri: - from google.cloud import storage import json + from google.cloud import storage + if hasattr(self, "_credentials") and self._credentials: storage_client = storage.Client( credentials=self._credentials, project=self._project_id diff --git a/src/extrai/utils/serialization_utils.py b/src/extrai/utils/serialization_utils.py index ec246d5..7974a04 100644 --- a/src/extrai/utils/serialization_utils.py +++ b/src/extrai/utils/serialization_utils.py @@ -71,19 +71,19 @@ def make_json_serializable(obj: Any) -> Any: def resolve_step_param( - param: str | list[str], step_index: int = 0, total_steps: int = 1 -) -> str: + param: Any, step_index: int = 0, total_steps: int = 1 +) -> str | list[str]: """ Resolves a parameter that can be a single string or a list of strings to the specific string for the current step. Args: - param: The parameter value (str or list[str]) + param: The parameter value (str, list[str], or list[str | list[str]]) step_index: The current step index (0-based) total_steps: The total number of steps in the process Returns: - The string value for the current step. + The string or list of strings value for the current step. Raises: ValueError: If list length does not match requirements. @@ -97,6 +97,10 @@ def resolve_step_param( if not param: return "" + if total_steps == 1: + # If there is only 1 step, a list of strings represents shards for this single step. + return param + if len(param) == 1: return param[0] diff --git a/tests/core/batch_pipeline/test_batch_counting.py b/tests/core/batch_pipeline/test_batch_counting.py index c85d62c..115a7f0 100644 --- a/tests/core/batch_pipeline/test_batch_counting.py +++ b/tests/core/batch_pipeline/test_batch_counting.py @@ -108,6 +108,50 @@ async def test_submit_batch_counting(self): config = added_context.config self.assertTrue(config.count_entities) + async def test_submit_batch_counting_sharded(self): + # Setup mocks + self.pipeline.entity_counter.prepare_counting_prompts.side_effect = [ + ("sys1", "user1"), + ("sys2", "user2"), + ] + self.pipeline.context_preparer.prepare_example = AsyncMock(return_value="") + + mock_batch_job = MagicMock() + mock_batch_job.id = "counting_batch_id_sharded" + + # Mock the entity_counter's client for counting phase + self.pipeline.entity_counter.llm_client.create_batch_job = AsyncMock( + return_value=mock_batch_job + ) + + # When _create_batch_requests is called it uses prepare_request + self.pipeline.entity_counter.llm_client.prepare_request = MagicMock( + return_value={"model": "test-model", "messages": []} + ) + + # Test submit + root_id = await self.pipeline.submit_batch( + self.mock_session, + ["doc"], + count_entities=True, + custom_counting_context=["shard1", "shard2"], + ) + + # Verify + self.assertIsInstance(root_id, str) + + # Check that create_batch_job was called once with requests from both shards + self.pipeline.entity_counter.llm_client.create_batch_job.assert_called_once() + call_args = self.pipeline.entity_counter.llm_client.create_batch_job.call_args[ + 0 + ] + requests = call_args[0] + + # Assuming 1 revision per shard (from mock_config), we should have 2 total requests + self.assertEqual(len(requests), 2) + self.assertIn("_shard_0", requests[0]["custom_id"]) + self.assertIn("_shard_1", requests[1]["custom_id"]) + async def test_process_batch_counting_transition(self): # Mock Context context = BatchJobContext( @@ -175,6 +219,84 @@ async def test_process_batch_counting_transition(self): [{"model": "RootModel", "description": "desc1"}], ) + async def test_process_batch_counting_transition_sharded(self): + # Mock Context + context = BatchJobContext( + root_batch_id="root_1", + current_batch_id="counting_batch_id", + status=BatchJobStatus.COUNTING_SUBMITTED, + input_strings=["doc"], + config=BatchJobConfig( + count_entities=True, + custom_extraction_process="proc", + custom_counting_context=["shard1", "shard2"], + ), + ) + self.mock_session.get.return_value = context + + self.pipeline.entity_counter.llm_client.get_batch_status = AsyncMock( + return_value=ProviderBatchStatus.COMPLETED + ) + + # We need the retrieve_batch_results to return the raw text that the client processes, + # but actually BatchProcessor calls retrieve_batch_results and then does logic. + # Let's mock extract_content_from_batch_response instead + self.pipeline.entity_counter.llm_client.retrieve_batch_results = AsyncMock( + return_value=[ + json.dumps({"custom_id": "123_shard_0"}), + json.dumps({"custom_id": "456_shard_1"}), + ] + ) + + def mock_extract(line): + if "shard_0" in line: + return '{"RootModel": ["desc1"]}' + return '{"RootModel": ["desc2"]}' + + self.pipeline.entity_counter.llm_client.extract_content_from_batch_response = ( + mock_extract + ) + + mock_extraction_job = MagicMock() + mock_extraction_job.id = "extraction_batch_id" + mock_client_instance = self.pipeline.client_rotator.get_next_client.return_value + mock_client_instance.create_batch_job = AsyncMock( + return_value=mock_extraction_job + ) + + self.pipeline.prompt_builder.build_prompts.return_value = ("sys", "user") + + # Mock counting_consensus + # For shards, achieve_consensus is called per shard + self.pipeline.entity_counter.counting_consensus.achieve_consensus = AsyncMock( + side_effect=[ + [{"model": "RootModel", "description": "desc1"}], + [{"model": "RootModel", "description": "desc2"}], + ] + ) + + result = await self.pipeline.process_batch("root_1", self.mock_session) + + # Verify transition + self.assertEqual(result.status, BatchJobStatus.SUBMITTED) + + # Verify config updated with combined deduplicated descriptions + config = context.config + self.assertIsNotNone(config.expected_entity_descriptions) + self.assertEqual(len(config.expected_entity_descriptions), 2) + self.assertIn( + {"model": "RootModel", "description": "desc1"}, + config.expected_entity_descriptions, + ) + self.assertIn( + {"model": "RootModel", "description": "desc2"}, + config.expected_entity_descriptions, + ) + self.assertEqual( + self.pipeline.entity_counter.counting_consensus.achieve_consensus.call_count, + 2, + ) + if __name__ == "__main__": unittest.main() diff --git a/tests/core/entity_counter/test_entity_counter.py b/tests/core/entity_counter/test_entity_counter.py index cacc201..4c54954 100644 --- a/tests/core/entity_counter/test_entity_counter.py +++ b/tests/core/entity_counter/test_entity_counter.py @@ -58,6 +58,67 @@ async def test_count_entities_success(self, mock_user_prompt, mock_system_prompt self.mock_client.generate_and_validate_raw_json_output.assert_called_once() self.counter.counting_consensus.achieve_consensus.assert_called_once() + @patch("extrai.core.entity_counter.generate_entity_counting_system_prompt") + @patch("extrai.core.entity_counter.generate_entity_counting_user_prompt") + async def test_count_entities_sharded(self, mock_user_prompt, mock_system_prompt): + # Setup mocks + self.mock_model_registry.get_schema_for_models.return_value = ( + '{"type": "object"}' + ) + + # We need the client to return different things for different calls, or just return the same and check deduplication. + mock_result_1 = [ + { + "counted_entities": [ + {"model": "ModelA", "temp_id": "1", "description": "desc1"} + ] + } + ] + mock_result_2 = [ + { + "counted_entities": [ + {"model": "ModelA", "temp_id": "2", "description": "desc2"} + ] + } + ] + + # We'll just have generate_and_validate_raw_json_output return the same list of mock results. + # But wait, it's called twice. We can use a side_effect. + self.mock_client.generate_and_validate_raw_json_output = AsyncMock( + side_effect=[mock_result_1, mock_result_2] + ) + + expected_consensus_1 = [ + {"model": "ModelA", "temp_id": "1", "description": "desc1"} + ] + expected_consensus_2 = [ + {"model": "ModelA", "temp_id": "2", "description": "desc2"}, + {"model": "ModelA", "temp_id": "1", "description": "desc1"}, + ] + + self.counter.counting_consensus.achieve_consensus = AsyncMock( + side_effect=[expected_consensus_1, expected_consensus_2] + ) + + counts = await self.counter.count_entities( + ["doc"], ["ModelA"], custom_counting_context=["shard1", "shard2"] + ) + + # order might matter based on how gather returns, but we can check elements + self.assertEqual(len(counts), 2) + self.assertIn( + {"model": "ModelA", "temp_id": "1", "description": "desc1"}, counts + ) + self.assertIn( + {"model": "ModelA", "temp_id": "2", "description": "desc2"}, counts + ) + self.assertEqual( + self.mock_client.generate_and_validate_raw_json_output.call_count, 2 + ) + self.assertEqual( + self.counter.counting_consensus.achieve_consensus.call_count, 2 + ) + async def test_count_entities_llm_failure(self): self.mock_model_registry.get_schema_for_models.return_value = "{}" self.mock_client.generate_and_validate_raw_json_output = AsyncMock(