diff --git a/graphrag_sdk/src/graphrag_sdk/__init__.py b/graphrag_sdk/src/graphrag_sdk/__init__.py index a7c8ff3..81ed8f5 100644 --- a/graphrag_sdk/src/graphrag_sdk/__init__.py +++ b/graphrag_sdk/src/graphrag_sdk/__init__.py @@ -96,6 +96,11 @@ from graphrag_sdk.retrieval.reranking_strategies.base import RerankingStrategy from graphrag_sdk.retrieval.reranking_strategies.cosine import CosineReranker from graphrag_sdk.retrieval.strategies.base import RetrievalStrategy +from graphrag_sdk.retrieval.strategies.cypher_first import ( + CypherFirstAggregationStrategy, + DefaultPhraseExtractor, + PhraseExtractor, +) from graphrag_sdk.retrieval.strategies.multi_path import MultiPathRetrieval # ── Storage ───────────────────────────────────────────────────── @@ -160,7 +165,10 @@ "SemanticResolution", # Retrieval "CosineReranker", + "CypherFirstAggregationStrategy", + "DefaultPhraseExtractor", "MultiPathRetrieval", + "PhraseExtractor", "RerankingStrategy", "RetrievalStrategy", # Storage diff --git a/graphrag_sdk/src/graphrag_sdk/api/main.py b/graphrag_sdk/src/graphrag_sdk/api/main.py index 13c4635..dc0c34d 100644 --- a/graphrag_sdk/src/graphrag_sdk/api/main.py +++ b/graphrag_sdk/src/graphrag_sdk/api/main.py @@ -92,6 +92,27 @@ "or is NOT true, preserve that meaning." ) +# Optional addendum appended to the system prompt only when the retriever +# produced an "Authoritative Graph Query Results" section (currently emitted +# by ``CypherFirstAggregationStrategy`` and by ``MultiPathRetrieval`` with +# ``enable_cypher=True``). Kept out of the base prompts so callers who never +# use cypher retrieval don't pay the prompt-token cost or risk subtle +# behavior changes on factoid questions. +_CYPHER_AUTH_RULE = ( + "\n8. When the context contains a section labeled " + "'Authoritative Graph Query Results', that section is computed " + "deterministically from the knowledge graph. For quantitative or " + "'which has the most/fewest/exactly N' questions, prefer those " + "numbers over prose passages, even if passages mention other " + "entities more frequently." +) + +# Marker the retriever uses in section content / metadata to signal that an +# authoritative cypher result is present. Kept as a module-level constant so +# tests can pin the contract. +_CYPHER_RESULTS_SECTION = "cypher_results" +_AUTH_RESULTS_HEADING_MARKER = "Authoritative Graph Query Results" + _RAG_PROMPT = "\n{context}\n\n\nQuestion: {question}\n\nAnswer:" # Matches a literal ```` closing tag (case-insensitive, whitespace @@ -105,6 +126,23 @@ def _neutralize_context_close_tag(text: str) -> str: return _CONTEXT_CLOSE_RE.sub("", text) +def _has_authoritative_cypher_results(retriever_result: RetrieverResult) -> bool: + """True when any retrieved item is an authoritative graph-query result. + + Used to decide whether to append the cypher-authority rule to the + default system prompt. We check both the section metadata (preferred, + cheap) and the heading marker in the content (defensive — covers + third-party strategies that don't tag the metadata). + """ + for item in retriever_result.items or []: + section = (item.metadata or {}).get("section") if item.metadata else None + if section == _CYPHER_RESULTS_SECTION: + return True + if isinstance(item.content, str) and _AUTH_RESULTS_HEADING_MARKER in item.content: + return True + return False + + _QUESTION_REWRITE_PROMPT = ( "Given the conversation history, rewrite the user's last question " "as a standalone question that includes all entity names, dates, " @@ -709,6 +747,12 @@ async def completion( context_str = "\n---\n".join(item.content for item in retriever_result.items) default_system_content = _RAG_SYSTEM_PROMPT + # Conditionally append the cypher-authority rule. Only fires when the + # retriever produced an authoritative-results section — keeps the base + # system prompt unchanged for callers who don't use cypher retrieval. + if _has_authoritative_cypher_results(retriever_result): + default_system_content = default_system_content + _CYPHER_AUTH_RULE + # Step 4: Build messages — unified path for single-turn and multi-turn. # If history starts with role="system", honor it as-is (trust the # consumer). Otherwise inject the SDK's default instructions. diff --git a/graphrag_sdk/src/graphrag_sdk/retrieval/strategies/__init__.py b/graphrag_sdk/src/graphrag_sdk/retrieval/strategies/__init__.py index 8841387..4a5cc1c 100644 --- a/graphrag_sdk/src/graphrag_sdk/retrieval/strategies/__init__.py +++ b/graphrag_sdk/src/graphrag_sdk/retrieval/strategies/__init__.py @@ -1,6 +1,17 @@ # GraphRAG SDK — Retrieval: Strategies from graphrag_sdk.retrieval.strategies.base import RetrievalStrategy +from graphrag_sdk.retrieval.strategies.cypher_first import ( + CypherFirstAggregationStrategy, + DefaultPhraseExtractor, + PhraseExtractor, +) from graphrag_sdk.retrieval.strategies.multi_path import MultiPathRetrieval -__all__ = ["RetrievalStrategy", "MultiPathRetrieval"] +__all__ = [ + "CypherFirstAggregationStrategy", + "DefaultPhraseExtractor", + "MultiPathRetrieval", + "PhraseExtractor", + "RetrievalStrategy", +] diff --git a/graphrag_sdk/src/graphrag_sdk/retrieval/strategies/cypher_first.py b/graphrag_sdk/src/graphrag_sdk/retrieval/strategies/cypher_first.py new file mode 100644 index 0000000..7f10be3 --- /dev/null +++ b/graphrag_sdk/src/graphrag_sdk/retrieval/strategies/cypher_first.py @@ -0,0 +1,1061 @@ +# GraphRAG SDK — Retrieval: Cypher-First Aggregation Strategy +# Routes aggregation/quantitative questions through a deterministic graph- +# query path that treats Cypher as the answer source (not just another +# retrieval signal). Free-text properties (roles, projects) that aren't +# captured as typed entities are recovered by parsing Person chunk text at +# retrieval time. +# +# Non-aggregation questions delegate to a fallback strategy (default: +# MultiPathRetrieval). The strategy is therefore safe as the top-level +# strategy on GraphRAG — RAG questions still get the existing pipeline. +# +# Background: traditional RAG retrieves prose evidence and lets the LLM +# synthesize. For "how many X" / "which X has the most Y" / "BOTH A and B" +# questions, prose evidence is the wrong shape — the answer wants exact +# counts or set operations. Cypher gives us those deterministically, but +# only when the underlying graph has the right structure. This strategy +# wraps Cypher generation in three mechanisms that make it reliable on +# noisy, real-world graphs: +# +# 1. Multi-candidate Cypher with row-count selection (M2): K parallel +# samples, execute all, pick the one with the most rows. Beats LLM +# stochasticity without serial retries. +# 2. Column-named markdown table formatting (M3): uses FalkorDB's +# ``result.header`` so the synthesizer sees ``acme_count=10`` not the +# lossy ``10 | 7 | True``. +# 3. Description + chunk-text fuzzy hybrid (M5): for "shared X" / "BOTH +# A and B" questions where X is a free-text property (role, project) +# not extracted as a typed node, parse Person chunks via regex and +# compute the set operation in Python with fuzzy token matching. +# +# Plus a deterministic numeric-math path (M6) — RETURN raw values, then +# average / sum / median in Python — and a negation-existential branch +# (M7) that treats an empty Cypher result as the definitive "No" when the +# question shape demands it. + +from __future__ import annotations + +import asyncio +import logging +import re +import statistics +from abc import ABC, abstractmethod +from typing import Any + +from graphrag_sdk.core.context import Context +from graphrag_sdk.core.models import ( + RawSearchResult, + RetrieverResult, + RetrieverResultItem, +) +from graphrag_sdk.core.providers import Embedder, LLMInterface +from graphrag_sdk.retrieval.strategies.base import RetrievalStrategy +from graphrag_sdk.retrieval.strategies.cypher_generation import ( + SCHEMA_PROMPT, + _sanitize_cypher, + extract_cypher, + validate_cypher, +) +from graphrag_sdk.retrieval.strategies.multi_path import MultiPathRetrieval + +logger = logging.getLogger(__name__) + + +# ───────────────────────────────────────────────────────────────── +# Intent classification (M1) +# ───────────────────────────────────────────────────────────────── + +_AGG_INTENT_PATTERNS = [ + r"\bhow\s+many\b", r"\bhow\s+much\b", + r"\bwhich\s+\w+\b", + r"\bwhat\s+(?:is\s+the\s+)?(?:average|mean|median|total|sum|count|number)\b", + r"\baverage\b", r"\bmedian\b", r"\btotal\b", + r"\bcount\s+of\b", r"\bnumber\s+of\b", + r"\blist\s+(?:all|the|every)\b", + r"\blist\s+\w+\s+(?:that|with|where|who)\b", + r"\bare\s+there\s+any\b", r"\bis\s+there\s+any\b", + # `more X than` / `fewer X than` — allow up to 4 words between + r"\bmore(?:\s+\S+){1,4}\s+than\b", + r"\bfewer(?:\s+\S+){0,4}\s+than\b", + r"\bless(?:\s+\S+){0,4}\s+than\b", + r"\bboth\s+\w+(?:\s+\S+)*?\s+and\s+\w+\b", + r"\bbetween\s+\w+\s+and\s+\w+\b", + r"\bexactly\s+\d+\b", r"\bat\s+least\s+\d+\b", + r"\b(?:does|do|has|have|is|are)\s+(?:[A-Z]\w*\s+){1,4}(?:have|has|contain|own|run|host)\b", + r"\bwho\s+(?:works|live|is)\s+(?:at|in|on)\b", +] +_AGG_INTENT_RE = re.compile("|".join(_AGG_INTENT_PATTERNS), re.IGNORECASE) + +_NUMERIC_AGG_RE = re.compile( + r"\b(average|mean|median|total|sum)\b\s+(?:of\s+)?(?:the\s+)?(?:\w+\s+)?" + r"\b(year|years|age|amount|price|cost|revenue|number|count|" + r"founding|founded|salary|salaries)\b", + re.IGNORECASE, +) + +_YES_NO_RE = re.compile( + r"^\s*(?:is|are|was|were|do|does|did|has|have|had|can|could|will|" + r"would|should|may|might|are\s+there|is\s+there)\b", + re.IGNORECASE, +) + +_WHICH_LIST_RE = re.compile( + r"^\s*(?:which|what|list|name|identify|enumerate)\b", + re.IGNORECASE, +) + + +def detect_aggregation_intent(question: str) -> str: + """Return ``"numeric_math"``, ``"aggregation"``, or ``"rag"``. + + ``"numeric_math"`` triggers the Python-arithmetic path; ``"aggregation"`` + triggers the Cypher-first table path; ``"rag"`` falls back to the + standard retrieval strategy. + """ + if _NUMERIC_AGG_RE.search(question): + return "numeric_math" + if _AGG_INTENT_RE.search(question): + return "aggregation" + return "rag" + + +def is_yes_no(question: str) -> bool: + return bool(_YES_NO_RE.match(question)) + + +def is_which_list(question: str) -> bool: + return bool(_WHICH_LIST_RE.match(question)) + + +def _is_negation_existential(question: str) -> bool: + """True for "are there any X without/no Y?" — empty Cypher = definitive No.""" + has_negation = bool( + re.search(r"\b(?:no|not|without|never|none)\b", question, re.IGNORECASE) + ) + has_existential = bool( + re.search(r"\b(?:any|are\s+there|is\s+there)\b", question, re.IGNORECASE) + ) + return has_negation and has_existential + + +# ───────────────────────────────────────────────────────────────── +# Free-text phrase extraction for description hybrid (M5) +# ───────────────────────────────────────────────────────────────── + +# Roles end in a closed set of professional suffixes — anchoring on those +# keeps noisy phrases ("the cross-region replication initiative") from +# getting captured. +_ROLE_RE = re.compile( + r"\b(?:as|is)\s+(?:an?\s+)?" + r"((?:[a-z][\w\s]*?)?\b(?:engineer|scientist|manager|architect|researcher|" + r"developer|analyst|specialist|designer|consultant)s?)\b", + re.IGNORECASE, +) + +_PROJECT_RE = re.compile( + r"\bcontribut(?:es|ing)\s+to\s+" + r"(?:the\s+|a\s+|an\s+)?" + r"(.+?)" + r"(?=\s+and\s+(?:is|active)|\s+based\s+in|,|\.\s|\.$|$)", + re.IGNORECASE, +) + +_BOTH_AB_RE = re.compile( + r"\bboth\s+([\w\s]+?)\s+and\s+([\w\s'-]+?)(?:\?|\.|$|\bof\b|\bat\b|\bwith\b)", + re.IGNORECASE, +) + +_SAME_AS_RE = re.compile( + r"\b(?:same\s+\w+\s+as|share\s+(?:an?\s+)?\w+\s+with|" + r"share\s+(?:any\s+of\s+)?\w+\s+with)" + r"\s+(?:someone\s+(?:at|in|from)\s+)?([\w\s]+?)(?:\?|\.|$)", + re.IGNORECASE, +) + + +def _extract_roles(text: str) -> set[str]: + out: set[str] = set() + for m in _ROLE_RE.finditer(text or ""): + role = re.sub(r"\s+", " ", m.group(1).strip().lower()) + if 3 < len(role) < 60: + out.add(role) + return out + + +def _extract_projects(text: str) -> set[str]: + out: set[str] = set() + for m in _PROJECT_RE.finditer(text or ""): + proj = re.sub(r"\s+", " ", m.group(1).strip().lower()) + proj = re.sub(r"\s+(?:and|who|while)$", "", proj) + if 5 < len(proj) < 100: + out.add(proj) + return out + + +def _extract_phrases(text: str, kind: str) -> set[str]: + if kind == "role": + return _extract_roles(text) + if kind == "project": + return _extract_projects(text) + return set() + + +# ───────────────────────────────────────────────────────────────── +# Pluggable phrase extractor (R8) +# ───────────────────────────────────────────────────────────────── + + +class PhraseExtractor(ABC): + """Pluggable phrase extractor for the shared-property hybrid path. + + The default implementation targets the prose patterns the SDK's + ``GraphExtraction`` pipeline produces (``"works at X as a "``, + ``"contributes to "``). Domain-specific subclasses can + override ``extract`` to recognize medical / legal / e-commerce / + non-English vocabularies without forking the strategy. + + Pass an instance via ``CypherFirstAggregationStrategy(..., phrase_extractor=MyExtractor())``. + """ + + @abstractmethod + def extract(self, text: str, kind: str) -> set[str]: + """Return phrases of ``kind`` found in ``text``. + + ``kind`` is typically ``"role"`` or ``"project"`` but extractors + may define additional kinds. Unknown kinds should return an empty + set rather than raise. + """ + + +class DefaultPhraseExtractor(PhraseExtractor): + """The default role/project extractor used by + ``CypherFirstAggregationStrategy``. Targets the prose patterns + produced by the SDK's default ``GraphExtraction`` pipeline. + + See module-level ``_ROLE_RE`` and ``_PROJECT_RE`` for the exact + regexes; the role suffix vocabulary is enumerated in the + ``CypherFirstAggregationStrategy`` docstring. + """ + + def extract(self, text: str, kind: str) -> set[str]: + return _extract_phrases(text, kind) + + +def _detect_property_kind(question: str) -> str | None: + q = question.lower() + if re.search(r"\b(?:role|roles|job|jobs|title|titles|position|positions)\b", q): + return "role" + if re.search( + r"\b(?:project|projects|initiative|initiatives|work\s+on|works\s+on|" + r"working\s+on|contribute|contributes|same\s+thing)\b", + q, + ): + return "project" + return None + + +def _fuzzy_intersect(a: set[str], b: set[str]) -> set[str]: + """Return phrases from ``a`` that fuzzy-match any phrase in ``b``. + + Fuzzy = case-insensitive substring in either direction OR ≥2 shared + content tokens. Catches cases where extraction paraphrased a project + name, e.g. "next-generation pipeline rewrite" vs "pipeline rewrite". + """ + stop = {"the", "a", "an", "to", "of", "for", "in", "on", "and", "or"} + + def _tokens(s: str) -> set[str]: + return {t.lower() for t in re.findall(r"\w+", s) + if t.lower() not in stop and len(t) > 2} + + out: set[str] = set() + for x in a: + xt = _tokens(x) + if not xt: + continue + for y in b: + if x == y or x in y or y in x: + out.add(x) + break + yt = _tokens(y) + if len(xt & yt) >= 2: + out.add(x) + break + return out + + +# ───────────────────────────────────────────────────────────────── +# Cypher result → markdown table (M3) +# ───────────────────────────────────────────────────────────────── + +def format_result_as_markdown_table( + result: Any, + *, + cap: int = 100, +) -> tuple[str, list[dict[str, Any]], bool]: + """Render a FalkorDB result_set as a markdown table with column headers. + + Returns ``(table_md, parsed_rows, truncated)``. ``parsed_rows`` is a list + of dicts keyed by column name so callers can post-process structured + data without re-parsing the markdown. + """ + if not getattr(result, "result_set", None): + return "(empty result)", [], False + + headers: list[str] = [] + if getattr(result, "header", None): + for h in result.header: + if isinstance(h, (list, tuple)) and len(h) >= 2: + headers.append(str(h[1])) + else: + headers.append(str(h)) + if not headers: + n_cols = len(result.result_set[0]) if result.result_set else 0 + headers = [f"col_{i}" for i in range(n_cols)] + + rows = result.result_set[:cap] + truncated = len(result.result_set) > cap + + parsed_rows: list[dict[str, Any]] = [] + for row in rows: + d: dict[str, Any] = {} + for i, val in enumerate(row): + d[headers[i] if i < len(headers) else f"col_{i}"] = val + parsed_rows.append(d) + + sep = " | " + lines = [sep.join(headers), sep.join(["---"] * len(headers))] + for row in rows: + cells: list[str] = [] + for v in row: + if v is None: + cells.append("(null)") + elif isinstance(v, list): + cells.append(", ".join(str(x) for x in v)) + else: + cells.append(str(v)) + lines.append(sep.join(cells)) + if truncated: + lines.append( + f"... (showing {len(rows)} of {len(result.result_set)} rows; " + "result truncated)" + ) + return "\n".join(lines), parsed_rows, truncated + + +# ───────────────────────────────────────────────────────────────── +# Schema prompt enrichment for aggregation generation +# ───────────────────────────────────────────────────────────────── + +_AGG_SCHEMA_SUFFIX = """ + +## Additional rules for AGGREGATION questions + +- For "BOTH X AND Y" / set-intersection over a shared property (e.g. "roles + held at both A and B"), use TWO separate matches against DIFFERENT + people, joined on the shared entity: + MATCH (p1:Person)-[:RELATES]-(o1:Organization), + (p1)-[:RELATES]-(prop:__Entity__), + (p2:Person)-[:RELATES]-(o2:Organization), + (p2)-[:RELATES]-(prop) + WHERE o1.name CONTAINS '...' AND o2.name CONTAINS '...' + RETURN DISTINCT prop.name AS shared_value + The shared `prop` variable IS the intersection. +- For "more X than Y" comparison questions, RETURN both counts AS named + columns (e.g. RETURN count(...) AS acme_count, count(...) AS initech_count) + so the answer is unambiguous. +- For "average / total of YEARS or NUMBERS", just RETURN the raw values + (e.g. d.name for Date entities). Arithmetic happens outside cypher. +- For "which X have/work/share Y" list questions, RETURN DISTINCT only the + X you're asked about — do not add extra columns. +- Always alias every RETURN column with a descriptive name. +- Prefer RETURNing one row per group in group-by patterns rather than + packing multiple counts into a single row. +""" + +_DESC_HINT_SUFFIX = ( + "\n\nIMPORTANT: Person entities have rich free-text in their " + "`description` property — phrases like 'works at X as a senior " + "engineer' or 'contributes to internal tooling for observability'. " + "If the question asks about ROLES, JOBS, PROJECTS, or other free-text " + "properties that are NOT first-class entities in the schema, prefer " + "filtering on `p.description CONTAINS '...'` over matching typed " + "structural edges.\n" +) + + +# ───────────────────────────────────────────────────────────────── +# Aggregation-mode answer-side directive +# (synthesized into the retrieved item content; the existing system +# prompt rule 8 in api/main.py already tells the LLM to trust the +# "Authoritative Graph Query Results" section.) +# ───────────────────────────────────────────────────────────────── + +_AUTH_HEADING = ( + "## Authoritative Graph Query Results " + "(deterministic; trust over passages on counts and aggregates)" +) + + +def _wrap_authoritative(body: str, *, source_note: str = "") -> str: + note = ( + f"\nSource: {source_note}." if source_note else + "\nSource: text-to-Cypher run against the knowledge graph." + ) + return f"{_AUTH_HEADING}{note}\n\n{body}" + + +# Canonical labels for the sub-path metadata key ``cypher_first_path``. +# Operators / metrics dashboards can group on these to see which path +# fires for each query. +PATH_NUMERIC_MATH = "numeric_math" +PATH_SHARED_PROPERTY_HYBRID = "shared_property_hybrid" +PATH_CYPHER_TABLE = "cypher_table" +PATH_NEGATION_EMPTY_NO = "negation_empty_no" +PATH_RAG_FALLBACK = "rag_fallback" +PATH_RAG_FALLBACK_NUMERIC_FAIL = "rag_fallback_numeric_fail" +PATH_RAG_FALLBACK_CYPHER_EMPTY = "rag_fallback_cypher_empty" + + +def _tag_path(result: RawSearchResult, path: str) -> RawSearchResult: + """Attach the ``cypher_first_path`` label to a strategy result. + + Used both for results we construct ourselves and for results returned + from the delegated RAG fallback strategy — operators get a uniform + signal regardless of which branch handled the query. + """ + meta = dict(result.metadata or {}) + meta.setdefault("strategy", "cypher_first") + meta["cypher_first_path"] = path + return RawSearchResult(records=result.records, metadata=meta) + + +# ───────────────────────────────────────────────────────────────── +# Sub-paths +# +# Each path is a small, focused class with a single ``maybe_handle()`` +# method that either produces a final ``RawSearchResult`` or returns +# ``None`` to defer to the next path. The strategy's ``_execute()`` +# dispatches by intent and iterates the relevant paths in order. +# Splitting this way makes the routing trivial to follow, each path +# trivially unit-testable in isolation, and adding new shapes (medical / +# legal / e-commerce) a matter of dropping in a new path class. +# ───────────────────────────────────────────────────────────────── + + +class _AggregationPath(ABC): + """Base class for CypherFirstAggregationStrategy sub-paths. + + Holds a reference to the parent strategy so subclasses can reach the + shared LLM / graph / fallback / k_candidates state without dragging + around a long argument list. + """ + + def __init__(self, strategy: CypherFirstAggregationStrategy) -> None: + self._s = strategy + + @abstractmethod + async def maybe_handle( + self, + query: str, + ctx: Context, + ) -> RawSearchResult | None: + """Return a final retrieval result, or ``None`` to defer.""" + + +class _RagDelegationPath(_AggregationPath): + """Hands the query to the RAG fallback verbatim. Used for intent="rag".""" + + async def maybe_handle( + self, + query: str, + ctx: Context, + ) -> RawSearchResult | None: + return _tag_path( + await self._s._fallback._execute(query, ctx), + PATH_RAG_FALLBACK, + ) + + +class _NumericMathPath(_AggregationPath): + """For "average / total / sum of YEARS / NUMBERS" — extract values via + Cypher, do the arithmetic in Python. Avoids LLM-arithmetic errors.""" + + async def maybe_handle( + self, + query: str, + ctx: Context, + ) -> RawSearchResult | None: + extraction_prompt = SCHEMA_PROMPT.format( + question=( + f"Generate a cypher that returns the RAW NUMERIC VALUES " + f"needed to answer this question (one value per row). " + f"Do NOT compute averages or sums in cypher; just return " + f"the raw numbers. Use Date entities if the question is " + f"about years.\n\n" + f"Question: {query}" + ) + ) + cypher: str | None = None + values: list[float] = [] + try: + resp = await self._s._llm.ainvoke(extraction_prompt) + cypher = extract_cypher(resp.content) + errors = validate_cypher(cypher) if cypher else ["empty"] + if errors: + logger.debug("Numeric-math cypher validation failed: %s", errors) + cypher = None + else: + cypher = _sanitize_cypher(cypher) + result = await self._s._graph.query_raw(cypher) + for row in (result.result_set or []): + for cell in row: + v = _coerce_number(cell) + if v is not None: + values.append(v) + except Exception as exc: + logger.debug("Numeric-math extraction failed: %s", exc) + + if not values: + # Fall back to standard retrieval — the LLM may still be able + # to extract the numbers from chunks. + ctx.log("CypherFirst numeric_math: no values extracted, " + "falling back to RAG") + return _tag_path( + await self._s._fallback._execute(query, ctx), + PATH_RAG_FALLBACK_NUMERIC_FAIL, + ) + + q_lower = query.lower() + if "median" in q_lower: + ans = statistics.median(values) + op = "median" + elif "total" in q_lower or "sum" in q_lower: + ans = float(sum(values)) + op = "sum" + else: # average / mean / default for "average-like" questions + ans = sum(values) / len(values) + op = "average" + + ans_str = f"{int(ans)}" if ans == int(ans) else f"{ans:.1f}" + body = ( + f"computed {op} = {ans_str}\n" + f"source_values ({len(values)} rows): " + f"{', '.join(str(int(v)) if v == int(v) else f'{v:.2f}' for v in values)}\n" + f"cypher: {cypher}" + ) + return RawSearchResult( + records=[{ + "section": "cypher_results", + "content": _wrap_authoritative( + body, + source_note="numeric extraction + Python arithmetic", + ), + }], + metadata={ + "strategy": "cypher_first", + "cypher_first_path": PATH_NUMERIC_MATH, + "op": op, + "value": ans, + "n_values": len(values), + "cypher": cypher, + }, + ) + + +class _SharedPropertyHybridPath(_AggregationPath): + """For "BOTH A and B" / "same X as Z" questions over free-text + properties (role, project) that aren't first-class entities. Parses + Person chunks via regex and computes the set operation in Python + with fuzzy token matching.""" + + async def maybe_handle( + self, + query: str, + ctx: Context, + ) -> RawSearchResult | None: + kind = _detect_property_kind(query) + if kind is None: + return None + shape1 = _BOTH_AB_RE.search(query) + shape2 = _SAME_AS_RE.search(query) + if not (shape1 or shape2): + return None + + batch_cypher = ( + "MATCH (o:Organization)<-[:RELATES]-(p:Person) " + "OPTIONAL MATCH (p)-[:MENTIONED_IN]->(c:Chunk) " + "RETURN o.name AS org, p.name AS person, " + " p.description AS desc, collect(DISTINCT c.text) AS chunks" + ) + try: + batch_res = await self._s._graph.query_raw(batch_cypher) + except Exception as exc: + logger.debug("Shared-property hybrid batch query failed: %s", exc) + return None + + # Topology check: if the batched ``(Org)<-[:RELATES]-(Person)`` query + # returns zero tuples, the graph doesn't match the assumptions M5 + # was tuned on (Person ↔ Organization edges + MENTIONED_IN chunks). + # Surface this loudly once per call so operators using custom + # extractors get a fast signal rather than silent wrong answers. + if not (batch_res.result_set or []): + logger.warning( + "CypherFirst shared-property hybrid found zero " + "(Organization)<-[:RELATES]-(Person) tuples; falling through. " + "If your graph uses different edge shapes or doesn't extract " + "Person/Organization labels, this hybrid will never fire — " + "see the strategy docstring's 'Assumptions and known limits' " + "section." + ) + return None + + org_phrase_map: dict[str, set[str]] = {} + for row in (batch_res.result_set or []): + org = row[0] or "" + person = row[1] or "" + desc = row[2] or "" + chunks = row[3] or [] + text_blob = desc + "\n" + "\n".join(chunks) + phrases = org_phrase_map.setdefault(org, set()) + for sent in re.split(r"(?<=[.\n])\s+", text_blob): + # Sentence-restrict to this person to avoid cross-paragraph + # contamination from chunks that contain multiple people. + if person and person.split()[0] not in sent: + continue + phrases |= self._s._phrase_extractor.extract(sent, kind) + + def _gather(org_name: str) -> set[str]: + out: set[str] = set() + for name, phrases in org_phrase_map.items(): + if org_name.lower() in (name or "").lower(): + out |= phrases + return out + + if shape1: + org_a, org_b = (shape1.group(1).strip(), shape1.group(2).strip()) + if not org_a or not org_b: + return None + a_phrases = _gather(org_a) + b_phrases = _gather(org_b) + common = sorted(_fuzzy_intersect(a_phrases, b_phrases)) + if not common: + return None + ctx.log(f"CypherFirst hybrid shape1: {len(common)} shared {kind}s") + body = ( + f"The {kind}s held by employees at both {org_a} and " + f"{org_b}: " + ", ".join(common) + ) + return RawSearchResult( + records=[{ + "section": "cypher_results", + "content": _wrap_authoritative( + body, + source_note=( + "Person chunks + description regex; " + "fuzzy-intersected by content tokens" + ), + ), + }], + metadata={ + "strategy": "cypher_first", + "cypher_first_path": PATH_SHARED_PROPERTY_HYBRID, + "shape": "both_a_and_b", + "kind": kind, + "common": common, + }, + ) + + # shape2 + target = shape2.group(1).strip().rstrip(",") + target_phrases = _gather(target) + if not target_phrases: + return None + sharing: list[str] = [] + for org_name, org_phrases in org_phrase_map.items(): + if not org_name: + continue + if ( + target.lower() in org_name.lower() + or org_name.lower() in target.lower() + ): + continue + if _fuzzy_intersect(org_phrases, target_phrases): + sharing.append(org_name) + sharing.sort() + if not sharing: + return None + ctx.log(f"CypherFirst hybrid shape2: {len(sharing)} sharing orgs") + body = ( + "The organizations that have at least one employee working on " + f"the same {kind} as someone at {target}: " + ", ".join(sharing) + ) + return RawSearchResult( + records=[{ + "section": "cypher_results", + "content": _wrap_authoritative( + body, + source_note=( + "Person chunks + description regex; fuzzy-matched " + "across orgs" + ), + ), + }], + metadata={ + "strategy": "cypher_first", + "cypher_first_path": PATH_SHARED_PROPERTY_HYBRID, + "shape": "same_as", + "kind": kind, + "sharing": sharing, + }, + ) + + +class _MultiCandidateCypherPath(_AggregationPath): + """Generates K parallel cypher candidates, executes them all, renders + the highest-row-count result as a markdown table with column headers. + + Also handles the empty-result branches: + - negation-existential ("are there any X without Y?") → return No + - everything else → delegate to the RAG fallback + Always returns a result (never ``None``) — this path is the last + line of defense for aggregation intent. + """ + + async def maybe_handle( + self, + query: str, + ctx: Context, + ) -> RawSearchResult | None: + # Pass 1: structural. + candidates = await self._generate_k_candidates(query) + cypher, table_md, parsed, truncated = await self._execute_and_pick( + candidates, + ) + + # Pass 2 (description hint): if pass 1 was sparse for a "which X" + # or "shared X" question, try again with the description hint + # enabled. Cheap because cypher-gen runs in parallel. + expects_many = is_which_list(query) or re.search( + r"\bboth\b|\bshared\b|\bsame\b|\bcommon\b|\bin\s+common\b", + query, re.IGNORECASE, + ) + if expects_many and (parsed is None or len(parsed) < 3): + more = await self._generate_k_candidates(query, with_desc_hint=True) + combined = list({*(candidates or []), *(more or [])}) + cypher2, table_md2, parsed2, truncated2 = await self._execute_and_pick(combined) + if parsed2 and len(parsed2) > len(parsed or []): + cypher, table_md, parsed, truncated = ( + cypher2, table_md2, parsed2, truncated2, + ) + + rows = len(parsed) if parsed else 0 + ctx.log(f"CypherFirst cypher_table: {rows} rows from " + f"{len(candidates)} candidates") + + if cypher and parsed: + directive = "" + if is_which_list(query): + directive = ( + "\nNOTE: This is a 'which / list' question. Enumerate " + "EVERY DISTINCT VALUE from the first column in your " + "answer — do not summarize, truncate, or pick a " + "subset unless the question explicitly asked for the " + "top/most/fewest one." + ) + body = table_md + directive + return RawSearchResult( + records=[{ + "section": "cypher_results", + "content": _wrap_authoritative(body), + }], + metadata={ + "strategy": "cypher_first", + "cypher_first_path": PATH_CYPHER_TABLE, + "cypher": cypher, + "cypher_rows": rows, + "cypher_truncated": truncated, + }, + ) + + # Cypher returned 0 / no candidate succeeded. + if is_yes_no(query) and _is_negation_existential(query): + ctx.log("CypherFirst empty-result branch: negation-existential = No") + return RawSearchResult( + records=[{ + "section": "cypher_results", + "content": _wrap_authoritative( + "No matching items: the cypher query returned 0 " + "rows. For a negation-existential question of " + "this shape, that means no such items exist.", + source_note="Cypher returned 0 rows (definitive)", + ), + }], + metadata={ + "strategy": "cypher_first", + "cypher_first_path": PATH_NEGATION_EMPTY_NO, + "cypher": cypher, + }, + ) + + # Vector fallback for everything else. + ctx.log("CypherFirst cypher empty — falling back to RAG") + return _tag_path( + await self._s._fallback._execute(query, ctx), + PATH_RAG_FALLBACK_CYPHER_EMPTY, + ) + + async def _generate_k_candidates( + self, + query: str, + *, + with_desc_hint: bool = False, + ) -> list[str]: + prompt = SCHEMA_PROMPT.format(question=query) + _AGG_SCHEMA_SUFFIX + if with_desc_hint: + prompt += _DESC_HINT_SUFFIX + + async def _one() -> str | None: + try: + resp = await self._s._llm.ainvoke(prompt) + cypher = extract_cypher(resp.content) + if not cypher: + return None + errors = validate_cypher(cypher) + if errors: + return None + return _sanitize_cypher(cypher) + except Exception as exc: + logger.debug("Candidate generation failed: %s", exc) + return None + + results = await asyncio.gather(*[_one() for _ in range(self._s._k)]) + # Dedupe while preserving order. + seen: set[str] = set() + out: list[str] = [] + for c in results: + if c and c not in seen: + seen.add(c) + out.append(c) + return out + + async def _execute_and_pick( + self, + candidates: list[str], + ) -> tuple[str | None, str, list[dict[str, Any]], bool]: + """Run all candidates in parallel; pick the one with most rows. + + Returns ``(cypher, table_md, parsed_rows, truncated)``. + """ + if not candidates: + return None, "(no candidate cypher)", [], False + results = await asyncio.gather( + *[self._s._graph.query_raw(c) for c in candidates], + return_exceptions=True, + ) + scored: list[tuple[int, int, str, Any]] = [] + for cypher, res in zip(candidates, results): + if isinstance(res, BaseException): + continue + rows = len(res.result_set) if res.result_set else 0 + cols = len(res.result_set[0]) if (res.result_set and res.result_set[0]) else 0 + scored.append((rows, cols, cypher, res)) + if not scored: + return None, "(no candidate executed successfully)", [], False + scored.sort(key=lambda t: (t[0], t[1]), reverse=True) + _, _, best_cypher, best_result = scored[0] + table_md, parsed, truncated = format_result_as_markdown_table(best_result) + return best_cypher, table_md, parsed, truncated + + +# ───────────────────────────────────────────────────────────────── +# Strategy +# ───────────────────────────────────────────────────────────────── + +class CypherFirstAggregationStrategy(RetrievalStrategy): + """Aggregation-aware retrieval strategy. + + Routes each question by detected intent: + + - ``"numeric_math"`` — RETURN raw values via Cypher, then compute the + ``average``/``sum``/``median`` in Python. Avoids LLM-arithmetic errors. + - ``"aggregation"`` — multi-candidate Cypher, pick the highest-row-count + result, render as a markdown table with column headers. Optionally + run a description+chunk-text fuzzy hybrid for "shared X" / "BOTH A + and B" shapes before falling back to vector retrieval on empty. + - ``"rag"`` — delegate to ``rag_fallback`` (default: ``MultiPathRetrieval``). + + Safe as the top-level strategy on ``GraphRAG``: non-aggregation + questions get the existing pipeline unchanged. + + Every returned :class:`RawSearchResult` carries a ``cypher_first_path`` + metadata key whose value is one of ``PATH_*`` module constants — useful + for operator dashboards that want to see which sub-path fired. + + Assumptions and known limits + ---------------------------- + The shared-property hybrid (M5) was tuned on graphs produced by the + SDK's default ``GraphExtraction`` pipeline. It makes the following + assumptions; when they're violated, the hybrid silently returns + ``None`` and the strategy falls back to the multi-candidate Cypher + path (which still works, just without free-text recovery): + + - **Graph topology.** Organizations are connected to Persons via + ``[:RELATES]``, and Persons are connected to ``Chunk`` nodes via + ``[:MENTIONED_IN]``. This is the canonical shape the SDK builds. + Custom extractors that use different edge types or skip chunk + provenance will not benefit from M5 — the strategy logs a warning + and continues. + - **Prose shape.** Role and project values are extracted by regex + from Person descriptions and chunk text. The regexes target the + phrasing patterns ``"works at X as a "`` and ``"contributes + to "``. Domains whose prose departs from these patterns + (medical / legal / e-commerce / non-English) will not match — the + result is empty role/project sets, not wrong answers, but the + hybrid won't help. + - **Role vocabulary.** The role extractor anchors on the suffixes + ``engineer | scientist | manager | architect | researcher | + developer | analyst | specialist | designer | consultant``. Other + job titles ("director", "VP", "lead") won't match. + + Accuracy ceiling + ---------------- + The strategy faithfully returns what is in the graph. Duplicate + entities ("Wayne En" vs "Wayne Enterprises"), chunk-boundary-truncated + names, or non-deduplicated short-form references ("Carla" vs "Carla + Okafor") all flow through into the answer. Cypher counts will be + inflated; "which X" lists will contain duplicates. These are + extraction-quality issues — not strategy bugs — and should be + addressed in the ingestion pipeline (resolver, coref, dedup). + + Args: + graph_store: Required for Cypher execution. + vector_store: Required for the RAG fallback. + embedder: Required for the RAG fallback. + llm: Required for Cypher generation + synthesis. + k_candidates: Number of parallel Cypher samples per aggregation + question. Default 3 — enough to surface alternate structural + interpretations without burning latency. + rag_fallback: Strategy used for non-aggregation intent. If + ``None``, a fresh ``MultiPathRetrieval`` is constructed + internally. + + Example:: + + strategy = CypherFirstAggregationStrategy( + graph_store=rag._graph_store, + vector_store=rag._vector_store, + embedder=embedder, + llm=llm, + ) + async with GraphRAG( + connection=conn, llm=llm, embedder=embedder, + embedding_dimension=256, + retrieval_strategy=strategy, + ) as rag: + ... + """ + + def __init__( + self, + graph_store: Any, + vector_store: Any, + embedder: Embedder, + llm: LLMInterface, + *, + k_candidates: int = 3, + rag_fallback: RetrievalStrategy | None = None, + phrase_extractor: PhraseExtractor | None = None, + ) -> None: + super().__init__(graph_store=graph_store, vector_store=vector_store) + self._embedder = embedder + self._llm = llm + self._k = max(1, k_candidates) + self._fallback = rag_fallback or MultiPathRetrieval( + graph_store=graph_store, + vector_store=vector_store, + embedder=embedder, + llm=llm, + ) + # Pluggable phrase extractor for the shared-property hybrid path. + # Override with a domain-specific subclass to recognize roles / + # projects beyond the default English-prose vocabulary. + self._phrase_extractor = phrase_extractor or DefaultPhraseExtractor() + # Sub-paths — each handles one shape of question. The order in + # which they're consulted is encoded in ``_execute`` below; the + # paths themselves don't know about each other. + self._rag_path = _RagDelegationPath(self) + self._numeric_path = _NumericMathPath(self) + self._hybrid_path = _SharedPropertyHybridPath(self) + self._cypher_table_path = _MultiCandidateCypherPath(self) + + # -- Template Method hook ------------------------------------- + + async def _execute( + self, + query: str, + ctx: Context, + **kwargs: Any, + ) -> RawSearchResult: + intent = detect_aggregation_intent(query) + ctx.log(f"CypherFirst intent={intent}") + + if intent == "rag": + # Non-aggregation questions don't benefit from any of the + # cypher-first mechanics — hand straight to the fallback. + return await self._rag_path.maybe_handle(query, ctx) + + if intent == "numeric_math": + return await self._numeric_path.maybe_handle(query, ctx) + + # intent == "aggregation": + # Try the description+chunk hybrid for "shared X" shapes first — + # when it fires it generally produces better answers than the + # multi-candidate cypher because chunks preserve original corpus + # phrasing that extraction may have summarized away. + hybrid_result = await self._hybrid_path.maybe_handle(query, ctx) + if hybrid_result is not None: + return hybrid_result + + # The multi-candidate cypher path always returns a result; it + # internally handles its own empty / negation / fallback branches. + return await self._cypher_table_path.maybe_handle(query, ctx) + + # NOTE: the per-path logic (numeric math, shared-property hybrid, + # multi-candidate cypher, etc.) lives in the ``_AggregationPath`` + # subclasses above this strategy. Keeping each path in its own class + # keeps the dispatch above readable and makes it trivial to swap one + # implementation out (e.g., a medical-prose phrase extractor) without + # touching the strategy itself. + + # -- Custom _format ------------------------------------------ + + def _format(self, raw: RawSearchResult) -> RetrieverResult: + """Render section records as markdown content items, preserving the + cypher metadata so callers / metrics can see which mode fired.""" + items: list[RetrieverResultItem] = [] + for record in raw.records: + content = record.get("content", "") if isinstance(record, dict) else str(record) + section = record.get("section", "") if isinstance(record, dict) else "" + if content: + items.append( + RetrieverResultItem( + content=content, + metadata={"section": section}, + ) + ) + return RetrieverResult(items=items, metadata=raw.metadata) + + + +def _coerce_number(cell: Any) -> float | None: + """Extract a single numeric value from a cypher result cell. + + Accepts ints/floats directly; for strings, pulls the first integer or + float substring (catches "1995" inside "1995 is the year ..."). + """ + if cell is None: + return None + if isinstance(cell, (int, float)): + return float(cell) + m = re.search(r"-?\d+(?:\.\d+)?", str(cell)) + return float(m.group()) if m else None diff --git a/graphrag_sdk/src/graphrag_sdk/retrieval/strategies/cypher_generation.py b/graphrag_sdk/src/graphrag_sdk/retrieval/strategies/cypher_generation.py index 0dd3d28..5d4232b 100644 --- a/graphrag_sdk/src/graphrag_sdk/retrieval/strategies/cypher_generation.py +++ b/graphrag_sdk/src/graphrag_sdk/retrieval/strategies/cypher_generation.py @@ -41,6 +41,30 @@ re.IGNORECASE, ) +# Default row cap auto-injected when the LLM's query lacks a LIMIT. +# Pure aggregations (count/sum/avg over no group-by) return one row, so +# they skip injection entirely; group-by lists need enough rows that a 10-org +# breakdown isn't truncated. +_DEFAULT_ROW_LIMIT = 100 + +_AGG_FN_NAMES = ("count", "sum", "avg", "min", "max", "collect", + "stdev", "percentileCont", "percentileDisc") +_AGG_FN_RE = re.compile( + r"^\s*(?:" + "|".join(_AGG_FN_NAMES) + r")\s*\(", + re.IGNORECASE, +) + +# Detects FUNCTION-style calls under a dotted namespace, e.g. +# ``apoc.text.regexGroups(...)``, ``gds.shortest.path(...)``, ``db.idx.fulltext.queryNodes(...)``. +# FalkorDB does not implement APOC/GDS/db plugins, so any dotted-namespace +# function call silently returns 0 rows at execution. We reject these in the +# validator so the existing retry-with-feedback loop can correct the query. +# Note: ``\bCALL\b`` already catches procedure-style invocations; this regex +# specifically targets the function-style pattern that slips through. +_DOTTED_FN_RE = re.compile( + r"\b([a-zA-Z_]\w*)\.([a-zA-Z_][\w.]*)\s*\(", +) + # ── Schema prompt ──────────────────────────────────────────────── SCHEMA_PROMPT = """\ @@ -81,6 +105,8 @@ - To find connections: `MATCH (a)-[:RELATES]-(b)` with entity name filters - To count: `RETURN count(DISTINCT e)` or `RETURN count(r)` - To list all of a type: `MATCH (e:Technology) RETURN e.name, e.description LIMIT 25` +- For "BOTH X AND Y" questions, use two MATCH clauses sharing the same + variable to express set intersection — never UNION. ## Examples @@ -122,11 +148,23 @@ LIMIT 20 ``` -Question: "What organizations are related to the technology?" +Question: "Which city has the most employees mentioned?" +Note: a Person and a Location are typically NOT directly connected — they +share an organization. Use a 2-hop traversal through the intermediary so the +group-by works on the real graph topology: ```cypher -MATCH (o:Organization)-[r:RELATES]-(t:Technology) -RETURN o.name AS organization, t.name AS technology, r.rel_type AS relation, r.fact AS evidence -LIMIT 20 +MATCH (p:Person)-[:RELATES]-(o:Organization)-[:RELATES]-(l:Location) +RETURN l.name AS city, count(DISTINCT p) AS employee_count +ORDER BY employee_count DESC +LIMIT 5 +``` + +Question: "Who works at BOTH Acme and Globex?" +```cypher +MATCH (p:Person)-[:RELATES]-(o1:Organization), + (p)-[:RELATES]-(o2:Organization) +WHERE o1.name CONTAINS 'Acme' AND o2.name CONTAINS 'Globex' +RETURN DISTINCT p.name AS person ``` ## Your task @@ -159,6 +197,53 @@ def extract_cypher(text: str) -> str: # ── Cypher sanitization ───────────────────────────────────────── +def _split_top_level_commas(s: str) -> list[str]: + """Split on commas that aren't inside parentheses/brackets/braces.""" + out: list[str] = [] + buf: list[str] = [] + depth = 0 + for ch in s: + if ch in "([{": + depth += 1 + buf.append(ch) + elif ch in ")]}": + depth = max(0, depth - 1) + buf.append(ch) + elif ch == "," and depth == 0: + out.append("".join(buf).strip()) + buf = [] + else: + buf.append(ch) + if buf: + out.append("".join(buf).strip()) + return [p for p in out if p] + + +def _is_pure_aggregation(cypher: str) -> bool: + """True iff the FINAL RETURN clause projects only aggregate functions. + + Pure aggregations (e.g., ``RETURN count(p)``) always return exactly one + row, so auto-injecting LIMIT is a no-op. Group-by patterns + (``RETURN o.name, count(p)``) are NOT pure — at least one projection is + a non-aggregate dimension that the LIMIT would actually apply to. + """ + # Find the final RETURN body, stopping at ORDER BY / SKIP / LIMIT / end. + matches = list(re.finditer( + r"\bRETURN\b\s+(.+?)(?=\bORDER\s+BY\b|\bSKIP\b|\bLIMIT\b|;|$)", + cypher, + re.IGNORECASE | re.DOTALL, + )) + if not matches: + return False + body = matches[-1].group(1).strip() + if not body: + return False + projections = _split_top_level_commas(body) + if not projections: + return False + return all(_AGG_FN_RE.match(p) for p in projections) + + def _sanitize_cypher(cypher: str) -> str: """Fix common LLM-generated Cypher issues before execution. @@ -174,9 +259,13 @@ def _sanitize_cypher(cypher: str) -> str: # Remove path variable assignments: "path = MATCH" -> "MATCH" cypher = re.sub(r"\bpath\s*=\s*", "", cypher, flags=re.IGNORECASE) - # Add LIMIT if missing (prevent runaway scans) + # Inject LIMIT only when the LLM didn't provide one AND the query isn't a + # single-row pure aggregation. Pure aggregations don't benefit from a cap; + # group-by lists do, but the previous default (25) was too small to fit a + # full 10-org breakdown. Skip on aggregations to avoid a misleading no-op. if not re.search(r"\bLIMIT\b", cypher, re.IGNORECASE): - cypher = cypher.rstrip().rstrip(";") + "\nLIMIT 25" + if not _is_pure_aggregation(cypher): + cypher = cypher.rstrip().rstrip(";") + f"\nLIMIT {_DEFAULT_ROW_LIMIT}" return cypher @@ -220,6 +309,18 @@ def validate_cypher(cypher: str) -> list[str]: if re.search(r"\bLOAD\s+CSV\b", cypher_norm, re.IGNORECASE): errors.append("LOAD CSV is not allowed in generated queries") + # Reject dotted-namespace function calls (apoc.*, gds.*, db.*). + # FalkorDB doesn't implement these plugins; the call silently returns 0 + # rows at execution. Surfacing it here lets the retry loop regenerate. + for ns, _ in _DOTTED_FN_RE.findall(cypher_norm): + errors.append( + f"Unsupported function namespace '{ns}.*' " + "(FalkorDB does not implement APOC/GDS/db plugin functions). " + "Use only built-in Cypher functions like count, sum, avg, " + "labels, toInteger, substring, etc." + ) + break # one error is enough — the LLM only needs to fix the pattern + # No write operations if _WRITE_KEYWORDS.search(cypher_norm): errors.append("Write operation detected — query must be read-only") @@ -287,43 +388,52 @@ async def generate_cypher( return None -async def execute_cypher_retrieval( - graph_store: Any, - llm: Any, - question: str, - *, - max_retries: int = 3, -) -> tuple[list[str], dict[str, dict]]: - """Full text-to-cypher retrieval: generate -> validate -> execute -> parse. +# Matches a typed entity label inside a node pattern so we can widen it to +# ``__Entity__`` on a 0-row retry. Captures the prefix (open paren + optional +# variable + colon) so ``re.sub`` keeps the surrounding shape intact. +_TYPED_NODE_LABEL_RE = re.compile( + r"(\(\s*\w*\s*:)(" + "|".join(re.escape(l) for l in _ENTITY_LABELS) + r")\b" +) - Results are intended to go DIRECTLY to the final LLM context - (as a dedicated "Cypher Query Results" section), NOT through - the cosine reranker. - Returns: - fact_strings: Formatted rows from Cypher execution. - entities: Dict of entity_id -> {name, description}. +def _widen_typed_labels(cypher: str) -> str: + """Swap typed entity labels (``:Person`` etc.) inside node patterns to + ``:__Entity__``. Used when a typed-label query returned 0 rows because + the extractor labelled the entity differently than the LLM expected.""" + return _TYPED_NODE_LABEL_RE.sub(r"\1__Entity__", cypher) - On any failure, returns empty results (silent degradation). - """ - cypher = await generate_cypher(llm, question, max_retries=max_retries) - if not cypher: - return [], {} - try: - result = await graph_store.query_raw(cypher) - except Exception as exc: - logger.debug("Cypher execution failed: %s — query: %s", exc, cypher) - return [], {} +def _should_widen_labels(cypher: str) -> bool: + """Gate for the 0-row label-widen fallback. - if not result.result_set: - return [], {} + Skip widening when the typed label IS the filter — i.e. the RETURN + aggregates over a labeled variable AND the query has no ``WHERE … CONTAINS`` + name predicate. In that case the user is asking "how many Persons?" and + widening would change the semantics. Otherwise (typical case: typed label + present alongside a name predicate or non-aggregate RETURN), widen. + + Conservative: when in doubt we skip the fallback rather than risk a + semantically-different query. + """ + if not _TYPED_NODE_LABEL_RE.search(cypher): + return False # nothing to widen + has_contains_filter = bool( + re.search(r"\bWHERE\b.*\bCONTAINS\b", cypher, re.IGNORECASE | re.DOTALL) + ) + if has_contains_filter: + return True + # No name filter — check whether the RETURN is an aggregate over a labeled + # variable. If so, widening turns "count Persons" into "count Entities". + if _is_pure_aggregation(cypher): + return False + return True - # Parse results into readable fact lines and entity dict + +def _parse_cypher_result_set(result_set: Any) -> tuple[list[str], dict[str, dict]]: + """Turn a FalkorDB result_set into (fact_strings, entities).""" fact_strings: list[str] = [] entities: dict[str, dict] = {} - - for row in result.result_set: + for row in result_set: parts: list[str] = [] for val in row: if val is None: @@ -333,11 +443,7 @@ async def execute_cypher_retrieval( parts.append(s) if not parts: continue - - line = " | ".join(parts) - fact_strings.append(line) - - # Extract entity names (strings that look like names, not numbers/lists) + fact_strings.append(" | ".join(parts)) for val in row: if ( isinstance(val, str) @@ -349,11 +455,75 @@ async def execute_cypher_retrieval( eid = val.strip().lower().replace(" ", "_") if eid and eid not in entities: entities[eid] = {"name": val.strip(), "description": ""} + return fact_strings, entities + + +async def execute_cypher_retrieval( + graph_store: Any, + llm: Any, + question: str, + *, + max_retries: int = 3, +) -> tuple[list[str], dict[str, dict], dict[str, Any]]: + """Full text-to-cypher retrieval: generate -> validate -> execute -> parse. + + Results are intended to go DIRECTLY to the final LLM context + (as a dedicated "Cypher Query Results" section), NOT through + the cosine reranker. + + Returns: + fact_strings: Formatted rows from Cypher execution. + entities: Dict of entity_id -> {name, description}. + metadata: Dict capturing diagnostic signal — ``cypher`` (final query + executed), ``cypher_rows`` (row count), ``cypher_fallback`` + (``"label_widened"`` if the 0-row fallback fired, else ``None``). + + On any failure, returns ``([], {}, metadata)`` — never raises. + """ + metadata: dict[str, Any] = { + "cypher": None, + "cypher_rows": 0, + "cypher_fallback": None, + } + + cypher = await generate_cypher(llm, question, max_retries=max_retries) + if not cypher: + return [], {}, metadata + metadata["cypher"] = cypher + + try: + result = await graph_store.query_raw(cypher) + except Exception as exc: + logger.debug("Cypher execution failed: %s — query: %s", exc, cypher) + return [], {}, metadata + + # 0-row recovery: try once with typed labels widened to __Entity__. + # The LLM's structural reasoning (joins, filters, aggregations) is preserved + # — only the label predicate is relaxed. No second LLM call. + if not result.result_set and _should_widen_labels(cypher): + widened = _widen_typed_labels(cypher) + if widened != cypher: + logger.debug("Cypher 0-row retry with widened labels: %s", widened[:120]) + try: + widened_result = await graph_store.query_raw(widened) + except Exception as exc: + logger.debug("Widened cypher execution failed: %s", exc) + else: + if widened_result.result_set: + result = widened_result + metadata["cypher"] = widened + metadata["cypher_fallback"] = "label_widened" + + if not result.result_set: + return [], {}, metadata + + fact_strings, entities = _parse_cypher_result_set(result.result_set) + metadata["cypher_rows"] = len(fact_strings) logger.debug( "Cypher retrieval: %d facts, %d entities from: %s", len(fact_strings), len(entities), - cypher[:120], + (metadata["cypher"] or "")[:120], ) - return fact_strings, entities + return fact_strings, entities, metadata diff --git a/graphrag_sdk/src/graphrag_sdk/retrieval/strategies/multi_path.py b/graphrag_sdk/src/graphrag_sdk/retrieval/strategies/multi_path.py index d7bcbd9..cc6e0b7 100644 --- a/graphrag_sdk/src/graphrag_sdk/retrieval/strategies/multi_path.py +++ b/graphrag_sdk/src/graphrag_sdk/retrieval/strategies/multi_path.py @@ -194,6 +194,7 @@ async def _execute( query_vector = await self._embedder.aembed_query(query) # 3. RELATES vector search + Text-to-Cypher (parallel when enabled) + cypher_metadata: dict[str, Any] = {} if self._enable_cypher: results = await asyncio.gather( search_relates_edges(self._vector, query_vector, self._rel_top_k), @@ -205,11 +206,11 @@ async def _execute( fact_strings_scored, rel_entities = [], {} else: fact_strings_scored, rel_entities = results[0] - # Unpack Cypher results + # Unpack Cypher results (3-tuple: facts, entities, metadata) cypher_facts: list[str] = [] cypher_entities: dict[str, dict] = {} if not isinstance(results[1], BaseException): - cypher_facts, cypher_entities = results[1] + cypher_facts, cypher_entities, cypher_metadata = results[1] else: fact_strings_scored, rel_entities = await search_relates_edges( self._vector, query_vector, self._rel_top_k @@ -303,6 +304,7 @@ async def _execute( source_passages, q_type_hint, cypher_results=cypher_facts if cypher_facts else None, + cypher_metadata=cypher_metadata or None, ) def _format(self, raw: RawSearchResult) -> RetrieverResult: diff --git a/graphrag_sdk/src/graphrag_sdk/retrieval/strategies/result_assembly.py b/graphrag_sdk/src/graphrag_sdk/retrieval/strategies/result_assembly.py index 6b2ad9a..d6742a2 100644 --- a/graphrag_sdk/src/graphrag_sdk/retrieval/strategies/result_assembly.py +++ b/graphrag_sdk/src/graphrag_sdk/retrieval/strategies/result_assembly.py @@ -13,6 +13,12 @@ logger = logging.getLogger(__name__) +# Maximum cypher result rows passed to the final LLM context. Was 20 — too low +# for group-by lists across the whole graph (e.g. a 10-org breakdown was being +# silently truncated). Pure aggregations return one row anyway, so a higher cap +# is essentially free. +_CYPHER_RESULT_CAP = 100 + def cosine_sim(a: list[float], b: list[float]) -> float: """Cosine similarity between two float vectors.""" @@ -145,11 +151,16 @@ def assemble_raw_result( source_passages: list[str], q_type_hint: str = "", cypher_results: list[str] | None = None, + cypher_metadata: dict[str, Any] | None = None, ) -> RawSearchResult: """Build structured RawSearchResult with section records. ``cypher_results`` are placed in their own section and are NOT subject to cosine reranking — they go directly to the final LLM. + + ``cypher_metadata`` (if non-empty) is merged into the result metadata + so callers and metrics can see whether the 0-row fallback fired, + whether truncation occurred, and the final cypher that ran. """ records: list[dict[str, Any]] = [] @@ -162,13 +173,28 @@ def assemble_raw_result( } ) - # Cypher Query Results (direct to LLM — not reranked) + # Authoritative Graph Query Results (direct to LLM — not reranked). + # Heading is deliberately worded to signal authority over prose passages on + # quantitative questions; pair with system-prompt rule 8 in api/main.py. + truncated = False if cypher_results: + shown = cypher_results[:_CYPHER_RESULT_CAP] + body_lines = [f"- {r}" for r in shown] + if len(cypher_results) > _CYPHER_RESULT_CAP: + truncated = True + body_lines.append( + f"- … (showing {len(shown)} of {len(cypher_results)} rows; " + "result was truncated)" + ) records.append( { "section": "cypher_results", - "content": "## Graph Query Results\n" - + "\n".join(f"- {r}" for r in cypher_results[:20]), + "content": ( + "## Authoritative Graph Query Results " + "(deterministic; trust over passages on counts and aggregates)\n" + "Source: text-to-Cypher run against the knowledge graph.\n" + + "\n".join(body_lines) + ), } ) @@ -218,7 +244,18 @@ def assemble_raw_result( } ) + metadata: dict[str, Any] = {"strategy": "multi_path"} + if cypher_metadata: + # Surface fallback firing rate, truncation, and the final cypher to + # callers / metrics. Don't overwrite if multiple keys conflict — the + # caller-supplied values win since they reflect the actual execution. + metadata.update( + {f"cypher_{k}" if not k.startswith("cypher") else k: v + for k, v in cypher_metadata.items()} + ) + metadata["cypher_truncated"] = truncated + return RawSearchResult( records=records, - metadata={"strategy": "multi_path"}, + metadata=metadata, ) diff --git a/graphrag_sdk/tests/test_cypher_first.py b/graphrag_sdk/tests/test_cypher_first.py new file mode 100644 index 0000000..412529f --- /dev/null +++ b/graphrag_sdk/tests/test_cypher_first.py @@ -0,0 +1,669 @@ +"""Tests for cypher_first strategy helpers. + +The strategy class itself needs a graph + LLM, so we exercise the +pure-Python pieces (intent classifier, regex extractors, fuzzy intersect, +table formatter, numeric coercion) without external dependencies. +""" + +from __future__ import annotations + +from types import SimpleNamespace + +import pytest + +from graphrag_sdk.core.models import RawSearchResult +from graphrag_sdk.retrieval.strategies.cypher_first import ( + PATH_CYPHER_TABLE, + PATH_NEGATION_EMPTY_NO, + PATH_NUMERIC_MATH, + PATH_RAG_FALLBACK, + PATH_RAG_FALLBACK_CYPHER_EMPTY, + PATH_RAG_FALLBACK_NUMERIC_FAIL, + PATH_SHARED_PROPERTY_HYBRID, + _coerce_number, + _detect_property_kind, + _extract_phrases, + _extract_projects, + _extract_roles, + _fuzzy_intersect, + _is_negation_existential, + _tag_path, + detect_aggregation_intent, + format_result_as_markdown_table, + is_which_list, + is_yes_no, +) + +# ── Intent classifier ──────────────────────────────────────────── + + +class TestDetectIntent: + def test_count_questions_are_aggregation(self): + for q in [ + "How many people work at Acme Corp?", + "How many distinct organizations are mentioned?", + "Count of employees by org?", + ]: + assert detect_aggregation_intent(q) == "aggregation", q + + def test_which_most_is_aggregation(self): + for q in [ + "Which city has the most employees mentioned?", + "Which organization has the fewest employees?", + "Which orgs share a project with Acme?", + ]: + assert detect_aggregation_intent(q) == "aggregation", q + + def test_more_than_multi_word(self): + # The intent regex must catch up to 4 words between `more` and `than`. + q = "Does Acme Corp have more employees mentioned than Initech Systems?" + assert detect_aggregation_intent(q) == "aggregation" + + def test_both_a_and_b_is_aggregation(self): + q = "Which job roles are held by employees at BOTH Acme Corp and Initech Systems?" + assert detect_aggregation_intent(q) == "aggregation" + + def test_existential_is_aggregation(self): + for q in [ + "Are there any organizations with no employees listed?", + "Is there any employee at Acme on observability tooling?", + ]: + assert detect_aggregation_intent(q) == "aggregation", q + + def test_average_year_is_numeric_math(self): + q = "What is the average year of founding across all 10 organizations?" + assert detect_aggregation_intent(q) == "numeric_math" + + def test_total_revenue_is_numeric_math(self): + assert detect_aggregation_intent("What is the total revenue?") == "numeric_math" + + def test_factoid_is_rag(self): + for q in [ + "Who is the lighthouse keeper?", + "Tell me about Acme Corp.", + "What did the professor discover?", + ]: + assert detect_aggregation_intent(q) == "rag", q + + +class TestShapeDetectors: + def test_yes_no_starts(self): + assert is_yes_no("Is there any X?") + assert is_yes_no("Does Acme have more employees than Initech?") + assert is_yes_no("Are there any orgs without employees?") + + def test_yes_no_negative(self): + assert not is_yes_no("Which X has the most Y?") + assert not is_yes_no("How many people?") + + def test_which_list_starts(self): + assert is_which_list("Which cities have offices?") + assert is_which_list("List the orgs with 5 employees.") + assert is_which_list("Name the people at Acme.") + + def test_which_list_negative(self): + assert not is_which_list("How many people work here?") + assert not is_which_list("Is there any employee at Acme?") + + def test_negation_existential(self): + assert _is_negation_existential( + "Are there any organizations with NO employees listed?" + ) + assert _is_negation_existential( + "Is there any company without an office?" + ) + assert not _is_negation_existential( + "Is there any employee at Acme who works on observability?" + ) + assert not _is_negation_existential("How many people work here?") + + +# ── Property kind + extractors (M5 hybrid) ─────────────────────── + + +class TestPropertyKind: + def test_role_keywords(self): + for q in [ + "What roles are at Acme?", + "Which jobs are common?", + "List the titles at Initech.", + ]: + assert _detect_property_kind(q) == "role", q + + def test_project_keywords(self): + for q in [ + "Who works on observability?", + "Which projects are shared?", + "What initiatives is Acme contributing to?", + ]: + assert _detect_property_kind(q) == "project", q + + def test_other_returns_none(self): + assert _detect_property_kind("Where is Acme based?") is None + assert _detect_property_kind("How many people work there?") is None + + +class TestRoleExtractor: + def test_basic_role_match(self): + desc = "Anna Reyes is a senior engineer at Acme Corp." + roles = _extract_roles(desc) + assert "senior engineer" in roles + + def test_multiword_role(self): + desc = "Uma Patel is a site reliability engineer at Initech Systems." + assert "site reliability engineer" in _extract_roles(desc) + + def test_applied_scientist(self): + desc = "Cyrus Doss is an applied scientist at Massive Dynamic." + assert "applied scientist" in _extract_roles(desc) + + def test_no_role_no_match(self): + # No professional-suffix word → no extraction. + assert _extract_roles("Alice is based in Boston.") == set() + + def test_short_phrase_filtered(self): + # 'engineer' alone (no qualifier) is too short — rejected by length gate. + # But still captured as "engineer" if length > 3 — which it is (8). + # The filter is just defensive; this test asserts current behavior. + result = _extract_roles("Bob works as an engineer.") + # accept either outcome; just ensure no crash + assert isinstance(result, set) + + +class TestProjectExtractor: + def test_basic_project_match(self): + desc = ( + "Anna Reyes is a senior engineer at Acme Corp, " + "contributing to the cross-region replication initiative " + "and active in the cloud infrastructure community." + ) + projects = _extract_projects(desc) + assert any("cross-region replication initiative" in p for p in projects) + + def test_contributes_to_variant(self): + desc = "Mira Jansen contributes to a migration to managed services." + projects = _extract_projects(desc) + assert any("migration to managed services" in p for p in projects) + + def test_no_contributes_no_match(self): + assert _extract_projects("Pavel works at Wayne.") == set() + + +class TestExtractPhrases: + def test_role_dispatch(self): + out = _extract_phrases("Anna is a senior engineer", "role") + assert "senior engineer" in out + + def test_project_dispatch(self): + out = _extract_phrases("contributes to the pipeline rewrite", "project") + assert any("pipeline rewrite" in p for p in out) + + def test_unknown_kind_empty(self): + assert _extract_phrases("anything", "city") == set() + + +# ── Fuzzy intersect ────────────────────────────────────────────── + + +class TestFuzzyIntersect: + def test_exact_match(self): + a = {"senior engineer", "data scientist"} + b = {"senior engineer", "product manager"} + assert _fuzzy_intersect(a, b) == {"senior engineer"} + + def test_substring_in_one_direction(self): + # "pipeline rewrite" (substring) should match "next-generation + # pipeline rewrite" (superset). + a = {"pipeline rewrite"} + b = {"next-generation pipeline rewrite"} + assert _fuzzy_intersect(a, b) == {"pipeline rewrite"} + + def test_substring_other_direction(self): + a = {"next-generation pipeline rewrite"} + b = {"pipeline rewrite"} + assert _fuzzy_intersect(a, b) == {"next-generation pipeline rewrite"} + + def test_two_token_overlap_matches(self): + a = {"automated incident response tooling"} + b = {"incident response system tooling"} + # 2+ shared content tokens: "incident", "response", "tooling" + assert _fuzzy_intersect(a, b) == {"automated incident response tooling"} + + def test_single_stopword_no_match(self): + # Sharing only "the" / "a" should not be enough. + a = {"the migration"} + b = {"a migration"} + # Both reduce to {"migration"} (1 token) after stopword filter — no match. + assert _fuzzy_intersect(a, b) == set() + + def test_no_overlap(self): + assert _fuzzy_intersect({"alpha beta"}, {"gamma delta"}) == set() + + +# ── Markdown table formatter ───────────────────────────────────── + + +class TestFormatTable: + def _result(self, header, rows): + return SimpleNamespace( + header=[[1, name] for name in header], + result_set=rows, + ) + + def test_renders_with_headers(self): + r = self._result(["city", "n"], [["Boston", 10], ["Chicago", 8]]) + md, parsed, truncated = format_result_as_markdown_table(r) + assert "city | n" in md + assert "Boston | 10" in md + assert "Chicago | 8" in md + assert not truncated + assert parsed == [{"city": "Boston", "n": 10}, {"city": "Chicago", "n": 8}] + + def test_synthesizes_headers_when_missing(self): + r = SimpleNamespace(header=None, result_set=[["x", "y"]]) + md, parsed, _ = format_result_as_markdown_table(r) + assert "col_0 | col_1" in md + assert parsed[0] == {"col_0": "x", "col_1": "y"} + + def test_empty_result(self): + r = SimpleNamespace(header=[], result_set=[]) + md, parsed, truncated = format_result_as_markdown_table(r) + assert md == "(empty result)" + assert parsed == [] + assert not truncated + + def test_truncation_sentinel(self): + rows = [["x", i] for i in range(150)] + r = self._result(["item", "n"], rows) + md, parsed, truncated = format_result_as_markdown_table(r, cap=50) + assert truncated + assert len(parsed) == 50 + assert "showing 50 of 150 rows" in md + + def test_null_and_list_cells(self): + r = self._result(["a", "b"], [[None, ["x", "y", "z"]]]) + md, _, _ = format_result_as_markdown_table(r) + assert "(null)" in md + assert "x, y, z" in md + + +# ── Numeric coercion (M6 helper) ───────────────────────────────── + + +class TestCoerceNumber: + def test_int(self): + assert _coerce_number(1995) == 1995.0 + + def test_float(self): + assert _coerce_number(3.14) == pytest.approx(3.14) + + def test_string_with_number(self): + # Date entity names are strings like "1995" — pull the integer. + assert _coerce_number("1995") == 1995.0 + + def test_string_with_embedded_number(self): + assert _coerce_number("1995 is the year Acme was founded") == 1995.0 + + def test_none(self): + assert _coerce_number(None) is None + + def test_no_digits(self): + assert _coerce_number("hello") is None + + def test_negative(self): + assert _coerce_number("-42") == -42.0 + + +# ── Path-tag contract (R2) ─────────────────────────────────────── + + +class TestPathTag: + """``_tag_path`` enforces the contract that every result emitted by + ``CypherFirstAggregationStrategy`` carries a ``cypher_first_path`` + label, so operators can route metrics on which sub-path fired.""" + + def test_labels_match_known_paths(self): + # If we add a new path later, this list should grow in lockstep. + assert PATH_NUMERIC_MATH == "numeric_math" + assert PATH_SHARED_PROPERTY_HYBRID == "shared_property_hybrid" + assert PATH_CYPHER_TABLE == "cypher_table" + assert PATH_NEGATION_EMPTY_NO == "negation_empty_no" + assert PATH_RAG_FALLBACK == "rag_fallback" + assert PATH_RAG_FALLBACK_NUMERIC_FAIL == "rag_fallback_numeric_fail" + assert PATH_RAG_FALLBACK_CYPHER_EMPTY == "rag_fallback_cypher_empty" + + def test_tag_path_adds_label_to_empty_metadata(self): + result = RawSearchResult(records=[{"section": "x", "content": "y"}], + metadata={}) + out = _tag_path(result, PATH_CYPHER_TABLE) + assert out.metadata["cypher_first_path"] == PATH_CYPHER_TABLE + assert out.metadata["strategy"] == "cypher_first" + + def test_tag_path_overwrites_existing_path_label(self): + # When a sub-path delegates to another that already tagged the + # result, the outer tag should win — it reflects the actual path + # taken from the caller's perspective. + result = RawSearchResult( + records=[], + metadata={"cypher_first_path": "earlier", "extra": 1}, + ) + out = _tag_path(result, PATH_RAG_FALLBACK) + assert out.metadata["cypher_first_path"] == PATH_RAG_FALLBACK + # Other metadata is preserved. + assert out.metadata["extra"] == 1 + + def test_tag_path_preserves_existing_strategy_label_if_set(self): + # If a delegated strategy already tagged itself (e.g. "multi_path"), + # don't clobber it — setdefault. + result = RawSearchResult( + records=[], + metadata={"strategy": "multi_path"}, + ) + out = _tag_path(result, PATH_RAG_FALLBACK) + assert out.metadata["strategy"] == "multi_path" + assert out.metadata["cypher_first_path"] == PATH_RAG_FALLBACK + + def test_tag_path_returns_new_object_not_mutating_input(self): + original_meta = {"strategy": "multi_path"} + result = RawSearchResult(records=[], metadata=original_meta) + out = _tag_path(result, PATH_RAG_FALLBACK) + # Don't surprise callers by mutating their dict in place. + assert "cypher_first_path" not in original_meta + assert out.metadata is not original_meta + + +# ── End-to-end routing with mocks (R4) ─────────────────────────── + + +class _FakeResult: + """Mimics FalkorDB's query_raw return: ``result_set`` rows + a ``header`` + list of ``[type_int, name]`` pairs (we only use the name).""" + def __init__(self, header, result_set): + self.header = [[1, name] for name in header] + self.result_set = result_set + + +class _FakeFallback: + """Stand-in for MultiPathRetrieval — records the call and returns a + well-formed RawSearchResult so we can verify delegation + tagging.""" + def __init__(self, records=None, metadata=None): + self._records = records or [{"section": "passages", "content": "## ..."}] + self._metadata = metadata or {"strategy": "multi_path"} + self.calls = [] + + async def _execute(self, query, ctx, **kwargs): + self.calls.append(query) + return RawSearchResult(records=list(self._records), + metadata=dict(self._metadata)) + + +def _make_strategy(*, llm_responses=None, graph_results=None, fallback=None): + """Build a CypherFirstAggregationStrategy with stubbed LLM + graph. + + ``llm_responses`` is a sequence of strings; each LLM call pops one. + ``graph_results`` is a sequence of _FakeResult objects (or exceptions + to raise); each ``query_raw`` call pops one. + """ + from unittest.mock import AsyncMock, MagicMock + + from graphrag_sdk.core.models import LLMResponse + from graphrag_sdk.retrieval.strategies.cypher_first import ( + CypherFirstAggregationStrategy, + ) + + llm = MagicMock() + responses = list(llm_responses or []) + async def _ainvoke(_prompt, **_kw): + return LLMResponse(content=responses.pop(0) if responses else "") + llm.ainvoke = AsyncMock(side_effect=_ainvoke) + + graph = MagicMock() + results = list(graph_results or []) + async def _query_raw(_cypher): + nxt = results.pop(0) if results else _FakeResult([], []) + if isinstance(nxt, BaseException): + raise nxt + return nxt + graph.query_raw = AsyncMock(side_effect=_query_raw) + + return CypherFirstAggregationStrategy( + graph_store=graph, + vector_store=MagicMock(), + embedder=MagicMock(), + llm=llm, + k_candidates=1, # one candidate is enough for routing tests + rag_fallback=fallback or _FakeFallback(), + ) + + +class TestStrategyRouting: + """Behavioural tests: assert which sub-path fires for each intent + + graph-state combination by inspecting ``metadata.cypher_first_path``.""" + + async def test_rag_intent_delegates_to_fallback(self): + # "Who is the lighthouse keeper?" doesn't match any aggregation + # pattern — strategy must hand the query to the fallback verbatim + # and tag the result with PATH_RAG_FALLBACK. + from graphrag_sdk.core.context import Context + fallback = _FakeFallback() + strat = _make_strategy(fallback=fallback) + ctx = Context() + result = await strat._execute("Who is the lighthouse keeper?", ctx) + assert result.metadata["cypher_first_path"] == PATH_RAG_FALLBACK + assert fallback.calls == ["Who is the lighthouse keeper?"] + + async def test_aggregation_with_cypher_rows_takes_cypher_table(self): + # Multi-candidate cypher returns a non-empty table → strategy + # picks the table path and emits a single "cypher_results" item. + from graphrag_sdk.core.context import Context + cypher_code = ( + "```cypher\n" + "MATCH (p:Person)-[:RELATES]-(o:Organization)\n" + "RETURN o.name AS org, count(DISTINCT p) AS n\n" + "ORDER BY n DESC LIMIT 5\n" + "```" + ) + strat = _make_strategy( + llm_responses=[cypher_code], + graph_results=[ + _FakeResult(["org", "n"], [["Acme", 10], ["Globex", 8]]) + ], + ) + result = await strat._execute( + "Which org has the most employees mentioned?", Context(), + ) + assert result.metadata["cypher_first_path"] == PATH_CYPHER_TABLE + assert result.metadata["cypher_rows"] == 2 + assert "Acme | 10" in result.records[0]["content"] + + async def test_numeric_intent_does_python_arithmetic(self): + # Question asks for an average; cypher returns raw year values; + # strategy computes 1989.9 in Python and tags PATH_NUMERIC_MATH. + from graphrag_sdk.core.context import Context + cypher_code = ( + "```cypher\nMATCH (d:Date) RETURN d.name AS year\n```" + ) + years = [[str(y)] for y in + [1939, 1968, 1973, 1984, 1995, 1998, 2003, 2009, 2011, 2019]] + strat = _make_strategy( + llm_responses=[cypher_code], + graph_results=[_FakeResult(["year"], years)], + ) + result = await strat._execute( + "What is the average year of founding across all 10 organizations?", + Context(), + ) + assert result.metadata["cypher_first_path"] == PATH_NUMERIC_MATH + assert result.metadata["op"] == "average" + assert result.metadata["value"] == pytest.approx(1989.9, abs=0.1) + + async def test_numeric_empty_extraction_falls_back_to_rag(self): + # Cypher generated for the numeric path returns 0 numeric values + # → strategy falls through to the RAG fallback and tags + # PATH_RAG_FALLBACK_NUMERIC_FAIL. + from graphrag_sdk.core.context import Context + fallback = _FakeFallback() + strat = _make_strategy( + llm_responses=["```cypher\nMATCH (n:Person) RETURN n.name\n```"], + graph_results=[_FakeResult(["name"], [["Alice"], ["Bob"]])], + fallback=fallback, + ) + result = await strat._execute( + "What is the average year of founding?", Context(), + ) + assert result.metadata["cypher_first_path"] == PATH_RAG_FALLBACK_NUMERIC_FAIL + assert fallback.calls # fallback was invoked + + async def test_negation_existential_empty_returns_no(self): + # "Are there any orgs WITHOUT employees?" + cypher returns 0 rows + # → strategy emits an authoritative "No" and tags + # PATH_NEGATION_EMPTY_NO. No fallback delegation. + from graphrag_sdk.core.context import Context + cypher_code = ( + "```cypher\n" + "MATCH (o:Organization) WHERE NOT EXISTS { " + "MATCH (o)<-[:RELATES]-(p:Person) } RETURN o.name\n" + "```" + ) + fallback = _FakeFallback() + strat = _make_strategy( + llm_responses=[cypher_code], + graph_results=[ + _FakeResult([], []), # hybrid batch query (no shape match) + _FakeResult(["name"], []), # the actual cypher + ], + fallback=fallback, + ) + result = await strat._execute( + "Are there any organizations for which NO employees are listed?", + Context(), + ) + assert result.metadata["cypher_first_path"] == PATH_NEGATION_EMPTY_NO + # Negation path must NOT delegate to the fallback. + assert fallback.calls == [] + + async def test_positive_existential_empty_falls_back_to_rag(self): + # "Is there any employee at Acme on observability?" + cypher returns + # 0 rows (typed label mismatch) → strategy falls through to RAG and + # tags PATH_RAG_FALLBACK_CYPHER_EMPTY. + from graphrag_sdk.core.context import Context + cypher_code = ( + "```cypher\n" + "MATCH (p:Person)-[:RELATES]-(t:Technology) RETURN p.name\n" + "```" + ) + fallback = _FakeFallback() + strat = _make_strategy( + llm_responses=[cypher_code], + graph_results=[_FakeResult(["name"], [])], + fallback=fallback, + ) + result = await strat._execute( + "Is there any employee at Acme working on observability?", + Context(), + ) + assert result.metadata["cypher_first_path"] == PATH_RAG_FALLBACK_CYPHER_EMPTY + assert len(fallback.calls) == 1 + + async def test_hybrid_warns_when_topology_assumption_violated(self, caplog): + # The batched (Org)<-[:RELATES]-(Person) query returns zero tuples + # → strategy emits a warning and falls through (returns None from + # the hybrid; the rest of _execute keeps going). + import logging + + from graphrag_sdk.core.context import Context + cypher_code = ( + "```cypher\nMATCH (p:Person) RETURN p.name AS name\n```" + ) + strat = _make_strategy( + llm_responses=[cypher_code], + graph_results=[ + _FakeResult([], []), # hybrid batch — zero topology tuples + _FakeResult(["name"], [["Alice"]]), # cypher_table cypher + ], + ) + with caplog.at_level(logging.WARNING, + logger="graphrag_sdk.retrieval.strategies.cypher_first"): + await strat._execute( + "Which roles are held by employees at BOTH Acme and Globex?", + Context(), + ) + # The topology-violation warning fired. + assert any("zero (Organization)<-[:RELATES]-(Person) tuples" in r.message + for r in caplog.records) + + +# ── Pluggable phrase extractor (R8) ────────────────────────────── + + +class TestPhraseExtractor: + """Domain-specific extractors can replace the default role/project + regexes without forking the strategy.""" + + def test_default_extractor_matches_module_regexes(self): + from graphrag_sdk.retrieval.strategies.cypher_first import ( + DefaultPhraseExtractor, + ) + ext = DefaultPhraseExtractor() + assert "senior engineer" in ext.extract( + "Anna is a senior engineer at Acme.", "role" + ) + assert any("pipeline rewrite" in p for p in ext.extract( + "contributes to the pipeline rewrite", "project" + )) + # Unknown kinds return an empty set, not an exception. + assert ext.extract("anything", "city") == set() + + async def test_strategy_uses_custom_extractor_in_hybrid_path(self): + """A custom extractor passed to the strategy is consulted by the + shared-property hybrid instead of the default regexes.""" + from graphrag_sdk.core.context import Context + from graphrag_sdk.retrieval.strategies.cypher_first import PhraseExtractor + + class _UpperCaseRoleExtractor(PhraseExtractor): + """Match exactly the literal strings 'ALPHA' and 'BETA' as roles + — clearly distinguishable from the default regex output.""" + def extract(self, text, kind): + if kind == "role": + return {w for w in ("ALPHA", "BETA") if w in text} + return set() + + # Both orgs have one person whose chunk text mentions the same + # custom token. With the default extractor, none of these phrases + # would match (no role suffix). With our custom one, ALPHA is + # common to both. + from unittest.mock import AsyncMock, MagicMock + + graph = MagicMock() + batch_rows = [ + ["Acme", "Alice", "no description", ["Alice does ALPHA at Acme."]], + ["Acme", "Anna", "no description", ["Anna does BETA at Acme."]], + ["Globex", "Bob", "no description", ["Bob does ALPHA at Globex."]], + ["Globex", "Bea", "no description", ["Bea does GAMMA at Globex."]], + ] + async def _query_raw(_cypher): + return SimpleNamespace( + header=[[1, "org"], [1, "person"], [1, "desc"], [1, "chunks"]], + result_set=batch_rows, + ) + graph.query_raw = AsyncMock(side_effect=_query_raw) + + from graphrag_sdk.retrieval.strategies.cypher_first import ( + CypherFirstAggregationStrategy, + ) + strat = CypherFirstAggregationStrategy( + graph_store=graph, + vector_store=MagicMock(), + embedder=MagicMock(), + llm=MagicMock(), + k_candidates=1, + rag_fallback=MagicMock(), + phrase_extractor=_UpperCaseRoleExtractor(), + ) + result = await strat._execute( + "Which roles are held by employees at BOTH Acme and Globex?", + Context(), + ) + # Hybrid fired (not a fallback) and computed the right intersection. + assert result.metadata["cypher_first_path"] == PATH_SHARED_PROPERTY_HYBRID + assert result.metadata["common"] == ["ALPHA"] diff --git a/graphrag_sdk/tests/test_cypher_generation.py b/graphrag_sdk/tests/test_cypher_generation.py index 5869415..ab7e0d6 100644 --- a/graphrag_sdk/tests/test_cypher_generation.py +++ b/graphrag_sdk/tests/test_cypher_generation.py @@ -7,6 +7,10 @@ extract_cypher, validate_cypher, _sanitize_cypher, + _is_pure_aggregation, + _should_widen_labels, + _split_top_level_commas, + _widen_typed_labels, ) @@ -73,6 +77,27 @@ def test_rejects_call_procedures(self): errors = validate_cypher("CALL db.labels() YIELD label RETURN label") assert any("CALL" in e for e in errors) + def test_rejects_apoc_function_calls(self): + # Function-style apoc / gds / db calls slip past the bare \bCALL\b + # check; the dotted-namespace rule should reject them. + for snippet in [ + "MATCH (n) RETURN apoc.text.regexGroups(n.description, '\\d+') AS m", + "MATCH (n) RETURN gds.shortest.path(n) AS p", + "MATCH (n) WHERE db.idx.fulltext.queryNodes('x', 'y') RETURN n", + ]: + errors = validate_cypher(snippet) + assert any("Unsupported function namespace" in e for e in errors), \ + f"should reject: {snippet}" + + def test_accepts_bare_builtin_functions(self): + # Built-ins are bare (count, toInteger, substring) — no namespace. + for snippet in [ + "MATCH (n:Person) RETURN count(n) AS c", + "MATCH (n:Person) RETURN toInteger(n.name) AS i", + "MATCH (n:Person) RETURN substring(n.description, 0, 4) AS s", + ]: + assert validate_cypher(snippet) == [], f"should accept: {snippet}" + def test_rejects_load_csv(self): errors = validate_cypher("LOAD CSV FROM 'file:///data.csv' AS row RETURN row") assert any("LOAD CSV" in e for e in errors) @@ -108,14 +133,34 @@ def test_with_allowed(self): class TestSanitizeCypher: def test_adds_limit_when_missing(self): - result = _sanitize_cypher("MATCH (n) RETURN n") + # Non-aggregation query without LIMIT — should get the new default cap. + result = _sanitize_cypher("MATCH (n:Person) RETURN n.name") assert "LIMIT" in result + assert "100" in result # _DEFAULT_ROW_LIMIT def test_keeps_existing_limit(self): cypher = "MATCH (n) RETURN n LIMIT 10" result = _sanitize_cypher(cypher) assert result.count("LIMIT") == 1 + def test_does_not_inject_limit_on_pure_aggregation(self): + # count / sum / avg without group-by always returns one row — adding + # LIMIT is misleading. + for cypher in [ + "MATCH (n:Person) RETURN count(n)", + "MATCH (n:Person) RETURN count(DISTINCT n) AS c", + "MATCH (n:Person) RETURN avg(toInteger(n.name)) AS a", + ]: + result = _sanitize_cypher(cypher) + assert "LIMIT" not in result, f"should skip LIMIT for: {cypher}" + + def test_injects_limit_on_group_by(self): + # Group-by has a non-aggregate dimension in RETURN — LIMIT applies to + # the row count of the breakdown. + cypher = "MATCH (o:Organization)-[:RELATES]-(p:Person) RETURN o.name, count(p) AS n" + result = _sanitize_cypher(cypher) + assert "LIMIT" in result + def test_removes_shortest_path(self): cypher = "MATCH path = shortestPath((a)-[*]-(b)) RETURN path" result = _sanitize_cypher(cypher) @@ -132,6 +177,106 @@ def test_removes_path_assignment(self): assert "path =" not in result and "path=" not in result +# ── _is_pure_aggregation ────────────────────────────────────────── + + +class TestIsPureAggregation: + def test_count_only(self): + assert _is_pure_aggregation("MATCH (n) RETURN count(n)") + + def test_count_distinct_with_alias(self): + assert _is_pure_aggregation( + "MATCH (n:Person) RETURN count(DISTINCT n) AS person_count" + ) + + def test_multi_aggregate(self): + assert _is_pure_aggregation( + "MATCH (n) RETURN count(n) AS c, avg(n.score) AS a" + ) + + def test_group_by_is_not_pure(self): + # A non-aggregate dimension in RETURN means LIMIT actually matters. + assert not _is_pure_aggregation( + "MATCH (o)-[:RELATES]-(p) RETURN o.name AS org, count(p) AS n" + ) + + def test_no_aggregate_is_not_pure(self): + assert not _is_pure_aggregation("MATCH (n:Person) RETURN n.name") + + def test_aggregate_with_order_by_limit_clause(self): + # ORDER BY / LIMIT after RETURN must not confuse the regex. + assert _is_pure_aggregation( + "MATCH (n) RETURN count(n) ORDER BY count(n)" + ) + + +# ── _split_top_level_commas ─────────────────────────────────────── + + +class TestSplitTopLevelCommas: + def test_simple(self): + assert _split_top_level_commas("a, b, c") == ["a", "b", "c"] + + def test_nested_parens_preserved(self): + # commas inside count(DISTINCT a, b) must NOT split projections + assert _split_top_level_commas("count(DISTINCT a, b) AS c, d") == [ + "count(DISTINCT a, b) AS c", + "d", + ] + + +# ── _widen_typed_labels / _should_widen_labels ──────────────────── + + +class TestWidenTypedLabels: + def test_widens_single_label(self): + cypher = "MATCH (p:Person) RETURN p.name" + assert _widen_typed_labels(cypher) == "MATCH (p:__Entity__) RETURN p.name" + + def test_widens_multiple_labels(self): + cypher = ( + "MATCH (p:Person)-[:RELATES]-(o:Organization) " + "WHERE o.name CONTAINS 'Acme' RETURN p.name" + ) + widened = _widen_typed_labels(cypher) + assert ":Person" not in widened + assert ":Organization" not in widened + assert widened.count(":__Entity__") == 2 + + def test_does_not_widen_structural_labels(self): + # Chunk / Document / __Entity__ already are structural — leave alone. + cypher = "MATCH (c:Chunk)-[:PART_OF]->(d:Document) RETURN c.text" + assert _widen_typed_labels(cypher) == cypher + + def test_idempotent(self): + cypher = "MATCH (p:Person) RETURN p.name" + once = _widen_typed_labels(cypher) + twice = _widen_typed_labels(once) + assert once == twice + + +class TestShouldWidenLabels: + def test_widens_when_name_filter_present(self): + # A typed label + name filter = LLM is using label as routing hint, not + # as the filter itself — safe to widen. + assert _should_widen_labels( + "MATCH (p:Person)-[:RELATES]-(o:Organization) " + "WHERE o.name CONTAINS 'Acme' RETURN p.name" + ) + + def test_skips_when_label_is_the_filter(self): + # Pure aggregation over a labeled variable with NO name predicate — + # widening would change semantics ("count Persons" → "count Entities"). + assert not _should_widen_labels( + "MATCH (p:Person) RETURN count(DISTINCT p) AS c" + ) + + def test_skips_when_no_typed_label(self): + assert not _should_widen_labels( + "MATCH (e:__Entity__) RETURN count(e)" + ) + + # ── execute_cypher_retrieval ────────────────────────────────────── @@ -148,9 +293,13 @@ async def test_returns_empty_on_generation_failure(self): mock_llm.ainvoke = AsyncMock(return_value=LLMResponse(content="I don't know")) mock_graph = MagicMock() - facts, entities = await execute_cypher_retrieval(mock_graph, mock_llm, "test?") + facts, entities, metadata = await execute_cypher_retrieval( + mock_graph, mock_llm, "test?" + ) assert facts == [] assert entities == {} + assert metadata["cypher_fallback"] is None + assert metadata["cypher_rows"] == 0 async def test_returns_empty_on_execution_error(self): """When Cypher execution fails, should return empty results.""" @@ -169,9 +318,12 @@ async def test_returns_empty_on_execution_error(self): mock_graph = MagicMock() mock_graph.query_raw = AsyncMock(side_effect=Exception("connection error")) - facts, entities = await execute_cypher_retrieval(mock_graph, mock_llm, "test?") + facts, entities, metadata = await execute_cypher_retrieval( + mock_graph, mock_llm, "test?" + ) assert facts == [] assert entities == {} + assert metadata["cypher_fallback"] is None async def test_parses_result_rows(self): """Successful execution should parse rows into facts and entities.""" @@ -192,8 +344,76 @@ async def test_parses_result_rows(self): mock_graph = MagicMock() mock_graph.query_raw = AsyncMock(return_value=result_mock) - facts, entities = await execute_cypher_retrieval(mock_graph, mock_llm, "test?") + facts, entities, metadata = await execute_cypher_retrieval( + mock_graph, mock_llm, "test?" + ) assert len(facts) == 2 assert "Alice" in facts[0] assert "alice" in entities assert "bob" in entities + assert metadata["cypher_rows"] == 2 + assert metadata["cypher_fallback"] is None + + async def test_label_widen_fires_on_zero_rows(self): + """When typed-label query returns 0 rows AND a name filter is present, + the fallback should rewrite typed labels to __Entity__ and re-execute.""" + from unittest.mock import AsyncMock, MagicMock + from graphrag_sdk.core.models import LLMResponse + from graphrag_sdk.retrieval.strategies.cypher_generation import ( + execute_cypher_retrieval, + ) + + # Original cypher: typed label + name filter — gating allows widen. + original = ( + "MATCH (p:Person)-[:RELATES]-(t:Technology) " + "WHERE t.name CONTAINS 'observability' RETURN p.name LIMIT 25" + ) + mock_llm = MagicMock() + mock_llm.ainvoke = AsyncMock( + return_value=LLMResponse(content=f"```cypher\n{original}\n```") + ) + + empty_result = MagicMock() + empty_result.result_set = [] + widened_result = MagicMock() + widened_result.result_set = [["Carla Okafor"]] + + mock_graph = MagicMock() + # First call: original (typed) → empty. Second call: widened → hit. + mock_graph.query_raw = AsyncMock(side_effect=[empty_result, widened_result]) + + facts, entities, metadata = await execute_cypher_retrieval( + mock_graph, mock_llm, "is there an Acme employee on observability?" + ) + assert facts == ["Carla Okafor"] + assert metadata["cypher_fallback"] == "label_widened" + assert ":__Entity__" in metadata["cypher"] + assert mock_graph.query_raw.await_count == 2 + + async def test_label_widen_skipped_for_pure_aggregation(self): + """For 'count Persons' (label IS the filter), widening would change + semantics — fallback must skip and return empty.""" + from unittest.mock import AsyncMock, MagicMock + from graphrag_sdk.core.models import LLMResponse + from graphrag_sdk.retrieval.strategies.cypher_generation import ( + execute_cypher_retrieval, + ) + + cypher = "MATCH (p:Person) RETURN count(DISTINCT p) AS c" + mock_llm = MagicMock() + mock_llm.ainvoke = AsyncMock( + return_value=LLMResponse(content=f"```cypher\n{cypher}\n```") + ) + + empty_result = MagicMock() + empty_result.result_set = [] + mock_graph = MagicMock() + mock_graph.query_raw = AsyncMock(return_value=empty_result) + + facts, entities, metadata = await execute_cypher_retrieval( + mock_graph, mock_llm, "how many people?" + ) + assert facts == [] + assert metadata["cypher_fallback"] is None + # Only one execution — fallback was correctly skipped. + assert mock_graph.query_raw.await_count == 1 diff --git a/graphrag_sdk/tests/test_facade.py b/graphrag_sdk/tests/test_facade.py index 2e4a22a..e252437 100644 --- a/graphrag_sdk/tests/test_facade.py +++ b/graphrag_sdk/tests/test_facade.py @@ -1104,3 +1104,66 @@ async def test_ingest_input_validation_runs_before_config_probe( # Probe and DB query must not have fired. g._graph_store.query_raw.assert_not_called() g.embedder.aembed_query.assert_not_called() + + +# ── Cypher authority rule injection (R1) ────────────────────────── + + +class TestCypherAuthorityRuleInjection: + """The cypher-authority rule is appended to the system prompt only when + the retriever produced an authoritative-results section. Callers who + don't use cypher retrieval keep the unchanged base prompt.""" + + def test_helper_returns_false_for_empty_items(self): + from graphrag_sdk.api.main import _has_authoritative_cypher_results + assert not _has_authoritative_cypher_results(RetrieverResult(items=[])) + + def test_helper_returns_false_for_non_cypher_items(self): + from graphrag_sdk.api.main import _has_authoritative_cypher_results + items = [ + RetrieverResultItem( + content="## Source Document Passages\n- foo", + metadata={"section": "passages"}, + ), + RetrieverResultItem( + content="## Key Entities\n- Acme", + metadata={"section": "entities"}, + ), + ] + assert not _has_authoritative_cypher_results(RetrieverResult(items=items)) + + def test_helper_returns_true_on_section_metadata(self): + from graphrag_sdk.api.main import _has_authoritative_cypher_results + item = RetrieverResultItem( + content="", + metadata={"section": "cypher_results"}, + ) + assert _has_authoritative_cypher_results(RetrieverResult(items=[item])) + + def test_helper_returns_true_on_heading_marker_in_content(self): + # Defensive path — third-party strategies that don't tag metadata + # but use the canonical heading still trigger the rule. + from graphrag_sdk.api.main import _has_authoritative_cypher_results + item = RetrieverResultItem( + content=( + "## Authoritative Graph Query Results " + "(deterministic; trust over passages on counts and aggregates)\n" + "- Boston | 10" + ), + metadata={}, + ) + assert _has_authoritative_cypher_results(RetrieverResult(items=[item])) + + def test_base_system_prompt_does_not_contain_rule_8(self): + # The rule has been moved out of the base prompts. Callers who + # don't surface cypher results keep the original 7-rule prompt. + from graphrag_sdk.api.main import ( + _CYPHER_AUTH_RULE, + _RAG_SYSTEM_PROMPT, + _RAG_SYSTEM_PROMPT_DELIMITED, + ) + for base in (_RAG_SYSTEM_PROMPT, _RAG_SYSTEM_PROMPT_DELIMITED): + assert "Authoritative Graph Query Results" not in base + assert "\n8." not in base + # The addendum, on its own, mentions the heading. + assert "Authoritative Graph Query Results" in _CYPHER_AUTH_RULE