Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 58 additions & 0 deletions bec_server/bec_server/scan_server/device_locking.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import threading
from contextlib import contextmanager
from typing import Dict, Iterable

from bec_lib.logger import bec_logger

logger = bec_logger.logger


class DeviceLockManager:
"""
Manages locks for devices, identified simply as their name.
Allows acquiring multiple item locks atomically via a context manager.
"""

def __init__(self) -> None:
self._locks: Dict[str, threading.RLock] = {}
self._locks_guard = threading.RLock()

def _get_lock(self, key: str) -> threading.RLock:
"""
Get (or create) a lock for a given key.
"""
with self._locks_guard:
if key not in self._locks:
self._locks[key] = threading.RLock()
return self._locks[key]

@contextmanager
def lock(self, keys: Iterable[str], blocking: bool = True):
"""
Context manager to lock one or more items.
"""
keys = list(set(keys))
try:
if not self.acquire(*keys, blocking=blocking):
return
yield
finally:
self.release(*keys)

def acquire(self, *keys: str, blocking: bool = True):
logger.info(f"Locking devices: {keys}")
with self._locks_guard:
new_locks = []
for key in sorted(keys):
next_lock = self._get_lock(key)
if not next_lock.acquire(blocking=blocking):
[lock.release() for lock in new_locks]
return False
new_locks.append(next_lock)
return True

def release(self, *keys: str):
logger.info(f"Releasing devices: {keys}")
with self._locks_guard:
for key in reversed(sorted(keys)):
self._get_lock(key).release()
5 changes: 4 additions & 1 deletion bec_server/bec_server/scan_server/scan_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,20 @@

from typing import TYPE_CHECKING

from bec_lib import messages
from bec_lib.alarm_handler import Alarms
from bec_lib.bec_service import BECService
from bec_lib.devicemanager import DeviceManagerBase as DeviceManager
from bec_lib.endpoints import MessageEndpoints
from bec_lib.logger import bec_logger
from bec_lib.scan_number_container import ScanNumberContainer
from bec_lib.service_config import ServiceConfig

from bec_lib import messages
from bec_server.procedures.container_utils import podman_available
from bec_server.procedures.container_worker import ContainerProcedureWorker
from bec_server.procedures.manager import ProcedureManager
from bec_server.procedures.subprocess_worker import SubProcessWorker
from bec_server.scan_server.device_locking import DeviceLockManager

from .scan_assembler import ScanAssembler
from .scan_guard import ScanGuard
Expand All @@ -36,6 +38,7 @@ class ScanServer(BECService):

def __init__(self, config: ServiceConfig, connector_cls: type[RedisConnector]):
super().__init__(config, connector_cls, unique_service=True)
self.device_locks = DeviceLockManager()
self._start_scan_manager()
self._start_device_manager()
self._start_queue_manager()
Expand Down
97 changes: 53 additions & 44 deletions bec_server/bec_server/scan_server/scan_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from bec_lib.endpoints import MessageEndpoints
from bec_lib.file_utils import compile_file_components
from bec_lib.logger import bec_logger
from bec_server.scan_server.scans import RequestBase

from .errors import DeviceInstructionError, ScanAbortion
from .scan_queue import InstructionQueueItem, InstructionQueueStatus, RequestBlock
Expand Down Expand Up @@ -398,25 +399,57 @@ def _process_instructions(self, queue: InstructionQueueItem) -> None:
self._wait_for_device_server()

