Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 19 additions & 1 deletion docs/features/entity_counting.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,25 @@ results = await orchestrator.synthesize(
)
```

**2. Dedicated LLM Client (`counting_llm_client`)**
**2. Sharded / Parallel Counting**

To handle complex counting scenarios (e.g. searching for distinct variations of an entity with disjoint rules), you can pass a `list[str]` of multiple specialized prompts to `custom_counting_context`.

The system automatically executes these prompts concurrently (in "shards"), achieves consensus on each shard individually, and then merges and deduplicates the final results. This parallelization prevents massive monolithic prompts from degrading LLM performance without increasing wall-clock time.

```python
results = await orchestrator.synthesize(
...,
count_entities=True,
custom_counting_context=[
"Count only Domestic invoices and describe their destinations.",
"Count only International invoices and describe their destinations.",
"Count any catch-all invoices that don't fit the above rules."
]
)
```

**3. Dedicated LLM Client (`counting_llm_client`)**

You can use a different model for counting (e.g., a faster/cheaper one) than for the main extraction.

Expand Down
2 changes: 1 addition & 1 deletion docs/workflow_orchestrator.rst
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ Once the orchestrator is configured, you can start processing documents using on
* ``input_strings`` (``List[str]``): A list of strings, where each string is a document to be processed.
* ``db_session_for_hydration`` (``Optional[Session]``): An optional SQLAlchemy session. If provided, the hydrator will use it to resolve relationships. If not, a temporary in-memory session is created.
* ``count_entities`` (``bool``, default ``False``): If True, performs an initial pass to count entities before extraction.
* ``custom_counting_context`` (``str``, optional): Custom instructions or context specifically for the entity counting phase.
* ``custom_counting_context`` (``Union[str, List[str]]``, optional): Custom instructions or context specifically for the entity counting phase. If a list of strings is provided, sharded parallel counting is enabled, where each string acts as a separate shard that is executed concurrently and the results are merged.
* ``extraction_example_json`` (``str``, optional): A JSON string that provides a few-shot example to the LLM, guiding it to produce a better-structured output. If not provided, the orchestrator will attempt to auto-generate one.
* ``extraction_example_object`` (``Optional[Union[SQLModel, List[SQLModel]]]``, optional): An existing SQLModel object or a list of them to be used as the few-shot example. This is an alternative to providing the example as a raw JSON string.
* ``custom_extraction_process`` (``str``, optional): Custom, step-by-step instructions for the LLM on how to perform the extraction.
Expand Down
123 changes: 104 additions & 19 deletions src/extrai/core/batch/batch_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,10 @@ async def process_batch(
) -> BatchProcessResult:
status = await self.status_checker.get_status(root_batch_id, db_session)
context = db_session.get(BatchJobContext, root_batch_id)
if context is None:
return BatchProcessResult(
status=BatchJobStatus.FAILED, message="Batch context not found"
)

# 1. Already Completed
if status == BatchJobStatus.COMPLETED and context.results:
Expand Down Expand Up @@ -144,6 +148,12 @@ async def _process_counting_completion(
if not lines:
raise ValueError("Empty results content")

import re

revisions = []
shard_buckets = {} # {shard_idx: [revisions]}
is_sharded = False

# Parse each line as a revision
for raw_content in lines:
try:
Expand All @@ -152,7 +162,15 @@ async def _process_counting_completion(
else:
wrapper = raw_content

custom_id = wrapper.get("custom_id", "")
shard_match = re.search(r"_shard_(\d+)$", custom_id)
shard_idx = None
if shard_match:
is_sharded = True
shard_idx = int(shard_match.group(1))

# Check if it's wrapped in OpenAI batch response format
parsed_json = None
if "response" in wrapper and "body" in wrapper.get("response", {}):
body = wrapper["response"]["body"]

Expand All @@ -168,32 +186,37 @@ async def _process_counting_completion(
if "choices" in body and body["choices"]:
content = body["choices"][0]["message"]["content"]
parsed_json = json.loads(content)
revisions.append(parsed_json)
else:
# Maybe it's directly the JSON string or dict
if isinstance(wrapper, str):
revisions.append(json.loads(wrapper))
parsed_json = json.loads(wrapper)
elif isinstance(wrapper, dict):
revisions.append(wrapper)
parsed_json = wrapper

if parsed_json:
if is_sharded and shard_idx is not None:
if shard_idx not in shard_buckets:
shard_buckets[shard_idx] = []
shard_buckets[shard_idx].append(parsed_json)
else:
revisions.append(parsed_json)

except json.JSONDecodeError as e:
self.logger.error(f"Failed to parse counting result as JSON: {e}")
continue

self.logger.debug(f"Parsed {len(revisions)} counting revisions")
self.logger.debug(
f"Parsed {len(revisions)} non-sharded counting revisions and {len(shard_buckets)} shards"
)

# Recreate original prompts to use for consensus fallback if needed
from extrai.core.prompts.counting import (
generate_entity_counting_system_prompt,
generate_entity_counting_user_prompt,
)
from extrai.utils.serialization_utils import resolve_step_param

schema_json = self.model_registry.get_schema_for_models(target_model_names)
system_prompt = generate_entity_counting_system_prompt(
target_model_names,
schema_json,
context.config.custom_counting_context,
)
user_prompt = generate_entity_counting_user_prompt(context.input_strings)
target_json_schema = (
self.entity_counter.get_counting_model(
Expand All @@ -203,15 +226,79 @@ async def _process_counting_completion(
else None
)

# Achieve consensus
consensus_result = (
await self.entity_counter.counting_consensus.achieve_consensus(
revisions=revisions,
system_prompt=system_prompt,
user_prompt=user_prompt,
target_json_schema=target_json_schema,
consensus_result = []

if is_sharded:
resolved_context = resolve_step_param(
context.config.custom_counting_context,
context.config.current_model_index
if context.config.hierarchical
else 0,
len(self.model_registry.models)
if context.config.hierarchical
else 1,
)
if not isinstance(resolved_context, list):
resolved_context = [resolved_context] # fallback

for shard_idx, shard_revisions in shard_buckets.items():
shard_ctx = (
resolved_context[shard_idx]
if shard_idx < len(resolved_context)
else resolved_context[-1]
)
system_prompt = generate_entity_counting_system_prompt(
target_model_names,
schema_json,
shard_ctx,
)
shard_consensus = (
await self.entity_counter.counting_consensus.achieve_consensus(
revisions=shard_revisions,
system_prompt=system_prompt,
user_prompt=user_prompt,
target_json_schema=target_json_schema,
)
)
consensus_result.extend(shard_consensus)

# Deduplicate
seen = set()
deduped = []
for e in consensus_result:
e_str = json.dumps(e, sort_keys=True)
if e_str not in seen:
seen.add(e_str)
deduped.append(e)
consensus_result = deduped

else:
resolved_context = resolve_step_param(
context.config.custom_counting_context,
context.config.current_model_index
if context.config.hierarchical
else 0,
len(self.model_registry.models)
if context.config.hierarchical
else 1,
)
if isinstance(resolved_context, list):
resolved_context = resolved_context[
0
] # Should not happen if it wasn't sharded
system_prompt = generate_entity_counting_system_prompt(
target_model_names,
schema_json,
resolved_context,
)
consensus_result = (
await self.entity_counter.counting_consensus.achieve_consensus(
revisions=revisions,
system_prompt=system_prompt,
user_prompt=user_prompt,
target_json_schema=target_json_schema,
)
)
)

# Filter out any hallucinated models not in target_model_names
entity_descriptions = [
Expand Down Expand Up @@ -372,7 +459,6 @@ async def _handle_batch_retry(
return BatchProcessResult(
status=BatchJobStatus.FAILED,
message="All revisions failed validation, cannot retry.",
errors=validation_errors,
)

# Store valid partial results
Expand All @@ -395,7 +481,6 @@ async def _handle_batch_retry(
return BatchProcessResult(
status=BatchJobStatus.SUBMITTED,
message=f"Partial success. Retrying {num_to_retry} failed revisions.",
errors=validation_errors,
)

async def _finalize_completion(
Expand Down
99 changes: 70 additions & 29 deletions src/extrai/core/batch/batch_submitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,21 +269,48 @@ async def _submit_counting_phase(
total_steps,
)

system_prompt, user_prompt = self.entity_counter.prepare_counting_prompts(
input_strings,
model_names,
resolved_context,
previous_entities=previous_entities if previous_entities else None,
examples=examples,
)

client = self.entity_counter.llm_client
requests = self._create_batch_requests(
system_prompt,
user_prompt,
num_revisions=self.config.num_counting_revisions,
override_client=client,
)
requests = []

if isinstance(resolved_context, list):
self.logger.info(
f"Preparing sharded counting batch with {len(resolved_context)} shards"
)
for shard_idx, shard_context in enumerate(resolved_context):
system_prompt, user_prompt = (
self.entity_counter.prepare_counting_prompts(
input_strings,
model_names,
shard_context,
previous_entities=previous_entities
if previous_entities
else None,
examples=examples,
)
)
shard_requests = self._create_batch_requests(
system_prompt,
user_prompt,
num_revisions=self.config.num_counting_revisions,
override_client=client,
)
for req in shard_requests:
req["custom_id"] = f"{req['custom_id']}_shard_{shard_idx}"
requests.extend(shard_requests)
else:
system_prompt, user_prompt = self.entity_counter.prepare_counting_prompts(
input_strings,
model_names,
resolved_context,
previous_entities=previous_entities if previous_entities else None,
examples=examples,
)
requests = self._create_batch_requests(
system_prompt,
user_prompt,
num_revisions=self.config.num_counting_revisions,
override_client=client,
)

response_model = None
if self.config.use_structured_output:
Expand Down Expand Up @@ -346,27 +373,41 @@ async def _submit_extraction_phase(
len(self.model_registry.models) if context.config.hierarchical else 1
)

from typing import cast

request = self.request_factory.prepare_request(
input_strings=context.input_strings,
config=self.config,
extraction_example_json=context.config.extraction_example_json,
custom_extraction_process=resolve_step_param(
context.config.custom_extraction_process,
step_index,
total_steps,
custom_extraction_process=cast(
str,
resolve_step_param(
context.config.custom_extraction_process,
step_index,
total_steps,
),
),
custom_extraction_guidelines=resolve_step_param(
context.config.custom_extraction_guidelines,
step_index,
total_steps,
custom_extraction_guidelines=cast(
str,
resolve_step_param(
context.config.custom_extraction_guidelines,
step_index,
total_steps,
),
),
custom_final_checklist=resolve_step_param(
context.config.custom_final_checklist,
step_index,
total_steps,
custom_final_checklist=cast(
str,
resolve_step_param(
context.config.custom_final_checklist,
step_index,
total_steps,
),
),
custom_context=resolve_step_param(
context.config.custom_context, step_index, total_steps
custom_context=cast(
str,
resolve_step_param(
context.config.custom_context, step_index, total_steps
),
),
expected_entity_descriptions=context.config.expected_entity_descriptions,
previous_entities=previous_entities if previous_entities else None,
Expand Down Expand Up @@ -410,7 +451,7 @@ def _create_batch_requests(
self,
system_prompt: str,
user_prompt: str,
json_schema: str | None = None,
json_schema: dict[str, Any] | str | None = None,
num_revisions: int | None = None,
override_client: BaseLLMClient | None = None,
) -> list[dict]:
Expand Down
2 changes: 1 addition & 1 deletion src/extrai/core/counting_consensus.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
import logging
from typing import Any

from .extraction_config import ExtractionConfig
from ..utils.alignment_utils import align_entity_arrays, calculate_similarity
from .extraction_config import ExtractionConfig


class CountingConsensus:
Expand Down
Loading
Loading