diff --git a/automatic-evaluation-pipeline/eval_frames.py b/automatic-evaluation-pipeline/eval_frames.py new file mode 100644 index 00000000..5156f11f --- /dev/null +++ b/automatic-evaluation-pipeline/eval_frames.py @@ -0,0 +1,832 @@ +""" +Evaluate OpenRAG on the FRAMES benchmark (google/frames-benchmark). + +FRAMES contains 824 multi-hop questions requiring 2-15 Wikipedia articles each. +Run `setup_frames.py` first to download articles and index them into OpenRAG. + +Modes: + - Default: queries OpenRAG (retriever + LLM generation) + - --no-rag: bypasses retrieval, sends prompt directly to the LLM without chunks + - --oracle: bypasses retrieval, provides gold Wikipedia articles to the LLM + - --gold-workspaces: creates one workspace per question, attaches only gold files + +Usage: + cd automatic-evaluation-pipeline + python eval_frames.py [--partition FRAMES] [--output results.json] [--limit N] + python eval_frames.py --no-rag + python eval_frames.py --oracle + python eval_frames.py --gold-workspaces [--partition FRAMES] + +Environment variables (from .env): + APP_URL, APP_PORT, AUTH_TOKEN, MODEL, BASE_URL, API_KEY + MODEL_JUDGE, BASE_URL_JUDGE, API_KEY_JUDGE +""" + +import argparse +import ast +import asyncio +import json +import os +import re +import string +import warnings +from collections import defaultdict +from pathlib import Path +from urllib.parse import unquote, urlparse + +import httpx +from dotenv import load_dotenv +from langchain_openai import ChatOpenAI +from loguru import logger +from pydantic import BaseModel, Field +from tqdm.asyncio import tqdm + +load_dotenv() + +warnings.filterwarnings("ignore", message=r".*Pydantic serializer warnings.*") + +# ─── Env / config ──────────────────────────────────────────────────────────── + +APP_URL = os.environ.get("APP_URL", "localhost") +APP_PORT = os.environ.get("APP_PORT", "8080") +AUTH_TOKEN = os.environ.get("AUTH_TOKEN", "sk-1234") +MODEL = os.environ.get("MODEL") +BASE_URL = os.environ.get("BASE_URL") +API_KEY = os.environ.get("API_KEY") +JUDGE_MODEL = os.environ.get("MODEL_JUDGE", MODEL) +JUDGE_BASE_URL = os.environ.get("BASE_URL_JUDGE", BASE_URL) +JUDGE_API_KEY = os.environ.get("API_KEY_JUDGE", API_KEY) + +OPENRAG_BASE_URL = f"http://{APP_URL}:{APP_PORT}" + +MAX_RETRIES = 3 +_RETRY_BACKOFF = [2, 5, 10] + + +async def _http_with_retry( + client: httpx.AsyncClient, method: str, url: str, *, max_retries: int = MAX_RETRIES, **kwargs +) -> httpx.Response: + """HTTP request with retry on 5xx errors and network errors.""" + for attempt in range(max_retries + 1): + try: + if method == "GET": + resp = await client.get(url, **kwargs) + else: + resp = await client.post(url, **kwargs) + if resp.status_code >= 500 and attempt < max_retries: + wait = _RETRY_BACKOFF[min(attempt, len(_RETRY_BACKOFF) - 1)] + logger.debug(f"HTTP {resp.status_code} for {url}, retry {attempt+1}/{max_retries} in {wait}s") + await asyncio.sleep(wait) + continue + resp.raise_for_status() + return resp + except httpx.HTTPStatusError: + raise + except Exception: + if attempt == max_retries: + raise + wait = _RETRY_BACKOFF[min(attempt, len(_RETRY_BACKOFF) - 1)] + logger.debug(f"Network error for {url}, retry {attempt+1}/{max_retries} in {wait}s") + await asyncio.sleep(wait) + + +async def _check_openrag_health(client: httpx.AsyncClient) -> bool: + """Check OpenRAG is reachable. Returns True if up, False otherwise.""" + try: + resp = await client.get(f"{OPENRAG_BASE_URL}/health_check") + resp.raise_for_status() + print("OpenRAG API is up.") + return True + except Exception as e: + print(f"ERROR: Cannot reach OpenRAG at {OPENRAG_BASE_URL}: {e}") + return False + + +# ─── Dataset caching ──────────────────────────────────────────────────────── + +DATASET_CACHE = Path(__file__).parent / "frames_dataset.json" +DEFAULT_UNANSWERABLE_EXACT_MATCH_ANSWERS = [ + "unanswerable", + "cannot answer", + "i cannot answer", + "the answer is not in the provided context", + "the answer is not in the context", + "not enough information in the provided context", +] + + +def _default_dataset_cache(dataset_path: str | None) -> Path: + if dataset_path: + path = Path(dataset_path) + return path if path.is_absolute() else Path(__file__).parent / dataset_path + return DATASET_CACHE + + +def load_dataset_cached(limit: int | None = None, dataset_path: str | None = None) -> list[dict]: + """Load dataset from local cache path, or download FRAMES and cache it there.""" + cache_path = _default_dataset_cache(dataset_path) + if cache_path.exists(): + logger.info(f"Loading dataset from cache ({cache_path.name})...") + with open(cache_path, encoding="utf-8") as f: + dataset = json.load(f) + else: + logger.info("Downloading FRAMES benchmark from HuggingFace (first time only)...") + from datasets import load_dataset + hf_dataset = load_dataset("google/frames-benchmark", split="test") + dataset = [dict(row) for row in hf_dataset] + cache_path.parent.mkdir(parents=True, exist_ok=True) + with open(cache_path, "w", encoding="utf-8") as f: + json.dump(dataset, f, ensure_ascii=False, indent=2) + logger.info(f"Cached {len(dataset)} questions to {cache_path.name}") + + logger.info(f"Loaded {len(dataset)} questions.") + if limit: + dataset = dataset[:limit] + logger.info(f"Limited to {len(dataset)} questions.") + return dataset + + +# ─── Wikipedia helpers ─────────────────────────────────────────────────────── + + +def extract_title_from_url(url: str) -> str | None: + """Extract the Wikipedia article title from a URL.""" + parsed = urlparse(url) + if "wikipedia.org" not in parsed.netloc: + return None + match = re.match(r"/wiki/(.+)", parsed.path) + if match: + return unquote(match.group(1)).replace("_", " ") + return None + + +def parse_wiki_links(row) -> list[str]: + """Extract Wikipedia URLs from a FRAMES dataset row.""" + if row.get("wiki_links") is None: + return [] + + def _split_embedded_wiki_urls(value: str) -> list[str]: + pattern = r"https?://(?:[\w.-]+\.)?wikipedia\.org/wiki/" + starts = [m.start() for m in re.finditer(pattern, value)] + if not starts: + return [] + if len(starts) == 1: + return [value.strip().strip(",")] + + urls = [] + for idx, start in enumerate(starts): + end = starts[idx + 1] if idx + 1 < len(starts) else len(value) + candidate = value[start:end].strip().strip(",").strip("'").strip('"') + if candidate: + urls.append(candidate) + return urls + + wiki_links = row.get("wiki_links") + if not wiki_links: + return [] + if isinstance(wiki_links, str): + try: + links = ast.literal_eval(wiki_links) + except (ValueError, SyntaxError): + try: + links = json.loads(wiki_links) + except json.JSONDecodeError: + links = [l.strip() for l in wiki_links.split(",") if l.strip()] + else: + links = wiki_links + + normalized_links = [] + for link in links: + if not isinstance(link, str) or "wikipedia.org" not in link: + continue + normalized_links.extend(_split_embedded_wiki_urls(link)) + + seen = set() + deduped = [] + for link in normalized_links: + if link not in seen: + seen.add(link) + deduped.append(link) + return deduped + + +def _safe_filename(title: str) -> str: + """Convert a Wikipedia title to a safe filename.""" + safe_name = re.sub(r'[^\w\s\-()]', '', title).strip() + safe_name = re.sub(r'\s+', '_', safe_name) + if not safe_name: + safe_name = f"article_{hash(title) % 10**8}" + return safe_name + + +def _strip_html_for_oracle(html: str) -> str: + """Lightweight HTML cleanup for oracle context.""" + text = re.sub(r']*>.*?', '', html, flags=re.DOTALL | re.IGNORECASE) + text = re.sub(r']*>.*?', '', text, flags=re.DOTALL | re.IGNORECASE) + text = re.sub(r'', '', text, flags=re.DOTALL) + text = re.sub(r'<(\w+)\s[^>]*?(/?)>', r'<\1\2>', text) + text = re.sub(r'<(\w+)>\s*', '', text) + text = re.sub(r'\n{3,}', '\n\n', text) + return text.strip() + + +def load_articles_from_disk(rows: list[dict], docs_dir: Path) -> dict[str, str]: + """Load Wikipedia articles from local .md or .html files for oracle mode.""" + all_titles: set[str] = set() + for row in rows: + for url in parse_wiki_links(row): + title = extract_title_from_url(url) + if title: + all_titles.add(title) + + # .md takes priority when both exist; .html is stripped of markup before use. + md_files = {p.stem: p for p in docs_dir.glob("*.md")} + html_files = {p.stem: p for p in docs_dir.glob("*.html")} + + articles = {} + missing = [] + for title in all_titles: + safe_name = _safe_filename(title) + if safe_name in md_files: + articles[title] = md_files[safe_name].read_text(encoding="utf-8") + elif safe_name in html_files: + articles[title] = _strip_html_for_oracle(html_files[safe_name].read_text(encoding="utf-8")) + else: + missing.append(title) + + print(f"Loaded {len(articles)} articles from disk ({len(missing)} missing).") + if missing: + logger.info(f"Missing articles: {missing[:20]}{'...' if len(missing) > 20 else ''}") + return articles + + +# ─── OpenRAG API calls ──────────────────────────────────────────────────────── + + +async def _fetch_chunk_content( + client: httpx.AsyncClient, + chunk_url: str, + headers: dict, +) -> str: + """Fetch the text content of a single chunk via its extract URL.""" + try: + resp = await _http_with_retry(client, "GET", chunk_url, headers=headers) + return resp.json().get("page_content", "") + except Exception as e: + logger.debug(f"Failed to fetch chunk {chunk_url}: {e}") + return "" + + +async def query_openrag_answer( + question: str, + partition: str, + semaphore: asyncio.Semaphore, + workspace: str | None = None, +) -> tuple[str | None, str]: + """Call OpenRAG chat completions for a single question.""" + headers = {"Authorization": f"Bearer {AUTH_TOKEN}"} + + async with semaphore: + async with httpx.AsyncClient(timeout=300) as client: + payload = { + "model": f"openrag-{partition}", + "messages": [{"role": "user", "content": question}], + "temperature": 0.1, + "max_tokens": 512, + "stream": False, + } + if workspace: + payload["metadata"] = {"workspace": workspace} + + try: + chat_resp = await _http_with_retry( + client, "POST", f"{OPENRAG_BASE_URL}/v1/chat/completions", + json=payload, headers=headers, + ) + body = chat_resp.json() + generated_answer = body["choices"][0]["message"]["content"] + except Exception as e: + logger.debug(f"Chat completions error for question {question!r}: {e}") + return None, "" + + sources_content = "" + try: + extra = body.get("extra", "") + if isinstance(extra, str) and extra: + extra = json.loads(extra) + sources = extra.get("sources", []) if isinstance(extra, dict) else [] + chunk_urls = [s["chunk_url"] for s in sources if "chunk_url" in s] + + if chunk_urls: + chunk_tasks = [_fetch_chunk_content(client, url, headers) for url in chunk_urls] + chunks = await asyncio.gather(*chunk_tasks) + sources_content = "\n\n".join(c for c in chunks if c) + except Exception as e: + logger.debug(f"Failed to extract sources for question {question!r}: {e}") + + return generated_answer, sources_content + + +def _build_gold_name(prefix: str, row_index: int) -> str: + """Build a deterministic name for a per-question gold workspace.""" + return f"{prefix}-q{row_index:04d}" + + +def _get_gold_file_ids(row: dict) -> list[str]: + """Return the normalized gold file ids referenced by one FRAMES row.""" + file_ids = [] + for url in parse_wiki_links(row): + title = extract_title_from_url(url) + if title: + file_ids.append(_safe_filename(title)) + return file_ids + + +async def ensure_workspace( + client: httpx.AsyncClient, + partition: str, + workspace_id: str, +) -> bool: + """Create a workspace if needed.""" + headers = {"Authorization": f"Bearer {AUTH_TOKEN}"} + payload = {"workspace_id": workspace_id, "display_name": workspace_id} + resp = await client.post(f"{OPENRAG_BASE_URL}/partition/{partition}/workspaces", json=payload, headers=headers) + if resp.status_code == 201: + return True + if resp.status_code == 409: + return False + resp.raise_for_status() + return False + + +async def add_files_to_workspace( + client: httpx.AsyncClient, + partition: str, + workspace_id: str, + file_ids: list[str], +) -> None: + """Attach files to a workspace.""" + headers = {"Authorization": f"Bearer {AUTH_TOKEN}"} + resp = await client.post( + f"{OPENRAG_BASE_URL}/partition/{partition}/workspaces/{workspace_id}/files", + json={"file_ids": file_ids}, + headers=headers, + ) + resp.raise_for_status() + + +async def prepare_gold_workspace_for_question( + client: httpx.AsyncClient, + row_index: int, + row: dict, + source_partition: str, + workspace_prefix: str, +) -> tuple[str, list[str]]: + """Create one question workspace and attach its gold files.""" + workspace_id = _build_gold_name(workspace_prefix, row_index) + file_ids = _get_gold_file_ids(row) + await ensure_workspace(client, source_partition, workspace_id) + if file_ids: + await add_files_to_workspace(client, source_partition, workspace_id, file_ids) + return workspace_id, file_ids + + +async def get_available_openrag_partitions(client: httpx.AsyncClient) -> list[str]: + """List partitions exposed as OpenRAG models.""" + headers = {"Authorization": f"Bearer {AUTH_TOKEN}"} + resp = await _http_with_retry(client, "GET", f"{OPENRAG_BASE_URL}/v1/models", headers=headers) + models = resp.json().get("data", []) + return [ + m["id"].removeprefix("openrag-") + for m in models + if m.get("id", "").startswith("openrag-") and m["id"] != "openrag-all" + ] + + +async def get_available_workspaces(client: httpx.AsyncClient, partition: str) -> list[str]: + """List workspaces accessible in a partition.""" + headers = {"Authorization": f"Bearer {AUTH_TOKEN}"} + resp = await _http_with_retry( + client, "GET", f"{OPENRAG_BASE_URL}/partition/{partition}/workspaces", headers=headers + ) + return [ws["workspace_id"] for ws in resp.json().get("workspaces", []) if ws.get("workspace_id")] + + +# ─── Oracle / No-RAG modes ────────────────────────────────────────────────── + +_oracle_system_prompt = """\ +You are a helpful assistant. Answer the user's question based on the provided Wikipedia articles. \ +The articles are provided as cleaned HTML — pay close attention to tables and structured data. \ +Use the information from the articles to reason step by step and provide an accurate, concise answer. + +If the answer cannot be determined from the provided context, reply with exactly: +unanswerable +""" + + +def build_oracle_context( + row, + articles: dict[str, str], + max_chars: int = 0, +) -> str: + """Build the oracle context string from gold Wikipedia articles.""" + parts = [] + for url in parse_wiki_links(row): + title = extract_title_from_url(url) + if title and title in articles: + parts.append((title, articles[title])) + + if not parts: + return "" + + if max_chars > 0: + overhead = sum(len(f"=== {t} ===\n\n\n") for t, _ in parts) + budget = max(max_chars - overhead, len(parts) * 200) + per_article = budget // len(parts) + return "\n\n".join( + f"=== {title} ===\n{content[:per_article]}" for title, content in parts + ) + + return "\n\n".join(f"=== {title} ===\n{content}" for title, content in parts) + + +async def query_oracle_answer( + question: str, + oracle_context: str, + semaphore: asyncio.Semaphore, + llm: ChatOpenAI, +) -> tuple[str | None, str]: + """Oracle mode: send question + gold Wikipedia articles directly to the LLM.""" + user_message = ( + f"Here are the relevant Wikipedia articles:\n\n{oracle_context}\n\n" + f"Question: {question}\n\nAnswer the question based on the articles above. Be concise." + ) + async with semaphore: + try: + response = await llm.ainvoke([ + {"role": "system", "content": _oracle_system_prompt}, + {"role": "user", "content": user_message}, + ]) + return response.content, oracle_context + except Exception as e: + logger.debug(f"Oracle LLM error for question {question!r}: {e}") + return None, oracle_context + + +async def query_llm_answer( + question: str, + semaphore: asyncio.Semaphore, + llm: ChatOpenAI, +) -> tuple[str | None, str]: + """No-RAG mode: send the raw question directly to the LLM.""" + async with semaphore: + try: + response = await llm.ainvoke([{"role": "user", "content": question}]) + return response.content, "" + except Exception as e: + logger.debug(f"No-RAG LLM error for question {question!r}: {e}") + return None, "" + + +# ─── LLM accuracy judge ───────────────────────────────────────────────────── + +_accuracy_judge_system_prompt = """\ +You are an impartial factual evaluator. You will be given a question, the \ +expected correct answer (gold), whether the gold answer is unanswerable, and a generated answer. + +Your task is to determine if the generated answer is **factually correct** \ +with respect to the gold answer. The generated answer does NOT need to be \ +word-for-word identical — it just needs to convey the same factual information. + +Special rule for unanswerable questions: +- If the gold answer is unanswerable, then the generated answer is correct only if it clearly abstains \ + or states that the answer cannot be determined from the provided context. +- If the gold answer is unanswerable and the generated answer provides a specific factual answer, it is incorrect. + +You MUST reply with ONLY a JSON object in this exact format (no other text): +{"correct": true/false, "justification": ""} +""" + + +class _AccuracyJudgeResponse(BaseModel): + correct: bool = Field(..., description="Whether the generated answer is factually correct") + justification: str = Field(..., description="One-sentence justification") + + +def _make_accuracy_judge() -> ChatOpenAI: + return ChatOpenAI( + model=JUDGE_MODEL, + base_url=JUDGE_BASE_URL, + api_key=JUDGE_API_KEY, + temperature=0.0, + max_tokens=256, + ).with_structured_output(_AccuracyJudgeResponse) + + +async def accuracy_judge( + question: str, + gold_answer: str, + generated_answer: str, + is_unanswerable: bool, + semaphore: asyncio.Semaphore, + judge: ChatOpenAI, +) -> tuple[bool, str]: + """Run the accuracy judge, returns (correct, justification).""" + user_message = ( + f"Question:\n{question}\n\n" + f"Gold answer:\n{gold_answer}\n\n" + f"Gold is unanswerable:\n{is_unanswerable}\n\n" + f"Generated answer:\n{generated_answer}" + ) + async with semaphore: + try: + response = await judge.ainvoke([ + {"role": "system", "content": _accuracy_judge_system_prompt}, + {"role": "user", "content": user_message}, + ]) + return response.correct, response.justification + except Exception as e: + logger.debug(f"Accuracy judge error: {e}") + return False, "" + + +async def run_accuracy_judging(results: list[dict], concurrency: int = 10) -> None: + """Run the accuracy judge on all results with a generated answer.""" + judge_semaphore = asyncio.Semaphore(concurrency) + acc_judge = _make_accuracy_judge() + + to_judge = [(i, r) for i, r in enumerate(results) if r["generated_answer"]] + acc_tasks = [ + accuracy_judge( + r["question"], + r["gold_answer"], + r["generated_answer"], + r.get("is_unanswerable", False), + judge_semaphore, + acc_judge, + ) + for _, r in to_judge + ] + acc_scores = await tqdm.gather(*acc_tasks, desc="Accuracy judge") + + for r in results: + r.setdefault("accuracy", False) + r.setdefault("accuracy_justification", "") + for (i, _), (correct, justification) in zip(to_judge, acc_scores): + results[i]["accuracy"] = correct + results[i]["accuracy_justification"] = justification + + +# ─── Result helpers ────────────────────────────────────────────────────────── + + +def _build_result(row, generated_answer, sources_content, mode, **extra) -> dict: + """Build a single result entry.""" + expected_exact_match_answers = row.get("expected_exact_match_answers") + if not expected_exact_match_answers: + expected_exact_match_answers = [row["Answer"]] + entry = { + "question": row["Prompt"], + "gold_answer": row["Answer"], + "expected_exact_match_answers": expected_exact_match_answers, + "generated_answer": generated_answer or "", + "sources_content": sources_content, + "reasoning_types": row.get("reasoning_types", ""), + "dataset_name": row.get("dataset_name", "FRAMES"), + "is_unanswerable": row.get("is_unanswerable", False), + "mode": mode, + } + entry.update(extra) + return entry + + +def _normalize_exact_match_text(text: str) -> str: + text = text.casefold().strip() + text = text.translate(str.maketrans("", "", string.punctuation)) + text = " ".join(text.split()) + return text + + +def _compute_exact_match(generated_answer: str, expected_answers: list[str]) -> bool: + generated_norm = _normalize_exact_match_text(generated_answer) + if not generated_norm: + return False + return any(generated_norm == _normalize_exact_match_text(answer) for answer in expected_answers if answer) + + +def annotate_exact_match(results: list[dict]) -> None: + for result in results: + expected = result.get("expected_exact_match_answers", [result.get("gold_answer", "")]) + result["exact_match"] = _compute_exact_match(result.get("generated_answer", ""), expected) + + +def _print_summary(results: list[dict], total_questions: int) -> None: + """Print accuracy summary and breakdown by reasoning type.""" + valid = [r for r in results if r["generated_answer"]] + mode = results[0].get("mode", "rag") if results else "rag" + mode_display = { + "oracle": "ORACLE", "no_rag": "NO_RAG", + "gold_workspaces": "GOLD_WORKSPACES", "rag": "RAG", + }.get(mode, mode.upper()) + + print(f"\n{'='*60}") + print(f" Mode: {mode_display}") + print(f"{'='*60}") + + acc_results = [r for r in valid if "accuracy" in r] + if acc_results: + correct = sum(1 for r in acc_results if r["accuracy"]) + print(f" Accuracy: {correct / len(acc_results):.1%} ({correct}/{len(acc_results)})") + else: + print(" Accuracy: N/A") + + em_results = [r for r in valid if "exact_match" in r] + if em_results: + exact = sum(1 for r in em_results if r["exact_match"]) + print(f" Exact match: {exact / len(em_results):.1%} ({exact}/{len(em_results)})") + ua = [r for r in em_results if r.get("is_unanswerable")] + if ua: + ua_exact = sum(1 for r in ua if r["exact_match"]) + print(f" EM unanswerable: {ua_exact / len(ua):.1%} ({ua_exact}/{len(ua)})") + else: + print(" Exact match: N/A") + + print(f" Total questions: {total_questions} | Valid answers: {len(valid)}") + print(f"{'='*60}") + + type_groups: dict[str, list[dict]] = defaultdict(list) + for r in valid: + rtype = r.get("reasoning_types", "unknown") or "unknown" + type_groups[rtype].append(r) + + if len(type_groups) > 1 or (len(type_groups) == 1 and "unknown" not in type_groups): + print(f"\n{'TYPE':<35} {'COUNT':>6} {'ACCURACY':>10}") + print(f"{'-'*55}") + for rtype in sorted(type_groups.keys()): + group = type_groups[rtype] + acc_group = [r for r in group if "accuracy" in r] + acc = sum(1 for r in acc_group if r["accuracy"]) / len(acc_group) if acc_group else 0 + print(f"{rtype:<35} {len(group):>6} {acc:>9.1%}") + print(f"{'='*55}") + + +# ─── Main ───────────────────────────────────────────────────────────────────── + + +async def main() -> None: + parser = argparse.ArgumentParser(description="Evaluate OpenRAG on the FRAMES benchmark.") + parser.add_argument("--partition", default="FRAMES") + parser.add_argument("--output", default="results_frames.json") + parser.add_argument("--limit", type=int, default=None) + parser.add_argument("--concurrency", type=int, default=4) + parser.add_argument("--no-rag", action="store_true", default=False, + help="Bypass OpenRAG, send question directly to LLM") + parser.add_argument("--oracle", action="store_true", default=False, + help="Provide gold Wikipedia articles directly to LLM (upper bound)") + parser.add_argument("--gold-workspaces", action="store_true", default=False, + help="Create one workspace per question with only its gold files") + parser.add_argument("--gold-workspace-prefix", default=None, + help="Prefix for per-question gold workspaces (default: -goldws)") + parser.add_argument("--reuse-gold-workspaces", action="store_true", default=False, + help="Reuse existing gold workspaces, skip creation/file attachment") + parser.add_argument("--max-context-chars", type=int, default=20000, + help="Max chars for oracle context (default: 20000). 0 = no limit.") + parser.add_argument("--docs-dir", default="./frames_docs", + help="Directory with .md Wikipedia articles for oracle mode (default: ./frames_docs)") + parser.add_argument("--judge-concurrency", type=int, default=10) + parser.add_argument("--from-results", metavar="FILE", default=None, + help="Load existing results JSON, skip generation, recompute metrics") + parser.add_argument("--dataset-path", default=None, + help="Dataset JSON path relative to automatic-evaluation-pipeline/ or absolute") + args = parser.parse_args() + + # Mutual exclusivity + exclusive_modes = [name for name, flag in [ + ("--no-rag", args.no_rag), ("--oracle", args.oracle), ("--gold-workspaces", args.gold_workspaces), + ] if flag] + if len(exclusive_modes) > 1: + parser.error(f"{' and '.join(exclusive_modes)} are mutually exclusive") + if args.reuse_gold_workspaces and not args.gold_workspaces: + parser.error("--reuse-gold-workspaces requires --gold-workspaces") + + dataset = load_dataset_cached(limit=args.limit, dataset_path=args.dataset_path) + docs_dir = Path(args.docs_dir) + if not docs_dir.is_absolute(): + docs_dir = (Path(__file__).parent / docs_dir).resolve() + + # ── Load existing results or generate ──────────────────────────────────── + if args.from_results: + from_path = Path(args.from_results) if Path(args.from_results).is_absolute() else Path(__file__).parent / args.from_results + print(f"Loading existing results from {from_path} ...") + with open(from_path, encoding="utf-8") as f: + results: list[dict] = json.load(f) + annotate_exact_match(results) + await run_accuracy_judging(results, concurrency=args.judge_concurrency) + + elif args.no_rag: + print(f"Evaluating {len(dataset)} questions [NO_RAG]") + llm = ChatOpenAI(model=MODEL, base_url=BASE_URL, api_key=API_KEY, temperature=0.2, max_tokens=512) + sem = asyncio.Semaphore(args.concurrency) + raw = await tqdm.gather( + *[query_llm_answer(row["Prompt"], sem, llm) for row in dataset], + desc="Querying LLM (no RAG)", + ) + results = [_build_result(row, ans, src, "no_rag") for row, (ans, src) in zip(dataset, raw)] + annotate_exact_match(results) + await run_accuracy_judging(results, concurrency=args.judge_concurrency) + + elif args.oracle: + print(f"Evaluating {len(dataset)} questions [ORACLE]") + if not docs_dir.exists(): + print(f"ERROR: docs dir {docs_dir} not found. Run setup_frames.py first.") + return + oracle_articles = load_articles_from_disk(dataset, docs_dir) + llm = ChatOpenAI(model=MODEL, base_url=BASE_URL, api_key=API_KEY, temperature=0.2, max_tokens=512) + sem = asyncio.Semaphore(args.concurrency) + raw = await tqdm.gather( + *[ + query_oracle_answer( + row["Prompt"], + build_oracle_context( + row, + oracle_articles, + max_chars=args.max_context_chars, + ), + sem, llm, + ) + for row in dataset + ], + desc="Querying LLM (oracle)", + ) + results = [_build_result(row, ans, src, "oracle") for row, (ans, src) in zip(dataset, raw)] + annotate_exact_match(results) + await run_accuracy_judging(results, concurrency=args.judge_concurrency) + + elif args.gold_workspaces: + workspace_prefix = args.gold_workspace_prefix or f"{args.partition}-goldws" + reuse = args.reuse_gold_workspaces + print(f"Evaluating {len(dataset)} questions [GOLD_WORKSPACES{'_REUSE' if reuse else ''}, partition={args.partition}]") + + async with httpx.AsyncClient(timeout=600) as client: + if not await _check_openrag_health(client): + return + available = await get_available_openrag_partitions(client) + if args.partition not in available: + print(f"ERROR: partition '{args.partition}' not available. Available: {', '.join(sorted(available)) or '(none)'}") + return + + if reuse: + prepared = [ + (_build_gold_name(workspace_prefix, i), _get_gold_file_ids(row)) + for i, row in enumerate(dataset) + ] + existing_ws = await get_available_workspaces(client, args.partition) + missing = [ws for ws, _ in prepared if ws not in existing_ws] + if missing: + print(f"ERROR: {len(missing)} gold workspaces missing: {', '.join(missing[:10])}{'...' if len(missing) > 10 else ''}") + return + else: + ws_sem = asyncio.Semaphore(args.concurrency) + + async def _prep(i, row): + async with ws_sem: + return await prepare_gold_workspace_for_question(client, i, row, args.partition, workspace_prefix) + + prepared = await tqdm.gather( + *[_prep(i, row) for i, row in enumerate(dataset)], + desc="Preparing gold workspaces", + ) + + sem = asyncio.Semaphore(args.concurrency) + raw = await tqdm.gather( + *[query_openrag_answer(row["Prompt"], args.partition, sem, workspace=ws) for row, (ws, _) in zip(dataset, prepared)], + desc="Querying OpenRAG (gold workspaces)", + ) + results = [ + _build_result(row, ans, src, "gold_workspaces", workspace_id=ws, question_index=i, gold_file_ids=fids) + for i, (row, (ws, fids), (ans, src)) in enumerate(zip(dataset, prepared, raw)) + ] + annotate_exact_match(results) + await run_accuracy_judging(results, concurrency=args.judge_concurrency) + + else: + print(f"Evaluating {len(dataset)} questions on partition '{args.partition}' at {OPENRAG_BASE_URL}") + sem = asyncio.Semaphore(args.concurrency) + raw = await tqdm.gather( + *[query_openrag_answer(row["Prompt"], args.partition, sem) for row in dataset], + desc="Querying OpenRAG", + ) + results = [_build_result(row, ans, src, "rag") for row, (ans, src) in zip(dataset, raw)] + annotate_exact_match(results) + await run_accuracy_judging(results, concurrency=args.judge_concurrency) + + # ── Save & summary ─────────────────────────────────────────────────────── + output_path = Path(__file__).parent / args.output + with open(output_path, "w", encoding="utf-8") as f: + json.dump(results, f, ensure_ascii=False, indent=2) + print(f"\nResults saved to {output_path}") + _print_summary(results, len(dataset)) + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/automatic-evaluation-pipeline/setup_frames.py b/automatic-evaluation-pipeline/setup_frames.py new file mode 100644 index 00000000..606bbc3b --- /dev/null +++ b/automatic-evaluation-pipeline/setup_frames.py @@ -0,0 +1,413 @@ +""" +Download FRAMES benchmark Wikipedia articles and index them into OpenRAG. + +pip install dataset is needed if you don't want to manually download the dataset JSON from HuggingFace. + +The FRAMES benchmark (google/frames-benchmark) contains 824 multi-hop questions +referencing 2-15 Wikipedia articles each (2474 articles). This script: + 1. Loads the dataset (cached JSON or HuggingFace) + 2. Extracts every unique Wikipedia URL + 3. Fetches each article in the chosen format (md / html / pdf) (45 minutes with default concurrency = 3 API recommend <=5) + 4. Creates the target partition and indexes the files + +Usage: + cd automatic-evaluation-pipeline + python setup_frames.py [--partition FRAMES] [--limit N] # pdf (default) + python setup_frames.py --format md # plain-text markdown + python setup_frames.py --format html # HTML (requires OpenRAG with text/html loader) + python setup_frames.py --index-only # skip download, index existing files + python setup_frames.py --retry-missing # fetch only missing articles + +Environment variables (from .env): + APP_URL, APP_PORT, AUTH_TOKEN +""" + +import argparse +import ast +import asyncio +import json +import os +import re +from pathlib import Path +from urllib.parse import quote, unquote, urlparse + +import httpx +from dotenv import load_dotenv +from loguru import logger +from tqdm.asyncio import tqdm + +load_dotenv() + +# ─── Env / config ──────────────────────────────────────────────────────────── + +APP_URL = os.environ.get("APP_URL", "localhost") +APP_PORT = os.environ.get("APP_PORT", "8080") +AUTH_TOKEN = os.environ.get("AUTH_TOKEN", "sk-1234") + +OPENRAG_BASE_URL = f"http://{APP_URL}:{APP_PORT}" + +WIKIPEDIA_API_URL = "https://en.wikipedia.org/w/api.php" +WIKIPEDIA_REST_HTML = "https://en.wikipedia.org/api/rest_v1/page/html" +WIKIPEDIA_REST_PDF = "https://en.wikipedia.org/api/rest_v1/page/pdf" +WIKIPEDIA_USER_AGENT = "OpenRAG-FRAMES-Benchmark/1.0 (https://github.com/linagora/openrag; eval pipeline)" + +TERMINAL_STATES = {"COMPLETED", "FAILED"} +POLL_INTERVAL = 5 + +DATASET_CACHE = Path(__file__).parent / "frames_dataset.json" + +FORMAT_DEFAULTS: dict[str, tuple[str, str]] = { + "md": ("./frames_docs", "text/markdown"), + "pdf": ("./frames_pdf", "application/pdf"), + "html": ("./frames_html", "text/html"), +} + + +# ─── Dataset ───────────────────────────────────────────────────────────────── + + +def load_dataset_cached(limit: int | None = None) -> list[dict]: + """Load dataset from local cache, or download from HuggingFace on first run.""" + if DATASET_CACHE.exists(): + logger.info(f"Loading dataset from cache ({DATASET_CACHE.name})...") + with open(DATASET_CACHE, encoding="utf-8") as f: + dataset = json.load(f) + else: + logger.info("Downloading FRAMES benchmark from HuggingFace...") + from datasets import load_dataset + hf_dataset = load_dataset("google/frames-benchmark", split="test") + dataset = [dict(row) for row in hf_dataset] + with open(DATASET_CACHE, "w", encoding="utf-8") as f: + json.dump(dataset, f, ensure_ascii=False, indent=2) + logger.info(f"Cached {len(dataset)} questions to {DATASET_CACHE.name}") + + if limit: + dataset = dataset[:limit] + logger.info(f"Loaded {len(dataset)} questions.") + return dataset + + +# ─── Wikipedia helpers ─────────────────────────────────────────────────────── + + +def extract_title_from_url(url: str) -> str | None: + parsed = urlparse(url) + if "wikipedia.org" not in parsed.netloc: + return None + match = re.match(r"/wiki/(.+)", parsed.path) + if match: + return unquote(match.group(1)).replace("_", " ") + return None + + +def extract_all_titles(dataset: list[dict]) -> list[str]: + """Extract unique Wikipedia article titles referenced by the dataset.""" + titles: set[str] = set() + for row in dataset: + wiki_links = row.get("wiki_links") + if not wiki_links: + continue + if isinstance(wiki_links, str): + try: + links = ast.literal_eval(wiki_links) + except (ValueError, SyntaxError): + try: + links = json.loads(wiki_links) + except json.JSONDecodeError: + links = [l.strip() for l in wiki_links.split(",") if l.strip()] + else: + links = wiki_links + for link in links: + if isinstance(link, str) and "wikipedia.org" in link: + title = extract_title_from_url(link.strip()) + if title: + titles.add(title) + return sorted(titles) + + +def safe_filename(title: str) -> str: + safe = re.sub(r'[^\w\s\-()]', '', title).strip() + safe = re.sub(r'\s+', '_', safe) + if not safe: + safe = f"article_{hash(title) % 10**8}" + return safe + + +def title_to_wiki_slug(title: str) -> str: + decoded = unquote(title) + return quote(decoded.replace(" ", "_"), safe="/:@!$&'()*+,;=-._~") + + +# ─── Fetch ─────────────────────────────────────────────────────────────────── + + +async def fetch_markdown( + client: httpx.AsyncClient, title: str, max_retries: int = 5, +) -> tuple[str, bytes | None]: + """Fetch an article as plain-text markdown via the `prop=extracts` API.""" + params = { + "action": "query", "titles": title, "prop": "extracts", + "explaintext": "true", "format": "json", + } + for attempt in range(max_retries): + try: + resp = await client.get(WIKIPEDIA_API_URL, params=params) + if resp.status_code == 429 or resp.status_code >= 500: + wait = float(resp.headers.get("retry-after") or min(2 ** attempt + 1, 60)) + if attempt < max_retries - 1: + await asyncio.sleep(wait) + continue + return title, None + resp.raise_for_status() + pages = resp.json().get("query", {}).get("pages", {}) + for page_id, page in pages.items(): + if page_id == "-1": + logger.warning(f"Wikipedia article not found: {title}") + return title, None + extract = page.get("extract", "") + if extract: + md = f"# {title}\n\n{extract}" + return title, md.encode("utf-8") + return title, None + except (httpx.TimeoutException, httpx.ConnectError): + if attempt < max_retries - 1: + await asyncio.sleep(min(2 ** attempt + 1, 60)) + continue + return title, None + except Exception as e: + logger.debug(f"Failed to fetch '{title}': {e}") + return title, None + return title, None + + +async def fetch_rest( + client: httpx.AsyncClient, title: str, fmt: str, max_retries: int = 5, +) -> tuple[str, bytes | None]: + """Fetch an article as HTML or PDF via the Wikipedia REST API.""" + base = WIKIPEDIA_REST_HTML if fmt == "html" else WIKIPEDIA_REST_PDF + url = f"{base}/{title_to_wiki_slug(title)}" + for attempt in range(max_retries): + try: + resp = await client.get(url) + if resp.status_code == 429 or resp.status_code >= 500: + wait = float(resp.headers.get("retry-after") or min(2 ** attempt + 1, 60)) + if attempt < max_retries - 1: + await asyncio.sleep(wait) + continue + return title, None + if resp.status_code == 404: + logger.warning(f"Wikipedia article not found: {title}") + return title, None + resp.raise_for_status() + return title, resp.content + except (httpx.TimeoutException, httpx.ConnectError): + if attempt < max_retries - 1: + await asyncio.sleep(min(2 ** attempt + 1, 60)) + continue + return title, None + except Exception as e: + logger.debug(f"Failed to fetch '{title}': {e}") + return title, None + return title, None + + +async def download_articles( + titles: list[str], output_dir: Path, fmt: str, concurrency: int, +) -> list[tuple[str, Path]]: + """Download articles for `titles` into `output_dir`, return (file_id, path) list.""" + output_dir.mkdir(parents=True, exist_ok=True) + existing_stems = {p.stem for p in output_dir.glob(f"*.{fmt}")} + + ready: list[tuple[str, Path]] = [] + to_fetch: list[str] = [] + for title in titles: + stem = safe_filename(title) + if stem in existing_stems: + ready.append((stem, output_dir / f"{stem}.{fmt}")) + else: + to_fetch.append(title) + + print(f"Articles ({fmt}): {len(ready)} on disk, {len(to_fetch)} to fetch.") + if not to_fetch: + return ready + + sem = asyncio.Semaphore(concurrency) + + async def _fetch_with_sem(title: str, client: httpx.AsyncClient): + async with sem: + if fmt == "md": + return await fetch_markdown(client, title) + return await fetch_rest(client, title, fmt) + + async with httpx.AsyncClient( + timeout=60, follow_redirects=True, + headers={"User-Agent": WIKIPEDIA_USER_AGENT}, + ) as client: + tasks = [_fetch_with_sem(t, client) for t in to_fetch] + results = await tqdm.gather(*tasks, desc=f"Fetching Wikipedia {fmt.upper()}") + + fetched = 0 + for title, content in results: + if content is None: + continue + stem = safe_filename(title) + path = output_dir / f"{stem}.{fmt}" + path.write_bytes(content) + ready.append((stem, path)) + fetched += 1 + print(f"Fetched {fetched} new articles ({len(to_fetch) - fetched} failed).") + return ready + + +# ─── OpenRAG upload ────────────────────────────────────────────────────────── + + +async def check_health(client: httpx.AsyncClient) -> bool: + try: + resp = await client.get(f"{OPENRAG_BASE_URL}/health_check") + resp.raise_for_status() + print("OpenRAG API is up.") + return True + except Exception as e: + print(f"ERROR: Cannot reach OpenRAG at {OPENRAG_BASE_URL}: {e}") + return False + + +async def create_partition(client: httpx.AsyncClient, partition: str) -> None: + headers = {"Authorization": f"Bearer {AUTH_TOKEN}"} + resp = await client.post(f"{OPENRAG_BASE_URL}/partition/{partition}", headers=headers) + if resp.status_code == 201: + print(f"Partition '{partition}' created.") + elif resp.status_code == 409: + print(f"Partition '{partition}' already exists.") + else: + resp.raise_for_status() + + +async def upload_and_track( + client: httpx.AsyncClient, + partition: str, + file_id: str, + file_path: Path, + mime: str, + sem: asyncio.Semaphore, +) -> dict: + headers = {"Authorization": f"Bearer {AUTH_TOKEN}"} + url = f"{OPENRAG_BASE_URL}/indexer/partition/{partition}/file/{file_id}" + async with sem: + try: + with open(file_path, "rb") as f: + files = {"file": (file_path.name, f, mime), "metadata": (None, "")} + resp = await client.post(url, files=files, headers=headers) + + if resp.status_code == 409: + return {"file_id": file_id, "status": "skipped"} + if resp.status_code != 201: + logger.error(f"Upload failed for '{file_id}': {resp.status_code} - {resp.text}") + resp.raise_for_status() + + task_url = resp.json().get("task_status_url") + if task_url.startswith("/"): + task_url = f"{OPENRAG_BASE_URL}{task_url}" + + while True: + poll = await client.get(task_url, headers=headers) + if poll.status_code == 200: + state = poll.json().get("task_state", "UNKNOWN") + if state in TERMINAL_STATES: + return {"file_id": file_id, "status": state} + else: + logger.warning(f"Poll failed for '{file_id}': {poll.status_code}") + await asyncio.sleep(POLL_INTERVAL) + except Exception as e: + logger.error(f"Error processing '{file_id}': {e}") + return {"file_id": file_id, "status": "ERROR", "error": str(e)} + + +async def index_files( + doc_files: list[tuple[str, Path]], + partition: str, + mime: str, + concurrency: int, +) -> None: + async with httpx.AsyncClient(timeout=600) as client: + if not await check_health(client): + return + await create_partition(client, partition) + + print(f"Uploading {len(doc_files)} files to '{partition}' (concurrency={concurrency})...") + sem = asyncio.Semaphore(concurrency) + tasks = [upload_and_track(client, partition, fid, path, mime, sem) for fid, path in doc_files] + results = await tqdm.gather(*tasks, desc="Uploading & indexing") + + completed = sum(1 for r in results if r["status"] == "COMPLETED") + failed = sum(1 for r in results if r["status"] == "FAILED") + skipped = sum(1 for r in results if r["status"] == "skipped") + errors = sum(1 for r in results if r["status"] == "ERROR") + + print(f"\n{'='*60}") + print(f"Indexing complete for partition '{partition}'") + print(f" COMPLETED: {completed}") + print(f" FAILED: {failed}") + print(f" SKIPPED: {skipped} (already existed)") + print(f" ERRORS: {errors}") + print(f"{'='*60}") + + if failed or errors: + print("\nFailed/errored files:") + for r in results: + if r["status"] in ("FAILED", "ERROR"): + print(f" - {r['file_id']}: {r['status']} {r.get('error', '')}") + + +# ─── Main ──────────────────────────────────────────────────────────────────── + + +async def main() -> None: + parser = argparse.ArgumentParser(description="Download + index FRAMES Wikipedia articles into OpenRAG.") + parser.add_argument("--partition", default="FRAMES") + parser.add_argument("--format", choices=["md", "pdf", "html"], default="pdf") + parser.add_argument("--output-dir", default=None, + help="Override article directory (defaults per format).") + parser.add_argument("--limit", type=int, default=None) + parser.add_argument("--concurrency", type=int, default=4, + help="Max concurrent uploads (default: 4)") + parser.add_argument("--wiki-concurrency", type=int, default=3, + help="Max concurrent Wikipedia fetches (default: 3)") + parser.add_argument("--index-only", action="store_true", + help="Skip download, index existing files from output-dir") + parser.add_argument("--retry-missing", action="store_true", + help="Only fetch articles missing from output-dir, upload only the new ones") + args = parser.parse_args() + + default_dir, mime = FORMAT_DEFAULTS[args.format] + output_dir = Path(args.output_dir).resolve() if args.output_dir else ( + Path(__file__).parent / default_dir + ).resolve() + + if args.index_only: + if not output_dir.exists(): + print(f"ERROR: Output directory {output_dir} does not exist.") + return + doc_files = [(p.stem, p) for p in sorted(output_dir.glob(f"*.{args.format}"))] + print(f"Index-only: {len(doc_files)} .{args.format} files in {output_dir}") + else: + dataset = load_dataset_cached(limit=args.limit) + titles = extract_all_titles(dataset) + print(f"Found {len(titles)} unique Wikipedia articles referenced.") + + before = {p.stem for p in output_dir.glob(f"*.{args.format}")} if output_dir.exists() else set() + doc_files = await download_articles(titles, output_dir, args.format, args.wiki_concurrency) + if args.retry_missing: + doc_files = [(fid, path) for fid, path in doc_files if fid not in before] + print(f"Will only upload {len(doc_files)} newly-fetched files.") + + if not doc_files: + print("Nothing to upload.") + return + + await index_files(doc_files, args.partition, mime, args.concurrency) + + +if __name__ == "__main__": + asyncio.run(main())