diff --git a/graphrag_sdk/src/graphrag_sdk/__init__.py b/graphrag_sdk/src/graphrag_sdk/__init__.py
index a7c8ff3..81ed8f5 100644
--- a/graphrag_sdk/src/graphrag_sdk/__init__.py
+++ b/graphrag_sdk/src/graphrag_sdk/__init__.py
@@ -96,6 +96,11 @@
from graphrag_sdk.retrieval.reranking_strategies.base import RerankingStrategy
from graphrag_sdk.retrieval.reranking_strategies.cosine import CosineReranker
from graphrag_sdk.retrieval.strategies.base import RetrievalStrategy
+from graphrag_sdk.retrieval.strategies.cypher_first import (
+ CypherFirstAggregationStrategy,
+ DefaultPhraseExtractor,
+ PhraseExtractor,
+)
from graphrag_sdk.retrieval.strategies.multi_path import MultiPathRetrieval
# ── Storage ─────────────────────────────────────────────────────
@@ -160,7 +165,10 @@
"SemanticResolution",
# Retrieval
"CosineReranker",
+ "CypherFirstAggregationStrategy",
+ "DefaultPhraseExtractor",
"MultiPathRetrieval",
+ "PhraseExtractor",
"RerankingStrategy",
"RetrievalStrategy",
# Storage
diff --git a/graphrag_sdk/src/graphrag_sdk/api/main.py b/graphrag_sdk/src/graphrag_sdk/api/main.py
index 13c4635..dc0c34d 100644
--- a/graphrag_sdk/src/graphrag_sdk/api/main.py
+++ b/graphrag_sdk/src/graphrag_sdk/api/main.py
@@ -92,6 +92,27 @@
"or is NOT true, preserve that meaning."
)
+# Optional addendum appended to the system prompt only when the retriever
+# produced an "Authoritative Graph Query Results" section (currently emitted
+# by ``CypherFirstAggregationStrategy`` and by ``MultiPathRetrieval`` with
+# ``enable_cypher=True``). Kept out of the base prompts so callers who never
+# use cypher retrieval don't pay the prompt-token cost or risk subtle
+# behavior changes on factoid questions.
+_CYPHER_AUTH_RULE = (
+ "\n8. When the context contains a section labeled "
+ "'Authoritative Graph Query Results', that section is computed "
+ "deterministically from the knowledge graph. For quantitative or "
+ "'which has the most/fewest/exactly N' questions, prefer those "
+ "numbers over prose passages, even if passages mention other "
+ "entities more frequently."
+)
+
+# Marker the retriever uses in section content / metadata to signal that an
+# authoritative cypher result is present. Kept as a module-level constant so
+# tests can pin the contract.
+_CYPHER_RESULTS_SECTION = "cypher_results"
+_AUTH_RESULTS_HEADING_MARKER = "Authoritative Graph Query Results"
+
_RAG_PROMPT = "\n{context}\n\n\nQuestion: {question}\n\nAnswer:"
# Matches a literal ```` closing tag (case-insensitive, whitespace
@@ -105,6 +126,23 @@ def _neutralize_context_close_tag(text: str) -> str:
return _CONTEXT_CLOSE_RE.sub(" context>", text)
+def _has_authoritative_cypher_results(retriever_result: RetrieverResult) -> bool:
+ """True when any retrieved item is an authoritative graph-query result.
+
+ Used to decide whether to append the cypher-authority rule to the
+ default system prompt. We check both the section metadata (preferred,
+ cheap) and the heading marker in the content (defensive — covers
+ third-party strategies that don't tag the metadata).
+ """
+ for item in retriever_result.items or []:
+ section = (item.metadata or {}).get("section") if item.metadata else None
+ if section == _CYPHER_RESULTS_SECTION:
+ return True
+ if isinstance(item.content, str) and _AUTH_RESULTS_HEADING_MARKER in item.content:
+ return True
+ return False
+
+
_QUESTION_REWRITE_PROMPT = (
"Given the conversation history, rewrite the user's last question "
"as a standalone question that includes all entity names, dates, "
@@ -709,6 +747,12 @@ async def completion(
context_str = "\n---\n".join(item.content for item in retriever_result.items)
default_system_content = _RAG_SYSTEM_PROMPT
+ # Conditionally append the cypher-authority rule. Only fires when the
+ # retriever produced an authoritative-results section — keeps the base
+ # system prompt unchanged for callers who don't use cypher retrieval.
+ if _has_authoritative_cypher_results(retriever_result):
+ default_system_content = default_system_content + _CYPHER_AUTH_RULE
+
# Step 4: Build messages — unified path for single-turn and multi-turn.
# If history starts with role="system", honor it as-is (trust the
# consumer). Otherwise inject the SDK's default instructions.
diff --git a/graphrag_sdk/src/graphrag_sdk/retrieval/strategies/__init__.py b/graphrag_sdk/src/graphrag_sdk/retrieval/strategies/__init__.py
index 8841387..4a5cc1c 100644
--- a/graphrag_sdk/src/graphrag_sdk/retrieval/strategies/__init__.py
+++ b/graphrag_sdk/src/graphrag_sdk/retrieval/strategies/__init__.py
@@ -1,6 +1,17 @@
# GraphRAG SDK — Retrieval: Strategies
from graphrag_sdk.retrieval.strategies.base import RetrievalStrategy
+from graphrag_sdk.retrieval.strategies.cypher_first import (
+ CypherFirstAggregationStrategy,
+ DefaultPhraseExtractor,
+ PhraseExtractor,
+)
from graphrag_sdk.retrieval.strategies.multi_path import MultiPathRetrieval
-__all__ = ["RetrievalStrategy", "MultiPathRetrieval"]
+__all__ = [
+ "CypherFirstAggregationStrategy",
+ "DefaultPhraseExtractor",
+ "MultiPathRetrieval",
+ "PhraseExtractor",
+ "RetrievalStrategy",
+]
diff --git a/graphrag_sdk/src/graphrag_sdk/retrieval/strategies/cypher_first.py b/graphrag_sdk/src/graphrag_sdk/retrieval/strategies/cypher_first.py
new file mode 100644
index 0000000..7f10be3
--- /dev/null
+++ b/graphrag_sdk/src/graphrag_sdk/retrieval/strategies/cypher_first.py
@@ -0,0 +1,1061 @@
+# GraphRAG SDK — Retrieval: Cypher-First Aggregation Strategy
+# Routes aggregation/quantitative questions through a deterministic graph-
+# query path that treats Cypher as the answer source (not just another
+# retrieval signal). Free-text properties (roles, projects) that aren't
+# captured as typed entities are recovered by parsing Person chunk text at
+# retrieval time.
+#
+# Non-aggregation questions delegate to a fallback strategy (default:
+# MultiPathRetrieval). The strategy is therefore safe as the top-level
+# strategy on GraphRAG — RAG questions still get the existing pipeline.
+#
+# Background: traditional RAG retrieves prose evidence and lets the LLM
+# synthesize. For "how many X" / "which X has the most Y" / "BOTH A and B"
+# questions, prose evidence is the wrong shape — the answer wants exact
+# counts or set operations. Cypher gives us those deterministically, but
+# only when the underlying graph has the right structure. This strategy
+# wraps Cypher generation in three mechanisms that make it reliable on
+# noisy, real-world graphs:
+#
+# 1. Multi-candidate Cypher with row-count selection (M2): K parallel
+# samples, execute all, pick the one with the most rows. Beats LLM
+# stochasticity without serial retries.
+# 2. Column-named markdown table formatting (M3): uses FalkorDB's
+# ``result.header`` so the synthesizer sees ``acme_count=10`` not the
+# lossy ``10 | 7 | True``.
+# 3. Description + chunk-text fuzzy hybrid (M5): for "shared X" / "BOTH
+# A and B" questions where X is a free-text property (role, project)
+# not extracted as a typed node, parse Person chunks via regex and
+# compute the set operation in Python with fuzzy token matching.
+#
+# Plus a deterministic numeric-math path (M6) — RETURN raw values, then
+# average / sum / median in Python — and a negation-existential branch
+# (M7) that treats an empty Cypher result as the definitive "No" when the
+# question shape demands it.
+
+from __future__ import annotations
+
+import asyncio
+import logging
+import re
+import statistics
+from abc import ABC, abstractmethod
+from typing import Any
+
+from graphrag_sdk.core.context import Context
+from graphrag_sdk.core.models import (
+ RawSearchResult,
+ RetrieverResult,
+ RetrieverResultItem,
+)
+from graphrag_sdk.core.providers import Embedder, LLMInterface
+from graphrag_sdk.retrieval.strategies.base import RetrievalStrategy
+from graphrag_sdk.retrieval.strategies.cypher_generation import (
+ SCHEMA_PROMPT,
+ _sanitize_cypher,
+ extract_cypher,
+ validate_cypher,
+)
+from graphrag_sdk.retrieval.strategies.multi_path import MultiPathRetrieval
+
+logger = logging.getLogger(__name__)
+
+
+# ─────────────────────────────────────────────────────────────────
+# Intent classification (M1)
+# ─────────────────────────────────────────────────────────────────
+
+_AGG_INTENT_PATTERNS = [
+ r"\bhow\s+many\b", r"\bhow\s+much\b",
+ r"\bwhich\s+\w+\b",
+ r"\bwhat\s+(?:is\s+the\s+)?(?:average|mean|median|total|sum|count|number)\b",
+ r"\baverage\b", r"\bmedian\b", r"\btotal\b",
+ r"\bcount\s+of\b", r"\bnumber\s+of\b",
+ r"\blist\s+(?:all|the|every)\b",
+ r"\blist\s+\w+\s+(?:that|with|where|who)\b",
+ r"\bare\s+there\s+any\b", r"\bis\s+there\s+any\b",
+ # `more X than` / `fewer X than` — allow up to 4 words between
+ r"\bmore(?:\s+\S+){1,4}\s+than\b",
+ r"\bfewer(?:\s+\S+){0,4}\s+than\b",
+ r"\bless(?:\s+\S+){0,4}\s+than\b",
+ r"\bboth\s+\w+(?:\s+\S+)*?\s+and\s+\w+\b",
+ r"\bbetween\s+\w+\s+and\s+\w+\b",
+ r"\bexactly\s+\d+\b", r"\bat\s+least\s+\d+\b",
+ r"\b(?:does|do|has|have|is|are)\s+(?:[A-Z]\w*\s+){1,4}(?:have|has|contain|own|run|host)\b",
+ r"\bwho\s+(?:works|live|is)\s+(?:at|in|on)\b",
+]
+_AGG_INTENT_RE = re.compile("|".join(_AGG_INTENT_PATTERNS), re.IGNORECASE)
+
+_NUMERIC_AGG_RE = re.compile(
+ r"\b(average|mean|median|total|sum)\b\s+(?:of\s+)?(?:the\s+)?(?:\w+\s+)?"
+ r"\b(year|years|age|amount|price|cost|revenue|number|count|"
+ r"founding|founded|salary|salaries)\b",
+ re.IGNORECASE,
+)
+
+_YES_NO_RE = re.compile(
+ r"^\s*(?:is|are|was|were|do|does|did|has|have|had|can|could|will|"
+ r"would|should|may|might|are\s+there|is\s+there)\b",
+ re.IGNORECASE,
+)
+
+_WHICH_LIST_RE = re.compile(
+ r"^\s*(?:which|what|list|name|identify|enumerate)\b",
+ re.IGNORECASE,
+)
+
+
+def detect_aggregation_intent(question: str) -> str:
+ """Return ``"numeric_math"``, ``"aggregation"``, or ``"rag"``.
+
+ ``"numeric_math"`` triggers the Python-arithmetic path; ``"aggregation"``
+ triggers the Cypher-first table path; ``"rag"`` falls back to the
+ standard retrieval strategy.
+ """
+ if _NUMERIC_AGG_RE.search(question):
+ return "numeric_math"
+ if _AGG_INTENT_RE.search(question):
+ return "aggregation"
+ return "rag"
+
+
+def is_yes_no(question: str) -> bool:
+ return bool(_YES_NO_RE.match(question))
+
+
+def is_which_list(question: str) -> bool:
+ return bool(_WHICH_LIST_RE.match(question))
+
+
+def _is_negation_existential(question: str) -> bool:
+ """True for "are there any X without/no Y?" — empty Cypher = definitive No."""
+ has_negation = bool(
+ re.search(r"\b(?:no|not|without|never|none)\b", question, re.IGNORECASE)
+ )
+ has_existential = bool(
+ re.search(r"\b(?:any|are\s+there|is\s+there)\b", question, re.IGNORECASE)
+ )
+ return has_negation and has_existential
+
+
+# ─────────────────────────────────────────────────────────────────
+# Free-text phrase extraction for description hybrid (M5)
+# ─────────────────────────────────────────────────────────────────
+
+# Roles end in a closed set of professional suffixes — anchoring on those
+# keeps noisy phrases ("the cross-region replication initiative") from
+# getting captured.
+_ROLE_RE = re.compile(
+ r"\b(?:as|is)\s+(?:an?\s+)?"
+ r"((?:[a-z][\w\s]*?)?\b(?:engineer|scientist|manager|architect|researcher|"
+ r"developer|analyst|specialist|designer|consultant)s?)\b",
+ re.IGNORECASE,
+)
+
+_PROJECT_RE = re.compile(
+ r"\bcontribut(?:es|ing)\s+to\s+"
+ r"(?:the\s+|a\s+|an\s+)?"
+ r"(.+?)"
+ r"(?=\s+and\s+(?:is|active)|\s+based\s+in|,|\.\s|\.$|$)",
+ re.IGNORECASE,
+)
+
+_BOTH_AB_RE = re.compile(
+ r"\bboth\s+([\w\s]+?)\s+and\s+([\w\s'-]+?)(?:\?|\.|$|\bof\b|\bat\b|\bwith\b)",
+ re.IGNORECASE,
+)
+
+_SAME_AS_RE = re.compile(
+ r"\b(?:same\s+\w+\s+as|share\s+(?:an?\s+)?\w+\s+with|"
+ r"share\s+(?:any\s+of\s+)?\w+\s+with)"
+ r"\s+(?:someone\s+(?:at|in|from)\s+)?([\w\s]+?)(?:\?|\.|$)",
+ re.IGNORECASE,
+)
+
+
+def _extract_roles(text: str) -> set[str]:
+ out: set[str] = set()
+ for m in _ROLE_RE.finditer(text or ""):
+ role = re.sub(r"\s+", " ", m.group(1).strip().lower())
+ if 3 < len(role) < 60:
+ out.add(role)
+ return out
+
+
+def _extract_projects(text: str) -> set[str]:
+ out: set[str] = set()
+ for m in _PROJECT_RE.finditer(text or ""):
+ proj = re.sub(r"\s+", " ", m.group(1).strip().lower())
+ proj = re.sub(r"\s+(?:and|who|while)$", "", proj)
+ if 5 < len(proj) < 100:
+ out.add(proj)
+ return out
+
+
+def _extract_phrases(text: str, kind: str) -> set[str]:
+ if kind == "role":
+ return _extract_roles(text)
+ if kind == "project":
+ return _extract_projects(text)
+ return set()
+
+
+# ─────────────────────────────────────────────────────────────────
+# Pluggable phrase extractor (R8)
+# ─────────────────────────────────────────────────────────────────
+
+
+class PhraseExtractor(ABC):
+ """Pluggable phrase extractor for the shared-property hybrid path.
+
+ The default implementation targets the prose patterns the SDK's
+ ``GraphExtraction`` pipeline produces (``"works at X as a "``,
+ ``"contributes to "``). Domain-specific subclasses can
+ override ``extract`` to recognize medical / legal / e-commerce /
+ non-English vocabularies without forking the strategy.
+
+ Pass an instance via ``CypherFirstAggregationStrategy(..., phrase_extractor=MyExtractor())``.
+ """
+
+ @abstractmethod
+ def extract(self, text: str, kind: str) -> set[str]:
+ """Return phrases of ``kind`` found in ``text``.
+
+ ``kind`` is typically ``"role"`` or ``"project"`` but extractors
+ may define additional kinds. Unknown kinds should return an empty
+ set rather than raise.
+ """
+
+
+class DefaultPhraseExtractor(PhraseExtractor):
+ """The default role/project extractor used by
+ ``CypherFirstAggregationStrategy``. Targets the prose patterns
+ produced by the SDK's default ``GraphExtraction`` pipeline.
+
+ See module-level ``_ROLE_RE`` and ``_PROJECT_RE`` for the exact
+ regexes; the role suffix vocabulary is enumerated in the
+ ``CypherFirstAggregationStrategy`` docstring.
+ """
+
+ def extract(self, text: str, kind: str) -> set[str]:
+ return _extract_phrases(text, kind)
+
+
+def _detect_property_kind(question: str) -> str | None:
+ q = question.lower()
+ if re.search(r"\b(?:role|roles|job|jobs|title|titles|position|positions)\b", q):
+ return "role"
+ if re.search(
+ r"\b(?:project|projects|initiative|initiatives|work\s+on|works\s+on|"
+ r"working\s+on|contribute|contributes|same\s+thing)\b",
+ q,
+ ):
+ return "project"
+ return None
+
+
+def _fuzzy_intersect(a: set[str], b: set[str]) -> set[str]:
+ """Return phrases from ``a`` that fuzzy-match any phrase in ``b``.
+
+ Fuzzy = case-insensitive substring in either direction OR ≥2 shared
+ content tokens. Catches cases where extraction paraphrased a project
+ name, e.g. "next-generation pipeline rewrite" vs "pipeline rewrite".
+ """
+ stop = {"the", "a", "an", "to", "of", "for", "in", "on", "and", "or"}
+
+ def _tokens(s: str) -> set[str]:
+ return {t.lower() for t in re.findall(r"\w+", s)
+ if t.lower() not in stop and len(t) > 2}
+
+ out: set[str] = set()
+ for x in a:
+ xt = _tokens(x)
+ if not xt:
+ continue
+ for y in b:
+ if x == y or x in y or y in x:
+ out.add(x)
+ break
+ yt = _tokens(y)
+ if len(xt & yt) >= 2:
+ out.add(x)
+ break
+ return out
+
+
+# ─────────────────────────────────────────────────────────────────
+# Cypher result → markdown table (M3)
+# ─────────────────────────────────────────────────────────────────
+
+def format_result_as_markdown_table(
+ result: Any,
+ *,
+ cap: int = 100,
+) -> tuple[str, list[dict[str, Any]], bool]:
+ """Render a FalkorDB result_set as a markdown table with column headers.
+
+ Returns ``(table_md, parsed_rows, truncated)``. ``parsed_rows`` is a list
+ of dicts keyed by column name so callers can post-process structured
+ data without re-parsing the markdown.
+ """
+ if not getattr(result, "result_set", None):
+ return "(empty result)", [], False
+
+ headers: list[str] = []
+ if getattr(result, "header", None):
+ for h in result.header:
+ if isinstance(h, (list, tuple)) and len(h) >= 2:
+ headers.append(str(h[1]))
+ else:
+ headers.append(str(h))
+ if not headers:
+ n_cols = len(result.result_set[0]) if result.result_set else 0
+ headers = [f"col_{i}" for i in range(n_cols)]
+
+ rows = result.result_set[:cap]
+ truncated = len(result.result_set) > cap
+
+ parsed_rows: list[dict[str, Any]] = []
+ for row in rows:
+ d: dict[str, Any] = {}
+ for i, val in enumerate(row):
+ d[headers[i] if i < len(headers) else f"col_{i}"] = val
+ parsed_rows.append(d)
+
+ sep = " | "
+ lines = [sep.join(headers), sep.join(["---"] * len(headers))]
+ for row in rows:
+ cells: list[str] = []
+ for v in row:
+ if v is None:
+ cells.append("(null)")
+ elif isinstance(v, list):
+ cells.append(", ".join(str(x) for x in v))
+ else:
+ cells.append(str(v))
+ lines.append(sep.join(cells))
+ if truncated:
+ lines.append(
+ f"... (showing {len(rows)} of {len(result.result_set)} rows; "
+ "result truncated)"
+ )
+ return "\n".join(lines), parsed_rows, truncated
+
+
+# ─────────────────────────────────────────────────────────────────
+# Schema prompt enrichment for aggregation generation
+# ─────────────────────────────────────────────────────────────────
+
+_AGG_SCHEMA_SUFFIX = """
+
+## Additional rules for AGGREGATION questions
+
+- For "BOTH X AND Y" / set-intersection over a shared property (e.g. "roles
+ held at both A and B"), use TWO separate matches against DIFFERENT
+ people, joined on the shared entity:
+ MATCH (p1:Person)-[:RELATES]-(o1:Organization),
+ (p1)-[:RELATES]-(prop:__Entity__),
+ (p2:Person)-[:RELATES]-(o2:Organization),
+ (p2)-[:RELATES]-(prop)
+ WHERE o1.name CONTAINS '...' AND o2.name CONTAINS '...'
+ RETURN DISTINCT prop.name AS shared_value
+ The shared `prop` variable IS the intersection.
+- For "more X than Y" comparison questions, RETURN both counts AS named
+ columns (e.g. RETURN count(...) AS acme_count, count(...) AS initech_count)
+ so the answer is unambiguous.
+- For "average / total of YEARS or NUMBERS", just RETURN the raw values
+ (e.g. d.name for Date entities). Arithmetic happens outside cypher.
+- For "which X have/work/share Y" list questions, RETURN DISTINCT only the
+ X you're asked about — do not add extra columns.
+- Always alias every RETURN column with a descriptive name.
+- Prefer RETURNing one row per group in group-by patterns rather than
+ packing multiple counts into a single row.
+"""
+
+_DESC_HINT_SUFFIX = (
+ "\n\nIMPORTANT: Person entities have rich free-text in their "
+ "`description` property — phrases like 'works at X as a senior "
+ "engineer' or 'contributes to internal tooling for observability'. "
+ "If the question asks about ROLES, JOBS, PROJECTS, or other free-text "
+ "properties that are NOT first-class entities in the schema, prefer "
+ "filtering on `p.description CONTAINS '...'` over matching typed "
+ "structural edges.\n"
+)
+
+
+# ─────────────────────────────────────────────────────────────────
+# Aggregation-mode answer-side directive
+# (synthesized into the retrieved item content; the existing system
+# prompt rule 8 in api/main.py already tells the LLM to trust the
+# "Authoritative Graph Query Results" section.)
+# ─────────────────────────────────────────────────────────────────
+
+_AUTH_HEADING = (
+ "## Authoritative Graph Query Results "
+ "(deterministic; trust over passages on counts and aggregates)"
+)
+
+
+def _wrap_authoritative(body: str, *, source_note: str = "") -> str:
+ note = (
+ f"\nSource: {source_note}." if source_note else
+ "\nSource: text-to-Cypher run against the knowledge graph."
+ )
+ return f"{_AUTH_HEADING}{note}\n\n{body}"
+
+
+# Canonical labels for the sub-path metadata key ``cypher_first_path``.
+# Operators / metrics dashboards can group on these to see which path
+# fires for each query.
+PATH_NUMERIC_MATH = "numeric_math"
+PATH_SHARED_PROPERTY_HYBRID = "shared_property_hybrid"
+PATH_CYPHER_TABLE = "cypher_table"
+PATH_NEGATION_EMPTY_NO = "negation_empty_no"
+PATH_RAG_FALLBACK = "rag_fallback"
+PATH_RAG_FALLBACK_NUMERIC_FAIL = "rag_fallback_numeric_fail"
+PATH_RAG_FALLBACK_CYPHER_EMPTY = "rag_fallback_cypher_empty"
+
+
+def _tag_path(result: RawSearchResult, path: str) -> RawSearchResult:
+ """Attach the ``cypher_first_path`` label to a strategy result.
+
+ Used both for results we construct ourselves and for results returned
+ from the delegated RAG fallback strategy — operators get a uniform
+ signal regardless of which branch handled the query.
+ """
+ meta = dict(result.metadata or {})
+ meta.setdefault("strategy", "cypher_first")
+ meta["cypher_first_path"] = path
+ return RawSearchResult(records=result.records, metadata=meta)
+
+
+# ─────────────────────────────────────────────────────────────────
+# Sub-paths
+#
+# Each path is a small, focused class with a single ``maybe_handle()``
+# method that either produces a final ``RawSearchResult`` or returns
+# ``None`` to defer to the next path. The strategy's ``_execute()``
+# dispatches by intent and iterates the relevant paths in order.
+# Splitting this way makes the routing trivial to follow, each path
+# trivially unit-testable in isolation, and adding new shapes (medical /
+# legal / e-commerce) a matter of dropping in a new path class.
+# ─────────────────────────────────────────────────────────────────
+
+
+class _AggregationPath(ABC):
+ """Base class for CypherFirstAggregationStrategy sub-paths.
+
+ Holds a reference to the parent strategy so subclasses can reach the
+ shared LLM / graph / fallback / k_candidates state without dragging
+ around a long argument list.
+ """
+
+ def __init__(self, strategy: CypherFirstAggregationStrategy) -> None:
+ self._s = strategy
+
+ @abstractmethod
+ async def maybe_handle(
+ self,
+ query: str,
+ ctx: Context,
+ ) -> RawSearchResult | None:
+ """Return a final retrieval result, or ``None`` to defer."""
+
+
+class _RagDelegationPath(_AggregationPath):
+ """Hands the query to the RAG fallback verbatim. Used for intent="rag"."""
+
+ async def maybe_handle(
+ self,
+ query: str,
+ ctx: Context,
+ ) -> RawSearchResult | None:
+ return _tag_path(
+ await self._s._fallback._execute(query, ctx),
+ PATH_RAG_FALLBACK,
+ )
+
+
+class _NumericMathPath(_AggregationPath):
+ """For "average / total / sum of YEARS / NUMBERS" — extract values via
+ Cypher, do the arithmetic in Python. Avoids LLM-arithmetic errors."""
+
+ async def maybe_handle(
+ self,
+ query: str,
+ ctx: Context,
+ ) -> RawSearchResult | None:
+ extraction_prompt = SCHEMA_PROMPT.format(
+ question=(
+ f"Generate a cypher that returns the RAW NUMERIC VALUES "
+ f"needed to answer this question (one value per row). "
+ f"Do NOT compute averages or sums in cypher; just return "
+ f"the raw numbers. Use Date entities if the question is "
+ f"about years.\n\n"
+ f"Question: {query}"
+ )
+ )
+ cypher: str | None = None
+ values: list[float] = []
+ try:
+ resp = await self._s._llm.ainvoke(extraction_prompt)
+ cypher = extract_cypher(resp.content)
+ errors = validate_cypher(cypher) if cypher else ["empty"]
+ if errors:
+ logger.debug("Numeric-math cypher validation failed: %s", errors)
+ cypher = None
+ else:
+ cypher = _sanitize_cypher(cypher)
+ result = await self._s._graph.query_raw(cypher)
+ for row in (result.result_set or []):
+ for cell in row:
+ v = _coerce_number(cell)
+ if v is not None:
+ values.append(v)
+ except Exception as exc:
+ logger.debug("Numeric-math extraction failed: %s", exc)
+
+ if not values:
+ # Fall back to standard retrieval — the LLM may still be able
+ # to extract the numbers from chunks.
+ ctx.log("CypherFirst numeric_math: no values extracted, "
+ "falling back to RAG")
+ return _tag_path(
+ await self._s._fallback._execute(query, ctx),
+ PATH_RAG_FALLBACK_NUMERIC_FAIL,
+ )
+
+ q_lower = query.lower()
+ if "median" in q_lower:
+ ans = statistics.median(values)
+ op = "median"
+ elif "total" in q_lower or "sum" in q_lower:
+ ans = float(sum(values))
+ op = "sum"
+ else: # average / mean / default for "average-like" questions
+ ans = sum(values) / len(values)
+ op = "average"
+
+ ans_str = f"{int(ans)}" if ans == int(ans) else f"{ans:.1f}"
+ body = (
+ f"computed {op} = {ans_str}\n"
+ f"source_values ({len(values)} rows): "
+ f"{', '.join(str(int(v)) if v == int(v) else f'{v:.2f}' for v in values)}\n"
+ f"cypher: {cypher}"
+ )
+ return RawSearchResult(
+ records=[{
+ "section": "cypher_results",
+ "content": _wrap_authoritative(
+ body,
+ source_note="numeric extraction + Python arithmetic",
+ ),
+ }],
+ metadata={
+ "strategy": "cypher_first",
+ "cypher_first_path": PATH_NUMERIC_MATH,
+ "op": op,
+ "value": ans,
+ "n_values": len(values),
+ "cypher": cypher,
+ },
+ )
+
+
+class _SharedPropertyHybridPath(_AggregationPath):
+ """For "BOTH A and B" / "same X as Z" questions over free-text
+ properties (role, project) that aren't first-class entities. Parses
+ Person chunks via regex and computes the set operation in Python
+ with fuzzy token matching."""
+
+ async def maybe_handle(
+ self,
+ query: str,
+ ctx: Context,
+ ) -> RawSearchResult | None:
+ kind = _detect_property_kind(query)
+ if kind is None:
+ return None
+ shape1 = _BOTH_AB_RE.search(query)
+ shape2 = _SAME_AS_RE.search(query)
+ if not (shape1 or shape2):
+ return None
+
+ batch_cypher = (
+ "MATCH (o:Organization)<-[:RELATES]-(p:Person) "
+ "OPTIONAL MATCH (p)-[:MENTIONED_IN]->(c:Chunk) "
+ "RETURN o.name AS org, p.name AS person, "
+ " p.description AS desc, collect(DISTINCT c.text) AS chunks"
+ )
+ try:
+ batch_res = await self._s._graph.query_raw(batch_cypher)
+ except Exception as exc:
+ logger.debug("Shared-property hybrid batch query failed: %s", exc)
+ return None
+
+ # Topology check: if the batched ``(Org)<-[:RELATES]-(Person)`` query
+ # returns zero tuples, the graph doesn't match the assumptions M5
+ # was tuned on (Person ↔ Organization edges + MENTIONED_IN chunks).
+ # Surface this loudly once per call so operators using custom
+ # extractors get a fast signal rather than silent wrong answers.
+ if not (batch_res.result_set or []):
+ logger.warning(
+ "CypherFirst shared-property hybrid found zero "
+ "(Organization)<-[:RELATES]-(Person) tuples; falling through. "
+ "If your graph uses different edge shapes or doesn't extract "
+ "Person/Organization labels, this hybrid will never fire — "
+ "see the strategy docstring's 'Assumptions and known limits' "
+ "section."
+ )
+ return None
+
+ org_phrase_map: dict[str, set[str]] = {}
+ for row in (batch_res.result_set or []):
+ org = row[0] or ""
+ person = row[1] or ""
+ desc = row[2] or ""
+ chunks = row[3] or []
+ text_blob = desc + "\n" + "\n".join(chunks)
+ phrases = org_phrase_map.setdefault(org, set())
+ for sent in re.split(r"(?<=[.\n])\s+", text_blob):
+ # Sentence-restrict to this person to avoid cross-paragraph
+ # contamination from chunks that contain multiple people.
+ if person and person.split()[0] not in sent:
+ continue
+ phrases |= self._s._phrase_extractor.extract(sent, kind)
+
+ def _gather(org_name: str) -> set[str]:
+ out: set[str] = set()
+ for name, phrases in org_phrase_map.items():
+ if org_name.lower() in (name or "").lower():
+ out |= phrases
+ return out
+
+ if shape1:
+ org_a, org_b = (shape1.group(1).strip(), shape1.group(2).strip())
+ if not org_a or not org_b:
+ return None
+ a_phrases = _gather(org_a)
+ b_phrases = _gather(org_b)
+ common = sorted(_fuzzy_intersect(a_phrases, b_phrases))
+ if not common:
+ return None
+ ctx.log(f"CypherFirst hybrid shape1: {len(common)} shared {kind}s")
+ body = (
+ f"The {kind}s held by employees at both {org_a} and "
+ f"{org_b}: " + ", ".join(common)
+ )
+ return RawSearchResult(
+ records=[{
+ "section": "cypher_results",
+ "content": _wrap_authoritative(
+ body,
+ source_note=(
+ "Person chunks + description regex; "
+ "fuzzy-intersected by content tokens"
+ ),
+ ),
+ }],
+ metadata={
+ "strategy": "cypher_first",
+ "cypher_first_path": PATH_SHARED_PROPERTY_HYBRID,
+ "shape": "both_a_and_b",
+ "kind": kind,
+ "common": common,
+ },
+ )
+
+ # shape2
+ target = shape2.group(1).strip().rstrip(",")
+ target_phrases = _gather(target)
+ if not target_phrases:
+ return None
+ sharing: list[str] = []
+ for org_name, org_phrases in org_phrase_map.items():
+ if not org_name:
+ continue
+ if (
+ target.lower() in org_name.lower()
+ or org_name.lower() in target.lower()
+ ):
+ continue
+ if _fuzzy_intersect(org_phrases, target_phrases):
+ sharing.append(org_name)
+ sharing.sort()
+ if not sharing:
+ return None
+ ctx.log(f"CypherFirst hybrid shape2: {len(sharing)} sharing orgs")
+ body = (
+ "The organizations that have at least one employee working on "
+ f"the same {kind} as someone at {target}: " + ", ".join(sharing)
+ )
+ return RawSearchResult(
+ records=[{
+ "section": "cypher_results",
+ "content": _wrap_authoritative(
+ body,
+ source_note=(
+ "Person chunks + description regex; fuzzy-matched "
+ "across orgs"
+ ),
+ ),
+ }],
+ metadata={
+ "strategy": "cypher_first",
+ "cypher_first_path": PATH_SHARED_PROPERTY_HYBRID,
+ "shape": "same_as",
+ "kind": kind,
+ "sharing": sharing,
+ },
+ )
+
+
+class _MultiCandidateCypherPath(_AggregationPath):
+ """Generates K parallel cypher candidates, executes them all, renders
+ the highest-row-count result as a markdown table with column headers.
+
+ Also handles the empty-result branches:
+ - negation-existential ("are there any X without Y?") → return No
+ - everything else → delegate to the RAG fallback
+ Always returns a result (never ``None``) — this path is the last
+ line of defense for aggregation intent.
+ """
+
+ async def maybe_handle(
+ self,
+ query: str,
+ ctx: Context,
+ ) -> RawSearchResult | None:
+ # Pass 1: structural.
+ candidates = await self._generate_k_candidates(query)
+ cypher, table_md, parsed, truncated = await self._execute_and_pick(
+ candidates,
+ )
+
+ # Pass 2 (description hint): if pass 1 was sparse for a "which X"
+ # or "shared X" question, try again with the description hint
+ # enabled. Cheap because cypher-gen runs in parallel.
+ expects_many = is_which_list(query) or re.search(
+ r"\bboth\b|\bshared\b|\bsame\b|\bcommon\b|\bin\s+common\b",
+ query, re.IGNORECASE,
+ )
+ if expects_many and (parsed is None or len(parsed) < 3):
+ more = await self._generate_k_candidates(query, with_desc_hint=True)
+ combined = list({*(candidates or []), *(more or [])})
+ cypher2, table_md2, parsed2, truncated2 = await self._execute_and_pick(combined)
+ if parsed2 and len(parsed2) > len(parsed or []):
+ cypher, table_md, parsed, truncated = (
+ cypher2, table_md2, parsed2, truncated2,
+ )
+
+ rows = len(parsed) if parsed else 0
+ ctx.log(f"CypherFirst cypher_table: {rows} rows from "
+ f"{len(candidates)} candidates")
+
+ if cypher and parsed:
+ directive = ""
+ if is_which_list(query):
+ directive = (
+ "\nNOTE: This is a 'which / list' question. Enumerate "
+ "EVERY DISTINCT VALUE from the first column in your "
+ "answer — do not summarize, truncate, or pick a "
+ "subset unless the question explicitly asked for the "
+ "top/most/fewest one."
+ )
+ body = table_md + directive
+ return RawSearchResult(
+ records=[{
+ "section": "cypher_results",
+ "content": _wrap_authoritative(body),
+ }],
+ metadata={
+ "strategy": "cypher_first",
+ "cypher_first_path": PATH_CYPHER_TABLE,
+ "cypher": cypher,
+ "cypher_rows": rows,
+ "cypher_truncated": truncated,
+ },
+ )
+
+ # Cypher returned 0 / no candidate succeeded.
+ if is_yes_no(query) and _is_negation_existential(query):
+ ctx.log("CypherFirst empty-result branch: negation-existential = No")
+ return RawSearchResult(
+ records=[{
+ "section": "cypher_results",
+ "content": _wrap_authoritative(
+ "No matching items: the cypher query returned 0 "
+ "rows. For a negation-existential question of "
+ "this shape, that means no such items exist.",
+ source_note="Cypher returned 0 rows (definitive)",
+ ),
+ }],
+ metadata={
+ "strategy": "cypher_first",
+ "cypher_first_path": PATH_NEGATION_EMPTY_NO,
+ "cypher": cypher,
+ },
+ )
+
+ # Vector fallback for everything else.
+ ctx.log("CypherFirst cypher empty — falling back to RAG")
+ return _tag_path(
+ await self._s._fallback._execute(query, ctx),
+ PATH_RAG_FALLBACK_CYPHER_EMPTY,
+ )
+
+ async def _generate_k_candidates(
+ self,
+ query: str,
+ *,
+ with_desc_hint: bool = False,
+ ) -> list[str]:
+ prompt = SCHEMA_PROMPT.format(question=query) + _AGG_SCHEMA_SUFFIX
+ if with_desc_hint:
+ prompt += _DESC_HINT_SUFFIX
+
+ async def _one() -> str | None:
+ try:
+ resp = await self._s._llm.ainvoke(prompt)
+ cypher = extract_cypher(resp.content)
+ if not cypher:
+ return None
+ errors = validate_cypher(cypher)
+ if errors:
+ return None
+ return _sanitize_cypher(cypher)
+ except Exception as exc:
+ logger.debug("Candidate generation failed: %s", exc)
+ return None
+
+ results = await asyncio.gather(*[_one() for _ in range(self._s._k)])
+ # Dedupe while preserving order.
+ seen: set[str] = set()
+ out: list[str] = []
+ for c in results:
+ if c and c not in seen:
+ seen.add(c)
+ out.append(c)
+ return out
+
+ async def _execute_and_pick(
+ self,
+ candidates: list[str],
+ ) -> tuple[str | None, str, list[dict[str, Any]], bool]:
+ """Run all candidates in parallel; pick the one with most rows.
+
+ Returns ``(cypher, table_md, parsed_rows, truncated)``.
+ """
+ if not candidates:
+ return None, "(no candidate cypher)", [], False
+ results = await asyncio.gather(
+ *[self._s._graph.query_raw(c) for c in candidates],
+ return_exceptions=True,
+ )
+ scored: list[tuple[int, int, str, Any]] = []
+ for cypher, res in zip(candidates, results):
+ if isinstance(res, BaseException):
+ continue
+ rows = len(res.result_set) if res.result_set else 0
+ cols = len(res.result_set[0]) if (res.result_set and res.result_set[0]) else 0
+ scored.append((rows, cols, cypher, res))
+ if not scored:
+ return None, "(no candidate executed successfully)", [], False
+ scored.sort(key=lambda t: (t[0], t[1]), reverse=True)
+ _, _, best_cypher, best_result = scored[0]
+ table_md, parsed, truncated = format_result_as_markdown_table(best_result)
+ return best_cypher, table_md, parsed, truncated
+
+
+# ─────────────────────────────────────────────────────────────────
+# Strategy
+# ─────────────────────────────────────────────────────────────────
+
+class CypherFirstAggregationStrategy(RetrievalStrategy):
+ """Aggregation-aware retrieval strategy.
+
+ Routes each question by detected intent:
+
+ - ``"numeric_math"`` — RETURN raw values via Cypher, then compute the
+ ``average``/``sum``/``median`` in Python. Avoids LLM-arithmetic errors.
+ - ``"aggregation"`` — multi-candidate Cypher, pick the highest-row-count
+ result, render as a markdown table with column headers. Optionally
+ run a description+chunk-text fuzzy hybrid for "shared X" / "BOTH A
+ and B" shapes before falling back to vector retrieval on empty.
+ - ``"rag"`` — delegate to ``rag_fallback`` (default: ``MultiPathRetrieval``).
+
+ Safe as the top-level strategy on ``GraphRAG``: non-aggregation
+ questions get the existing pipeline unchanged.
+
+ Every returned :class:`RawSearchResult` carries a ``cypher_first_path``
+ metadata key whose value is one of ``PATH_*`` module constants — useful
+ for operator dashboards that want to see which sub-path fired.
+
+ Assumptions and known limits
+ ----------------------------
+ The shared-property hybrid (M5) was tuned on graphs produced by the
+ SDK's default ``GraphExtraction`` pipeline. It makes the following
+ assumptions; when they're violated, the hybrid silently returns
+ ``None`` and the strategy falls back to the multi-candidate Cypher
+ path (which still works, just without free-text recovery):
+
+ - **Graph topology.** Organizations are connected to Persons via
+ ``[:RELATES]``, and Persons are connected to ``Chunk`` nodes via
+ ``[:MENTIONED_IN]``. This is the canonical shape the SDK builds.
+ Custom extractors that use different edge types or skip chunk
+ provenance will not benefit from M5 — the strategy logs a warning
+ and continues.
+ - **Prose shape.** Role and project values are extracted by regex
+ from Person descriptions and chunk text. The regexes target the
+ phrasing patterns ``"works at X as a "`` and ``"contributes
+ to "``. Domains whose prose departs from these patterns
+ (medical / legal / e-commerce / non-English) will not match — the
+ result is empty role/project sets, not wrong answers, but the
+ hybrid won't help.
+ - **Role vocabulary.** The role extractor anchors on the suffixes
+ ``engineer | scientist | manager | architect | researcher |
+ developer | analyst | specialist | designer | consultant``. Other
+ job titles ("director", "VP", "lead") won't match.
+
+ Accuracy ceiling
+ ----------------
+ The strategy faithfully returns what is in the graph. Duplicate
+ entities ("Wayne En" vs "Wayne Enterprises"), chunk-boundary-truncated
+ names, or non-deduplicated short-form references ("Carla" vs "Carla
+ Okafor") all flow through into the answer. Cypher counts will be
+ inflated; "which X" lists will contain duplicates. These are
+ extraction-quality issues — not strategy bugs — and should be
+ addressed in the ingestion pipeline (resolver, coref, dedup).
+
+ Args:
+ graph_store: Required for Cypher execution.
+ vector_store: Required for the RAG fallback.
+ embedder: Required for the RAG fallback.
+ llm: Required for Cypher generation + synthesis.
+ k_candidates: Number of parallel Cypher samples per aggregation
+ question. Default 3 — enough to surface alternate structural
+ interpretations without burning latency.
+ rag_fallback: Strategy used for non-aggregation intent. If
+ ``None``, a fresh ``MultiPathRetrieval`` is constructed
+ internally.
+
+ Example::
+
+ strategy = CypherFirstAggregationStrategy(
+ graph_store=rag._graph_store,
+ vector_store=rag._vector_store,
+ embedder=embedder,
+ llm=llm,
+ )
+ async with GraphRAG(
+ connection=conn, llm=llm, embedder=embedder,
+ embedding_dimension=256,
+ retrieval_strategy=strategy,
+ ) as rag:
+ ...
+ """
+
+ def __init__(
+ self,
+ graph_store: Any,
+ vector_store: Any,
+ embedder: Embedder,
+ llm: LLMInterface,
+ *,
+ k_candidates: int = 3,
+ rag_fallback: RetrievalStrategy | None = None,
+ phrase_extractor: PhraseExtractor | None = None,
+ ) -> None:
+ super().__init__(graph_store=graph_store, vector_store=vector_store)
+ self._embedder = embedder
+ self._llm = llm
+ self._k = max(1, k_candidates)
+ self._fallback = rag_fallback or MultiPathRetrieval(
+ graph_store=graph_store,
+ vector_store=vector_store,
+ embedder=embedder,
+ llm=llm,
+ )
+ # Pluggable phrase extractor for the shared-property hybrid path.
+ # Override with a domain-specific subclass to recognize roles /
+ # projects beyond the default English-prose vocabulary.
+ self._phrase_extractor = phrase_extractor or DefaultPhraseExtractor()
+ # Sub-paths — each handles one shape of question. The order in
+ # which they're consulted is encoded in ``_execute`` below; the
+ # paths themselves don't know about each other.
+ self._rag_path = _RagDelegationPath(self)
+ self._numeric_path = _NumericMathPath(self)
+ self._hybrid_path = _SharedPropertyHybridPath(self)
+ self._cypher_table_path = _MultiCandidateCypherPath(self)
+
+ # -- Template Method hook -------------------------------------
+
+ async def _execute(
+ self,
+ query: str,
+ ctx: Context,
+ **kwargs: Any,
+ ) -> RawSearchResult:
+ intent = detect_aggregation_intent(query)
+ ctx.log(f"CypherFirst intent={intent}")
+
+ if intent == "rag":
+ # Non-aggregation questions don't benefit from any of the
+ # cypher-first mechanics — hand straight to the fallback.
+ return await self._rag_path.maybe_handle(query, ctx)
+
+ if intent == "numeric_math":
+ return await self._numeric_path.maybe_handle(query, ctx)
+
+ # intent == "aggregation":
+ # Try the description+chunk hybrid for "shared X" shapes first —
+ # when it fires it generally produces better answers than the
+ # multi-candidate cypher because chunks preserve original corpus
+ # phrasing that extraction may have summarized away.
+ hybrid_result = await self._hybrid_path.maybe_handle(query, ctx)
+ if hybrid_result is not None:
+ return hybrid_result
+
+ # The multi-candidate cypher path always returns a result; it
+ # internally handles its own empty / negation / fallback branches.
+ return await self._cypher_table_path.maybe_handle(query, ctx)
+
+ # NOTE: the per-path logic (numeric math, shared-property hybrid,
+ # multi-candidate cypher, etc.) lives in the ``_AggregationPath``
+ # subclasses above this strategy. Keeping each path in its own class
+ # keeps the dispatch above readable and makes it trivial to swap one
+ # implementation out (e.g., a medical-prose phrase extractor) without
+ # touching the strategy itself.
+
+ # -- Custom _format ------------------------------------------
+
+ def _format(self, raw: RawSearchResult) -> RetrieverResult:
+ """Render section records as markdown content items, preserving the
+ cypher metadata so callers / metrics can see which mode fired."""
+ items: list[RetrieverResultItem] = []
+ for record in raw.records:
+ content = record.get("content", "") if isinstance(record, dict) else str(record)
+ section = record.get("section", "") if isinstance(record, dict) else ""
+ if content:
+ items.append(
+ RetrieverResultItem(
+ content=content,
+ metadata={"section": section},
+ )
+ )
+ return RetrieverResult(items=items, metadata=raw.metadata)
+
+
+
+def _coerce_number(cell: Any) -> float | None:
+ """Extract a single numeric value from a cypher result cell.
+
+ Accepts ints/floats directly; for strings, pulls the first integer or
+ float substring (catches "1995" inside "1995 is the year ...").
+ """
+ if cell is None:
+ return None
+ if isinstance(cell, (int, float)):
+ return float(cell)
+ m = re.search(r"-?\d+(?:\.\d+)?", str(cell))
+ return float(m.group()) if m else None
diff --git a/graphrag_sdk/src/graphrag_sdk/retrieval/strategies/cypher_generation.py b/graphrag_sdk/src/graphrag_sdk/retrieval/strategies/cypher_generation.py
index 0dd3d28..5d4232b 100644
--- a/graphrag_sdk/src/graphrag_sdk/retrieval/strategies/cypher_generation.py
+++ b/graphrag_sdk/src/graphrag_sdk/retrieval/strategies/cypher_generation.py
@@ -41,6 +41,30 @@
re.IGNORECASE,
)
+# Default row cap auto-injected when the LLM's query lacks a LIMIT.
+# Pure aggregations (count/sum/avg over no group-by) return one row, so
+# they skip injection entirely; group-by lists need enough rows that a 10-org
+# breakdown isn't truncated.
+_DEFAULT_ROW_LIMIT = 100
+
+_AGG_FN_NAMES = ("count", "sum", "avg", "min", "max", "collect",
+ "stdev", "percentileCont", "percentileDisc")
+_AGG_FN_RE = re.compile(
+ r"^\s*(?:" + "|".join(_AGG_FN_NAMES) + r")\s*\(",
+ re.IGNORECASE,
+)
+
+# Detects FUNCTION-style calls under a dotted namespace, e.g.
+# ``apoc.text.regexGroups(...)``, ``gds.shortest.path(...)``, ``db.idx.fulltext.queryNodes(...)``.
+# FalkorDB does not implement APOC/GDS/db plugins, so any dotted-namespace
+# function call silently returns 0 rows at execution. We reject these in the
+# validator so the existing retry-with-feedback loop can correct the query.
+# Note: ``\bCALL\b`` already catches procedure-style invocations; this regex
+# specifically targets the function-style pattern that slips through.
+_DOTTED_FN_RE = re.compile(
+ r"\b([a-zA-Z_]\w*)\.([a-zA-Z_][\w.]*)\s*\(",
+)
+
# ── Schema prompt ────────────────────────────────────────────────
SCHEMA_PROMPT = """\
@@ -81,6 +105,8 @@
- To find connections: `MATCH (a)-[:RELATES]-(b)` with entity name filters
- To count: `RETURN count(DISTINCT e)` or `RETURN count(r)`
- To list all of a type: `MATCH (e:Technology) RETURN e.name, e.description LIMIT 25`
+- For "BOTH X AND Y" questions, use two MATCH clauses sharing the same
+ variable to express set intersection — never UNION.
## Examples
@@ -122,11 +148,23 @@
LIMIT 20
```
-Question: "What organizations are related to the technology?"
+Question: "Which city has the most employees mentioned?"
+Note: a Person and a Location are typically NOT directly connected — they
+share an organization. Use a 2-hop traversal through the intermediary so the
+group-by works on the real graph topology:
```cypher
-MATCH (o:Organization)-[r:RELATES]-(t:Technology)
-RETURN o.name AS organization, t.name AS technology, r.rel_type AS relation, r.fact AS evidence
-LIMIT 20
+MATCH (p:Person)-[:RELATES]-(o:Organization)-[:RELATES]-(l:Location)
+RETURN l.name AS city, count(DISTINCT p) AS employee_count
+ORDER BY employee_count DESC
+LIMIT 5
+```
+
+Question: "Who works at BOTH Acme and Globex?"
+```cypher
+MATCH (p:Person)-[:RELATES]-(o1:Organization),
+ (p)-[:RELATES]-(o2:Organization)
+WHERE o1.name CONTAINS 'Acme' AND o2.name CONTAINS 'Globex'
+RETURN DISTINCT p.name AS person
```
## Your task
@@ -159,6 +197,53 @@ def extract_cypher(text: str) -> str:
# ── Cypher sanitization ─────────────────────────────────────────
+def _split_top_level_commas(s: str) -> list[str]:
+ """Split on commas that aren't inside parentheses/brackets/braces."""
+ out: list[str] = []
+ buf: list[str] = []
+ depth = 0
+ for ch in s:
+ if ch in "([{":
+ depth += 1
+ buf.append(ch)
+ elif ch in ")]}":
+ depth = max(0, depth - 1)
+ buf.append(ch)
+ elif ch == "," and depth == 0:
+ out.append("".join(buf).strip())
+ buf = []
+ else:
+ buf.append(ch)
+ if buf:
+ out.append("".join(buf).strip())
+ return [p for p in out if p]
+
+
+def _is_pure_aggregation(cypher: str) -> bool:
+ """True iff the FINAL RETURN clause projects only aggregate functions.
+
+ Pure aggregations (e.g., ``RETURN count(p)``) always return exactly one
+ row, so auto-injecting LIMIT is a no-op. Group-by patterns
+ (``RETURN o.name, count(p)``) are NOT pure — at least one projection is
+ a non-aggregate dimension that the LIMIT would actually apply to.
+ """
+ # Find the final RETURN body, stopping at ORDER BY / SKIP / LIMIT / end.
+ matches = list(re.finditer(
+ r"\bRETURN\b\s+(.+?)(?=\bORDER\s+BY\b|\bSKIP\b|\bLIMIT\b|;|$)",
+ cypher,
+ re.IGNORECASE | re.DOTALL,
+ ))
+ if not matches:
+ return False
+ body = matches[-1].group(1).strip()
+ if not body:
+ return False
+ projections = _split_top_level_commas(body)
+ if not projections:
+ return False
+ return all(_AGG_FN_RE.match(p) for p in projections)
+
+
def _sanitize_cypher(cypher: str) -> str:
"""Fix common LLM-generated Cypher issues before execution.
@@ -174,9 +259,13 @@ def _sanitize_cypher(cypher: str) -> str:
# Remove path variable assignments: "path = MATCH" -> "MATCH"
cypher = re.sub(r"\bpath\s*=\s*", "", cypher, flags=re.IGNORECASE)
- # Add LIMIT if missing (prevent runaway scans)
+ # Inject LIMIT only when the LLM didn't provide one AND the query isn't a
+ # single-row pure aggregation. Pure aggregations don't benefit from a cap;
+ # group-by lists do, but the previous default (25) was too small to fit a
+ # full 10-org breakdown. Skip on aggregations to avoid a misleading no-op.
if not re.search(r"\bLIMIT\b", cypher, re.IGNORECASE):
- cypher = cypher.rstrip().rstrip(";") + "\nLIMIT 25"
+ if not _is_pure_aggregation(cypher):
+ cypher = cypher.rstrip().rstrip(";") + f"\nLIMIT {_DEFAULT_ROW_LIMIT}"
return cypher
@@ -220,6 +309,18 @@ def validate_cypher(cypher: str) -> list[str]:
if re.search(r"\bLOAD\s+CSV\b", cypher_norm, re.IGNORECASE):
errors.append("LOAD CSV is not allowed in generated queries")
+ # Reject dotted-namespace function calls (apoc.*, gds.*, db.*).
+ # FalkorDB doesn't implement these plugins; the call silently returns 0
+ # rows at execution. Surfacing it here lets the retry loop regenerate.
+ for ns, _ in _DOTTED_FN_RE.findall(cypher_norm):
+ errors.append(
+ f"Unsupported function namespace '{ns}.*' "
+ "(FalkorDB does not implement APOC/GDS/db plugin functions). "
+ "Use only built-in Cypher functions like count, sum, avg, "
+ "labels, toInteger, substring, etc."
+ )
+ break # one error is enough — the LLM only needs to fix the pattern
+
# No write operations
if _WRITE_KEYWORDS.search(cypher_norm):
errors.append("Write operation detected — query must be read-only")
@@ -287,43 +388,52 @@ async def generate_cypher(
return None
-async def execute_cypher_retrieval(
- graph_store: Any,
- llm: Any,
- question: str,
- *,
- max_retries: int = 3,
-) -> tuple[list[str], dict[str, dict]]:
- """Full text-to-cypher retrieval: generate -> validate -> execute -> parse.
+# Matches a typed entity label inside a node pattern so we can widen it to
+# ``__Entity__`` on a 0-row retry. Captures the prefix (open paren + optional
+# variable + colon) so ``re.sub`` keeps the surrounding shape intact.
+_TYPED_NODE_LABEL_RE = re.compile(
+ r"(\(\s*\w*\s*:)(" + "|".join(re.escape(l) for l in _ENTITY_LABELS) + r")\b"
+)
- Results are intended to go DIRECTLY to the final LLM context
- (as a dedicated "Cypher Query Results" section), NOT through
- the cosine reranker.
- Returns:
- fact_strings: Formatted rows from Cypher execution.
- entities: Dict of entity_id -> {name, description}.
+def _widen_typed_labels(cypher: str) -> str:
+ """Swap typed entity labels (``:Person`` etc.) inside node patterns to
+ ``:__Entity__``. Used when a typed-label query returned 0 rows because
+ the extractor labelled the entity differently than the LLM expected."""
+ return _TYPED_NODE_LABEL_RE.sub(r"\1__Entity__", cypher)
- On any failure, returns empty results (silent degradation).
- """
- cypher = await generate_cypher(llm, question, max_retries=max_retries)
- if not cypher:
- return [], {}
- try:
- result = await graph_store.query_raw(cypher)
- except Exception as exc:
- logger.debug("Cypher execution failed: %s — query: %s", exc, cypher)
- return [], {}
+def _should_widen_labels(cypher: str) -> bool:
+ """Gate for the 0-row label-widen fallback.
- if not result.result_set:
- return [], {}
+ Skip widening when the typed label IS the filter — i.e. the RETURN
+ aggregates over a labeled variable AND the query has no ``WHERE … CONTAINS``
+ name predicate. In that case the user is asking "how many Persons?" and
+ widening would change the semantics. Otherwise (typical case: typed label
+ present alongside a name predicate or non-aggregate RETURN), widen.
+
+ Conservative: when in doubt we skip the fallback rather than risk a
+ semantically-different query.
+ """
+ if not _TYPED_NODE_LABEL_RE.search(cypher):
+ return False # nothing to widen
+ has_contains_filter = bool(
+ re.search(r"\bWHERE\b.*\bCONTAINS\b", cypher, re.IGNORECASE | re.DOTALL)
+ )
+ if has_contains_filter:
+ return True
+ # No name filter — check whether the RETURN is an aggregate over a labeled
+ # variable. If so, widening turns "count Persons" into "count Entities".
+ if _is_pure_aggregation(cypher):
+ return False
+ return True
- # Parse results into readable fact lines and entity dict
+
+def _parse_cypher_result_set(result_set: Any) -> tuple[list[str], dict[str, dict]]:
+ """Turn a FalkorDB result_set into (fact_strings, entities)."""
fact_strings: list[str] = []
entities: dict[str, dict] = {}
-
- for row in result.result_set:
+ for row in result_set:
parts: list[str] = []
for val in row:
if val is None:
@@ -333,11 +443,7 @@ async def execute_cypher_retrieval(
parts.append(s)
if not parts:
continue
-
- line = " | ".join(parts)
- fact_strings.append(line)
-
- # Extract entity names (strings that look like names, not numbers/lists)
+ fact_strings.append(" | ".join(parts))
for val in row:
if (
isinstance(val, str)
@@ -349,11 +455,75 @@ async def execute_cypher_retrieval(
eid = val.strip().lower().replace(" ", "_")
if eid and eid not in entities:
entities[eid] = {"name": val.strip(), "description": ""}
+ return fact_strings, entities
+
+
+async def execute_cypher_retrieval(
+ graph_store: Any,
+ llm: Any,
+ question: str,
+ *,
+ max_retries: int = 3,
+) -> tuple[list[str], dict[str, dict], dict[str, Any]]:
+ """Full text-to-cypher retrieval: generate -> validate -> execute -> parse.
+
+ Results are intended to go DIRECTLY to the final LLM context
+ (as a dedicated "Cypher Query Results" section), NOT through
+ the cosine reranker.
+
+ Returns:
+ fact_strings: Formatted rows from Cypher execution.
+ entities: Dict of entity_id -> {name, description}.
+ metadata: Dict capturing diagnostic signal — ``cypher`` (final query
+ executed), ``cypher_rows`` (row count), ``cypher_fallback``
+ (``"label_widened"`` if the 0-row fallback fired, else ``None``).
+
+ On any failure, returns ``([], {}, metadata)`` — never raises.
+ """
+ metadata: dict[str, Any] = {
+ "cypher": None,
+ "cypher_rows": 0,
+ "cypher_fallback": None,
+ }
+
+ cypher = await generate_cypher(llm, question, max_retries=max_retries)
+ if not cypher:
+ return [], {}, metadata
+ metadata["cypher"] = cypher
+
+ try:
+ result = await graph_store.query_raw(cypher)
+ except Exception as exc:
+ logger.debug("Cypher execution failed: %s — query: %s", exc, cypher)
+ return [], {}, metadata
+
+ # 0-row recovery: try once with typed labels widened to __Entity__.
+ # The LLM's structural reasoning (joins, filters, aggregations) is preserved
+ # — only the label predicate is relaxed. No second LLM call.
+ if not result.result_set and _should_widen_labels(cypher):
+ widened = _widen_typed_labels(cypher)
+ if widened != cypher:
+ logger.debug("Cypher 0-row retry with widened labels: %s", widened[:120])
+ try:
+ widened_result = await graph_store.query_raw(widened)
+ except Exception as exc:
+ logger.debug("Widened cypher execution failed: %s", exc)
+ else:
+ if widened_result.result_set:
+ result = widened_result
+ metadata["cypher"] = widened
+ metadata["cypher_fallback"] = "label_widened"
+
+ if not result.result_set:
+ return [], {}, metadata
+
+ fact_strings, entities = _parse_cypher_result_set(result.result_set)
+ metadata["cypher_rows"] = len(fact_strings)
logger.debug(
"Cypher retrieval: %d facts, %d entities from: %s",
len(fact_strings),
len(entities),
- cypher[:120],
+ (metadata["cypher"] or "")[:120],
)
- return fact_strings, entities
+ return fact_strings, entities, metadata
diff --git a/graphrag_sdk/src/graphrag_sdk/retrieval/strategies/multi_path.py b/graphrag_sdk/src/graphrag_sdk/retrieval/strategies/multi_path.py
index d7bcbd9..cc6e0b7 100644
--- a/graphrag_sdk/src/graphrag_sdk/retrieval/strategies/multi_path.py
+++ b/graphrag_sdk/src/graphrag_sdk/retrieval/strategies/multi_path.py
@@ -194,6 +194,7 @@ async def _execute(
query_vector = await self._embedder.aembed_query(query)
# 3. RELATES vector search + Text-to-Cypher (parallel when enabled)
+ cypher_metadata: dict[str, Any] = {}
if self._enable_cypher:
results = await asyncio.gather(
search_relates_edges(self._vector, query_vector, self._rel_top_k),
@@ -205,11 +206,11 @@ async def _execute(
fact_strings_scored, rel_entities = [], {}
else:
fact_strings_scored, rel_entities = results[0]
- # Unpack Cypher results
+ # Unpack Cypher results (3-tuple: facts, entities, metadata)
cypher_facts: list[str] = []
cypher_entities: dict[str, dict] = {}
if not isinstance(results[1], BaseException):
- cypher_facts, cypher_entities = results[1]
+ cypher_facts, cypher_entities, cypher_metadata = results[1]
else:
fact_strings_scored, rel_entities = await search_relates_edges(
self._vector, query_vector, self._rel_top_k
@@ -303,6 +304,7 @@ async def _execute(
source_passages,
q_type_hint,
cypher_results=cypher_facts if cypher_facts else None,
+ cypher_metadata=cypher_metadata or None,
)
def _format(self, raw: RawSearchResult) -> RetrieverResult:
diff --git a/graphrag_sdk/src/graphrag_sdk/retrieval/strategies/result_assembly.py b/graphrag_sdk/src/graphrag_sdk/retrieval/strategies/result_assembly.py
index 6b2ad9a..d6742a2 100644
--- a/graphrag_sdk/src/graphrag_sdk/retrieval/strategies/result_assembly.py
+++ b/graphrag_sdk/src/graphrag_sdk/retrieval/strategies/result_assembly.py
@@ -13,6 +13,12 @@
logger = logging.getLogger(__name__)
+# Maximum cypher result rows passed to the final LLM context. Was 20 — too low
+# for group-by lists across the whole graph (e.g. a 10-org breakdown was being
+# silently truncated). Pure aggregations return one row anyway, so a higher cap
+# is essentially free.
+_CYPHER_RESULT_CAP = 100
+
def cosine_sim(a: list[float], b: list[float]) -> float:
"""Cosine similarity between two float vectors."""
@@ -145,11 +151,16 @@ def assemble_raw_result(
source_passages: list[str],
q_type_hint: str = "",
cypher_results: list[str] | None = None,
+ cypher_metadata: dict[str, Any] | None = None,
) -> RawSearchResult:
"""Build structured RawSearchResult with section records.
``cypher_results`` are placed in their own section and are NOT
subject to cosine reranking — they go directly to the final LLM.
+
+ ``cypher_metadata`` (if non-empty) is merged into the result metadata
+ so callers and metrics can see whether the 0-row fallback fired,
+ whether truncation occurred, and the final cypher that ran.
"""
records: list[dict[str, Any]] = []
@@ -162,13 +173,28 @@ def assemble_raw_result(
}
)
- # Cypher Query Results (direct to LLM — not reranked)
+ # Authoritative Graph Query Results (direct to LLM — not reranked).
+ # Heading is deliberately worded to signal authority over prose passages on
+ # quantitative questions; pair with system-prompt rule 8 in api/main.py.
+ truncated = False
if cypher_results:
+ shown = cypher_results[:_CYPHER_RESULT_CAP]
+ body_lines = [f"- {r}" for r in shown]
+ if len(cypher_results) > _CYPHER_RESULT_CAP:
+ truncated = True
+ body_lines.append(
+ f"- … (showing {len(shown)} of {len(cypher_results)} rows; "
+ "result was truncated)"
+ )
records.append(
{
"section": "cypher_results",
- "content": "## Graph Query Results\n"
- + "\n".join(f"- {r}" for r in cypher_results[:20]),
+ "content": (
+ "## Authoritative Graph Query Results "
+ "(deterministic; trust over passages on counts and aggregates)\n"
+ "Source: text-to-Cypher run against the knowledge graph.\n"
+ + "\n".join(body_lines)
+ ),
}
)
@@ -218,7 +244,18 @@ def assemble_raw_result(
}
)
+ metadata: dict[str, Any] = {"strategy": "multi_path"}
+ if cypher_metadata:
+ # Surface fallback firing rate, truncation, and the final cypher to
+ # callers / metrics. Don't overwrite if multiple keys conflict — the
+ # caller-supplied values win since they reflect the actual execution.
+ metadata.update(
+ {f"cypher_{k}" if not k.startswith("cypher") else k: v
+ for k, v in cypher_metadata.items()}
+ )
+ metadata["cypher_truncated"] = truncated
+
return RawSearchResult(
records=records,
- metadata={"strategy": "multi_path"},
+ metadata=metadata,
)
diff --git a/graphrag_sdk/tests/test_cypher_first.py b/graphrag_sdk/tests/test_cypher_first.py
new file mode 100644
index 0000000..412529f
--- /dev/null
+++ b/graphrag_sdk/tests/test_cypher_first.py
@@ -0,0 +1,669 @@
+"""Tests for cypher_first strategy helpers.
+
+The strategy class itself needs a graph + LLM, so we exercise the
+pure-Python pieces (intent classifier, regex extractors, fuzzy intersect,
+table formatter, numeric coercion) without external dependencies.
+"""
+
+from __future__ import annotations
+
+from types import SimpleNamespace
+
+import pytest
+
+from graphrag_sdk.core.models import RawSearchResult
+from graphrag_sdk.retrieval.strategies.cypher_first import (
+ PATH_CYPHER_TABLE,
+ PATH_NEGATION_EMPTY_NO,
+ PATH_NUMERIC_MATH,
+ PATH_RAG_FALLBACK,
+ PATH_RAG_FALLBACK_CYPHER_EMPTY,
+ PATH_RAG_FALLBACK_NUMERIC_FAIL,
+ PATH_SHARED_PROPERTY_HYBRID,
+ _coerce_number,
+ _detect_property_kind,
+ _extract_phrases,
+ _extract_projects,
+ _extract_roles,
+ _fuzzy_intersect,
+ _is_negation_existential,
+ _tag_path,
+ detect_aggregation_intent,
+ format_result_as_markdown_table,
+ is_which_list,
+ is_yes_no,
+)
+
+# ── Intent classifier ────────────────────────────────────────────
+
+
+class TestDetectIntent:
+ def test_count_questions_are_aggregation(self):
+ for q in [
+ "How many people work at Acme Corp?",
+ "How many distinct organizations are mentioned?",
+ "Count of employees by org?",
+ ]:
+ assert detect_aggregation_intent(q) == "aggregation", q
+
+ def test_which_most_is_aggregation(self):
+ for q in [
+ "Which city has the most employees mentioned?",
+ "Which organization has the fewest employees?",
+ "Which orgs share a project with Acme?",
+ ]:
+ assert detect_aggregation_intent(q) == "aggregation", q
+
+ def test_more_than_multi_word(self):
+ # The intent regex must catch up to 4 words between `more` and `than`.
+ q = "Does Acme Corp have more employees mentioned than Initech Systems?"
+ assert detect_aggregation_intent(q) == "aggregation"
+
+ def test_both_a_and_b_is_aggregation(self):
+ q = "Which job roles are held by employees at BOTH Acme Corp and Initech Systems?"
+ assert detect_aggregation_intent(q) == "aggregation"
+
+ def test_existential_is_aggregation(self):
+ for q in [
+ "Are there any organizations with no employees listed?",
+ "Is there any employee at Acme on observability tooling?",
+ ]:
+ assert detect_aggregation_intent(q) == "aggregation", q
+
+ def test_average_year_is_numeric_math(self):
+ q = "What is the average year of founding across all 10 organizations?"
+ assert detect_aggregation_intent(q) == "numeric_math"
+
+ def test_total_revenue_is_numeric_math(self):
+ assert detect_aggregation_intent("What is the total revenue?") == "numeric_math"
+
+ def test_factoid_is_rag(self):
+ for q in [
+ "Who is the lighthouse keeper?",
+ "Tell me about Acme Corp.",
+ "What did the professor discover?",
+ ]:
+ assert detect_aggregation_intent(q) == "rag", q
+
+
+class TestShapeDetectors:
+ def test_yes_no_starts(self):
+ assert is_yes_no("Is there any X?")
+ assert is_yes_no("Does Acme have more employees than Initech?")
+ assert is_yes_no("Are there any orgs without employees?")
+
+ def test_yes_no_negative(self):
+ assert not is_yes_no("Which X has the most Y?")
+ assert not is_yes_no("How many people?")
+
+ def test_which_list_starts(self):
+ assert is_which_list("Which cities have offices?")
+ assert is_which_list("List the orgs with 5 employees.")
+ assert is_which_list("Name the people at Acme.")
+
+ def test_which_list_negative(self):
+ assert not is_which_list("How many people work here?")
+ assert not is_which_list("Is there any employee at Acme?")
+
+ def test_negation_existential(self):
+ assert _is_negation_existential(
+ "Are there any organizations with NO employees listed?"
+ )
+ assert _is_negation_existential(
+ "Is there any company without an office?"
+ )
+ assert not _is_negation_existential(
+ "Is there any employee at Acme who works on observability?"
+ )
+ assert not _is_negation_existential("How many people work here?")
+
+
+# ── Property kind + extractors (M5 hybrid) ───────────────────────
+
+
+class TestPropertyKind:
+ def test_role_keywords(self):
+ for q in [
+ "What roles are at Acme?",
+ "Which jobs are common?",
+ "List the titles at Initech.",
+ ]:
+ assert _detect_property_kind(q) == "role", q
+
+ def test_project_keywords(self):
+ for q in [
+ "Who works on observability?",
+ "Which projects are shared?",
+ "What initiatives is Acme contributing to?",
+ ]:
+ assert _detect_property_kind(q) == "project", q
+
+ def test_other_returns_none(self):
+ assert _detect_property_kind("Where is Acme based?") is None
+ assert _detect_property_kind("How many people work there?") is None
+
+
+class TestRoleExtractor:
+ def test_basic_role_match(self):
+ desc = "Anna Reyes is a senior engineer at Acme Corp."
+ roles = _extract_roles(desc)
+ assert "senior engineer" in roles
+
+ def test_multiword_role(self):
+ desc = "Uma Patel is a site reliability engineer at Initech Systems."
+ assert "site reliability engineer" in _extract_roles(desc)
+
+ def test_applied_scientist(self):
+ desc = "Cyrus Doss is an applied scientist at Massive Dynamic."
+ assert "applied scientist" in _extract_roles(desc)
+
+ def test_no_role_no_match(self):
+ # No professional-suffix word → no extraction.
+ assert _extract_roles("Alice is based in Boston.") == set()
+
+ def test_short_phrase_filtered(self):
+ # 'engineer' alone (no qualifier) is too short — rejected by length gate.
+ # But still captured as "engineer" if length > 3 — which it is (8).
+ # The filter is just defensive; this test asserts current behavior.
+ result = _extract_roles("Bob works as an engineer.")
+ # accept either outcome; just ensure no crash
+ assert isinstance(result, set)
+
+
+class TestProjectExtractor:
+ def test_basic_project_match(self):
+ desc = (
+ "Anna Reyes is a senior engineer at Acme Corp, "
+ "contributing to the cross-region replication initiative "
+ "and active in the cloud infrastructure community."
+ )
+ projects = _extract_projects(desc)
+ assert any("cross-region replication initiative" in p for p in projects)
+
+ def test_contributes_to_variant(self):
+ desc = "Mira Jansen contributes to a migration to managed services."
+ projects = _extract_projects(desc)
+ assert any("migration to managed services" in p for p in projects)
+
+ def test_no_contributes_no_match(self):
+ assert _extract_projects("Pavel works at Wayne.") == set()
+
+
+class TestExtractPhrases:
+ def test_role_dispatch(self):
+ out = _extract_phrases("Anna is a senior engineer", "role")
+ assert "senior engineer" in out
+
+ def test_project_dispatch(self):
+ out = _extract_phrases("contributes to the pipeline rewrite", "project")
+ assert any("pipeline rewrite" in p for p in out)
+
+ def test_unknown_kind_empty(self):
+ assert _extract_phrases("anything", "city") == set()
+
+
+# ── Fuzzy intersect ──────────────────────────────────────────────
+
+
+class TestFuzzyIntersect:
+ def test_exact_match(self):
+ a = {"senior engineer", "data scientist"}
+ b = {"senior engineer", "product manager"}
+ assert _fuzzy_intersect(a, b) == {"senior engineer"}
+
+ def test_substring_in_one_direction(self):
+ # "pipeline rewrite" (substring) should match "next-generation
+ # pipeline rewrite" (superset).
+ a = {"pipeline rewrite"}
+ b = {"next-generation pipeline rewrite"}
+ assert _fuzzy_intersect(a, b) == {"pipeline rewrite"}
+
+ def test_substring_other_direction(self):
+ a = {"next-generation pipeline rewrite"}
+ b = {"pipeline rewrite"}
+ assert _fuzzy_intersect(a, b) == {"next-generation pipeline rewrite"}
+
+ def test_two_token_overlap_matches(self):
+ a = {"automated incident response tooling"}
+ b = {"incident response system tooling"}
+ # 2+ shared content tokens: "incident", "response", "tooling"
+ assert _fuzzy_intersect(a, b) == {"automated incident response tooling"}
+
+ def test_single_stopword_no_match(self):
+ # Sharing only "the" / "a" should not be enough.
+ a = {"the migration"}
+ b = {"a migration"}
+ # Both reduce to {"migration"} (1 token) after stopword filter — no match.
+ assert _fuzzy_intersect(a, b) == set()
+
+ def test_no_overlap(self):
+ assert _fuzzy_intersect({"alpha beta"}, {"gamma delta"}) == set()
+
+
+# ── Markdown table formatter ─────────────────────────────────────
+
+
+class TestFormatTable:
+ def _result(self, header, rows):
+ return SimpleNamespace(
+ header=[[1, name] for name in header],
+ result_set=rows,
+ )
+
+ def test_renders_with_headers(self):
+ r = self._result(["city", "n"], [["Boston", 10], ["Chicago", 8]])
+ md, parsed, truncated = format_result_as_markdown_table(r)
+ assert "city | n" in md
+ assert "Boston | 10" in md
+ assert "Chicago | 8" in md
+ assert not truncated
+ assert parsed == [{"city": "Boston", "n": 10}, {"city": "Chicago", "n": 8}]
+
+ def test_synthesizes_headers_when_missing(self):
+ r = SimpleNamespace(header=None, result_set=[["x", "y"]])
+ md, parsed, _ = format_result_as_markdown_table(r)
+ assert "col_0 | col_1" in md
+ assert parsed[0] == {"col_0": "x", "col_1": "y"}
+
+ def test_empty_result(self):
+ r = SimpleNamespace(header=[], result_set=[])
+ md, parsed, truncated = format_result_as_markdown_table(r)
+ assert md == "(empty result)"
+ assert parsed == []
+ assert not truncated
+
+ def test_truncation_sentinel(self):
+ rows = [["x", i] for i in range(150)]
+ r = self._result(["item", "n"], rows)
+ md, parsed, truncated = format_result_as_markdown_table(r, cap=50)
+ assert truncated
+ assert len(parsed) == 50
+ assert "showing 50 of 150 rows" in md
+
+ def test_null_and_list_cells(self):
+ r = self._result(["a", "b"], [[None, ["x", "y", "z"]]])
+ md, _, _ = format_result_as_markdown_table(r)
+ assert "(null)" in md
+ assert "x, y, z" in md
+
+
+# ── Numeric coercion (M6 helper) ─────────────────────────────────
+
+
+class TestCoerceNumber:
+ def test_int(self):
+ assert _coerce_number(1995) == 1995.0
+
+ def test_float(self):
+ assert _coerce_number(3.14) == pytest.approx(3.14)
+
+ def test_string_with_number(self):
+ # Date entity names are strings like "1995" — pull the integer.
+ assert _coerce_number("1995") == 1995.0
+
+ def test_string_with_embedded_number(self):
+ assert _coerce_number("1995 is the year Acme was founded") == 1995.0
+
+ def test_none(self):
+ assert _coerce_number(None) is None
+
+ def test_no_digits(self):
+ assert _coerce_number("hello") is None
+
+ def test_negative(self):
+ assert _coerce_number("-42") == -42.0
+
+
+# ── Path-tag contract (R2) ───────────────────────────────────────
+
+
+class TestPathTag:
+ """``_tag_path`` enforces the contract that every result emitted by
+ ``CypherFirstAggregationStrategy`` carries a ``cypher_first_path``
+ label, so operators can route metrics on which sub-path fired."""
+
+ def test_labels_match_known_paths(self):
+ # If we add a new path later, this list should grow in lockstep.
+ assert PATH_NUMERIC_MATH == "numeric_math"
+ assert PATH_SHARED_PROPERTY_HYBRID == "shared_property_hybrid"
+ assert PATH_CYPHER_TABLE == "cypher_table"
+ assert PATH_NEGATION_EMPTY_NO == "negation_empty_no"
+ assert PATH_RAG_FALLBACK == "rag_fallback"
+ assert PATH_RAG_FALLBACK_NUMERIC_FAIL == "rag_fallback_numeric_fail"
+ assert PATH_RAG_FALLBACK_CYPHER_EMPTY == "rag_fallback_cypher_empty"
+
+ def test_tag_path_adds_label_to_empty_metadata(self):
+ result = RawSearchResult(records=[{"section": "x", "content": "y"}],
+ metadata={})
+ out = _tag_path(result, PATH_CYPHER_TABLE)
+ assert out.metadata["cypher_first_path"] == PATH_CYPHER_TABLE
+ assert out.metadata["strategy"] == "cypher_first"
+
+ def test_tag_path_overwrites_existing_path_label(self):
+ # When a sub-path delegates to another that already tagged the
+ # result, the outer tag should win — it reflects the actual path
+ # taken from the caller's perspective.
+ result = RawSearchResult(
+ records=[],
+ metadata={"cypher_first_path": "earlier", "extra": 1},
+ )
+ out = _tag_path(result, PATH_RAG_FALLBACK)
+ assert out.metadata["cypher_first_path"] == PATH_RAG_FALLBACK
+ # Other metadata is preserved.
+ assert out.metadata["extra"] == 1
+
+ def test_tag_path_preserves_existing_strategy_label_if_set(self):
+ # If a delegated strategy already tagged itself (e.g. "multi_path"),
+ # don't clobber it — setdefault.
+ result = RawSearchResult(
+ records=[],
+ metadata={"strategy": "multi_path"},
+ )
+ out = _tag_path(result, PATH_RAG_FALLBACK)
+ assert out.metadata["strategy"] == "multi_path"
+ assert out.metadata["cypher_first_path"] == PATH_RAG_FALLBACK
+
+ def test_tag_path_returns_new_object_not_mutating_input(self):
+ original_meta = {"strategy": "multi_path"}
+ result = RawSearchResult(records=[], metadata=original_meta)
+ out = _tag_path(result, PATH_RAG_FALLBACK)
+ # Don't surprise callers by mutating their dict in place.
+ assert "cypher_first_path" not in original_meta
+ assert out.metadata is not original_meta
+
+
+# ── End-to-end routing with mocks (R4) ───────────────────────────
+
+
+class _FakeResult:
+ """Mimics FalkorDB's query_raw return: ``result_set`` rows + a ``header``
+ list of ``[type_int, name]`` pairs (we only use the name)."""
+ def __init__(self, header, result_set):
+ self.header = [[1, name] for name in header]
+ self.result_set = result_set
+
+
+class _FakeFallback:
+ """Stand-in for MultiPathRetrieval — records the call and returns a
+ well-formed RawSearchResult so we can verify delegation + tagging."""
+ def __init__(self, records=None, metadata=None):
+ self._records = records or [{"section": "passages", "content": "## ..."}]
+ self._metadata = metadata or {"strategy": "multi_path"}
+ self.calls = []
+
+ async def _execute(self, query, ctx, **kwargs):
+ self.calls.append(query)
+ return RawSearchResult(records=list(self._records),
+ metadata=dict(self._metadata))
+
+
+def _make_strategy(*, llm_responses=None, graph_results=None, fallback=None):
+ """Build a CypherFirstAggregationStrategy with stubbed LLM + graph.
+
+ ``llm_responses`` is a sequence of strings; each LLM call pops one.
+ ``graph_results`` is a sequence of _FakeResult objects (or exceptions
+ to raise); each ``query_raw`` call pops one.
+ """
+ from unittest.mock import AsyncMock, MagicMock
+
+ from graphrag_sdk.core.models import LLMResponse
+ from graphrag_sdk.retrieval.strategies.cypher_first import (
+ CypherFirstAggregationStrategy,
+ )
+
+ llm = MagicMock()
+ responses = list(llm_responses or [])
+ async def _ainvoke(_prompt, **_kw):
+ return LLMResponse(content=responses.pop(0) if responses else "")
+ llm.ainvoke = AsyncMock(side_effect=_ainvoke)
+
+ graph = MagicMock()
+ results = list(graph_results or [])
+ async def _query_raw(_cypher):
+ nxt = results.pop(0) if results else _FakeResult([], [])
+ if isinstance(nxt, BaseException):
+ raise nxt
+ return nxt
+ graph.query_raw = AsyncMock(side_effect=_query_raw)
+
+ return CypherFirstAggregationStrategy(
+ graph_store=graph,
+ vector_store=MagicMock(),
+ embedder=MagicMock(),
+ llm=llm,
+ k_candidates=1, # one candidate is enough for routing tests
+ rag_fallback=fallback or _FakeFallback(),
+ )
+
+
+class TestStrategyRouting:
+ """Behavioural tests: assert which sub-path fires for each intent +
+ graph-state combination by inspecting ``metadata.cypher_first_path``."""
+
+ async def test_rag_intent_delegates_to_fallback(self):
+ # "Who is the lighthouse keeper?" doesn't match any aggregation
+ # pattern — strategy must hand the query to the fallback verbatim
+ # and tag the result with PATH_RAG_FALLBACK.
+ from graphrag_sdk.core.context import Context
+ fallback = _FakeFallback()
+ strat = _make_strategy(fallback=fallback)
+ ctx = Context()
+ result = await strat._execute("Who is the lighthouse keeper?", ctx)
+ assert result.metadata["cypher_first_path"] == PATH_RAG_FALLBACK
+ assert fallback.calls == ["Who is the lighthouse keeper?"]
+
+ async def test_aggregation_with_cypher_rows_takes_cypher_table(self):
+ # Multi-candidate cypher returns a non-empty table → strategy
+ # picks the table path and emits a single "cypher_results" item.
+ from graphrag_sdk.core.context import Context
+ cypher_code = (
+ "```cypher\n"
+ "MATCH (p:Person)-[:RELATES]-(o:Organization)\n"
+ "RETURN o.name AS org, count(DISTINCT p) AS n\n"
+ "ORDER BY n DESC LIMIT 5\n"
+ "```"
+ )
+ strat = _make_strategy(
+ llm_responses=[cypher_code],
+ graph_results=[
+ _FakeResult(["org", "n"], [["Acme", 10], ["Globex", 8]])
+ ],
+ )
+ result = await strat._execute(
+ "Which org has the most employees mentioned?", Context(),
+ )
+ assert result.metadata["cypher_first_path"] == PATH_CYPHER_TABLE
+ assert result.metadata["cypher_rows"] == 2
+ assert "Acme | 10" in result.records[0]["content"]
+
+ async def test_numeric_intent_does_python_arithmetic(self):
+ # Question asks for an average; cypher returns raw year values;
+ # strategy computes 1989.9 in Python and tags PATH_NUMERIC_MATH.
+ from graphrag_sdk.core.context import Context
+ cypher_code = (
+ "```cypher\nMATCH (d:Date) RETURN d.name AS year\n```"
+ )
+ years = [[str(y)] for y in
+ [1939, 1968, 1973, 1984, 1995, 1998, 2003, 2009, 2011, 2019]]
+ strat = _make_strategy(
+ llm_responses=[cypher_code],
+ graph_results=[_FakeResult(["year"], years)],
+ )
+ result = await strat._execute(
+ "What is the average year of founding across all 10 organizations?",
+ Context(),
+ )
+ assert result.metadata["cypher_first_path"] == PATH_NUMERIC_MATH
+ assert result.metadata["op"] == "average"
+ assert result.metadata["value"] == pytest.approx(1989.9, abs=0.1)
+
+ async def test_numeric_empty_extraction_falls_back_to_rag(self):
+ # Cypher generated for the numeric path returns 0 numeric values
+ # → strategy falls through to the RAG fallback and tags
+ # PATH_RAG_FALLBACK_NUMERIC_FAIL.
+ from graphrag_sdk.core.context import Context
+ fallback = _FakeFallback()
+ strat = _make_strategy(
+ llm_responses=["```cypher\nMATCH (n:Person) RETURN n.name\n```"],
+ graph_results=[_FakeResult(["name"], [["Alice"], ["Bob"]])],
+ fallback=fallback,
+ )
+ result = await strat._execute(
+ "What is the average year of founding?", Context(),
+ )
+ assert result.metadata["cypher_first_path"] == PATH_RAG_FALLBACK_NUMERIC_FAIL
+ assert fallback.calls # fallback was invoked
+
+ async def test_negation_existential_empty_returns_no(self):
+ # "Are there any orgs WITHOUT employees?" + cypher returns 0 rows
+ # → strategy emits an authoritative "No" and tags
+ # PATH_NEGATION_EMPTY_NO. No fallback delegation.
+ from graphrag_sdk.core.context import Context
+ cypher_code = (
+ "```cypher\n"
+ "MATCH (o:Organization) WHERE NOT EXISTS { "
+ "MATCH (o)<-[:RELATES]-(p:Person) } RETURN o.name\n"
+ "```"
+ )
+ fallback = _FakeFallback()
+ strat = _make_strategy(
+ llm_responses=[cypher_code],
+ graph_results=[
+ _FakeResult([], []), # hybrid batch query (no shape match)
+ _FakeResult(["name"], []), # the actual cypher
+ ],
+ fallback=fallback,
+ )
+ result = await strat._execute(
+ "Are there any organizations for which NO employees are listed?",
+ Context(),
+ )
+ assert result.metadata["cypher_first_path"] == PATH_NEGATION_EMPTY_NO
+ # Negation path must NOT delegate to the fallback.
+ assert fallback.calls == []
+
+ async def test_positive_existential_empty_falls_back_to_rag(self):
+ # "Is there any employee at Acme on observability?" + cypher returns
+ # 0 rows (typed label mismatch) → strategy falls through to RAG and
+ # tags PATH_RAG_FALLBACK_CYPHER_EMPTY.
+ from graphrag_sdk.core.context import Context
+ cypher_code = (
+ "```cypher\n"
+ "MATCH (p:Person)-[:RELATES]-(t:Technology) RETURN p.name\n"
+ "```"
+ )
+ fallback = _FakeFallback()
+ strat = _make_strategy(
+ llm_responses=[cypher_code],
+ graph_results=[_FakeResult(["name"], [])],
+ fallback=fallback,
+ )
+ result = await strat._execute(
+ "Is there any employee at Acme working on observability?",
+ Context(),
+ )
+ assert result.metadata["cypher_first_path"] == PATH_RAG_FALLBACK_CYPHER_EMPTY
+ assert len(fallback.calls) == 1
+
+ async def test_hybrid_warns_when_topology_assumption_violated(self, caplog):
+ # The batched (Org)<-[:RELATES]-(Person) query returns zero tuples
+ # → strategy emits a warning and falls through (returns None from
+ # the hybrid; the rest of _execute keeps going).
+ import logging
+
+ from graphrag_sdk.core.context import Context
+ cypher_code = (
+ "```cypher\nMATCH (p:Person) RETURN p.name AS name\n```"
+ )
+ strat = _make_strategy(
+ llm_responses=[cypher_code],
+ graph_results=[
+ _FakeResult([], []), # hybrid batch — zero topology tuples
+ _FakeResult(["name"], [["Alice"]]), # cypher_table cypher
+ ],
+ )
+ with caplog.at_level(logging.WARNING,
+ logger="graphrag_sdk.retrieval.strategies.cypher_first"):
+ await strat._execute(
+ "Which roles are held by employees at BOTH Acme and Globex?",
+ Context(),
+ )
+ # The topology-violation warning fired.
+ assert any("zero (Organization)<-[:RELATES]-(Person) tuples" in r.message
+ for r in caplog.records)
+
+
+# ── Pluggable phrase extractor (R8) ──────────────────────────────
+
+
+class TestPhraseExtractor:
+ """Domain-specific extractors can replace the default role/project
+ regexes without forking the strategy."""
+
+ def test_default_extractor_matches_module_regexes(self):
+ from graphrag_sdk.retrieval.strategies.cypher_first import (
+ DefaultPhraseExtractor,
+ )
+ ext = DefaultPhraseExtractor()
+ assert "senior engineer" in ext.extract(
+ "Anna is a senior engineer at Acme.", "role"
+ )
+ assert any("pipeline rewrite" in p for p in ext.extract(
+ "contributes to the pipeline rewrite", "project"
+ ))
+ # Unknown kinds return an empty set, not an exception.
+ assert ext.extract("anything", "city") == set()
+
+ async def test_strategy_uses_custom_extractor_in_hybrid_path(self):
+ """A custom extractor passed to the strategy is consulted by the
+ shared-property hybrid instead of the default regexes."""
+ from graphrag_sdk.core.context import Context
+ from graphrag_sdk.retrieval.strategies.cypher_first import PhraseExtractor
+
+ class _UpperCaseRoleExtractor(PhraseExtractor):
+ """Match exactly the literal strings 'ALPHA' and 'BETA' as roles
+ — clearly distinguishable from the default regex output."""
+ def extract(self, text, kind):
+ if kind == "role":
+ return {w for w in ("ALPHA", "BETA") if w in text}
+ return set()
+
+ # Both orgs have one person whose chunk text mentions the same
+ # custom token. With the default extractor, none of these phrases
+ # would match (no role suffix). With our custom one, ALPHA is
+ # common to both.
+ from unittest.mock import AsyncMock, MagicMock
+
+ graph = MagicMock()
+ batch_rows = [
+ ["Acme", "Alice", "no description", ["Alice does ALPHA at Acme."]],
+ ["Acme", "Anna", "no description", ["Anna does BETA at Acme."]],
+ ["Globex", "Bob", "no description", ["Bob does ALPHA at Globex."]],
+ ["Globex", "Bea", "no description", ["Bea does GAMMA at Globex."]],
+ ]
+ async def _query_raw(_cypher):
+ return SimpleNamespace(
+ header=[[1, "org"], [1, "person"], [1, "desc"], [1, "chunks"]],
+ result_set=batch_rows,
+ )
+ graph.query_raw = AsyncMock(side_effect=_query_raw)
+
+ from graphrag_sdk.retrieval.strategies.cypher_first import (
+ CypherFirstAggregationStrategy,
+ )
+ strat = CypherFirstAggregationStrategy(
+ graph_store=graph,
+ vector_store=MagicMock(),
+ embedder=MagicMock(),
+ llm=MagicMock(),
+ k_candidates=1,
+ rag_fallback=MagicMock(),
+ phrase_extractor=_UpperCaseRoleExtractor(),
+ )
+ result = await strat._execute(
+ "Which roles are held by employees at BOTH Acme and Globex?",
+ Context(),
+ )
+ # Hybrid fired (not a fallback) and computed the right intersection.
+ assert result.metadata["cypher_first_path"] == PATH_SHARED_PROPERTY_HYBRID
+ assert result.metadata["common"] == ["ALPHA"]
diff --git a/graphrag_sdk/tests/test_cypher_generation.py b/graphrag_sdk/tests/test_cypher_generation.py
index 5869415..ab7e0d6 100644
--- a/graphrag_sdk/tests/test_cypher_generation.py
+++ b/graphrag_sdk/tests/test_cypher_generation.py
@@ -7,6 +7,10 @@
extract_cypher,
validate_cypher,
_sanitize_cypher,
+ _is_pure_aggregation,
+ _should_widen_labels,
+ _split_top_level_commas,
+ _widen_typed_labels,
)
@@ -73,6 +77,27 @@ def test_rejects_call_procedures(self):
errors = validate_cypher("CALL db.labels() YIELD label RETURN label")
assert any("CALL" in e for e in errors)
+ def test_rejects_apoc_function_calls(self):
+ # Function-style apoc / gds / db calls slip past the bare \bCALL\b
+ # check; the dotted-namespace rule should reject them.
+ for snippet in [
+ "MATCH (n) RETURN apoc.text.regexGroups(n.description, '\\d+') AS m",
+ "MATCH (n) RETURN gds.shortest.path(n) AS p",
+ "MATCH (n) WHERE db.idx.fulltext.queryNodes('x', 'y') RETURN n",
+ ]:
+ errors = validate_cypher(snippet)
+ assert any("Unsupported function namespace" in e for e in errors), \
+ f"should reject: {snippet}"
+
+ def test_accepts_bare_builtin_functions(self):
+ # Built-ins are bare (count, toInteger, substring) — no namespace.
+ for snippet in [
+ "MATCH (n:Person) RETURN count(n) AS c",
+ "MATCH (n:Person) RETURN toInteger(n.name) AS i",
+ "MATCH (n:Person) RETURN substring(n.description, 0, 4) AS s",
+ ]:
+ assert validate_cypher(snippet) == [], f"should accept: {snippet}"
+
def test_rejects_load_csv(self):
errors = validate_cypher("LOAD CSV FROM 'file:///data.csv' AS row RETURN row")
assert any("LOAD CSV" in e for e in errors)
@@ -108,14 +133,34 @@ def test_with_allowed(self):
class TestSanitizeCypher:
def test_adds_limit_when_missing(self):
- result = _sanitize_cypher("MATCH (n) RETURN n")
+ # Non-aggregation query without LIMIT — should get the new default cap.
+ result = _sanitize_cypher("MATCH (n:Person) RETURN n.name")
assert "LIMIT" in result
+ assert "100" in result # _DEFAULT_ROW_LIMIT
def test_keeps_existing_limit(self):
cypher = "MATCH (n) RETURN n LIMIT 10"
result = _sanitize_cypher(cypher)
assert result.count("LIMIT") == 1
+ def test_does_not_inject_limit_on_pure_aggregation(self):
+ # count / sum / avg without group-by always returns one row — adding
+ # LIMIT is misleading.
+ for cypher in [
+ "MATCH (n:Person) RETURN count(n)",
+ "MATCH (n:Person) RETURN count(DISTINCT n) AS c",
+ "MATCH (n:Person) RETURN avg(toInteger(n.name)) AS a",
+ ]:
+ result = _sanitize_cypher(cypher)
+ assert "LIMIT" not in result, f"should skip LIMIT for: {cypher}"
+
+ def test_injects_limit_on_group_by(self):
+ # Group-by has a non-aggregate dimension in RETURN — LIMIT applies to
+ # the row count of the breakdown.
+ cypher = "MATCH (o:Organization)-[:RELATES]-(p:Person) RETURN o.name, count(p) AS n"
+ result = _sanitize_cypher(cypher)
+ assert "LIMIT" in result
+
def test_removes_shortest_path(self):
cypher = "MATCH path = shortestPath((a)-[*]-(b)) RETURN path"
result = _sanitize_cypher(cypher)
@@ -132,6 +177,106 @@ def test_removes_path_assignment(self):
assert "path =" not in result and "path=" not in result
+# ── _is_pure_aggregation ──────────────────────────────────────────
+
+
+class TestIsPureAggregation:
+ def test_count_only(self):
+ assert _is_pure_aggregation("MATCH (n) RETURN count(n)")
+
+ def test_count_distinct_with_alias(self):
+ assert _is_pure_aggregation(
+ "MATCH (n:Person) RETURN count(DISTINCT n) AS person_count"
+ )
+
+ def test_multi_aggregate(self):
+ assert _is_pure_aggregation(
+ "MATCH (n) RETURN count(n) AS c, avg(n.score) AS a"
+ )
+
+ def test_group_by_is_not_pure(self):
+ # A non-aggregate dimension in RETURN means LIMIT actually matters.
+ assert not _is_pure_aggregation(
+ "MATCH (o)-[:RELATES]-(p) RETURN o.name AS org, count(p) AS n"
+ )
+
+ def test_no_aggregate_is_not_pure(self):
+ assert not _is_pure_aggregation("MATCH (n:Person) RETURN n.name")
+
+ def test_aggregate_with_order_by_limit_clause(self):
+ # ORDER BY / LIMIT after RETURN must not confuse the regex.
+ assert _is_pure_aggregation(
+ "MATCH (n) RETURN count(n) ORDER BY count(n)"
+ )
+
+
+# ── _split_top_level_commas ───────────────────────────────────────
+
+
+class TestSplitTopLevelCommas:
+ def test_simple(self):
+ assert _split_top_level_commas("a, b, c") == ["a", "b", "c"]
+
+ def test_nested_parens_preserved(self):
+ # commas inside count(DISTINCT a, b) must NOT split projections
+ assert _split_top_level_commas("count(DISTINCT a, b) AS c, d") == [
+ "count(DISTINCT a, b) AS c",
+ "d",
+ ]
+
+
+# ── _widen_typed_labels / _should_widen_labels ────────────────────
+
+
+class TestWidenTypedLabels:
+ def test_widens_single_label(self):
+ cypher = "MATCH (p:Person) RETURN p.name"
+ assert _widen_typed_labels(cypher) == "MATCH (p:__Entity__) RETURN p.name"
+
+ def test_widens_multiple_labels(self):
+ cypher = (
+ "MATCH (p:Person)-[:RELATES]-(o:Organization) "
+ "WHERE o.name CONTAINS 'Acme' RETURN p.name"
+ )
+ widened = _widen_typed_labels(cypher)
+ assert ":Person" not in widened
+ assert ":Organization" not in widened
+ assert widened.count(":__Entity__") == 2
+
+ def test_does_not_widen_structural_labels(self):
+ # Chunk / Document / __Entity__ already are structural — leave alone.
+ cypher = "MATCH (c:Chunk)-[:PART_OF]->(d:Document) RETURN c.text"
+ assert _widen_typed_labels(cypher) == cypher
+
+ def test_idempotent(self):
+ cypher = "MATCH (p:Person) RETURN p.name"
+ once = _widen_typed_labels(cypher)
+ twice = _widen_typed_labels(once)
+ assert once == twice
+
+
+class TestShouldWidenLabels:
+ def test_widens_when_name_filter_present(self):
+ # A typed label + name filter = LLM is using label as routing hint, not
+ # as the filter itself — safe to widen.
+ assert _should_widen_labels(
+ "MATCH (p:Person)-[:RELATES]-(o:Organization) "
+ "WHERE o.name CONTAINS 'Acme' RETURN p.name"
+ )
+
+ def test_skips_when_label_is_the_filter(self):
+ # Pure aggregation over a labeled variable with NO name predicate —
+ # widening would change semantics ("count Persons" → "count Entities").
+ assert not _should_widen_labels(
+ "MATCH (p:Person) RETURN count(DISTINCT p) AS c"
+ )
+
+ def test_skips_when_no_typed_label(self):
+ assert not _should_widen_labels(
+ "MATCH (e:__Entity__) RETURN count(e)"
+ )
+
+
# ── execute_cypher_retrieval ──────────────────────────────────────
@@ -148,9 +293,13 @@ async def test_returns_empty_on_generation_failure(self):
mock_llm.ainvoke = AsyncMock(return_value=LLMResponse(content="I don't know"))
mock_graph = MagicMock()
- facts, entities = await execute_cypher_retrieval(mock_graph, mock_llm, "test?")
+ facts, entities, metadata = await execute_cypher_retrieval(
+ mock_graph, mock_llm, "test?"
+ )
assert facts == []
assert entities == {}
+ assert metadata["cypher_fallback"] is None
+ assert metadata["cypher_rows"] == 0
async def test_returns_empty_on_execution_error(self):
"""When Cypher execution fails, should return empty results."""
@@ -169,9 +318,12 @@ async def test_returns_empty_on_execution_error(self):
mock_graph = MagicMock()
mock_graph.query_raw = AsyncMock(side_effect=Exception("connection error"))
- facts, entities = await execute_cypher_retrieval(mock_graph, mock_llm, "test?")
+ facts, entities, metadata = await execute_cypher_retrieval(
+ mock_graph, mock_llm, "test?"
+ )
assert facts == []
assert entities == {}
+ assert metadata["cypher_fallback"] is None
async def test_parses_result_rows(self):
"""Successful execution should parse rows into facts and entities."""
@@ -192,8 +344,76 @@ async def test_parses_result_rows(self):
mock_graph = MagicMock()
mock_graph.query_raw = AsyncMock(return_value=result_mock)
- facts, entities = await execute_cypher_retrieval(mock_graph, mock_llm, "test?")
+ facts, entities, metadata = await execute_cypher_retrieval(
+ mock_graph, mock_llm, "test?"
+ )
assert len(facts) == 2
assert "Alice" in facts[0]
assert "alice" in entities
assert "bob" in entities
+ assert metadata["cypher_rows"] == 2
+ assert metadata["cypher_fallback"] is None
+
+ async def test_label_widen_fires_on_zero_rows(self):
+ """When typed-label query returns 0 rows AND a name filter is present,
+ the fallback should rewrite typed labels to __Entity__ and re-execute."""
+ from unittest.mock import AsyncMock, MagicMock
+ from graphrag_sdk.core.models import LLMResponse
+ from graphrag_sdk.retrieval.strategies.cypher_generation import (
+ execute_cypher_retrieval,
+ )
+
+ # Original cypher: typed label + name filter — gating allows widen.
+ original = (
+ "MATCH (p:Person)-[:RELATES]-(t:Technology) "
+ "WHERE t.name CONTAINS 'observability' RETURN p.name LIMIT 25"
+ )
+ mock_llm = MagicMock()
+ mock_llm.ainvoke = AsyncMock(
+ return_value=LLMResponse(content=f"```cypher\n{original}\n```")
+ )
+
+ empty_result = MagicMock()
+ empty_result.result_set = []
+ widened_result = MagicMock()
+ widened_result.result_set = [["Carla Okafor"]]
+
+ mock_graph = MagicMock()
+ # First call: original (typed) → empty. Second call: widened → hit.
+ mock_graph.query_raw = AsyncMock(side_effect=[empty_result, widened_result])
+
+ facts, entities, metadata = await execute_cypher_retrieval(
+ mock_graph, mock_llm, "is there an Acme employee on observability?"
+ )
+ assert facts == ["Carla Okafor"]
+ assert metadata["cypher_fallback"] == "label_widened"
+ assert ":__Entity__" in metadata["cypher"]
+ assert mock_graph.query_raw.await_count == 2
+
+ async def test_label_widen_skipped_for_pure_aggregation(self):
+ """For 'count Persons' (label IS the filter), widening would change
+ semantics — fallback must skip and return empty."""
+ from unittest.mock import AsyncMock, MagicMock
+ from graphrag_sdk.core.models import LLMResponse
+ from graphrag_sdk.retrieval.strategies.cypher_generation import (
+ execute_cypher_retrieval,
+ )
+
+ cypher = "MATCH (p:Person) RETURN count(DISTINCT p) AS c"
+ mock_llm = MagicMock()
+ mock_llm.ainvoke = AsyncMock(
+ return_value=LLMResponse(content=f"```cypher\n{cypher}\n```")
+ )
+
+ empty_result = MagicMock()
+ empty_result.result_set = []
+ mock_graph = MagicMock()
+ mock_graph.query_raw = AsyncMock(return_value=empty_result)
+
+ facts, entities, metadata = await execute_cypher_retrieval(
+ mock_graph, mock_llm, "how many people?"
+ )
+ assert facts == []
+ assert metadata["cypher_fallback"] is None
+ # Only one execution — fallback was correctly skipped.
+ assert mock_graph.query_raw.await_count == 1
diff --git a/graphrag_sdk/tests/test_facade.py b/graphrag_sdk/tests/test_facade.py
index 2e4a22a..e252437 100644
--- a/graphrag_sdk/tests/test_facade.py
+++ b/graphrag_sdk/tests/test_facade.py
@@ -1104,3 +1104,66 @@ async def test_ingest_input_validation_runs_before_config_probe(
# Probe and DB query must not have fired.
g._graph_store.query_raw.assert_not_called()
g.embedder.aembed_query.assert_not_called()
+
+
+# ── Cypher authority rule injection (R1) ──────────────────────────
+
+
+class TestCypherAuthorityRuleInjection:
+ """The cypher-authority rule is appended to the system prompt only when
+ the retriever produced an authoritative-results section. Callers who
+ don't use cypher retrieval keep the unchanged base prompt."""
+
+ def test_helper_returns_false_for_empty_items(self):
+ from graphrag_sdk.api.main import _has_authoritative_cypher_results
+ assert not _has_authoritative_cypher_results(RetrieverResult(items=[]))
+
+ def test_helper_returns_false_for_non_cypher_items(self):
+ from graphrag_sdk.api.main import _has_authoritative_cypher_results
+ items = [
+ RetrieverResultItem(
+ content="## Source Document Passages\n- foo",
+ metadata={"section": "passages"},
+ ),
+ RetrieverResultItem(
+ content="## Key Entities\n- Acme",
+ metadata={"section": "entities"},
+ ),
+ ]
+ assert not _has_authoritative_cypher_results(RetrieverResult(items=items))
+
+ def test_helper_returns_true_on_section_metadata(self):
+ from graphrag_sdk.api.main import _has_authoritative_cypher_results
+ item = RetrieverResultItem(
+ content="",
+ metadata={"section": "cypher_results"},
+ )
+ assert _has_authoritative_cypher_results(RetrieverResult(items=[item]))
+
+ def test_helper_returns_true_on_heading_marker_in_content(self):
+ # Defensive path — third-party strategies that don't tag metadata
+ # but use the canonical heading still trigger the rule.
+ from graphrag_sdk.api.main import _has_authoritative_cypher_results
+ item = RetrieverResultItem(
+ content=(
+ "## Authoritative Graph Query Results "
+ "(deterministic; trust over passages on counts and aggregates)\n"
+ "- Boston | 10"
+ ),
+ metadata={},
+ )
+ assert _has_authoritative_cypher_results(RetrieverResult(items=[item]))
+
+ def test_base_system_prompt_does_not_contain_rule_8(self):
+ # The rule has been moved out of the base prompts. Callers who
+ # don't surface cypher results keep the original 7-rule prompt.
+ from graphrag_sdk.api.main import (
+ _CYPHER_AUTH_RULE,
+ _RAG_SYSTEM_PROMPT,
+ _RAG_SYSTEM_PROMPT_DELIMITED,
+ )
+ for base in (_RAG_SYSTEM_PROMPT, _RAG_SYSTEM_PROMPT_DELIMITED):
+ assert "Authoritative Graph Query Results" not in base
+ assert "\n8." not in base
+ # The addendum, on its own, mentions the heading.
+ assert "Authoritative Graph Query Results" in _CYPHER_AUTH_RULE