queue.is_active = True
try:
for instr in queue:
self._check_for_interruption()
if instr is None:
continue
self._exposure_time = getattr(queue.active_request_block.scan, "exp_time", None)
self._instruction_step(instr)
except ScanAbortion as exc:
if queue.stopped or not (queue.return_to_start and queue.active_request_block):
raise ScanAbortion from exc
queue.stopped = True
scan_instance: RequestBase | None
if (scan_instance := getattr(queue.active_request_block, "scan", None)) is None:
devices_to_lock = []
else:
devices_to_lock = scan_instance.instance_device_access().device_locking
with self.parent.device_locks.lock(devices_to_lock):
Comment thread
d-perl marked this conversation as resolved.
Outdated
try:
cleanup = queue.active_request_block.scan.move_to_start()
self.status = InstructionQueueStatus.RUNNING
for instr in cleanup:
for instr in queue:
self._check_for_interruption()
instr.metadata["scan_id"] = queue.queue.active_rb.scan_id
instr.metadata["queue_id"] = queue.queue_id
if instr is None:
continue
self._exposure_time = getattr(queue.active_request_block.scan, "exp_time", None)
self._instruction_step(instr)
except ScanAbortion as exc:
if queue.stopped or not (queue.return_to_start and queue.active_request_block):
raise ScanAbortion from exc
queue.stopped = True
try:
cleanup = queue.active_request_block.scan.move_to_start()
self.status = InstructionQueueStatus.RUNNING
for instr in cleanup:
self._check_for_interruption()
instr.metadata["scan_id"] = queue.queue.active_rb.scan_id
instr.metadata["queue_id"] = queue.queue_id
self._instruction_step(instr)
except DeviceInstructionError as exc_di:
content = traceback.format_exc()
logger.error(content)
self.connector.raise_alarm(
severity=Alarms.MAJOR,
info=exc_di.error_info,
metadata=self._get_metadata_for_alarm(),
)
raise ScanAbortion from exc_di
except Exception as exc_return_to_start:
# if the return_to_start fails, raise the original exception
content = traceback.format_exc()
logger.error(content)
error_info = messages.ErrorInfo(
error_message=content,
compact_error_message=traceback.format_exc(limit=0),
exception_type=exc_return_to_start.__class__.__name__,
device=None,
)
self.connector.raise_alarm(
severity=Alarms.MAJOR,
info=error_info,
metadata=self._get_metadata_for_alarm(),
)
raise ScanAbortion from exc
raise ScanAbortion from exc
except DeviceInstructionError as exc_di:
content = traceback.format_exc()
logger.error(content)
Expand All @@ -425,46 +458,22 @@ def _process_instructions(self, queue: InstructionQueueItem) -> None:
info=exc_di.error_info,
metadata=self._get_metadata_for_alarm(),
)

raise ScanAbortion from exc_di
except Exception as exc_return_to_start:
# if the return_to_start fails, raise the original exception
except Exception as exc:
content = traceback.format_exc()
logger.error(content)
error_info = messages.ErrorInfo(
error_message=content,
compact_error_message=traceback.format_exc(limit=0),
exception_type=exc_return_to_start.__class__.__name__,
exception_type=exc.__class__.__name__,
device=None,
)
self.connector.raise_alarm(
severity=Alarms.MAJOR, info=error_info, metadata=self._get_metadata_for_alarm()
)
raise ScanAbortion from exc
raise ScanAbortion from exc
except DeviceInstructionError as exc_di:
content = traceback.format_exc()
logger.error(content)
self.connector.raise_alarm(
severity=Alarms.MAJOR,
info=exc_di.error_info,
metadata=self._get_metadata_for_alarm(),
)

raise ScanAbortion from exc_di
except Exception as exc:
content = traceback.format_exc()
logger.error(content)
error_info = messages.ErrorInfo(
error_message=content,
compact_error_message=traceback.format_exc(limit=0),
exception_type=exc.__class__.__name__,
device=None,
)
self.connector.raise_alarm(
severity=Alarms.MAJOR, info=error_info, metadata=self._get_metadata_for_alarm()
)

raise ScanAbortion from exc
raise ScanAbortion from exc
queue.is_active = False
queue.status = InstructionQueueStatus.COMPLETED
self.current_instruction_queue_item = None
Expand Down
31 changes: 27 additions & 4 deletions bec_server/bec_server/scan_server/scans.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,14 @@
from typing import Any, Literal

import numpy as np

from bec_lib import messages
from bec_lib.alarm_handler import Alarms
from bec_lib.device import DeviceBase
from bec_lib.devicemanager import DeviceManagerBase
from bec_lib.endpoints import MessageEndpoints
from bec_lib.logger import bec_logger
from pydantic import BaseModel

from bec_lib import messages
from bec_server.scan_server.instruction_handler import InstructionHandler

