diff --git a/docs/config.rst b/docs/config.rst index f613bb942..c74c50c79 100644 --- a/docs/config.rst +++ b/docs/config.rst @@ -540,3 +540,14 @@ Default: None Sets the URI to which an OAuth 2.0 server redirects the user after successful authentication and authorization. `oauth2_redirect_uri` option should be used with :ref:`auth`, :ref:`auth_provider`, :ref:`oauth2_key` and :ref:`oauth2_secret` options. + +.. _queue_cache_ttl: + +queue_cache_ttl +~~~~~~~~~~~~~~~ + +Default: 5.0 + +TTL in seconds for caching broker queue stats. Set to 0 to disable caching. +When many queues are configured (e.g. 10,000+), caching avoids re-fetching +queue lengths from the broker on every page load or API call. diff --git a/flower/api/tasks.py b/flower/api/tasks.py index 730c290e4..bbde7b109 100644 --- a/flower/api/tasks.py +++ b/flower/api/tasks.py @@ -391,17 +391,44 @@ async def get(self): :statuscode 503: result backend is not configured """ app = self.application + limit = self.get_argument('limit', default=None, type=int) + offset = self.get_argument('offset', default=0, type=int) + offset = max(offset, 0) http_api = None if app.transport == 'amqp' and app.options.broker_api: http_api = app.options.broker_api - broker = Broker(app.capp.connection().as_uri(include_password=True), - http_api=http_api, broker_options=self.capp.conf.broker_transport_options, - broker_use_ssl=self.capp.conf.broker_use_ssl) + queue_names = self.get_active_queue_names() + names_key = frozenset(queue_names) - queues = await broker.queues(self.get_active_queue_names()) - self.write({'active_queues': queues}) + # Check cache first + queues = app.get_cached_queue_stats(names_key) + if queues is None: + with app.capp.connection() as conn: + broker_uri = conn.as_uri(include_password=True) + broker = Broker(broker_uri, + http_api=http_api, broker_options=self.capp.conf.broker_transport_options, + broker_use_ssl=self.capp.conf.broker_use_ssl) + + try: + queues = await broker.queues(queue_names) + app.set_queue_cache(names_key, queues) + finally: + if hasattr(broker, 'close'): + broker.close() + + total = len(queues) + + # Apply pagination + if offset: + queues = queues[offset:] + if limit is not None: + if limit < 0: + raise HTTPError(400, "Query argument 'limit' must be a non-negative integer") + queues = queues[:limit] + + self.write({'active_queues': queues, 'total': total}) class ListTasks(BaseTaskHandler): diff --git a/flower/app.py b/flower/app.py index 3427e098a..a28ce11f8 100644 --- a/flower/app.py +++ b/flower/app.py @@ -1,5 +1,6 @@ import sys import logging +import time from concurrent.futures import ThreadPoolExecutor @@ -8,6 +9,7 @@ from tornado import ioloop from tornado.httpserver import HTTPServer +from tornado.ioloop import PeriodicCallback from tornado.web import url from .urls import handlers as default_handlers @@ -64,6 +66,10 @@ def __init__(self, options=None, capp=None, events=None, max_workers_in_memory=self.options.max_workers, max_tasks_in_memory=self.options.max_tasks) self.started = False + self._transport = None + self._purge_timer = None + self._queue_cache = None # (timestamp, frozenset(names), result) + self._queue_cache_ttl = getattr(self.options, 'queue_cache_ttl', 5.0) def start(self): self.events.start() @@ -80,11 +86,26 @@ def start(self): self.started = True self.update_workers() + + if self.options.purge_offline_workers is not None: + interval_ms = max(self.options.purge_offline_workers * 1000, 10000) + self._purge_timer = PeriodicCallback(self._purge_offline_workers, + interval_ms) + self._purge_timer.start() + self.io_loop.start() def stop(self): if self.started: - self.events.stop() + try: + self.events.stop() + except Exception: + logger.debug("Error stopping events", exc_info=True) + if self._purge_timer: + try: + self._purge_timer.stop() + except Exception: + logger.debug("Error stopping purge timer", exc_info=True) logging.debug("Stopping executors...") self.executor.shutdown(wait=False) logging.debug("Stopping event loop...") @@ -93,7 +114,10 @@ def stop(self): @property def transport(self): - return getattr(self.capp.connection().transport, 'driver_type', None) + if self._transport is None: + with self.capp.connection() as conn: + self._transport = getattr(conn.transport, 'driver_type', None) + return self._transport @property def workers(self): @@ -101,3 +125,62 @@ def workers(self): def update_workers(self, workername=None): return self.inspector.inspect(workername) + + def get_cached_queue_stats(self, names_key): + """Return cached queue stats if still valid, else None. + + Returns a shallow copy to prevent callers from mutating the cache.""" + if self._queue_cache_ttl <= 0 or self._queue_cache is None: + return None + ts, cached_key, result = self._queue_cache + if cached_key == names_key and (time.time() - ts) < self._queue_cache_ttl: + return list(result) + return None + + def set_queue_cache(self, names_key, result): + """Store queue stats in the cache.""" + if self._queue_cache_ttl > 0: + self._queue_cache = (time.time(), names_key, result) + + def _purge_offline_workers(self): + """Purge workers that have been offline beyond the threshold. + + Handles two cases: + - Workers present in state.workers: check alive status + heartbeat age + - Orphaned entries (in counter/inspector but not state.workers): always purge + """ + threshold = self.options.purge_offline_workers + if threshold is None: + return + + now = time.time() + state = self.events.state + + # Collect all known worker names from state.counter and inspector.workers + all_worker_names = set(state.counter.keys()) | set(self.inspector.workers.keys()) + + for worker_name in all_worker_names: + worker = state.workers.get(worker_name) + if worker is not None: + # Skip workers that are still alive + if worker.alive: + continue + + # Check if the worker has been offline beyond the threshold + heartbeats = getattr(worker, 'heartbeats', []) + if heartbeats: + last_heartbeat = max(heartbeats) + if now - last_heartbeat <= threshold: + continue + # else: worker not in state.workers — orphaned entry, always purge + + # Purge from state.counter + state.counter.pop(worker_name, None) + + # Purge Prometheus metrics for this worker + state.metrics.remove_worker_metrics(worker_name) + + # Purge from inspector + self.inspector.purge_worker(worker_name) + + logger.debug("Purged offline worker: %s", worker_name) diff --git a/flower/command.py b/flower/command.py index 94ed6c7b6..31a1631e7 100644 --- a/flower/command.py +++ b/flower/command.py @@ -173,7 +173,8 @@ def print_banner(app, ssl): else: logger.info("Visit me via unix socket file: %s", options.unix_socket) - logger.info('Broker: %s', app.connection().as_uri()) + with app.connection() as conn: + logger.info('Broker: %s', conn.as_uri()) logger.info( 'Registered tasks: \n%s', pformat(sorted(app.tasks.keys())) diff --git a/flower/events.py b/flower/events.py index cd15d7a2e..4bff6fa75 100644 --- a/flower/events.py +++ b/flower/events.py @@ -1,5 +1,6 @@ import collections import logging +import queue import shelve import threading import time @@ -17,6 +18,8 @@ PROMETHEUS_METRICS = None +MAX_RETRY_INTERVAL = 60 + def get_prometheus_metrics(): global PROMETHEUS_METRICS # pylint: disable=global-statement @@ -53,6 +56,30 @@ def __init__(self): ['worker'] ) + def remove_worker_metrics(self, worker_name): + """Remove all Prometheus metric label series for a given worker.""" + metrics = [ + self.events, self.runtime, self.prefetch_time, + self.number_of_prefetched_tasks, self.worker_online, + self.worker_number_of_currently_executing_tasks, + ] + for metric in metrics: + # _metrics is the internal dict of label-value tuples -> child metrics. + # Guard access since it's a private attr that may vary across versions. + storage = getattr(metric, '_metrics', None) + if storage is None: + continue + try: + keys_to_remove = [ + key for key in storage + if key and key[0] == worker_name + ] + for key in keys_to_remove: + metric.remove(*key) + except Exception: + logger.debug("Failed to remove metrics for worker %s from %s", + worker_name, metric, exc_info=True) + class EventsState(State): # EventsState object is created and accessed only from ioloop thread @@ -112,6 +139,9 @@ def event(self, event): class Events(threading.Thread): events_enable_interval = 5000 + _BACKPRESSURE_MAXSIZE = 10000 + _DRAIN_INTERVAL_MS = 100 + _DRAIN_BATCH_SIZE = 500 # pylint: disable=too-many-arguments def __init__(self, capp, io_loop, db=None, persistent=False, @@ -128,13 +158,21 @@ def __init__(self, capp, io_loop, db=None, persistent=False, self.enable_events = enable_events self.state = None self.state_save_timer = None + self._drain_timer = None + self._event_queue = queue.Queue(maxsize=self._BACKPRESSURE_MAXSIZE) + self._drop_count = 0 + self._last_drop_log_time = 0.0 if self.persistent: logger.debug("Loading state from '%s'...", self.db) - state = shelve.open(self.db) - if state: - self.state = state['events'] - state.close() + try: + with shelve.open(self.db) as state: + if state: + self.state = state['events'] + except KeyError: + logger.debug("No existing state found in '%s'", self.db) + except Exception: + logger.error("Failed to load state from '%s'", self.db, exc_info=True) if state_save_interval: self.state_save_timer = PeriodicCallback(self.save_state, @@ -156,23 +194,42 @@ def start(self): logger.debug("Starting state save timer...") self.state_save_timer.start() + self._drain_timer = PeriodicCallback(self._drain_events, + self._DRAIN_INTERVAL_MS) + self._drain_timer.start() + def stop(self): - if self.enable_events: - logger.debug("Stopping enable events timer...") - self.timer.stop() + try: + if self.enable_events: + logger.debug("Stopping enable events timer...") + try: + self.timer.stop() + except Exception: + logger.debug("Error stopping enable events timer", exc_info=True) - if self.state_save_timer: - logger.debug("Stopping state save timer...") - self.state_save_timer.stop() + if self.state_save_timer: + logger.debug("Stopping state save timer...") + try: + self.state_save_timer.stop() + except Exception: + logger.debug("Error stopping state save timer", exc_info=True) - if self.persistent: - self.save_state() + if self._drain_timer: + try: + self._drain_timer.stop() + except Exception: + logger.debug("Error stopping drain timer", exc_info=True) + finally: + if self.persistent: + self.save_state() def run(self): try_interval = 1 while True: try: try_interval *= 2 + if try_interval > MAX_RETRY_INTERVAL: + try_interval = MAX_RETRY_INTERVAL with self.capp.connection() as conn: recv = EventReceiver(conn, @@ -196,9 +253,11 @@ def run(self): def save_state(self): logger.debug("Saving state to '%s'...", self.db) - state = shelve.open(self.db, flag='n') - state['events'] = self.state - state.close() + try: + with shelve.open(self.db, flag='n') as state: + state['events'] = self.state + except Exception: + logger.error("Failed to save state to '%s'", self.db, exc_info=True) def on_enable_events(self): # Periodically enable events for workers @@ -206,5 +265,30 @@ def on_enable_events(self): self.io_loop.run_in_executor(None, self.capp.control.enable_events) def on_event(self, event): - # Call EventsState.event in ioloop thread to avoid synchronization - self.io_loop.add_callback(partial(self.state.event, event)) + # Enqueue event with backpressure — drop if queue is full. + # Rate-limit drop warnings to avoid flooding logs under sustained load. + try: + self._event_queue.put_nowait(event) + except queue.Full: + self._drop_count += 1 + now = time.monotonic() + if now - self._last_drop_log_time >= 5.0: + window_start = self._last_drop_log_time or now + duration = now - window_start + logger.warning( + "Event queue full (%d), dropped %d event(s) in last %.0fs", + self._BACKPRESSURE_MAXSIZE, self._drop_count, duration) + self._drop_count = 0 + self._last_drop_log_time = now + + def _drain_events(self): + """Process up to _DRAIN_BATCH_SIZE events from the backpressure queue.""" + for _ in range(self._DRAIN_BATCH_SIZE): + try: + event = self._event_queue.get_nowait() + except queue.Empty: + break + try: + self.state.event(event) + except Exception: + logger.error("Error processing event", exc_info=True) diff --git a/flower/inspector.py b/flower/inspector.py index 3b1a64d48..b14041c3b 100644 --- a/flower/inspector.py +++ b/flower/inspector.py @@ -16,6 +16,10 @@ def __init__(self, io_loop, capp, timeout): self.timeout = timeout self.workers = collections.defaultdict(dict) + def purge_worker(self, worker_name): + """Remove a worker from the inspector's cached data.""" + self.workers.pop(worker_name, None) + def inspect(self, workername=None): feutures = [] for method in self.methods: diff --git a/flower/options.py b/flower/options.py index 083d4b5b6..a19be80be 100644 --- a/flower/options.py +++ b/flower/options.py @@ -68,6 +68,8 @@ define("url_prefix", type=str, help="base url prefix") define("task_runtime_metric_buckets", type=float, default=Histogram.DEFAULT_BUCKETS, multiple=True, help="histogram latency bucket value") +define("queue_cache_ttl", type=float, default=5.0, + help="TTL in seconds for caching broker queue stats (0 to disable)") default_options = options diff --git a/flower/utils/broker.py b/flower/utils/broker.py index f04208ac6..e042de98d 100644 --- a/flower/utils/broker.py +++ b/flower/utils/broker.py @@ -65,18 +65,18 @@ async def queues(self, names): try: response = await http_client.fetch( url, auth_username=username, auth_password=password, - connect_timeout=1.0, request_timeout=2.0, + connect_timeout=5.0, request_timeout=30.0, validate_cert=False) except (socket.error, httpclient.HTTPError) as e: logger.error("RabbitMQ management API call failed: %s", e) return [] - finally: - http_client.close() if response.code == 200: info = json.loads(response.body.decode()) - return [x for x in info if x['name'] in names] + names_set = frozenset(names) + return [x for x in info if x['name'] in names_set] response.rethrow() + return [] @classmethod def validate_http_api(cls, http_api): @@ -102,21 +102,53 @@ def __init__(self, broker_url, *_, **kwargs): self.sep = broker_options.get('sep', self.DEFAULT_SEP) self.broker_prefix = broker_options.get('global_keyprefix', '') + def close(self): + """Close the Redis connection and release resources.""" + if self.redis is not None: + try: + if hasattr(self.redis, 'close'): + self.redis.close() + elif hasattr(self.redis, 'connection_pool'): + self.redis.connection_pool.disconnect() + except Exception: + logger.debug("Error closing Redis connection", exc_info=True) + self.redis = None + def _q_for_pri(self, queue, pri): if pri not in self.priority_steps: raise ValueError('Priority not in priority steps') # pylint: disable=consider-using-f-string return '{0}{1}{2}'.format(*((queue, self.sep, pri) if pri else (queue, '', ''))) + _PIPELINE_CHUNK_SIZE = 5000 + async def queues(self, names): - queue_stats = [] + if not names: + return [] + + steps = len(self.priority_steps) + + # Build all Redis key names upfront + all_keys = [] for name in names: - priority_names = [self.broker_prefix + self._q_for_pri( - name, pri) for pri in self.priority_steps] - queue_stats.append({ - 'name': name, - 'messages': sum((self.redis.llen(x) for x in priority_names)) - }) + for pri in self.priority_steps: + all_keys.append(self.broker_prefix + self._q_for_pri(name, pri)) + + # Execute pipelined LLEN in chunks to avoid overwhelming Redis + # with a single 400k-command pipeline for very large queue counts. + all_results = [] + chunk_size = self._PIPELINE_CHUNK_SIZE + for start in range(0, len(all_keys), chunk_size): + pipe = self.redis.pipeline(transaction=False) + for key in all_keys[start:start + chunk_size]: + pipe.llen(key) + all_results.extend(pipe.execute()) + + queue_stats = [] + for i, name in enumerate(names): + offset = i * steps + total = sum(all_results[offset:offset + steps]) + queue_stats.append({'name': name, 'messages': total}) return queue_stats diff --git a/flower/views/__init__.py b/flower/views/__init__.py index fbd80b016..29fd7b85d 100644 --- a/flower/views/__init__.py +++ b/flower/views/__init__.py @@ -13,6 +13,8 @@ logger = logging.getLogger(__name__) +_UNSET = object() + class BaseHandler(tornado.web.RequestHandler): def set_default_headers(self): @@ -91,8 +93,9 @@ def get_current_user(self): return user return None - # pylint: disable=dangerous-default-value - def get_argument(self, name, default=[], strip=True, type=None): + def get_argument(self, name, default=_UNSET, strip=True, type=None): + if default is _UNSET: + default = [] arg = super().get_argument(name, default, strip) if arg and isinstance(arg, str): arg = tornado.escape.xhtml_escape(arg) diff --git a/flower/views/broker.py b/flower/views/broker.py index 75f6c9b3f..33c970bdb 100644 --- a/flower/views/broker.py +++ b/flower/views/broker.py @@ -17,19 +17,35 @@ async def get(self): if app.transport == 'amqp' and app.options.broker_api: http_api = app.options.broker_api - try: - broker = Broker(app.capp.connection(connect_timeout=1.0).as_uri(include_password=True), - http_api=http_api, broker_options=self.capp.conf.broker_transport_options, - broker_use_ssl=self.capp.conf.broker_use_ssl) - except NotImplementedError as exc: - raise web.HTTPError( - 404, f"'{app.transport}' broker is not supported") from exc - - try: - queues = await broker.queues(self.get_active_queue_names()) - except Exception as e: - logger.error("Unable to get queues: '%s'", e) + queue_names = self.get_active_queue_names() + names_key = frozenset(queue_names) + + # Get broker URI once — reuse for both Broker creation and display + with app.capp.connection(connect_timeout=1.0) as conn: + broker_uri = conn.as_uri(include_password=True) + broker_url = conn.as_uri() + + # Check cache first + queues = app.get_cached_queue_stats(names_key) + if queues is None: + try: + broker = Broker(broker_uri, + http_api=http_api, broker_options=self.capp.conf.broker_transport_options, + broker_use_ssl=self.capp.conf.broker_use_ssl) + except NotImplementedError as exc: + raise web.HTTPError( + 404, f"'{app.transport}' broker is not supported") from exc + + queues = [] + try: + queues = await broker.queues(queue_names) + app.set_queue_cache(names_key, queues) + except Exception as e: + logger.error("Unable to get queues: '%s'", e) + finally: + if hasattr(broker, 'close'): + broker.close() self.render("broker.html", - broker_url=app.capp.connection().as_uri(), + broker_url=broker_url, queues=queues) diff --git a/flower/views/workers.py b/flower/views/workers.py index defd0469a..0fdb9287b 100644 --- a/flower/views/workers.py +++ b/flower/views/workers.py @@ -69,9 +69,11 @@ async def get(self): if json: self.write(dict(data=list(workers.values()))) else: + with self.application.capp.connection() as conn: + broker_url = conn.as_uri() self.render("workers.html", workers=workers, - broker=self.application.capp.connection().as_uri(), + broker=broker_url, autorefresh=1 if self.application.options.auto_refresh else 0) @classmethod diff --git a/tests/unit/test_app.py b/tests/unit/test_app.py new file mode 100644 index 000000000..2bfce1005 --- /dev/null +++ b/tests/unit/test_app.py @@ -0,0 +1,220 @@ +import time +import unittest +from unittest.mock import MagicMock, patch, PropertyMock + +import celery +from celery.events import Event +from celery.events.state import Worker +from tornado.ioloop import IOLoop +from tornado.options import options + +from flower import command # noqa: F401 side effect - define options +from flower.app import Flower +from flower.events import Events, EventsState, get_prometheus_metrics +from flower.urls import handlers, settings + + +class TestQueueCache(unittest.TestCase): + def setUp(self): + capp = celery.Celery() + events = Events(capp, IOLoop.current()) + self.app = Flower(capp=capp, events=events, + options=options, handlers=handlers, **settings) + self.app._queue_cache_ttl = 5.0 + + def test_cache_miss_returns_none(self): + result = self.app.get_cached_queue_stats(frozenset(['q1', 'q2'])) + self.assertIsNone(result) + + def test_cache_hit_returns_copy(self): + names_key = frozenset(['q1', 'q2']) + data = [{'name': 'q1', 'messages': 5}, {'name': 'q2', 'messages': 10}] + self.app.set_queue_cache(names_key, data) + + result = self.app.get_cached_queue_stats(names_key) + self.assertEqual(result, data) + self.assertIsNot(result, data) + + def test_cache_returns_copy_to_prevent_mutation(self): + names_key = frozenset(['q1']) + data = [{'name': 'q1', 'messages': 5}] + self.app.set_queue_cache(names_key, data) + + result = self.app.get_cached_queue_stats(names_key) + result.append({'name': 'q2', 'messages': 99}) + + result2 = self.app.get_cached_queue_stats(names_key) + self.assertEqual(len(result2), 1) + + def test_cache_expires_after_ttl(self): + names_key = frozenset(['q1']) + data = [{'name': 'q1', 'messages': 5}] + self.app.set_queue_cache(names_key, data) + + ts, key, result = self.app._queue_cache + self.app._queue_cache = (ts - 10.0, key, result) + + self.assertIsNone(self.app.get_cached_queue_stats(names_key)) + + def test_cache_miss_on_different_names(self): + names_key = frozenset(['q1']) + data = [{'name': 'q1', 'messages': 5}] + self.app.set_queue_cache(names_key, data) + + different_key = frozenset(['q1', 'q2']) + self.assertIsNone(self.app.get_cached_queue_stats(different_key)) + + def test_cache_disabled_when_ttl_zero(self): + self.app._queue_cache_ttl = 0 + names_key = frozenset(['q1']) + data = [{'name': 'q1', 'messages': 5}] + self.app.set_queue_cache(names_key, data) + + self.assertIsNone(self.app.get_cached_queue_stats(names_key)) + + +class TestPurgeOfflineWorkers(unittest.TestCase): + def setUp(self): + capp = celery.Celery() + events = Events(capp, IOLoop.current()) + self.app = Flower(capp=capp, events=events, + options=options, handlers=handlers, **settings) + self._orig_purge = options.purge_offline_workers + + def tearDown(self): + options.purge_offline_workers = self._orig_purge + + def test_purge_removes_offline_workers(self): + state = EventsState() + w, _ = state.get_or_create_worker('w1') + state.counter['w1']['worker-online'] = 1 + w.heartbeats = [time.time() - 3600] + self.app.events.state = state + + self.app.options.purge_offline_workers = 60 + with patch.object(Worker, 'alive', new_callable=PropertyMock, return_value=False): + self.app._purge_offline_workers() + + self.assertNotIn('w1', state.counter) + + def test_purge_keeps_alive_workers(self): + state = EventsState() + w, _ = state.get_or_create_worker('w1') + state.counter['w1']['worker-online'] = 1 + w.heartbeats = [time.time()] + self.app.events.state = state + + self.app.options.purge_offline_workers = 60 + with patch.object(Worker, 'alive', new_callable=PropertyMock, return_value=True): + self.app._purge_offline_workers() + + self.assertIn('w1', state.counter) + + def test_purge_keeps_recently_offline_workers(self): + state = EventsState() + w, _ = state.get_or_create_worker('w1') + state.counter['w1']['worker-online'] = 1 + w.heartbeats = [time.time() - 10] # 10 seconds ago + self.app.events.state = state + + self.app.options.purge_offline_workers = 60 # threshold 60s + with patch.object(Worker, 'alive', new_callable=PropertyMock, return_value=False): + self.app._purge_offline_workers() + + self.assertIn('w1', state.counter) + + def test_purge_removes_orphaned_counter_entries(self): + state = EventsState() + state.counter['orphan_worker']['worker-online'] = 1 + self.app.events.state = state + + self.app.options.purge_offline_workers = 60 + self.app._purge_offline_workers() + + self.assertNotIn('orphan_worker', state.counter) + + def test_purge_removes_orphaned_inspector_entries(self): + state = EventsState() + self.app.events.state = state + self.app.inspector.workers['orphan_worker'] = {'stats': {}} + + self.app.options.purge_offline_workers = 60 + self.app._purge_offline_workers() + + self.assertNotIn('orphan_worker', self.app.inspector.workers) + + def test_purge_noop_when_threshold_is_none(self): + state = EventsState() + state.counter['w1']['worker-online'] = 1 + self.app.events.state = state + + self.app.options.purge_offline_workers = None + self.app._purge_offline_workers() + + self.assertIn('w1', state.counter) + + def test_purge_cleans_prometheus_metrics(self): + state = EventsState() + w, _ = state.get_or_create_worker('test_purge_prom_w1') + state.counter['test_purge_prom_w1']['worker-online'] = 1 + w.heartbeats = [time.time() - 3600] + metrics = get_prometheus_metrics() + metrics.worker_online.labels('test_purge_prom_w1').set(1) + self.app.events.state = state + + self.app.options.purge_offline_workers = 60 + with patch.object(Worker, 'alive', new_callable=PropertyMock, return_value=False): + self.app._purge_offline_workers() + + self.assertNotIn(('test_purge_prom_w1',), metrics.worker_online._metrics) + + +class TestFlowerStopSafety(unittest.TestCase): + def test_stop_continues_if_purge_timer_fails(self): + capp = celery.Celery() + events = Events(capp, IOLoop.current()) + app = Flower(capp=capp, events=events, + options=options, handlers=handlers, **settings) + app.started = True + app._purge_timer = MagicMock() + app._purge_timer.stop.side_effect = RuntimeError("timer error") + app.events = MagicMock() + app.executor = MagicMock() + app.io_loop = MagicMock() + + app.stop() + + app.executor.shutdown.assert_called_once() + app.io_loop.stop.assert_called_once() + self.assertFalse(app.started) + + def test_stop_continues_if_events_stop_fails(self): + capp = celery.Celery() + events = Events(capp, IOLoop.current()) + app = Flower(capp=capp, events=events, + options=options, handlers=handlers, **settings) + app.started = True + app.events = MagicMock() + app.events.stop.side_effect = RuntimeError("events error") + app.executor = MagicMock() + app.io_loop = MagicMock() + + app.stop() + + app.executor.shutdown.assert_called_once() + app.io_loop.stop.assert_called_once() + + +class TestTransportCaching(unittest.TestCase): + def test_transport_is_cached(self): + capp = celery.Celery() + events = Events(capp, IOLoop.current()) + app = Flower(capp=capp, events=events, + options=options, handlers=handlers, **settings) + + app._transport = 'amqp' + self.assertEqual(app.transport, 'amqp') + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/unit/test_events.py b/tests/unit/test_events.py new file mode 100644 index 000000000..7a6775561 --- /dev/null +++ b/tests/unit/test_events.py @@ -0,0 +1,202 @@ +import queue +import time +import unittest +from unittest.mock import MagicMock, patch + +from celery.events import Event +from tornado.ioloop import IOLoop + +from flower.events import Events, EventsState, get_prometheus_metrics + +import celery + + +class TestEventsState(unittest.TestCase): + def test_counter_tracks_events_by_worker(self): + state = EventsState() + state.get_or_create_worker('w1') + e = Event('worker-online', hostname='w1') + e['clock'] = 0 + e['local_received'] = time.time() + state.event(e) + + self.assertIn('w1', state.counter) + self.assertEqual(state.counter['w1']['worker-online'], 1) + + def test_counter_increments(self): + state = EventsState() + state.get_or_create_worker('w1') + for i in range(5): + e = Event('worker-heartbeat', hostname='w1', active=0) + e['clock'] = i + e['local_received'] = time.time() + state.event(e) + + self.assertEqual(state.counter['w1']['worker-heartbeat'], 5) + + +class TestPrometheusMetricsRemoval(unittest.TestCase): + """Test remove_worker_metrics using the global singleton to avoid + duplicate registry errors from prometheus_client.""" + + def test_remove_worker_metrics_clears_labels(self): + metrics = get_prometheus_metrics() + metrics.worker_online.labels('test_remove_w1').set(1) + metrics.worker_online.labels('test_remove_w2').set(1) + + self.assertIn(('test_remove_w1',), metrics.worker_online._metrics) + + metrics.remove_worker_metrics('test_remove_w1') + + self.assertNotIn(('test_remove_w1',), metrics.worker_online._metrics) + self.assertIn(('test_remove_w2',), metrics.worker_online._metrics) + + def test_remove_nonexistent_worker_is_noop(self): + metrics = get_prometheus_metrics() + # Should not raise + metrics.remove_worker_metrics('test_remove_nonexistent_worker_xyz') + + def test_remove_multi_label_metrics(self): + metrics = get_prometheus_metrics() + metrics.runtime.labels('test_remove_mw1', 'task1').observe(1.0) + metrics.runtime.labels('test_remove_mw1', 'task2').observe(2.0) + metrics.runtime.labels('test_remove_mw2', 'task1').observe(3.0) + + metrics.remove_worker_metrics('test_remove_mw1') + + remaining_keys = list(metrics.runtime._metrics.keys()) + for key in remaining_keys: + self.assertNotEqual(key[0], 'test_remove_mw1') + self.assertIn(('test_remove_mw2', 'task1'), metrics.runtime._metrics) + + def test_remove_handles_missing_private_attr(self): + metrics = get_prometheus_metrics() + # Temporarily remove _metrics to simulate missing attr + original = metrics.worker_online._metrics + try: + del metrics.worker_online._metrics + # Should not raise — getattr guard should catch it + metrics.remove_worker_metrics('w1') + finally: + metrics.worker_online._metrics = original + + +class TestEventsBackpressure(unittest.TestCase): + def test_on_event_drops_when_queue_full(self): + capp = celery.Celery() + io_loop = MagicMock() + events = Events(capp, io_loop) + # Fill the queue + for i in range(events._BACKPRESSURE_MAXSIZE): + events.on_event({'hostname': 'w1', 'type': 'worker-heartbeat'}) + + # Next event should be dropped without raising + events.on_event({'hostname': 'w1', 'type': 'worker-heartbeat'}) + self.assertEqual(events._event_queue.qsize(), events._BACKPRESSURE_MAXSIZE) + + def test_drop_logging_is_rate_limited(self): + capp = celery.Celery() + io_loop = MagicMock() + events = Events(capp, io_loop) + # Fill the queue + for i in range(events._BACKPRESSURE_MAXSIZE): + events.on_event({'hostname': 'w1', 'type': 'worker-heartbeat'}) + + # Reset drop state so we control it entirely within the patch. + # Set _last_drop_log_time far enough in the past to guarantee the + # 5-second cooldown has elapsed (time.monotonic() can be small on + # short-lived processes). + events._drop_count = 0 + events._last_drop_log_time = time.monotonic() - 10.0 + + with patch('flower.events.logger') as mock_logger: + # First drop should trigger a log (cooldown elapsed) + events.on_event({'hostname': 'w1', 'type': 'worker-heartbeat'}) + self.assertEqual(mock_logger.warning.call_count, 1) + + # Subsequent drops within 5s should NOT trigger more logs + for _ in range(99): + events.on_event({'hostname': 'w1', 'type': 'worker-heartbeat'}) + self.assertEqual(mock_logger.warning.call_count, 1) + + def test_drain_events_processes_batch(self): + capp = celery.Celery() + io_loop = MagicMock() + events = Events(capp, io_loop) + events.state = MagicMock() + + for i in range(10): + events._event_queue.put({'hostname': 'w1', 'type': 'worker-heartbeat', + 'clock': i, 'local_received': time.time()}) + + events._drain_events() + + self.assertEqual(events.state.event.call_count, 10) + self.assertTrue(events._event_queue.empty()) + + def test_drain_events_handles_errors_gracefully(self): + capp = celery.Celery() + io_loop = MagicMock() + events = Events(capp, io_loop) + events.state = MagicMock() + events.state.event.side_effect = [RuntimeError("test"), None] + + events._event_queue.put({'hostname': 'w1', 'type': 'a'}) + events._event_queue.put({'hostname': 'w1', 'type': 'b'}) + + events._drain_events() + + # Both events should be consumed despite the error on the first one + self.assertEqual(events.state.event.call_count, 2) + self.assertTrue(events._event_queue.empty()) + + def test_drain_respects_batch_size(self): + capp = celery.Celery() + io_loop = MagicMock() + events = Events(capp, io_loop) + events.state = MagicMock() + + count = events._DRAIN_BATCH_SIZE + 100 + for i in range(count): + events._event_queue.put({'hostname': 'w1', 'type': 'hb'}) + + events._drain_events() + + # Should process exactly _DRAIN_BATCH_SIZE, leaving 100 + self.assertEqual(events.state.event.call_count, events._DRAIN_BATCH_SIZE) + self.assertEqual(events._event_queue.qsize(), 100) + + +class TestEventsRetryBackoff(unittest.TestCase): + def test_retry_interval_caps_at_max(self): + from flower.events import MAX_RETRY_INTERVAL + try_interval = 1 + for _ in range(100): + try_interval *= 2 + if try_interval > MAX_RETRY_INTERVAL: + try_interval = MAX_RETRY_INTERVAL + + self.assertEqual(try_interval, MAX_RETRY_INTERVAL) + self.assertEqual(MAX_RETRY_INTERVAL, 60) + + +class TestEventsStopSafety(unittest.TestCase): + def test_stop_calls_save_state_even_if_timer_fails(self): + capp = celery.Celery() + io_loop = MagicMock() + events = Events(capp, io_loop, persistent=True, db='test_db') + + events.timer = MagicMock() + events.timer.stop.side_effect = RuntimeError("timer error") + events.state_save_timer = MagicMock() + events.state_save_timer.stop.side_effect = RuntimeError("save timer error") + events._drain_timer = MagicMock() + events._drain_timer.stop.side_effect = RuntimeError("drain timer error") + + with patch.object(events, 'save_state') as mock_save: + events.stop() + mock_save.assert_called_once() + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/unit/test_inspector.py b/tests/unit/test_inspector.py new file mode 100644 index 000000000..f929c95d5 --- /dev/null +++ b/tests/unit/test_inspector.py @@ -0,0 +1,32 @@ +import unittest +from unittest.mock import MagicMock + +from flower.inspector import Inspector + + +class TestInspectorPurgeWorker(unittest.TestCase): + def test_purge_existing_worker(self): + io_loop = MagicMock() + capp = MagicMock() + inspector = Inspector(io_loop, capp, timeout=1.0) + + inspector.workers['w1'] = {'stats': {}, 'timestamp': 1000} + inspector.workers['w2'] = {'stats': {}, 'timestamp': 1000} + + inspector.purge_worker('w1') + + self.assertNotIn('w1', inspector.workers) + self.assertIn('w2', inspector.workers) + + def test_purge_nonexistent_worker_is_noop(self): + io_loop = MagicMock() + capp = MagicMock() + inspector = Inspector(io_loop, capp, timeout=1.0) + + # Should not raise + inspector.purge_worker('nonexistent') + self.assertEqual(len(inspector.workers), 0) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/unit/utils/test_broker_queues.py b/tests/unit/utils/test_broker_queues.py new file mode 100644 index 000000000..5f43a5a3a --- /dev/null +++ b/tests/unit/utils/test_broker_queues.py @@ -0,0 +1,142 @@ +import asyncio +import unittest +from unittest.mock import MagicMock, patch, call + +from flower.utils import broker +from flower.utils.broker import RedisBase, Redis, RedisSentinel, RedisSocket + +# Ensure redis is mocked at module level like existing tests +broker.redis = MagicMock() + + +class TestRedisPipeline(unittest.TestCase): + """Test that Redis queue fetching uses pipelining.""" + + def _make_broker(self): + b = Redis('redis://localhost:6379/0') + b.redis = MagicMock() + return b + + def test_queues_uses_pipeline(self): + b = self._make_broker() + mock_pipe = MagicMock() + mock_pipe.execute.return_value = [10, 0, 0, 0, 5, 0, 0, 0] + b.redis.pipeline.return_value = mock_pipe + + result = asyncio.get_event_loop().run_until_complete( + b.queues(['q1', 'q2'])) + + # Should have created a pipeline + b.redis.pipeline.assert_called_once_with(transaction=False) + # Should have queued LLEN calls (4 priority steps × 2 queues = 8) + self.assertEqual(mock_pipe.llen.call_count, 8) + mock_pipe.execute.assert_called_once() + + # Results should be summed per queue + self.assertEqual(result[0], {'name': 'q1', 'messages': 10}) + self.assertEqual(result[1], {'name': 'q2', 'messages': 5}) + + def test_queues_empty_names(self): + b = self._make_broker() + result = asyncio.get_event_loop().run_until_complete(b.queues([])) + self.assertEqual(result, []) + b.redis.pipeline.assert_not_called() + + def test_queues_sums_priority_steps(self): + b = self._make_broker() + mock_pipe = MagicMock() + # 4 priority steps for one queue, all with values + mock_pipe.execute.return_value = [10, 20, 30, 40] + b.redis.pipeline.return_value = mock_pipe + + result = asyncio.get_event_loop().run_until_complete(b.queues(['q1'])) + self.assertEqual(result[0]['messages'], 100) + + def test_queues_chunked_for_large_counts(self): + """Verify pipeline batching kicks in for very large queue counts.""" + b = self._make_broker() + b.priority_steps = [0] # 1 step to simplify + + chunk_size = b._PIPELINE_CHUNK_SIZE + num_queues = chunk_size + 10 # Just over one chunk + names = [f'q{i}' for i in range(num_queues)] + + mock_pipe = MagicMock() + mock_pipe.execute.side_effect = [ + [1] * chunk_size, # First chunk + [2] * 10, # Second chunk + ] + b.redis.pipeline.return_value = mock_pipe + + result = asyncio.get_event_loop().run_until_complete(b.queues(names)) + + # Should have created 2 pipelines (one per chunk) + self.assertEqual(b.redis.pipeline.call_count, 2) + self.assertEqual(len(result), num_queues) + # First chunk queues have message count 1, second chunk have 2 + self.assertEqual(result[0]['messages'], 1) + self.assertEqual(result[chunk_size]['messages'], 2) + + +class TestRedisClose(unittest.TestCase): + def test_close_with_close_method(self): + b = Redis('redis://localhost:6379/0') + mock_redis = MagicMock() + mock_redis.close = MagicMock() + b.redis = mock_redis + + b.close() + + mock_redis.close.assert_called_once() + self.assertIsNone(b.redis) + + def test_close_without_close_method(self): + """Test fallback to connection_pool.disconnect() for older redis-py.""" + b = Redis('redis://localhost:6379/0') + mock_redis = MagicMock(spec=[]) # No close method + mock_redis.connection_pool = MagicMock() + b.redis = mock_redis + + b.close() + self.assertIsNone(b.redis) + + def test_close_already_none(self): + b = Redis('redis://localhost:6379/0') + b.redis = None + # Should not raise + b.close() + self.assertIsNone(b.redis) + + def test_close_handles_exception(self): + b = Redis('redis://localhost:6379/0') + b.redis = MagicMock() + b.redis.close.side_effect = RuntimeError("connection error") + + # Should not raise + b.close() + self.assertIsNone(b.redis) + + +class TestRabbitMQOptimizations(unittest.TestCase): + def test_frozenset_filtering(self): + """Ensure set-based filtering works correctly.""" + from flower.utils.broker import RabbitMQ + b = RabbitMQ('amqp://', '') + + # Simulate what happens inside queues() after API response + info = [ + {'name': 'q1', 'messages': 5}, + {'name': 'q2', 'messages': 10}, + {'name': 'q3', 'messages': 15}, + ] + names = ['q1', 'q3'] + names_set = frozenset(names) + result = [x for x in info if x['name'] in names_set] + + self.assertEqual(len(result), 2) + self.assertEqual(result[0]['name'], 'q1') + self.assertEqual(result[1]['name'], 'q3') + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/unit/views/test_base_handler.py b/tests/unit/views/test_base_handler.py new file mode 100644 index 000000000..27c5d4adc --- /dev/null +++ b/tests/unit/views/test_base_handler.py @@ -0,0 +1,21 @@ +import unittest + +from flower.views import _UNSET + + +class TestMutableDefaultFix(unittest.TestCase): + """Verify that the mutable default argument fix works correctly.""" + + def test_unset_sentinel_is_unique(self): + self.assertIsNotNone(_UNSET) + self.assertIsNot(_UNSET, []) + self.assertIsNot(_UNSET, None) + + def test_sentinel_identity(self): + # Same object every time + from flower.views import _UNSET as unset2 + self.assertIs(_UNSET, unset2) + + +if __name__ == '__main__': + unittest.main()