diff --git a/.hydra_config/config.yaml b/.hydra_config/config.yaml index e7e30cb9..36c61c74 100644 --- a/.hydra_config/config.yaml +++ b/.hydra_config/config.yaml @@ -55,6 +55,13 @@ reranker: top_k: ${oc.decode:${oc.env:RERANKER_TOP_K, 10}} # Number of documents to return after reranking. Upgrade for better results if your llm has a wider context window. base_url: ${oc.env:RERANKER_BASE_URL, http://reranker:${oc.env:RERANKER_PORT, 7997}} +file_reducer: + max_group_tokens: ${oc.decode:${oc.env:FILE_REDUCER_MAX_GROUP_TOKENS, 4096}} + min_group_tokens: ${oc.decode:${oc.env:FILE_REDUCER_MIN_GROUP_TOKENS, 2048}} + target_size_tokens: ${oc.decode:${oc.env:FILE_REDUCER_TARGET_SIZE_TOKENS, 1024}} + max_rounds: ${oc.decode:${oc.env:FILE_REDUCER_MAX_ROUNDS, 3}} + min_shrink_ratio: ${oc.decode:${oc.env:FILE_REDUCER_MIN_SHRINK_RATIO, 0.1}} + map_reduce: # Number of documents to process in the initial mapping phase initial_batch_size: ${oc.decode:${oc.env:MAP_REDUCE_INITIAL_BATCH_SIZE, 10}} @@ -91,6 +98,7 @@ prompts: chunk_contextualizer: chunk_contextualizer_tmpl.txt image_describer: image_captioning_tmpl.txt spoken_style_answer: spoken_style_answer_tmpl.txt + file_reducer: file_reducer_tmpl.txt # query templates for different retriever types hyde: hyde.txt diff --git a/AGENTS.md b/AGENTS.md new file mode 100644 index 00000000..c52b5186 --- /dev/null +++ b/AGENTS.md @@ -0,0 +1,292 @@ +# OpenRAG Agent Guide + +## Build, Lint, and Test Commands + +### Dependencies +```bash +# Install dependencies (uv package manager) +uv sync + +# Install dev dependencies +uv sync --group dev + +# Install lint dependencies +uv sync --group lint +``` + +### Development Server +```bash +# GPU deployment +docker compose up -d + +# CPU deployment +docker compose --profile cpu up -d + +# Rebuild and run +docker compose up --build -d +``` + +### Testing +```bash +# Run all unit tests +uv run pytest + +# Run a single test file +uv run pytest openrag/components/indexer/chunker/test_chunking.py + +# Run tests matching a pattern +uv run pytest -k "test_chunk" + +# Run with verbose output +uv run pytest -v + +# Run integration tests (requires running server) +uv run pytest -m integration + +# Run tests with coverage +uv run pytest --cov=openrag +``` + +### Linting and Formatting +```bash +# Check code style +uv run ruff check openrag/ tests/ + +# Auto-fix linting issues +uv run ruff check --fix openrag/ tests/ + +# Format code +uv run ruff format openrag/ tests/ + +# Check formatting without modifying +uv run ruff format --check openrag/ tests/ +``` + +### CI/CD +```bash +# Run API integration tests locally with act +act -j api-tests -W .github/workflows/api_tests.yml --bind +``` + +## Code Style Guidelines + +### Imports +- Use **absolute imports** from the `openrag/` directory (Python path root) +- Group imports: standard library → third-party → first-party (`openrag.*`) +- Use `from openrag.X import Y` not relative imports across packages +- Isort configuration: `known-first-party = ["openrag"]` + +```python +# Correct +from components.ray_utils import call_ray_actor_with_timeout +from utils.logger import get_logger +from config import load_config + +# Avoid +from ..ray_utils import ... # Only use within same package +``` + +### Formatting +- **Line length**: 120 characters (configured in `pyproject.toml`) +- **Target Python**: 3.12+ +- Use **double quotes** for strings +- Use **4 spaces** for indentation (no tabs) +- Follow Black-compatible formatting (Ruff format) + +### Type Hints +- Use **type hints** for function parameters and return values +- Use `|` for union types (Python 3.10+ syntax) +- Use `Optional[T]` or `T | None` for optional values +- Use `list[T]`, `dict[str, Any]` for collections + +```python +def process_file(file_id: str, partition: str | None = None) -> dict[str, Any]: + """Process a file and return metadata.""" + ... +``` + +### Naming Conventions +- **Functions/variables**: `snake_case` +- **Classes**: `PascalCase` +- **Constants**: `UPPER_CASE` +- **Private members**: `_leading_underscore` +- **Ray Actors**: `PascalCase` (e.g., `Indexer`, `TaskStateManager`) +- **Test functions**: `test_` + +### Error Handling +- Use **custom exceptions** from `openrag/utils/exceptions/` +- All exceptions inherit from `OpenRAGError` +- Include `code`, `message`, and optional `status_code` +- Use specific exception types: `VDBError`, `EmbeddingError` + +```python +from utils.exceptions import OpenRAGError, VDBError + +# Raise error with code and message +raise VDBError(message="Failed to connect", code="VDB_001", status_code=503) + +# Custom exception with extra context +raise OpenRAGError( + message="File not found", + code="FILE_NOT_FOUND", + status_code=404, + file_id=file_id +) +``` + +### Logging +- Use **Loguru** with structured logging via `get_logger()` +- Include contextual data using `.bind()` +- Never log secrets or sensitive data + +```python +from utils.logger import get_logger + +logger = get_logger() + +# Log with context +logger.bind(file_id=file_id, partition=partition).info("Processing file") + +# Error logging with exception +logger.bind(error=str(e)).error("Failed to process document") +``` + +### Async/Await +- Use `async def` for I/O operations (database, HTTP, Ray) +- Always `await` async calls +- Use `asyncio.gather()` for concurrent independent operations +- Use `call_ray_actor_with_timeout()` for Ray actor calls + +```python +from components.ray_utils import call_ray_actor_with_timeout + +# Concurrent operations +results = await asyncio.gather( + task1(), + task2(), + task3() +) + +# Ray actor with timeout +result = await call_ray_actor_with_timeout( + future=indexer.process.remote(data), + timeout=30, + task_description="Processing document" +) +``` + +### Ray Actors +- Ray Actors are initialized in `openrag/api.py` +- Access actors via `ray.get_actor(name, namespace="openrag")` +- All actor methods called with `.remote()` + +```python +import ray + +# Get actor reference +vectordb = ray.get_actor("Vectordb", namespace="openrag") +indexer = ray.get_actor("Indexer", namespace="openrag") + +# Call methods +await vectordb.async_search.remote(query=query, partition=partition) +``` + +### Configuration +- Configuration via **Hydra** with YAML files in `.hydra_config/` +- Access config via `load_config()` from `config.py` +- Environment variables override config values + +```python +from config import load_config + +config = load_config() +chunk_size = config.chunker.size +``` + +### API Patterns +- FastAPI routers in `openrag/routers/` +- Use dependency injection for shared resources +- Return `JSONResponse` for custom error responses +- Use Pydantic models for request/response validation + +```python +from fastapi import APIRouter, Depends +from pydantic import BaseModel + +router = APIRouter() + +class DocumentRequest(BaseModel): + text: str + partition: str | None = None + +@router.post("/documents") +async def create_document(req: DocumentRequest, user: User = Depends(get_current_user)): + ... +``` + +### Testing Guidelines +- Unit tests: `openrag/components/**/test_*.py` (pytest) +- Integration tests: `tests/api_tests/*.py` +- Use pytest fixtures from `conftest.py` +- Mark tests: `@pytest.mark.integration` or `@pytest.mark.unit` + +```python +import pytest + +@pytest.mark.unit +def test_chunking(): + assert result == expected + +@pytest.mark.integration +async def test_api_endpoint(): + response = await client.post("/v1/chat/completions", json={...}) + assert response.status_code == 200 +``` + +### Documentation +- Docstrings: **Google style** or **reStructuredText** +- Include type hints in docstrings if not obvious +- Document complex algorithms and business logic + +```python +def process_chunk(chunk: Chunk) -> Embedding: + """Process a document chunk and generate embedding. + + Args: + chunk: The chunk to process + + Returns: + Generated embedding vector + + Raises: + EmbeddingError: If embedding generation fails + """ + ... +``` + +## Key Files and Directories + +``` +openrag/ +├── api.py # FastAPI app entry point, Ray initialization +├── routers/ # API route handlers +├── components/ # Core components (Indexer, Vectordb, Pipeline) +│ ├── indexer/ # Document ingestion, chunking, embedding +│ ├── pipeline.py # RAG pipeline orchestration +│ └── websearch/ # Web search integration +├── utils/ # Shared utilities +│ ├── exceptions/ # Custom exception classes +│ ├── logger.py # Logging configuration +│ └── config.py # Configuration loading +├── models/ # Pydantic models +└── prompts/ # LLM prompt templates +``` + +## Important Notes + +- **Never commit secrets** - use `.env` files (not in repo) +- **Ray namespace** is always `"openrag"` for all actors +- **Milvus** is the vector database with hybrid search (dense + BM25) +- **Authentication** uses token-based auth with RBAC +- **Partition-based** multi-tenant document organization +- **OpenAI-compatible** API format for chat completions diff --git a/docs/content/docs/documentation/API.mdx b/docs/content/docs/documentation/API.mdx index 191f0266..c3550a96 100644 --- a/docs/content/docs/documentation/API.mdx +++ b/docs/content/docs/documentation/API.mdx @@ -409,6 +409,7 @@ OpenAI-compatible text completion endpoint. | `websearch` | `bool` | `false` | Augments the RAG context with live web search results. When used with a partition (`openrag-{partition}`), document and web results are combined. When used without a partition (direct LLM mode), web results are the sole context. Requires `WEBSEARCH_API_TOKEN` to be configured. See [web search configuration](/openrag/documentation/env_vars/#web-search-configuration). | | `spoken_style_answer` | `bool` | `false` | Generates a succinct spoken-style conversational answer based on the retrieved documents. | | `use_map_reduce` | `bool` | `false` | Uses a map-reduce strategy to aggregate information from multiple documents. See [map-reduce configuration](/openrag/documentation/env_vars/#map--reduce-configuration). | +| `attachments` | `list[{id: string}]` | `null` | Pins specific files by ID for retrieval, bypassing semantic search entirely. Each file's chunks are compressed by the file reducer before being sent to the LLM. See [file reducer configuration](/openrag/documentation/env_vars/#file-reducer-configuration). | | `llm_override` | `object` | `null` | Routes the request to a different LLM endpoint while still using OpenRAG's RAG pipeline (retrieval, reranking, prompt construction). Accepts: `base_url` (string), `api_key` (string), `model` (string). Any field not provided falls back to the default OpenRAG LLM configuration. | Examples: diff --git a/docs/content/docs/documentation/env_vars.md b/docs/content/docs/documentation/env_vars.md index 32534c3b..01c454cd 100644 --- a/docs/content/docs/documentation/env_vars.md +++ b/docs/content/docs/documentation/env_vars.md @@ -257,6 +257,7 @@ The RAG pipeline comes with preconfigured prompts **`./prompts/example1`**. Here | `image_captioning_tmpl.txt` | Template for generating image descriptions using the VLM | | `hyde.txt` | Hypothetical Document Embeddings (HyDE) query expansion template | | `multi_query_pmpt_tmpl.txt` | Template for generating multiple query variations | +| `file_reducer_tmpl.txt` | System prompt for the file reducer's chunk compression LLM calls | To customize prompt: 1. **Duplicate the example folder**: Copy the `example1` folder from `./prompts/` @@ -455,6 +456,21 @@ curl -X 'POST' 'http://localhost:8080/v1/chat/completions' \ ``` ::: +### File Reducer Configuration + +The file reducer compresses a file's chunks down to a size that fits within the LLM context window. It works iteratively: chunks are grouped, each group is summarized by the LLM, and the process repeats until the total content fits. Two safety mechanisms prevent it from running indefinitely: + +- **`max_rounds`** — hard cap on the number of compression iterations. +- **`min_shrink_ratio`** — if a round shrinks the content by less than this fraction, the LLM is not compressing meaningfully and the loop stops early. + +| Variable | Type | Default | Description | +|----------|------|---------|-------------| +| `FILE_REDUCER_TARGET_SIZE_TOKENS` | `int` | 1024 | Token budget for the final output. Compression rounds continue until the total content fits within this limit | +| `FILE_REDUCER_MAX_GROUP_TOKENS` | `int` | 4096 | Maximum tokens per group fed to the LLM in a single summarization call | +| `FILE_REDUCER_MIN_GROUP_TOKENS` | `int` | 2048 | Groups smaller than this threshold are passed through without calling the LLM | +| `FILE_REDUCER_MAX_ROUNDS` | `int` | 3 | Maximum number of compression rounds before stopping regardless of output size | +| `FILE_REDUCER_MIN_SHRINK_RATIO` | `float` | 0.1 | Minimum fraction of tokens that must be removed in a round to continue iterating (e.g. `0.1` = at least 10% reduction required) | + ### FastAPI & Access Control :::info By default, our API (FastAPI) uses **`uvicorn`** for deployment. One can opt in to use `Ray Serve` for scalability (see the [ray serve configuration](/openrag/documentation/env_vars/#ray-serve-configuration)) diff --git a/openrag/components/file_summarizer.py b/openrag/components/file_summarizer.py new file mode 100644 index 00000000..37325ca7 --- /dev/null +++ b/openrag/components/file_summarizer.py @@ -0,0 +1,147 @@ +"""FileReducer — iterative map-then-merge summarization.""" + +from components.prompts.prompts import FILE_REDUCER_PROMPT +from components.utils import get_llm_semaphore +from langchain_core.documents.base import Document +from langchain_openai import ChatOpenAI +from tqdm.asyncio import tqdm +from utils.logger import get_logger + +logger = get_logger() + +_IRRELEVANT = "IRRELEVANT" + + +class FileReducer: + """Summarizes a file's chunks by repeatedly grouping and summarizing + until the result fits within `max_tokens`.""" + + def __init__(self, config): + self._llm = ChatOpenAI( + base_url=config.llm.get("base_url"), + api_key=config.llm.get("api_key"), + model=config.llm.get("model"), + temperature=config.llm.get("temperature", 0.3), + timeout=config.llm.get("timeout", 60), + ) + self._max_group_tokens: int = config.file_reducer.get("max_group_tokens", 4096) + self._min_group_tokens: int = config.file_reducer.get("min_group_tokens", 2048) + self._max_rounds: int = config.file_reducer.get("max_rounds", 3) + self._min_shrink_ratio: float = config.file_reducer.get("min_shrink_ratio", 0.1) + self._target_size_tokens: int = config.file_reducer.get("target_size_tokens", 1024) + + # ------------------------------------------------------------------ + # Helpers + # ------------------------------------------------------------------ + + @staticmethod + def _estimate_tokens(text: str) -> int: + """Fast ~4 chars-per-token estimate.""" + return len(text) // 4 + + def _fits(self, texts: list[str]) -> bool: + """True when the joined texts are already within the output budget.""" + return self._estimate_tokens("\n\n".join(texts)) <= self._target_size_tokens + + def _group(self, texts: list[str]) -> list[list[str]]: + """Bin texts into groups that each stay under `_max_group_tokens`.""" + groups: list[list[str]] = [] + current: list[str] = [] + current_tokens = 0 + for text in texts: + tokens = self._estimate_tokens(text) + if current and current_tokens + tokens > self._max_group_tokens: + groups.append(current) + current = [text] + current_tokens = tokens + else: + current.append(text) + current_tokens += tokens + + if current: + groups.append(current) + + return groups + + async def _summarize(self, query: str, texts: list[str]) -> str: + """Summarize a group of texts; skip the LLM if the group is already small.""" + + async with get_llm_semaphore(): + try: + joined = "\n\n".join(texts) + if self._estimate_tokens(joined) <= self._min_group_tokens: + return joined + + response = await self._llm.ainvoke( + [ + {"role": "system", "content": FILE_REDUCER_PROMPT}, + {"role": "user", "content": f"user query: {query}\n\ncontent to compress:\n{joined}"}, + ] + ) + return response.content + except Exception as e: + logger.error("Error during summarization", error=str(e)) + return "\n\n".join(texts) # fall back to original to avoid None in texts + + # ------------------------------------------------------------------ + # Main entry point + # ------------------------------------------------------------------ + + async def run(self, query: str, chunks: list[Document]) -> Document: + """Summarize *chunks* by grouping and merging until the result fits.""" + + # Normalise to plain strings, preserve first chunk's metadata + first_metadata = chunks[0].metadata if isinstance(chunks[0], Document) else {} + filename = first_metadata.get("filename") + log = logger.bind(filename=filename) + + texts: list[str] = [c.page_content if isinstance(c, Document) else c for c in chunks] + tag = f"[{filename}] " if filename else "" + rounds = 0 + + while not self._fits(texts): + if rounds >= self._max_rounds: + log.warning("FileReducer hit max_rounds cap — stopping early", rounds=rounds) + break + + tokens_before = self._estimate_tokens("\n\n".join(texts)) + groups = self._group(texts) + texts = list( + await tqdm.gather( + *[self._summarize(query, g) for g in groups], + desc=f"{tag}merge (round {rounds + 1})", + ) + ) + + # Filter chunks the LLM deemed irrelevant (keep at least one to avoid empty output) + relevant = [t for t in texts if t.strip() != _IRRELEVANT] + if relevant: + texts = relevant + + tokens_after = self._estimate_tokens("\n\n".join(texts)) + shrink = (tokens_before - tokens_after) / max(tokens_before, 1) + + rounds += 1 + log.debug("Merge round complete", round=rounds, shrink_pct=round(shrink * 100, 1)) + + if shrink < self._min_shrink_ratio: + log.warning( + "FileReducer not converging (shrink below threshold) — stopping early", + rounds=rounds, + shrink_pct=round(shrink * 100, 1), + ) + break + + content = texts[0] if len(texts) == 1 else "\n\n".join(texts) + metadata = { + **first_metadata, + "_summarized": True, + "_original_chunk_count": len(chunks), + "_rounds": rounds, + } + log.debug("FileReducer done", estimated_tokens=self._estimate_tokens(content), rounds=rounds) + return Document(page_content=f"{filename}\n\n{content}", metadata=metadata) + + async def reduce_all(self, query: str, docs_l: list[Document]) -> list[Document]: + tasks = [self.run(query, chunks) for chunks in docs_l] + return await tqdm.gather(*tasks, desc="Reducing files") diff --git a/openrag/components/indexer/vectordb/test_file_attachments.py b/openrag/components/indexer/vectordb/test_file_attachments.py new file mode 100644 index 00000000..4cc25cbc --- /dev/null +++ b/openrag/components/indexer/vectordb/test_file_attachments.py @@ -0,0 +1,121 @@ +"""Tests for file attachment retrieval logic.""" + +import pytest + + +class TestAttachmentFiltering: + """Test attachment filtering logic in pipeline.""" + + def test_extract_file_ids_from_attachments(self): + """Test extracting file IDs from attachments list.""" + from models.openai import Attachment + + # Valid attachments only - empty/missing ids are filtered before validation in pipeline + attachments_raw = [ + {"id": "file-123"}, + {"id": "file-456"}, + {"id": "file-789", "type": "file"}, + ] + + # Validate and extract file_ids (like pipeline does) + attachments = [Attachment.model_validate(att) for att in attachments_raw if isinstance(att, dict)] + file_ids = [att.id for att in attachments if att.id] + + assert len(file_ids) == 3 + assert file_ids == ["file-123", "file-456", "file-789"] + + def test_extract_file_ids_empty_list(self): + """Test extracting file IDs from empty attachments list.""" + attachments_raw = [] + + if attachments_raw: + from models.openai import Attachment + + attachments = [Attachment.model_validate(att) for att in attachments_raw if isinstance(att, dict)] + file_ids = [att.id for att in attachments if att.id] + else: + file_ids = [] + + assert file_ids == [] + + def test_extract_file_ids_none(self): + """Test extracting file IDs when attachments is None.""" + attachments_raw = None + + if attachments_raw: + from models.openai import Attachment + + attachments = [Attachment.model_validate(att) for att in attachments_raw if isinstance(att, dict)] + file_ids = [att.id for att in attachments if att.id] + else: + file_ids = [] + + assert file_ids == [] + + +class TestFilterExpression: + """Test filter expression building for file queries.""" + + def test_filter_expression_with_specific_partitions(self): + """Test filter expression for specific partition list.""" + partition = ["partition1", "partition2"] + file_id = "file-123" + + # Build filter expression like _retrieve_file_chunks does + expr_parts = [] + if partition != ["all"]: + expr_parts.append(f"partition in {partition}") + expr_parts.append(f'file_id == "{file_id}"') + filter_expr = " and ".join(expr_parts) if expr_parts else "" + + # Check that partition and file_id are in the expression + assert "partition in" in filter_expr + assert "partition1" in filter_expr + assert "partition2" in filter_expr + assert 'file_id == "file-123"' in filter_expr + assert " and " in filter_expr + + def test_filter_expression_with_all_partitions(self): + """Test filter expression for ['all'] partitions.""" + partition = ["all"] + file_id = "file-123" + + # Build filter expression like _retrieve_file_chunks does + expr_parts = [] + if partition != ["all"]: + expr_parts.append(f"partition in {partition}") + expr_parts.append(f'file_id == "{file_id}"') + filter_expr = " and ".join(expr_parts) if expr_parts else "" + + assert "partition in" not in filter_expr + assert 'file_id == "file-123"' in filter_expr + assert " and " in filter_expr + + def test_filter_expression_with_all_partitions(self): + """Test filter expression for ['all'] partitions.""" + partition = ["all"] + file_id = "file-123" + + # Build filter expression like _retrieve_file_chunks does + expr_parts = [] + if partition != ["all"]: + expr_parts.append(f"partition in {partition}") + expr_parts.append(f'file_id == "{file_id}"') + filter_expr = " and ".join(expr_parts) if expr_parts else "" + + assert "partition in" not in filter_expr + assert 'file_id == "file-123"' in filter_expr + + def test_extract_file_ids_none(self): + """Test extracting file IDs when attachments is None.""" + attachments_raw = None + + if attachments_raw: + from models.openai import Attachment + + attachments = [Attachment.model_validate(att) for att in attachments_raw if isinstance(att, dict)] + file_ids = [att.id for att in attachments if att.id] + else: + file_ids = [] + + assert file_ids == [] diff --git a/openrag/components/indexer/vectordb/vectordb.py b/openrag/components/indexer/vectordb/vectordb.py index 580fa5ff..7ce32643 100644 --- a/openrag/components/indexer/vectordb/vectordb.py +++ b/openrag/components/indexer/vectordb/vectordb.py @@ -101,6 +101,12 @@ async def list_all_chunk(self, partition: str, include_embedding: bool = True) - async def get_file_chunks(self, file_id: str, partition: str, include_id: bool = False, limit: int = 2000): pass + @abstractmethod + async def get_chunks_by_file_ids( + self, file_ids: list[str], partition: list[str] | None, include_id: bool = True + ) -> list[list[Document]]: + pass + @abstractmethod async def get_chunk_by_id(self, chunk_id: str): pass @@ -722,6 +728,144 @@ async def get_file_chunks(self, file_id: str, partition: str, include_id: bool = file_id=file_id, ) + async def _retrieve_file_chunks( + self, file_id: str, partition: list[str] | None, include_id: bool = True + ) -> list[Document]: + """Helper to retrieve chunks for a single file_id across one or more partitions.""" + if not partition: + self.logger.warning("No partition provided for file_id retrieval", file_id=file_id) + return [] + + log = self.logger.bind(file_id=file_id, partition=partition) + + if partition != ["all"]: + file_found = False + + # Check if file exists in any of the specified partitions + for partition_name in partition: + if self.file_exists(file_id=file_id, partition=partition_name): + file_found = True + break + + if not file_found: + log.warning("File not found in specified partitions", file_id=file_id) + return [] + + # Build filter expression like async_search does + expr_parts = [] + if partition != ["all"]: + expr_parts.append(f"partition in {partition}") + + # Always filter by file_id + expr_parts.append(f'file_id == "{file_id}"') + + # Join all parts with " and " only if there are multiple conditions + filter_expr = " and ".join(expr_parts) if expr_parts else "" + + try: + excluded_keys = ["text", "vector", "_id"] if not include_id else ["text", "vector"] + + results = [] + iterator = self._client.query_iterator( + collection_name=self.collection_name, + filter=filter_expr, + limit=2000, + batch_size=min(2000, 16000), + output_fields=["*"], + ) + try: + while True: + batch = iterator.next() + if not batch: + break + results.extend(batch) + finally: + iterator.close() + + docs = [ + Document( + page_content=res["text"], + metadata={key: value for key, value in res.items() if key not in excluded_keys}, + ) + for res in results + ] + log.debug(f"Retrieved {len(results)} chunks for file_id", count=len(results)) + return docs + + except MilvusException as e: + log.exception(f"Couldn't get file chunks for file_id {file_id}", error=str(e)) + raise VDBSearchError( + f"Couldn't get file chunks for file_id {file_id}: {e!s}", + collection_name=self.collection_name, + partition=str(partition), + file_id=file_id, + ) + except VDBError: + raise + except Exception as e: + log.exception("Unexpected error while getting file chunks", error=str(e)) + raise VDBSearchError( + f"Unexpected error while getting file chunks {file_id}: {e!s}", + collection_name=self.collection_name, + partition=str(partition), + file_id=file_id, + ) + + async def get_chunks_by_file_ids( + self, file_ids: list[str], partition: list[str] | None, include_id: bool = True + ) -> list[list[Document]]: + """Retrieve chunks for given file_ids in parallel, grouped and ordered by file_id. + + Args: + file_ids: List of file IDs to retrieve chunks for + partition: Partition(s) to search in - can be ["all"] for admin or list of partition names + include_id: Whether to include file_id in chunk metadata + + Returns: + List of chunk lists, one per file_id, maintaining input order. + Empty lists are excluded. Non-existent file_ids are silently ignored. + + Raises: + VDBError: If vector database operation fails catastrophically + """ + log = self.logger.bind(file_ids_count=len(file_ids), partition=partition) + + if not file_ids: + log.debug("No file_ids provided, returning empty list") + return [] + + # Handle partition validation + if partition and len(partition) > 1: + log.debug(f"Searching across {len(partition)} partitions", partitions=partition) + + # Parallel retrieval: create tasks for all file_ids + tasks = [ + self._retrieve_file_chunks(file_id=file_id, partition=partition, include_id=include_id) + for file_id in file_ids + ] + + # Execute all retrievals concurrently + try: + results = await asyncio.gather(*tasks) + except MilvusException as e: + log.error("Milvus error during parallel file retrieval", error=str(e)) + raise VDBSearchError( + message="Failed to retrieve chunks for file_ids", + code="VDB_FILE_RETRIEVE_ERROR", + status_code=503, + collection_name=self.collection_name, + ) from e + + chunks_by_file = [] + for file_id, chunks in zip(file_ids, results): + if chunks: + chunks_by_file.append(chunks) + log.debug(f"Retrieved {len(chunks)} chunks for file_id", file_id=file_id) + else: + log.warning("No chunks found for file_id", file_id=file_id) + + return chunks_by_file + async def get_chunk_by_id(self, chunk_id: str): """ Retrieve a chunk by its ID. diff --git a/openrag/components/pipeline.py b/openrag/components/pipeline.py index fffd433c..42b3e90d 100644 --- a/openrag/components/pipeline.py +++ b/openrag/components/pipeline.py @@ -15,9 +15,11 @@ from config import load_config from langchain_core.documents.base import Document from langchain_openai import ChatOpenAI +from models.openai import Attachment from pydantic import BaseModel, Field from utils.logger import get_logger +from .file_summarizer import FileReducer from .llm import LLM from .map_reduce import RAGMapReduce from .reranker import Reranker @@ -137,6 +139,9 @@ def __init__(self) -> None: # map reduce self.map_reduce: RAGMapReduce = RAGMapReduce(config=config) + # file reducer + self.file_reducer = FileReducer(config) + # Web search self.web_search_service = WebSearchFactory.create_service(config) if self.web_search_service.provider: @@ -187,12 +192,19 @@ async def _prepare_for_chat_completion(self, partition: list[str] | None, payloa messages = payload["messages"] messages = messages[-self.chat_history_depth :] # limit history depth - # 1. get the query - queries: SearchQueries = await self.generate_query(messages) - logger.debug("Prepared query for chat completion", queries=str(queries)) - metadata = payload.get("metadata") or {} + # Extract and validate attachments from metadata + attachments_raw = metadata.get("attachments") + file_ids: list[str] = [] + if attachments_raw: + try: + attachments = [Attachment.model_validate(att) for att in attachments_raw if isinstance(att, dict)] + file_ids = [att.id for att in attachments if att.id] + except Exception as e: + logger.warning("Failed to validate attachments", error=str(e)) + file_ids = [] + use_map_reduce = metadata.get("use_map_reduce", False) spoken_style_answer = metadata.get("spoken_style_answer", False) use_websearch = metadata.get("websearch", False) @@ -204,73 +216,109 @@ async def _prepare_for_chat_completion(self, partition: list[str] | None, payloa spoken_style_answer=spoken_style_answer, use_websearch=use_websearch, workspace=workspace, + file_ids_count=len(file_ids), ) - # 2. get docs and/or web results concurrently - top_k = config.map_reduce["max_total_documents"] if use_map_reduce else None - if workspace: - vectordb = ray.get_actor("Vectordb", namespace="openrag") - ws = await call_ray_actor_with_timeout( - vectordb.get_workspace.remote(workspace), - timeout=VECTORDB_TIMEOUT, - task_description=f"get_workspace({workspace})", - ) - if not ws or ("all" not in partition and ws["partition_name"] not in partition): - logger.warning( - "Workspace not found in partition(s) — ignoring workspace filter", - workspace=workspace, - partition=partition, - ) - workspace = None + # FILE_ID RETRIEVAL MODE (skip semantic search) + if file_ids: + log = logger.bind(file_ids=file_ids, mode="file_based_retrieval") + log.info("File-based retrieval mode enabled") - filter_params = {"workspace_id": workspace} if workspace else None - - if partition is not None and use_websearch: - # Run one retrieval and one web search per sub-query, all concurrently (Option C). - # Web results from different sub-queries are deduplicated by URL, preserving order. - rag_tasks = [ - self.retriever_pipeline.retrieve_docs( - partition=partition, query=q, top_k=top_k, filter_params=filter_params + # Retrieve chunks directly by file_id (parallel retrieval) + vectordb = ray.get_actor("Vectordb", namespace="openrag") + try: + docs_by_file: list[list[Document]] = await call_ray_actor_with_timeout( + vectordb.get_chunks_by_file_ids.remote(file_ids=file_ids, partition=partition), + timeout=VECTORDB_TIMEOUT, + task_description=f"get_chunks_by_file_ids({len(file_ids)} files)", ) - for q in queries.query_list - ] - web_tasks = [self.web_search_service.search(q) for q in queries.query_list] - all_results = await asyncio.gather(*rag_tasks, *web_tasks) - n = len(queries.query_list) - raw_doc_lists = list(all_results[:n]) - raw_web_lists = list(all_results[n:]) - docs = self.retriever_pipeline.reranker.rrf_reranking(doc_lists=raw_doc_lists) - if top_k is not None: - docs = docs[:top_k] - # Deduplicate web results by URL, preserving first-seen order - seen_urls: set[str] = set() + log.debug(f"Retrieved {sum(len(d) for d in docs_by_file)} chunks from {len(file_ids)} files") + except TimeoutError as e: + # Timeout handling - log and return empty docs + log.error("Timeout retrieving chunks for file_ids", timeout=VECTORDB_TIMEOUT, error=str(e)) + docs_by_file = [] + + # Create dummy queries for logging consistency + queries = SearchQueries(query_list=[messages[-1]["content"]]) web_results = [] - for result in (r for web_list in raw_web_lists for r in web_list): - if result.url not in seen_urls: - seen_urls.add(result.url) - web_results.append(result) - elif partition is not None: - docs = await self.retriever_pipeline.get_relevant_docs( - partition=partition, search_queries=queries, top_k=top_k, filter_params=filter_params - ) - web_results = [] - else: - # Web-only mode (partition is None): no RAG retrieval. - # Run one web search per sub-query concurrently and deduplicate by URL. - raw_web_lists = await asyncio.gather(*[self.web_search_service.search(q) for q in queries.query_list]) - seen_urls = set() - web_results = [] - for result in (r for web_list in raw_web_lists for r in web_list): - if result.url not in seen_urls: - seen_urls.add(result.url) - web_results.append(result) - docs = [] - # Web-only with no results: fall back to plain direct LLM mode - if not docs and not web_results and partition is None: - return payload, [], [] + # Apply file reduction per file, then flatten + if docs_by_file: + docs = await self.file_reducer.reduce_all(query=queries.query_list[0], docs_l=docs_by_file) + else: + docs = [] - if use_map_reduce and docs: + # NORMAL SEMANTIC SEARCH MODE + else: + # 1. get the query + queries: SearchQueries = await self.generate_query(messages) + logger.debug("Prepared query for chat completion", queries=str(queries)) + + # 2. get docs and/or web results concurrently + top_k = config.map_reduce["max_total_documents"] if use_map_reduce else None + if workspace: + vectordb = ray.get_actor("Vectordb", namespace="openrag") + ws = await call_ray_actor_with_timeout( + vectordb.get_workspace.remote(workspace), + timeout=VECTORDB_TIMEOUT, + task_description=f"get_workspace({workspace})", + ) + if not ws or ("all" not in partition and ws["partition_name"] not in partition): + logger.warning( + "Workspace not found in partition(s) — ignoring workspace filter", + workspace=workspace, + partition=partition, + ) + workspace = None + + filter_params = {"workspace_id": workspace} if workspace else None + + if partition is not None and use_websearch: + # Run one retrieval and one web search per sub-query, all concurrently (Option C). + # Web results from different sub-queries are deduplicated by URL, preserving order. + rag_tasks = [ + self.retriever_pipeline.retrieve_docs( + partition=partition, query=q, top_k=top_k, filter_params=filter_params + ) + for q in queries.query_list + ] + web_tasks = [self.web_search_service.search(q) for q in queries.query_list] + all_results = await asyncio.gather(*rag_tasks, *web_tasks) + n = len(queries.query_list) + raw_doc_lists = list(all_results[:n]) + raw_web_lists = list(all_results[n:]) + docs = self.retriever_pipeline.reranker.rrf_reranking(doc_lists=raw_doc_lists) + if top_k is not None: + docs = docs[:top_k] + # Deduplicate web results by URL, preserving first-seen order + seen_urls: set[str] = set() + web_results = [] + for result in (r for web_list in raw_web_lists for r in web_list): + if result.url not in seen_urls: + seen_urls.add(result.url) + web_results.append(result) + elif partition is not None: + docs = await self.retriever_pipeline.get_relevant_docs( + partition=partition, search_queries=queries, top_k=top_k, filter_params=filter_params + ) + web_results = [] + else: + # Web-only mode (partition is None): no RAG retrieval. + # Run one web search per sub-query concurrently and deduplicate by URL. + raw_web_lists = await asyncio.gather(*[self.web_search_service.search(q) for q in queries.query_list]) + seen_urls = set() + web_results = [] + for result in (r for web_list in raw_web_lists for r in web_list): + if result.url not in seen_urls: + seen_urls.add(result.url) + web_results.append(result) + docs = [] + + # Web-only with no results: fall back to plain direct LLM mode + if not docs and not web_results and partition is None: + return payload, [], [] + + if not file_ids and use_map_reduce and docs: docs = await self.map_reduce.map(query=" ".join(queries.query_list), chunks=docs) # 3. Format web results first to know actual token usage, then allocate remaining budget to RAG diff --git a/openrag/components/prompts/prompts.py b/openrag/components/prompts/prompts.py index e7cf0ec6..855a840d 100644 --- a/openrag/components/prompts/prompts.py +++ b/openrag/components/prompts/prompts.py @@ -39,3 +39,6 @@ def load_prompt( # Short answer prompt SPOKEN_STYLE_ANSWER_PROMPT = load_prompt("spoken_style_answer") + +# File reducer prompt +FILE_REDUCER_PROMPT = load_prompt("file_reducer") diff --git a/openrag/models/openai.py b/openrag/models/openai.py index 323e44d6..063d3e2a 100644 --- a/openrag/models/openai.py +++ b/openrag/models/openai.py @@ -7,7 +7,12 @@ default_max_tokens = int(config.llm_context.get("max_output_tokens", 1024)) -# Classes pour la compatibilité OpenAI +class Attachment(BaseModel): + """Represents a file attachment for RAG retrieval.""" + + id: str = Field(..., min_length=1, description="File ID") + + class OpenAIMessage(BaseModel): """Modèle représentant un message dans l'API OpenAI.""" @@ -31,8 +36,9 @@ class OpenAIChatCompletionRequest(BaseModel): "spoken_style_answer": False, "websearch": False, "llm_override": None, + "attachments": None, }, - description="Extra custom parameters. Supports 'llm_override' object with optional 'base_url', 'api_key', and 'model' to override the downstream LLM endpoint.", + description="Extra custom parameters. Supports 'attachments' for file-based retrieval with automatic file reduction, 'use_map_reduce' for semantic search summarization.", ) diff --git a/openrag/models/test_openai.py b/openrag/models/test_openai.py new file mode 100644 index 00000000..383a5e88 --- /dev/null +++ b/openrag/models/test_openai.py @@ -0,0 +1,89 @@ +"""Tests for OpenAI-compatible models.""" + +import pytest +from pydantic import ValidationError + +from models.openai import Attachment, MetadataDict + + +class TestAttachment: + """Test Attachment model validation.""" + + def test_attachment_with_required_id(self): + """Test attachment with only required id field.""" + attachment = Attachment(id="file-123") + assert attachment.id == "file-123" + assert attachment.type is None + assert attachment.priority is None + + def test_attachment_with_all_fields(self): + """Test attachment with all fields.""" + attachment = Attachment(id="file-123", type="file", priority=1) + assert attachment.id == "file-123" + assert attachment.type == "file" + assert attachment.priority == 1 + + def test_attachment_empty_id_raises_error(self): + """Test that empty id raises validation error.""" + with pytest.raises(ValidationError) as exc_info: + Attachment(id="") + error_str = str(exc_info.value).lower() + assert "min_length" in error_str or "at least 1 character" in error_str or "string_too_short" in error_str + + def test_attachment_missing_id_raises_error(self): + """Test that missing id raises validation error.""" + with pytest.raises(ValidationError): + Attachment() # type: ignore + + def test_attachment_invalid_priority(self): + """Test that negative priority raises validation error.""" + with pytest.raises(ValidationError): + Attachment(id="file-123", priority=-1) + + def test_attachment_invalid_type(self): + """Test that invalid type raises validation error.""" + with pytest.raises(ValidationError): + Attachment(id="file-123", type="invalid") # type: ignore + + def test_attachment_extra_fields_ignored(self): + """Test that extra fields are ignored (forward compatibility).""" + attachment = Attachment(id="file-123", extra_field="should_be_ignored") # type: ignore + assert attachment.id == "file-123" + # Extra fields should not be accessible + assert not hasattr(attachment, "extra_field") + + +class TestMetadataDict: + """Test MetadataDict TypedDict usage.""" + + def test_metadata_dict_empty(self): + """Test empty metadata dict.""" + metadata: MetadataDict = {} + assert metadata == {} + + def test_metadata_dict_with_attachments(self): + """Test metadata dict with attachments.""" + metadata: MetadataDict = {"attachments": [{"id": "file-123"}, {"id": "file-456"}]} + assert len(metadata["attachments"]) == 2 + + def test_metadata_dict_with_all_fields(self): + """Test metadata dict with all known fields.""" + metadata: MetadataDict = { + "use_map_reduce": True, + "spoken_style_answer": False, + "websearch": True, + "llm_override": {"model": "custom-model"}, + "attachments": [{"id": "file-123"}], + } + assert metadata["use_map_reduce"] is True + assert metadata["websearch"] is True + assert metadata["attachments"] is not None + + def test_metadata_dict_with_unknown_field(self): + """Test that unknown fields are allowed (total=False).""" + metadata: MetadataDict = { + "use_map_reduce": True, + "unknown_field": "value", # type: ignore + } + assert metadata["use_map_reduce"] is True + assert metadata.get("unknown_field") == "value" diff --git a/prompts/example1/file_reducer_tmpl.txt b/prompts/example1/file_reducer_tmpl.txt new file mode 100644 index 00000000..e22b9f1d --- /dev/null +++ b/prompts/example1/file_reducer_tmpl.txt @@ -0,0 +1,14 @@ +You are an AI assistant specialized in aggressive yet lossless compression of text relative to a user query. + +Your task: +1. Identify every fact, figure, date, name, and decision in the text that is relevant to the query +2. Discard all filler, repetition, preamble, and tangential content +3. Rewrite the retained information as dense, standalone sentences — no prose padding + +Target: reduce the text to roughly 60% of its original length while retaining 100% of query-relevant information. + +Rules: +- Keep proper nouns, numbers, dates, and technical terms verbatim +- Merge redundant statements into one +- Preserve logical order so the output stays coherent +- If the text contains no relevant information, reply exactly: "IRRELEVANT" diff --git a/tests/api_tests/test_openai_compat.py b/tests/api_tests/test_openai_compat.py index b0446da3..22cc735b 100644 --- a/tests/api_tests/test_openai_compat.py +++ b/tests/api_tests/test_openai_compat.py @@ -471,3 +471,159 @@ def test_user_models_list_only_shows_accessible( # Should NOT see partition2 assert f"openrag-{partition2}" not in model_ids + + +class TestFileAttachments: + """Test file attachments feature in chat completions. + + These tests verify that the attachments parameter in metadata + correctly triggers file-based retrieval instead of semantic search. + """ + + def test_chat_with_empty_attachments(self, api_client): + """Test chat with empty attachments list - should work normally.""" + response = api_client.post( + "/v1/chat/completions", + json={ + "model": "openrag-all", + "messages": [{"role": "user", "content": "Hello"}], + "metadata": {"attachments": []}, + }, + ) + assert response.status_code == 200 + data = response.json() + assert "choices" in data + + def test_chat_with_valid_attachments_format(self, api_client): + """Test chat with valid attachments format - returns 200 even if files don't exist.""" + response = api_client.post( + "/v1/chat/completions", + json={ + "model": "openrag-all", + "messages": [{"role": "user", "content": "Tell me about this file"}], + "metadata": { + "attachments": [ + {"id": "036e0ba3-201c-4411-84f9-5b0a3b6974b7"}, + {"id": "file-123"}, + ] + }, + }, + ) + # Returns 200 - empty results for non-existent files are handled gracefully + assert response.status_code == 200 + data = response.json() + assert "choices" in data + + def test_chat_with_attachments_missing_id(self, api_client): + """Test chat with attachments missing id field - invalid attachments are skipped.""" + response = api_client.post( + "/v1/chat/completions", + json={ + "model": "openrag-all", + "messages": [{"role": "user", "content": "Hello"}], + "metadata": { + "attachments": [ + {"id": "file-123"}, + {"type": "file"}, # Missing id + {"id": "file-456"}, + ] + }, + }, + ) + assert response.status_code == 200 + data = response.json() + assert "choices" in data + + def test_chat_with_attachments_empty_id(self, api_client): + """Test chat with attachments with empty id - empty ids are skipped.""" + response = api_client.post( + "/v1/chat/completions", + json={ + "model": "openrag-all", + "messages": [{"role": "user", "content": "Hello"}], + "metadata": { + "attachments": [ + {"id": "file-123"}, + {"id": ""}, # Empty id + {"id": "file-456"}, + ] + }, + }, + ) + assert response.status_code == 200 + data = response.json() + assert "choices" in data + + def test_chat_with_attachments_extra_fields(self, api_client): + """Test chat with attachments containing extra fields - extra fields are ignored.""" + response = api_client.post( + "/v1/chat/completions", + json={ + "model": "openrag-all", + "messages": [{"role": "user", "content": "Hello"}], + "metadata": { + "attachments": [ + { + "id": "file-123", + "type": "file", + "priority": 1, + "custom_field": "ignored", + } + ] + }, + }, + ) + assert response.status_code == 200 + data = response.json() + assert "choices" in data + + def test_chat_with_null_attachments(self, api_client): + """Test chat with null attachments - should work normally.""" + response = api_client.post( + "/v1/chat/completions", + json={ + "model": "openrag-all", + "messages": [{"role": "user", "content": "Hello"}], + "metadata": {"attachments": None}, + }, + ) + assert response.status_code == 200 + data = response.json() + assert "choices" in data + + def test_chat_with_single_attachment(self, api_client): + """Test chat with single attachment.""" + response = api_client.post( + "/v1/chat/completions", + json={ + "model": "openrag-all", + "messages": [{"role": "user", "content": "Tell me about this file"}], + "metadata": { + "attachments": [ + {"id": "single-file-id"}, + ] + }, + }, + ) + assert response.status_code == 200 + data = response.json() + assert "choices" in data + + def test_chat_with_attachments_and_websearch(self, api_client): + """Test chat with both attachments and websearch enabled.""" + response = api_client.post( + "/v1/chat/completions", + json={ + "model": "openrag-all", + "messages": [{"role": "user", "content": "Tell me about this file"}], + "metadata": { + "attachments": [{"id": "file-123"}], + "websearch": True, + }, + }, + ) + # When attachments are provided, file-based retrieval takes precedence + # Web search may still run depending on implementation + assert response.status_code == 200 + data = response.json() + assert "choices" in data