from .errors import LimitError, ScanAbortion
Expand Down Expand Up @@ -245,7 +246,7 @@ class RequestBase(ABC):
"""

scan_name = ""
arg_input = {}
arg_input: dict[str, ScanArgType] = {}
arg_bundle_size = {"bundle": len(arg_input), "min": None, "max": None}
gui_args = {}
required_kwargs = []
Expand Down Expand Up @@ -376,6 +377,28 @@ def update_readout_priority(self):
def run(self):
pass

@classmethod
def device_access(cls, scan_parameters: dict) -> ScanDeviceAccessList:
"""Provide the devices for which permissions and locking are needed for this scan, with the given parameter set."""
arg_devices = set(scan_parameters.get("args", {}).keys())
param_kwargs = scan_parameters.get("kwargs", {})
kwarg_devices = set()
for arg, T in cls.arg_input.items():
if T == ScanArgType.DEVICE and arg in param_kwargs:
kwarg_devices.add(str(param_kwargs[arg]))
devices_used_in_scan = arg_devices | kwarg_devices
return ScanDeviceAccessList(
device_permissions=devices_used_in_scan, device_locking=devices_used_in_scan
)

def instance_device_access(self) -> ScanDeviceAccessList:
return self.device_access(self.parameter)


class ScanDeviceAccessList(BaseModel):
device_permissions: set[str]
device_locking: set[str]


class ScanBase(RequestBase, PathOptimizerMixin):
"""
Expand Down Expand Up @@ -403,7 +426,7 @@ class ScanBase(RequestBase, PathOptimizerMixin):
Attributes:
scan_name (str): name of the scan
scan_type (str): scan type. Can be "step" or "fly"
arg_input (list): list of scan argument types
arg_input (dict[str, ScanArgType]): list of scan argument types
arg_bundle_size (dict):
- bundle: number of arguments that are bundled together
- min: minimum number of bundles
Expand Down
21 changes: 21 additions & 0 deletions bec_server/tests/tests_scan_server/conftest.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
from typing import Callable, Generator

import fakeredis
import pytest

from bec_lib.logger import bec_logger
from bec_lib.redis_connector import RedisConnector
from bec_server.scan_server.scan_queue import QueueManager
from bec_server.scan_server.tests.fixtures import scan_server_mock

# overwrite threads_check fixture from bec_lib,
# to have it in autouse
Expand All @@ -27,3 +31,20 @@ def connected_connector():
yield connector
finally:
connector.shutdown()


@pytest.fixture
def queuemanager_mock(scan_server_mock):
def _get_queuemanager(queues=None) -> QueueManager:
scan_server = scan_server_mock
if queues is None:
queues = ["primary"]
if isinstance(queues, str):
queues = [queues]
for queue in queues:
scan_server.queue_manager.add_queue(queue)
return scan_server.queue_manager

yield _get_queuemanager

scan_server_mock.queue_manager.shutdown()
38 changes: 38 additions & 0 deletions bec_server/tests/tests_scan_server/test_device_locking.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
from typing import Callable

import pytest
from bec_server.scan_server.scan_queue import QueueManager

from bec_lib import messages


@pytest.fixture
def qm_with_3_qs_and_lock_man(queuemanager_mock: Callable[..., QueueManager]):
queue_manager = queuemanager_mock(["1", "2", "3"])
yield queue_manager, queue_manager.parent.device_locks


def _linescan_msg(dev: str, start: float, stop: float):
return messages.ScanQueueMessage(
scan_type="line_scan",
parameter={"args": {dev: (start, stop)}, "kwargs": {}},
queue="primary",
metadata={"RID": "something"},
)


def test_devices_from_instance(queuemanager_mock):
q_manager = queuemanager_mock()
assembler = q_manager.parent.scan_assembler
scan_instance = assembler.assemble_device_instructions(_linescan_msg("samx", -1, 1), "test")
device_access = scan_instance.instance_device_access()
assert device_access.device_locking == set(("samx",))


def test_queuemanager_add_to_queue_restarts_queue_if_worker_is_dead(qm_with_3_qs_and_lock_man):
queue_manager, locks = qm_with_3_qs_and_lock_man
msg = _linescan_msg("samx", -5, 5)

queue_manager.add_to_queue(scan_queue="1", msg=msg)

...
18 changes: 0 additions & 18 deletions bec_server/tests/tests_scan_server/test_scan_server_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,30 +20,12 @@
ScanQueueStatus,
)
from bec_server.scan_server.scan_worker import ScanWorker
from bec_server.scan_server.tests.fixtures import scan_server_mock

# pylint: disable=missing-function-docstring
# pylint: disable=protected-access
ScanQueue.AUTO_SHUTDOWN_TIME = 1 # Reduce auto-shutdown time for testing


@pytest.fixture
def queuemanager_mock(scan_server_mock) -> QueueManager:
def _get_queuemanager(queues=None):
scan_server = scan_server_mock
if queues is None:
queues = ["primary"]
if isinstance(queues, str):
queues = [queues]
for queue in queues:
scan_server.queue_manager.add_queue(queue)
return scan_server.queue_manager

yield _get_queuemanager

scan_server_mock.queue_manager.shutdown()


class RequestBlockQueueMock(RequestBlockQueue):
request_blocks = []
_scan_id = []
Expand Down
Loading