diff --git a/openrag/api.py b/openrag/api.py index 922a33ea..95e60bd7 100644 --- a/openrag/api.py +++ b/openrag/api.py @@ -1,9 +1,12 @@ +import asyncio import os +import time import warnings from enum import Enum from importlib.metadata import version as get_package_version from pathlib import Path +import httpx import ray import uvicorn from config import load_config @@ -196,10 +199,85 @@ async def unhandled_exception_handler(request: Request, exc: Exception): app.mount("/static", StaticFiles(directory=DATA_DIR.resolve(), check_dir=True), name="static") +async def check_service_health(base_url: str, service_name: str) -> dict: + """ + Probe a service health endpoint with timeout. + + Args: + base_url: Base URL of the service (e.g., "http://localhost:8000") + service_name: Human-readable name for logging + + Returns: + dict with status (healthy/unhealthy/timeout/unreachable/error), + response_time_ms, and error message if applicable + """ + try: + async with httpx.AsyncClient(timeout=httpx.Timeout(3.0)) as client: + response = await client.get(f"{base_url}/health") + elapsed_ms = response.elapsed.total_seconds() * 1000 + + if response.status_code == 200: + return {"status": "healthy", "response_time_ms": round(elapsed_ms, 2)} + else: + return { + "status": "unhealthy", + "error": f"HTTP {response.status_code}", + "response_time_ms": round(elapsed_ms, 2), + } + except httpx.TimeoutException: + return {"status": "timeout", "error": "Service did not respond within 3s"} + except httpx.ConnectError: + return {"status": "unreachable", "error": "Connection refused"} + except Exception as e: + return {"status": "error", "error": str(e)} + + @app.get("/health_check", summary="Health check endpoint for API", dependencies=[]) async def health_check(request: Request): - # TODO : Error reporting about llm and vlm - return "RAG API is up." + """ + Health check endpoint with LLM and VLM service probes. + + Returns HTTP 200 for healthy/degraded, HTTP 503 for unhealthy. + LLM is critical, VLM is non-critical (used only for image captioning). + """ + config = request.app.state.app_state.config + + # Probe LLM and VLM services concurrently + # Strip API path (e.g. /v1/) to get the service root for health probes + llm_base_url = config.llm.get("base_url", "").split("/v1")[0] + vlm_base_url = config.vlm.get("base_url", "").split("/v1")[0] + + results = await asyncio.gather( + check_service_health(llm_base_url, "llm"), check_service_health(vlm_base_url, "vlm"), return_exceptions=True + ) + + # Handle gather results (defensive: check if any result is an Exception) + llm_result = results[0] if not isinstance(results[0], Exception) else {"status": "error", "error": str(results[0])} + vlm_result = results[1] if not isinstance(results[1], Exception) else {"status": "error", "error": str(results[1])} + + # Determine overall status + llm_healthy = llm_result.get("status") == "healthy" + vlm_healthy = vlm_result.get("status") == "healthy" + + if llm_healthy and vlm_healthy: + overall_status = "healthy" + status_code = 200 + elif llm_healthy and not vlm_healthy: + # VLM is non-critical (only used for image captioning) + overall_status = "degraded" + status_code = 200 + else: + # LLM is critical - any LLM failure is unhealthy + overall_status = "unhealthy" + status_code = 503 + + response_data = { + "status": overall_status, + "checks": {"api": {"status": "healthy"}, "llm": llm_result, "vlm": vlm_result}, + "timestamp": time.time(), + } + + return JSONResponse(status_code=status_code, content=response_data) @app.get("/version", summary="Get openRAG version", dependencies=[]) diff --git a/openrag/scripts/restore.py b/openrag/scripts/restore.py index 62a97774..26282599 100644 --- a/openrag/scripts/restore.py +++ b/openrag/scripts/restore.py @@ -19,6 +19,7 @@ def read_rdb_section( added_documents: dict[str, set[str]], existing_partitions: dict[str, Any], logger: Any, + restore_state: dict[str, Any], user_id: int, verbose: bool = False, dry_run: bool = False, @@ -69,19 +70,44 @@ def read_rdb_section( try: res = pfm.add_file_to_partition(doc["file_id"], part["name"], doc, user_id) except Exception as e: - logger.exception( - f"{type(e)} in add_file_to_partition({doc['file_id']}, {part['name']}, ...)\n" + str(e) - ) - raise + # Non-critical failure: log and continue instead of raising + logger.bind( + file_id=doc["file_id"], + partition=part["name"], + error_type=type(e).__name__, + ).error(f"Failed to add file to partition: {str(e)}") + restore_state["files_failed"] += 1 + if len(restore_state["errors"]) < 100: + restore_state["errors"].append( + { + "file_id": doc["file_id"], + "partition": part["name"], + "error": str(e), + } + ) + res = False else: res = True if res: if part["name"] not in added_documents: added_documents[part["name"]] = set() + # Track partition creation (first file added successfully) + restore_state["partitions_created"].append(part["name"]) added_documents[part["name"]].add(doc["file_id"]) + restore_state["files_added"] += 1 else: - logger.error(f"Can't add file {doc['file_id']} to partition {part['name']}") + if not dry_run: + logger.error(f"Can't add file {doc['file_id']} to partition {part['name']}") + + # Log progress every 100 files + total_files = restore_state["files_added"] + restore_state["files_failed"] + if total_files > 0 and total_files % 100 == 0: + logger.bind( + files_added=restore_state["files_added"], + files_failed=restore_state["files_failed"], + total_processed=total_files, + ).info("Restore progress") def insert_into_vdb( @@ -126,6 +152,7 @@ def read_vdb_section( client: MilvusClient, batch_size: int, logger: Any, + restore_state: dict[str, Any], verbose: bool = False, dry_run: bool = False, ) -> None: @@ -155,6 +182,7 @@ def read_vdb_section( if len(batch) >= batch_size: insert_into_vdb(client, collection_name, batch, logger, verbose, dry_run) + restore_state["chunks_inserted"] += len(batch) batch = [] chunk = json.loads(line) @@ -165,6 +193,7 @@ def read_vdb_section( if len(batch) > 0: insert_into_vdb(client, collection_name, batch, logger, verbose, dry_run) + restore_state["chunks_inserted"] += len(batch) def open_backup_file(file_name: str, logger: Any) -> IO[str]: @@ -254,15 +283,25 @@ def load_openrag_config(logger: Any) -> tuple[dict[str, Any], dict[str, Any]]: logger = get_logger() + restore_state = { + "partitions_created": [], # List of partition names created in RDB + "files_added": 0, # Count of files successfully added + "files_failed": 0, # Count of files that failed + "chunks_inserted": 0, # Count of VDB chunks inserted + "errors": [], # List of error dicts: {"file_id", "partition", "error"} + } + try: # It will create a the Milvus collection if it doesn't exist vdb_tmp = MilvusDB.options(name="Vectordb", namespace="openrag", lifetime="detached").remote() - await vdb_tmp.__ray_ready__.remote() # ensure the actor is fully initialized and ready: collection and all created if nont existing + await ( + vdb_tmp.__ray_ready__.remote() + ) # ensure the actor is fully initialized and ready: collection and all created if nont existing print("VectorDB (Milvus) actor fully initialized") except Exception as e: logger.exception(f"Failed while trying to create Milvus collection: {e}") - # TODO: stop execution here + return 1 rdb, vdb = load_openrag_config(logger) @@ -306,6 +345,7 @@ def load_openrag_config(logger: Any) -> tuple[dict[str, Any], dict[str, Any]]: added_documents, existing_partitions, logger, + restore_state, args.user_id, args.verbose, args.dry_run, @@ -319,11 +359,55 @@ def load_openrag_config(logger: Any) -> tuple[dict[str, Any], dict[str, Any]]: client, args.batch_size, logger, + restore_state, args.verbose, args.dry_run, ) + + # Log final summary + logger.bind( + partitions_restored=len(restore_state["partitions_created"]), + files_added=restore_state["files_added"], + files_failed=restore_state["files_failed"], + chunks_inserted=restore_state["chunks_inserted"], + ).info("Restore completed") + + if restore_state["errors"]: + logger.bind( + total_errors=len(restore_state["errors"]), + first_10=restore_state["errors"][:10], + ).warning("Restore completed with file-level errors") except Exception as e: - logger.error("Error: " + str(e)) + logger.bind( + error=str(e), + partitions_created=restore_state["partitions_created"], + files_added=restore_state["files_added"], + files_failed=restore_state["files_failed"], + ).error("Critical restore failure - initiating rollback") + + # Rollback in reverse order: VDB first, then RDB + for partition_name in reversed(restore_state["partitions_created"]): + # 1. Delete from VDB first (no FK constraints, orphaned vectors are worse) + try: + client.delete( + collection_name=vdb["collection_name"], + filter=f'partition == "{partition_name}"', + ) + logger.info(f"VDB rollback succeeded for partition: {partition_name}") + except Exception: + logger.exception(f"VDB rollback failed for partition {partition_name}") + + # 2. Delete from RDB (cascades to files via FK) + try: + pfm.delete_partition(partition_name) + logger.info(f"RDB rollback succeeded for partition: {partition_name}") + except Exception: + logger.exception(f"RDB rollback failed for partition {partition_name}") + + logger.bind( + partitions_rolled_back=len(restore_state["partitions_created"]), + ).error("Rollback complete") + raise finally: client.close() diff --git a/tests/api_tests/test_health.py b/tests/api_tests/test_health.py index 0c4a58ee..cc1959f8 100644 --- a/tests/api_tests/test_health.py +++ b/tests/api_tests/test_health.py @@ -5,7 +5,8 @@ def test_health_check(api_client): """Test health check endpoint returns OK.""" response = api_client.get("/health_check") assert response.status_code == 200 - assert "RAG API is up" in response.text + data = response.json() + assert data["status"] in ("healthy", "degraded") def test_openapi_docs_accessible(api_client):