diff --git a/app/src/api/cache_mgmt.ts b/app/src/api/cache_mgmt.ts new file mode 100644 index 000000000..31481e218 --- /dev/null +++ b/app/src/api/cache_mgmt.ts @@ -0,0 +1,31 @@ +import { api } from './client'; + +export type CacheStats = { + total_keys: number; + hits: number; + misses: number; + hit_rate: number; +}; + +export type CacheInvalidateResponse = { + deleted: number; +}; + +export type CacheKeysResponse = { + keys: string[]; +}; + +export async function getCacheStats(): Promise { + return api('/cache/stats'); +} + +export async function invalidateCache(keys?: string[]): Promise { + return api('/cache/invalidate', { + method: 'POST', + body: keys && keys.length > 0 ? { keys } : {}, + }); +} + +export async function listCacheKeys(): Promise { + return api('/cache/keys'); +} diff --git a/packages/backend/app/routes/__init__.py b/packages/backend/app/routes/__init__.py index f13b0f897..bc34d094a 100644 --- a/packages/backend/app/routes/__init__.py +++ b/packages/backend/app/routes/__init__.py @@ -7,6 +7,7 @@ from .categories import bp as categories_bp from .docs import bp as docs_bp from .dashboard import bp as dashboard_bp +from .cache_mgmt import bp as cache_mgmt_bp def register_routes(app: Flask): @@ -18,3 +19,4 @@ def register_routes(app: Flask): app.register_blueprint(categories_bp, url_prefix="/categories") app.register_blueprint(docs_bp, url_prefix="/docs") app.register_blueprint(dashboard_bp, url_prefix="/dashboard") + app.register_blueprint(cache_mgmt_bp, url_prefix="/cache") diff --git a/packages/backend/app/routes/cache_mgmt.py b/packages/backend/app/routes/cache_mgmt.py new file mode 100644 index 000000000..2bf333307 --- /dev/null +++ b/packages/backend/app/routes/cache_mgmt.py @@ -0,0 +1,93 @@ +from flask import Blueprint, jsonify, request +from flask_jwt_extended import jwt_required, get_jwt_identity +from ..extensions import redis_client +import logging + +bp = Blueprint("cache_mgmt", __name__) +logger = logging.getLogger("finmind.cache_mgmt") + +_USER_CACHE_PREFIX = "user_cache" + + +def _user_prefix(uid: int) -> str: + return f"{_USER_CACHE_PREFIX}:{uid}:" + + +@bp.get("/stats") +@jwt_required() +def cache_stats(): + """Return cache hit/miss statistics for the current user.""" + uid = int(get_jwt_identity()) + prefix = _user_prefix(uid) + + try: + keys = redis_client.keys(f"{prefix}*") + except Exception: + keys = [] + + total_keys = len(keys) + # Track hits/misses via a counter key + hit_key = f"{prefix}__hits" + miss_key = f"{prefix}__misses" + + try: + hits = int(redis_client.get(hit_key) or 0) + misses = int(redis_client.get(miss_key) or 0) + except Exception: + hits = 0 + misses = 0 + + total_lookups = hits + misses + hit_rate = round((hits / total_lookups) * 100, 2) if total_lookups > 0 else 0.0 + + logger.info("Cache stats user=%s keys=%s hits=%s misses=%s", uid, total_keys, hits, misses) + return jsonify( + total_keys=total_keys, + hits=hits, + misses=misses, + hit_rate=hit_rate, + ) + + +@bp.post("/invalidate") +@jwt_required() +def invalidate_cache(): + """Clear cache for specific keys or all keys for the current user.""" + uid = int(get_jwt_identity()) + prefix = _user_prefix(uid) + data = request.get_json(silent=True) or {} + specific_keys = data.get("keys") + + deleted = 0 + try: + if specific_keys and isinstance(specific_keys, list): + for k in specific_keys: + full_key = f"{prefix}{k}" + deleted += redis_client.delete(full_key) + else: + # Clear all user cache keys + all_keys = redis_client.keys(f"{prefix}*") + if all_keys: + deleted = redis_client.delete(*all_keys) + except Exception: + pass + + logger.info("Cache invalidated user=%s deleted=%s", uid, deleted) + return jsonify(deleted=deleted) + + +@bp.get("/keys") +@jwt_required() +def list_keys(): + """List cached keys for the current user.""" + uid = int(get_jwt_identity()) + prefix = _user_prefix(uid) + + try: + raw_keys = redis_client.keys(f"{prefix}*") + keys = [k.replace(prefix, "") if isinstance(k, str) else k.decode().replace(prefix, "") for k in raw_keys] + except Exception: + keys = [] + + logger.info("Cache keys listed user=%s count=%s", uid, len(keys)) + return jsonify(keys=keys) diff --git a/packages/backend/tests/test_cache_mgmt.py b/packages/backend/tests/test_cache_mgmt.py new file mode 100644 index 000000000..9adcb6bc8 --- /dev/null +++ b/packages/backend/tests/test_cache_mgmt.py @@ -0,0 +1,49 @@ +def test_cache_stats_requires_auth(client): + r = client.get("/cache/stats") + assert r.status_code == 401 + + +def test_cache_stats(client, auth_header): + r = client.get("/cache/stats", headers=auth_header) + assert r.status_code == 200 + data = r.get_json() + assert "total_keys" in data + assert "hits" in data + assert "misses" in data + assert "hit_rate" in data + + +def test_cache_invalidate_requires_auth(client): + r = client.post("/cache/invalidate") + assert r.status_code == 401 + + +def test_cache_invalidate_all(client, auth_header): + r = client.post("/cache/invalidate", headers=auth_header) + assert r.status_code == 200 + data = r.get_json() + assert "deleted" in data + + +def test_cache_invalidate_specific_keys(client, auth_header): + r = client.post( + "/cache/invalidate", + json={"keys": ["dashboard_summary", "insights"]}, + headers=auth_header, + ) + assert r.status_code == 200 + data = r.get_json() + assert "deleted" in data + + +def test_cache_keys_requires_auth(client): + r = client.get("/cache/keys") + assert r.status_code == 401 + + +def test_cache_keys_list(client, auth_header): + r = client.get("/cache/keys", headers=auth_header) + assert r.status_code == 200 + data = r.get_json() + assert "keys" in data + assert isinstance(data["keys"], list)