-
Notifications
You must be signed in to change notification settings - Fork 0
MAIT-205: Replace ROAT Filter Function with OWUI Workspace Tool #48
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -111,8 +111,22 @@ of the `.env` file. The default ones are: | |
| If you did not change the `ENABLE_OPENAI_API` you will also have LLM provider | ||
| pre-configured with the values you have in the `.env` including the default chat model | ||
|
|
||
| The filter function that's responsible for the RAG service communication will also be | ||
| automatically provisioned and enabled globally. You can change these settings at the Admin panel | ||
| The filter function that was previously responsible for RAG service communication has been replaced | ||
| by the **ROAT Knowledge Base Search** workspace tool (`tools/roat_retrieval.py`). Unlike the filter | ||
| (which queried ROAT on every user message), the tool lets the LLM decide when retrieval is needed, | ||
| improving the natural conversation experience. | ||
|
|
||
| ### Activating the ROAT Tool | ||
|
|
||
| After first boot, manually register the tool: | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Moving forward, we're going to remove the code that automatically enables the Filter function and replace it with the code that automatically enables the Tool instead. So these instructions on manual activation should be unnecessary |
||
|
|
||
| 1. **Import the tool**: Admin Panel → Workspace → Tools → click `+` → paste the contents of `tools/roat_retrieval.py` → Save. | ||
| 2. **Configure Valves**: set `rag_service_url` to your ROAT query endpoint (e.g. `http://api:8000/query`) and `rag_service_api_key` if required. | ||
| 3. **Enable on a model**: Admin Panel → Models → select a model → check the **ROAT Knowledge Base Search** checkbox under Tools → Save. | ||
| 4. **Set function calling mode**: on the same model page, open **Advanced Parameters** → set **Function Calling** to `Native`. | ||
| 5. **Per-chat activation**: when starting a new chat, click `+` at the bottom and enable the tool checkbox. | ||
|
|
||
| > **Note**: the legacy filter (`functions/function.py`) is shipped disabled. Do not re-enable it while the tool is active — doing so will result in double context injection. | ||
|
|
||
| ## Connectors configuration | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -8,18 +8,20 @@ | |
| requirements: requests | ||
| """ | ||
|
|
||
| import logging | ||
| import os | ||
| import re | ||
| from collections.abc import Awaitable, Callable | ||
|
|
||
| import requests | ||
| import logging | ||
| from typing import List, Optional, Callable, Awaitable | ||
| from pydantic import BaseModel | ||
|
|
||
| log = logging.getLogger(__name__) | ||
| log.setLevel(logging.INFO) | ||
|
|
||
| # Read configuration from environment variables | ||
| DEFAULT_ENABLED = os.getenv("ENABLE_CUSTOM_RAG_SERVICE", "true").lower() == "true" | ||
| # Disabled by default — retrieval has moved to the ROAT Knowledge Base Search workspace tool (tools/roat_retrieval.py) | ||
| DEFAULT_ENABLED = os.getenv("ENABLE_CUSTOM_RAG_SERVICE", "false").lower() == "true" | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We can keep other improvements made to this function code, but let's please revert the change that changes the DEFAULT_ENABLED. The state of the filter function is to be controlled via the OWUI interface, not via Valve's defaults |
||
| DEFAULT_RAG_URL = os.getenv("CUSTOM_RAG_SERVICE_URL", "") | ||
| DEFAULT_API_KEY = os.getenv("CUSTOM_RAG_SERVICE_API_KEY", "") | ||
| DEFAULT_TIMEOUT = int(os.getenv("CUSTOM_RAG_SERVICE_TIMEOUT", "30")) | ||
|
|
@@ -29,7 +31,7 @@ class Filter: | |
| class Valves(BaseModel): | ||
| """Configuration valves for the filter""" | ||
|
|
||
| pipelines: List[str] = ["*"] # Apply to all pipelines | ||
| pipelines: list[str] = ["*"] # Apply to all pipelines | ||
| priority: int = 0 | ||
|
|
||
| # Custom RAG Service Configuration (defaults from environment variables) | ||
|
|
@@ -39,7 +41,7 @@ class Valves(BaseModel): | |
| rag_service_timeout: int = DEFAULT_TIMEOUT | ||
| top_k: int = 5 | ||
|
|
||
| # Context injection settings | ||
| # Context injection settings | ||
| inject_context: bool = True | ||
| context_template: str = """Based on the following retrieved context, please answer the user's question. | ||
|
|
||
|
|
@@ -57,7 +59,7 @@ def __init__(self): | |
| self.valves = self.Valves() | ||
|
|
||
| async def on_startup(self): | ||
| log.info(f"Pipeline loaded") | ||
| log.info("Pipeline loaded") | ||
| log.info(f"Enabled: {self.valves.enabled}") | ||
| log.info(f"URL: {self.valves.rag_service_url}") | ||
|
|
||
|
|
@@ -122,10 +124,7 @@ def parse_raw_chunk(self, raw_text: str) -> dict: | |
| # Try to parse "Score: X.XX | Text: content" format | ||
| match = re.match(r"Score:\s*([\d.]+)\s*\|\s*Text:\s*(.*)", raw_text, re.DOTALL) | ||
| if match: | ||
| return { | ||
| "score": float(match.group(1)), | ||
| "text": match.group(2).strip() | ||
| } | ||
| return {"score": float(match.group(1)), "text": match.group(2).strip()} | ||
| # If format doesn't match, return the whole text | ||
| return {"score": 0.0, "text": raw_text.strip()} | ||
|
|
||
|
|
@@ -134,12 +133,7 @@ def parse_raw_chunk(self, raw_text: str) -> dict: | |
| return {"score": 0.0, "text": raw_text.strip()} | ||
|
|
||
| def get_filename_from_extras(self, extras: dict) -> str: | ||
| return ( | ||
| extras.get("key") | ||
| or extras.get("filename") | ||
| or extras.get("name") | ||
| or None | ||
| ) | ||
| return extras.get("key") or extras.get("filename") or extras.get("name") or None | ||
|
|
||
| def format_context_and_sources(self, rag_result: dict, query: str) -> tuple: | ||
| references = rag_result.get("references", []) | ||
|
|
@@ -173,12 +167,7 @@ def format_context_and_sources(self, rag_result: dict, query: str) -> tuple: | |
| filename = self.get_filename_from_extras(extras) | ||
|
|
||
| # Extract source information from different possible locations | ||
| source_name = ( | ||
| ref.get("title") | ||
| or ref.get("source_name") | ||
| or filename | ||
| or f"Source {i+1}" | ||
| ) | ||
| source_name = ref.get("title") or ref.get("source_name") or filename or f"Source {i + 1}" | ||
|
|
||
| # Strip internal storage/ingestion fields that are not meaningful to the LLM | ||
| # (e.g. checksums, version numbers, and low-level format hints) | ||
|
|
@@ -227,19 +216,16 @@ def format_context_and_sources(self, rag_result: dict, query: str) -> tuple: | |
| return "", [] | ||
|
|
||
| # Use template to format final context | ||
| formatted_context = self.valves.context_template.format( | ||
| context=context, | ||
| query=query | ||
| ) | ||
| formatted_context = self.valves.context_template.format(context=context, query=query) | ||
| log.info(f"Formatted context with {len(sources)} sources, length: {len(formatted_context)} chars") | ||
|
|
||
| return formatted_context, sources | ||
|
|
||
| async def inlet( | ||
| self, | ||
| body: dict, | ||
| __user__: Optional[dict] = None, | ||
| __event_emitter__: Optional[Callable[[dict], Awaitable[None]]] = None, | ||
| __user__: dict | None = None, | ||
| __event_emitter__: Callable[[dict], Awaitable[None]] | None = None, | ||
| ) -> dict: | ||
| """ | ||
| Inlet filter: Process the request before it goes to the LLM | ||
|
|
@@ -298,7 +284,7 @@ async def inlet( | |
| context, sources = self.format_context_and_sources(rag_result, query) | ||
|
|
||
| if context: | ||
| log.info(f"Injecting context into messages") | ||
| log.info("Injecting context into messages") | ||
|
|
||
| # Inject context as a system message | ||
| context_msg = {"role": "system", "content": context} | ||
|
|
@@ -319,8 +305,8 @@ async def inlet( | |
| async def outlet( | ||
| self, | ||
| body: dict, | ||
| __user__: Optional[dict] = None, | ||
| __event_emitter__: Optional[Callable[[dict], Awaitable[None]]] = None, | ||
| __user__: dict | None = None, | ||
| __event_emitter__: Callable[[dict], Awaitable[None]] | None = None, | ||
| ) -> dict: | ||
| """ | ||
| Outlet filter: Process the response after the LLM generates it | ||
|
|
@@ -333,4 +319,4 @@ async def outlet( | |
| Returns: | ||
| dict: Response body (unmodified) | ||
| """ | ||
| return body | ||
| return body | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,203 @@ | ||
| """ | ||
| title: ROAT Knowledge Base Search | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please rename this to just the "Knowledge Base Search" |
||
| author: WikiTeq | ||
| date: 2025-05-01 | ||
| version: 1.0 | ||
| license: MIT | ||
| description: Searches the RAG-of-All-Trades knowledge base and returns relevant context for the user's query. | ||
| requirements: requests | ||
| """ | ||
|
|
||
| import asyncio | ||
| import logging | ||
| import re | ||
| from collections.abc import Awaitable, Callable | ||
| from typing import Optional | ||
|
|
||
| import requests | ||
| from pydantic import BaseModel, Field | ||
|
|
||
| log = logging.getLogger(__name__) | ||
| log.setLevel(logging.INFO) | ||
|
|
||
|
|
||
| def _parse_raw_chunk(raw_text: str) -> dict: | ||
| match = re.match(r"Score:\s*([\d.]+)\s*\|\s*Text:\s*(.*)", raw_text, re.DOTALL) | ||
| if match: | ||
| return {"score": float(match.group(1)), "text": match.group(2).strip()} | ||
| return {"score": 0.0, "text": raw_text.strip()} | ||
|
|
||
|
|
||
| def _get_filename_from_extras(extras: dict) -> Optional[str]: | ||
| return extras.get("key") or extras.get("filename") or extras.get("name") or None | ||
|
|
||
|
|
||
| def _call_rag_service(url: str, api_key: str, timeout: int, top_k: int, query: str) -> dict: | ||
| payload = {"query": query, "top_k": top_k, "metadata_filters": {}} | ||
| headers = {"Content-Type": "application/json"} | ||
| if api_key: | ||
| headers["Authorization"] = f"Bearer {api_key}" | ||
| url = url.strip() | ||
| log.info("Calling ROAT: query_length=%d", len(query)) | ||
| response = requests.post(url, json=payload, headers=headers, timeout=timeout) | ||
| response.raise_for_status() | ||
| return response.json() | ||
|
|
||
|
|
||
| def _format_context_and_sources(rag_result: dict) -> tuple[str, list]: | ||
| references = rag_result.get("references", []) or [] | ||
| raw_chunks = rag_result.get("raw") or [] | ||
|
|
||
| if not references and not raw_chunks: | ||
| return "", [] | ||
|
|
||
| context_parts = [] | ||
| sources = [] | ||
| _internal_fields = {"key", "format", "version", "checksum"} | ||
|
|
||
| for i in range(max(len(references), len(raw_chunks))): | ||
| ref = references[i] if i < len(references) else {} | ||
| extras = ref.get("extras") or {} | ||
| score = ref.get("score", 0.0) | ||
|
|
||
| text = ref.get("text", "") | ||
| if not text and i < len(raw_chunks): | ||
| parsed = _parse_raw_chunk(raw_chunks[i]) | ||
| text = parsed["text"] | ||
| if score == 0.0: | ||
| score = parsed["score"] | ||
|
|
||
| if not text: | ||
| continue | ||
|
|
||
| filename = _get_filename_from_extras(extras) | ||
| source_name = ( | ||
| ref.get("title") | ||
| or ref.get("source_name") | ||
| or filename | ||
| or f"Source {i + 1}" | ||
| ) | ||
|
|
||
| metadata_fields = {k: v for k, v in extras.items() if k not in _internal_fields} | ||
| metadata_fields["url"] = ref.get("url") or extras.get("url") | ||
| metadata_md = "\n".join( | ||
| f"- *{k}*: {v}" for k, v in metadata_fields.items() if v is not None | ||
| ) | ||
| metadata_section = f"## Metadata\n\n{metadata_md}" if metadata_md else "" | ||
|
|
||
| context_parts.append(f"[Source: {source_name}]\n\n{metadata_section}\n\n{text}\n") | ||
|
|
||
| source_obj = { | ||
| "source": {"name": source_name}, | ||
| "document": [text[:1000] if len(text) > 1000 else text], | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. While we're at it, can we please make this configurable via Valve, making it unlimited by default? |
||
| "metadata": [ | ||
| { | ||
| "source": source_name, | ||
| "file": filename, | ||
| "relevance_score": score, | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not something new - it was like that in the Filter function, but I noticed that the relevance_score is not displayed by the OWUI when viewing the citation source. Do you have any ideas why? |
||
| "type": extras.get("format", "document"), | ||
| "storage": extras.get("source"), | ||
| "key": extras.get("key"), | ||
| "checksum": extras.get("checksum"), | ||
| "version": extras.get("version"), | ||
| "format": extras.get("format"), | ||
| } | ||
| ], | ||
| } | ||
| url = ref.get("url") or extras.get("url") | ||
| if url: | ||
| source_obj["source"]["url"] = url | ||
|
|
||
| sources.append(source_obj) | ||
|
|
||
| return "\n".join(context_parts), sources | ||
|
|
||
|
|
||
| class Tools: | ||
| class Valves(BaseModel): | ||
| rag_service_url: str = Field( | ||
| default="", | ||
| description="Full URL to the ROAT query endpoint, e.g. http://api:8000/query.", | ||
| ) | ||
| rag_service_api_key: str = Field( | ||
| default="", | ||
| description="Bearer token for the ROAT API (leave blank if not required).", | ||
| ) | ||
| rag_service_timeout: int = Field( | ||
| default=30, | ||
| description="Request timeout in seconds.", | ||
| ) | ||
| top_k: int = Field( | ||
| default=20, | ||
| description="Number of top results to retrieve from the knowledge base.", | ||
| ) | ||
|
|
||
| def __init__(self): | ||
| self.valves = self.Valves() | ||
|
|
||
| async def search_knowledge_base( | ||
| self, | ||
| query: str, | ||
| __event_emitter__: Callable[[dict], Awaitable[None]] | None = None, | ||
| ) -> str: | ||
| """ | ||
| Search the mAItion knowledge base for information relevant to the user's question. | ||
|
|
||
| Use this tool when: | ||
| - The user asks a factual question that likely requires domain-specific knowledge | ||
| - The user asks about documented processes, policies, data, or internal content | ||
| - The user's question cannot be confidently answered from general training knowledge alone | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In fact, we want this tool used for most of the queries, and we want to disincentivize the LLM from using general knowledge. Thus, I propose to drop this line and instead replace it with something that would instruct the model to use this tool more frequently We're of course just guessing for now until we have some evaluations in place |
||
| - The user explicitly asks to search the knowledge base or documentation | ||
|
|
||
| Do NOT use this tool for: | ||
| - Casual greetings or small talk | ||
| - Simple arithmetic or general world knowledge | ||
| - Follow-up questions within an ongoing conversation where context was already retrieved | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think this specific instruction as follow up questions may require queries into the ROAT, let's try to reformulate this |
||
|
|
||
| Args: | ||
| query: A concise search query derived from the user's question | ||
|
|
||
| Returns: | ||
| Retrieved context passages from the knowledge base, or a message that nothing was found. | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please reword into |
||
| """ | ||
|
|
||
| async def emit(description: str, done: bool = False) -> None: | ||
| if __event_emitter__: | ||
| await __event_emitter__( | ||
| {"type": "status", "data": {"description": description, "done": done}} | ||
| ) | ||
|
|
||
| if not self.valves.rag_service_url: | ||
| await emit("Knowledge base URL is not configured in Tool Valves.", done=True) | ||
| return "Error: rag_service_url is not configured in Tool Valves." | ||
|
|
||
| await emit("Searching knowledge base…") | ||
|
|
||
| try: | ||
| log.info("ROAT url=%r top_k=%r timeout=%r", self.valves.rag_service_url, self.valves.top_k, self.valves.rag_service_timeout) | ||
| rag_result = await asyncio.to_thread( | ||
| _call_rag_service, | ||
| self.valves.rag_service_url, | ||
| self.valves.rag_service_api_key, | ||
| self.valves.rag_service_timeout, | ||
| self.valves.top_k, | ||
| query, | ||
| ) | ||
|
coderabbitai[bot] marked this conversation as resolved.
|
||
| except Exception as e: | ||
| log.error("ROAT request failed: %s", e, exc_info=True) | ||
| await emit("Failed to reach the knowledge base.", done=True) | ||
| return "Error: could not reach the knowledge base. Check the server logs for details." | ||
|
|
||
| context, sources = _format_context_and_sources(rag_result) | ||
|
|
||
| if not context: | ||
| await emit("No relevant information found.", done=True) | ||
| return "No relevant information was found in the knowledge base for this query." | ||
|
|
||
| if __event_emitter__: | ||
| for src in sources: | ||
| await __event_emitter__({"type": "source", "data": src}) | ||
|
|
||
| await emit(f"Found {len(sources)} relevant source(s).", done=True) | ||
| log.info("Returning context with %d sources (%d chars)", len(sources), len(context)) | ||
| return context | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this fits more to be included in a changelog (what we don't yet have) rather than in a README. Let's instead simply ensure both the Filter function and the Tool are documented