diff --git a/docs/developer/design/message-flow-and-transformation.md b/docs/developer/design/message-flow-and-transformation.md index 271d07c8d..90fafa64d 100644 --- a/docs/developer/design/message-flow-and-transformation.md +++ b/docs/developer/design/message-flow-and-transformation.md @@ -111,7 +111,7 @@ Isolates ESSlivedata from Kafka topic structure: Kafka uses `(topic, source_name ### Adapter Pattern -`MessageAdapter` protocol: `adapt(message: T) -> U`. Benefits: composable, type-safe, testable, reusable. +`MessageAdapter` protocol: `adapt(message: T) -> Sequence[U]`. Each adapter returns a sequence, enabling 1:N message expansion (e.g., splitting multi-pulse ev44 messages into one message per pulse). Most adapters return a single-element tuple. Benefits: composable, type-safe, testable, reusable. ### Core Adapters @@ -127,7 +127,7 @@ Isolates ESSlivedata from Kafka topic structure: Kafka uses `(topic, source_name ### Adapter Composition -**ChainedAdapter**: Chains two adapters sequentially (`second.adapt(first.adapt(message))`). +**ChainedAdapter**: Chains two adapters with flatmap semantics — for each intermediate result from the first adapter, all results from the second adapter are collected into a flat sequence. **RouteByTopicAdapter**: Routes by Kafka topic to different adapters. Provides `.topics` list for subscription. diff --git a/src/ess/livedata/handlers/to_nxevent_data.py b/src/ess/livedata/handlers/to_nxevent_data.py index 7d9c4110f..98d0490db 100644 --- a/src/ess/livedata/handlers/to_nxevent_data.py +++ b/src/ess/livedata/handlers/to_nxevent_data.py @@ -13,10 +13,33 @@ from ess.livedata.core.handler import Accumulator -def _require_single_pulse(ev44: eventdata_ev44.EventData) -> None: +def split_ev44_pulses( + ev44: eventdata_ev44.EventData, +) -> list[eventdata_ev44.EventData]: + """Split a multi-pulse ev44 message into one EventData per pulse. + + For single-pulse messages the input is returned as-is (no copy). + """ + n_pulses = len(ev44.reference_time) + if n_pulses <= 1: + return [ev44] index = ev44.reference_time_index - if len(index) > 1 or index[0] != 0 or len(ev44.reference_time) > 1: - raise NotImplementedError("Processing multi-pulse messages is not supported.") + n_events = len(ev44.time_of_flight) + pulses: list[eventdata_ev44.EventData] = [] + for i in range(n_pulses): + start = index[i] + end = index[i + 1] if i + 1 < len(index) else n_events + pulses.append( + eventdata_ev44.EventData( + source_name=ev44.source_name, + message_id=ev44.message_id, + reference_time=ev44.reference_time[i : i + 1], + reference_time_index=np.array([0]), + time_of_flight=ev44.time_of_flight[start:end], + pixel_id=ev44.pixel_id[start:end] if ev44.pixel_id is not None else [], + ) + ) + return pulses @dataclass @@ -36,7 +59,6 @@ class MonitorEvents: @staticmethod def from_ev44(ev44: eventdata_ev44.EventData) -> MonitorEvents: - _require_single_pulse(ev44) return MonitorEvents(time_of_arrival=ev44.time_of_flight, unit='ns') @@ -63,7 +85,6 @@ def __post_init__(self) -> None: @staticmethod def from_ev44(ev44: eventdata_ev44.EventData) -> DetectorEvents: - _require_single_pulse(ev44) return DetectorEvents( pixel_id=ev44.pixel_id, time_of_arrival=ev44.time_of_flight, unit='ns' ) diff --git a/src/ess/livedata/kafka/message_adapter.py b/src/ess/livedata/kafka/message_adapter.py index bb5aed03e..831231a57 100644 --- a/src/ess/livedata/kafka/message_adapter.py +++ b/src/ess/livedata/kafka/message_adapter.py @@ -88,7 +88,7 @@ def __eq__(self, other: object) -> bool: class MessageAdapter(Protocol, Generic[T, U]): - def adapt(self, message: T) -> U: ... + def adapt(self, message: T) -> Sequence[U]: ... class KafkaAdapter(MessageAdapter[KafkaMessage, Message[T]]): @@ -114,37 +114,53 @@ def get_stream_id(self, topic: str, source_name: str) -> StreamId: class KafkaToEv44Adapter(KafkaAdapter[eventdata_ev44.EventData]): - def adapt(self, message: KafkaMessage) -> Message[eventdata_ev44.EventData]: + def adapt( + self, message: KafkaMessage + ) -> Sequence[Message[eventdata_ev44.EventData]]: + from ..handlers.to_nxevent_data import split_ev44_pulses + ev44 = eventdata_ev44.deserialise_ev44(message.value()) stream = self.get_stream_id(topic=message.topic(), source_name=ev44.source_name) - # A fallback, useful in particular for testing so serialized data can be reused. - if ev44.reference_time.size > 0: - timestamp = ev44.reference_time[-1] - else: - timestamp = message.timestamp()[1] - return Message(timestamp=timestamp, stream=stream, value=ev44) + fallback_ts = message.timestamp()[1] + pulses = split_ev44_pulses(ev44) + return tuple( + Message( + timestamp=( + int(pulse.reference_time[0]) + if len(pulse.reference_time) > 0 + else fallback_ts + ), + stream=stream, + value=pulse, + ) + for pulse in pulses + ) class KafkaToDa00Adapter(KafkaAdapter[list[dataarray_da00.Variable]]): - def adapt(self, message: KafkaMessage) -> Message[list[dataarray_da00.Variable]]: + def adapt( + self, message: KafkaMessage + ) -> Sequence[Message[list[dataarray_da00.Variable]]]: da00: dataarray_da00.da00_DataArray_t da00 = dataarray_da00.deserialise_da00(message.value()) # type: ignore[reportAssignmentType] key = self.get_stream_id(topic=message.topic(), source_name=da00.source_name) timestamp = da00.timestamp_ns - return Message(timestamp=timestamp, stream=key, value=da00.data) + return (Message(timestamp=timestamp, stream=key, value=da00.data),) class KafkaToF144Adapter(KafkaAdapter[logdata_f144.ExtractedLogData]): def __init__(self, *, stream_lut: StreamLUT | None = None): super().__init__(stream_lut=stream_lut, stream_kind=StreamKind.LOG) - def adapt(self, message: KafkaMessage) -> Message[logdata_f144.ExtractedLogData]: + def adapt( + self, message: KafkaMessage + ) -> Sequence[Message[logdata_f144.ExtractedLogData]]: log_data = logdata_f144.deserialise_f144(message.value()) key = self.get_stream_id( topic=message.topic(), source_name=log_data.source_name ) timestamp = log_data.timestamp_unix_ns - return Message(timestamp=timestamp, stream=key, value=log_data) + return (Message(timestamp=timestamp, stream=key, value=log_data),) class F144ToLogDataAdapter( @@ -152,11 +168,13 @@ class F144ToLogDataAdapter( ): def adapt( self, message: Message[logdata_f144.ExtractedLogData] - ) -> Message[LogData]: - return Message( - timestamp=message.timestamp, - stream=message.stream, - value=LogData.from_f144(message.value), + ) -> Sequence[Message[LogData]]: + return ( + Message( + timestamp=message.timestamp, + stream=message.stream, + value=LogData.from_f144(message.value), + ), ) @@ -165,11 +183,13 @@ class Ev44ToMonitorEventsAdapter( ): def adapt( self, message: Message[eventdata_ev44.EventData] - ) -> Message[MonitorEvents]: - return Message( - timestamp=message.timestamp, - stream=message.stream, - value=MonitorEvents.from_ev44(message.value), + ) -> Sequence[Message[MonitorEvents]]: + return ( + Message( + timestamp=message.timestamp, + stream=message.stream, + value=MonitorEvents.from_ev44(message.value), + ), ) @@ -182,11 +202,15 @@ class X5f2ToStatusAdapter( Discriminates based on the `message_type` field in the x5f2 status_json. """ - def adapt(self, message: KafkaMessage) -> Message[JobStatus | ServiceStatus]: - return Message( - timestamp=message.timestamp()[1], - stream=STATUS_STREAM_ID, - value=x5f2_to_status(message.value()), + def adapt( + self, message: KafkaMessage + ) -> Sequence[Message[JobStatus | ServiceStatus]]: + return ( + Message( + timestamp=message.timestamp()[1], + stream=STATUS_STREAM_ID, + value=x5f2_to_status(message.value()), + ), ) @@ -202,7 +226,7 @@ class KafkaToMonitorEventsAdapter(KafkaAdapter[MonitorEvents]): def __init__(self, stream_lut: StreamLUT): super().__init__(stream_lut=stream_lut, stream_kind=StreamKind.MONITOR_EVENTS) - def adapt(self, message: KafkaMessage) -> Message[MonitorEvents]: + def adapt(self, message: KafkaMessage) -> Sequence[Message[MonitorEvents]]: buffer = message.value() eventdata_ev44.check_schema_identifier(buffer, eventdata_ev44.FILE_IDENTIFIER) event = Event44Message.Event44Message.GetRootAs(buffer, 0) @@ -212,16 +236,32 @@ def adapt(self, message: KafkaMessage) -> Message[MonitorEvents]: reference_time = event.ReferenceTimeAsNumpy() time_of_arrival = event.TimeOfFlightAsNumpy() - # A fallback, useful in particular for testing so serialized data can be reused. - if reference_time.size > 0: - timestamp = reference_time[-1] - else: - timestamp = message.timestamp()[1] - return Message( - timestamp=timestamp, - stream=stream, - value=MonitorEvents(time_of_arrival=time_of_arrival, unit='ns'), - ) + n_pulses = reference_time.size + if n_pulses <= 1: + timestamp = reference_time[0] if n_pulses == 1 else message.timestamp()[1] + return ( + Message( + timestamp=timestamp, + stream=stream, + value=MonitorEvents(time_of_arrival=time_of_arrival, unit='ns'), + ), + ) + ref_index = event.ReferenceTimeIndexAsNumpy() + n_events = len(time_of_arrival) + results: list[Message[MonitorEvents]] = [] + for i in range(n_pulses): + start = ref_index[i] + end = ref_index[i + 1] if i + 1 < len(ref_index) else n_events + results.append( + Message( + timestamp=int(reference_time[i]), + stream=stream, + value=MonitorEvents( + time_of_arrival=time_of_arrival[start:end], unit='ns' + ), + ) + ) + return results class Ev44ToDetectorEventsAdapter( @@ -241,14 +281,16 @@ def __init__(self, *, merge_detectors: bool = False): def adapt( self, message: Message[eventdata_ev44.EventData] - ) -> Message[DetectorEvents]: + ) -> Sequence[Message[DetectorEvents]]: stream = message.stream if self._merge_detectors: stream = replace(stream, name='unified_detector') - return Message( - timestamp=message.timestamp, - stream=stream, - value=DetectorEvents.from_ev44(message.value), + return ( + Message( + timestamp=message.timestamp, + stream=stream, + value=DetectorEvents.from_ev44(message.value), + ), ) @@ -257,20 +299,24 @@ class Da00ToScippAdapter( ): def adapt( self, message: Message[list[dataarray_da00.Variable]] - ) -> Message[sc.DataArray]: - return Message( - timestamp=message.timestamp, - stream=message.stream, - value=da00_to_scipp(message.value), + ) -> Sequence[Message[sc.DataArray]]: + return ( + Message( + timestamp=message.timestamp, + stream=message.stream, + value=da00_to_scipp(message.value), + ), ) class KafkaToAd00Adapter(KafkaAdapter[area_detector_ad00.ADArray]): - def adapt(self, message: KafkaMessage) -> Message[area_detector_ad00.ADArray]: + def adapt( + self, message: KafkaMessage + ) -> Sequence[Message[area_detector_ad00.ADArray]]: ad00 = area_detector_ad00.deserialise_ad00(message.value()) key = self.get_stream_id(topic=message.topic(), source_name=ad00.source_name) timestamp = ad00.timestamp_ns - return Message(timestamp=timestamp, stream=key, value=ad00) + return (Message(timestamp=timestamp, stream=key, value=ad00),) class Ad00ToScippAdapter( @@ -278,11 +324,13 @@ class Ad00ToScippAdapter( ): def adapt( self, message: Message[area_detector_ad00.ADArray] - ) -> Message[sc.DataArray]: - return Message( - timestamp=message.timestamp, - stream=message.stream, - value=ad00_to_scipp(message.value), + ) -> Sequence[Message[sc.DataArray]]: + return ( + Message( + timestamp=message.timestamp, + stream=message.stream, + value=ad00_to_scipp(message.value), + ), ) @@ -295,21 +343,21 @@ class RawConfigItem: class CommandsAdapter(MessageAdapter[KafkaMessage, Message[RawConfigItem]]): """Adapts Kafka messages from the livedata commands topic.""" - def adapt(self, message: KafkaMessage) -> Message[RawConfigItem]: + def adapt(self, message: KafkaMessage) -> Sequence[Message[RawConfigItem]]: timestamp = message.timestamp()[1] # Livedata configuration uses a compacted Kafka topic. The Kafka message key # is the encoded string representation of a :py:class:`ConfigKey` object. item = RawConfigItem(key=message.key(), value=message.value()) - return Message(stream=COMMANDS_STREAM_ID, timestamp=timestamp, value=item) + return (Message(stream=COMMANDS_STREAM_ID, timestamp=timestamp, value=item),) class ResponsesAdapter(MessageAdapter[KafkaMessage, Message[CommandAcknowledgement]]): """Adapts Kafka messages from the livedata responses topic.""" - def adapt(self, message: KafkaMessage) -> Message[CommandAcknowledgement]: + def adapt(self, message: KafkaMessage) -> Sequence[Message[CommandAcknowledgement]]: timestamp = message.timestamp()[1] ack = CommandAcknowledgement.model_validate_json(message.value()) - return Message(stream=RESPONSES_STREAM_ID, timestamp=timestamp, value=ack) + return (Message(stream=RESPONSES_STREAM_ID, timestamp=timestamp, value=ack),) class ChainedAdapter(MessageAdapter[T, V]): @@ -321,9 +369,12 @@ def __init__(self, first: MessageAdapter[T, U], second: MessageAdapter[U, V]): self._first = first self._second = second - def adapt(self, message: T) -> V: - intermediate = self._first.adapt(message) - return self._second.adapt(intermediate) + def adapt(self, message: T) -> Sequence[V]: + return [ + result + for intermediate in self._first.adapt(message) + for result in self._second.adapt(intermediate) + ] class RouteBySchemaAdapter(MessageAdapter[KafkaMessage, T]): @@ -334,7 +385,7 @@ class RouteBySchemaAdapter(MessageAdapter[KafkaMessage, T]): def __init__(self, routes: dict[str, MessageAdapter[KafkaMessage, T]]): self._routes = routes - def adapt(self, message: KafkaMessage) -> T: + def adapt(self, message: KafkaMessage) -> Sequence[T]: schema = streaming_data_types.utils.get_schema(message.value()) if schema is None: raise streaming_data_types.exceptions.WrongSchemaException( @@ -362,7 +413,7 @@ def topics(self) -> list[str]: """Returns the list of topics to subscribe to.""" return list(self._routes.keys()) - def adapt(self, message: KafkaMessage) -> T: + def adapt(self, message: KafkaMessage) -> Sequence[T]: topic = message.topic() if topic not in self._routes: raise KeyError( @@ -403,7 +454,7 @@ def get_messages(self) -> Sequence[U]: adapted = [] for msg in raw_messages: try: - adapted.append(self._adapter.adapt(msg)) + adapted.extend(self._adapter.adapt(msg)) except streaming_data_types.exceptions.WrongSchemaException: logger.warning('Message %s has an unknown schema. Skipping.', msg) if self._raise_on_error: diff --git a/src/ess/livedata/services/fake_monitors.py b/src/ess/livedata/services/fake_monitors.py index c866156de..96cff3293 100644 --- a/src/ess/livedata/services/fake_monitors.py +++ b/src/ess/livedata/services/fake_monitors.py @@ -90,11 +90,13 @@ class EventsToHistogramAdapter( def __init__(self, toa: sc.Variable): self._toa = toa - def adapt(self, message: Message[sc.Variable]) -> Message[sc.DataArray]: - return replace( - message, - stream=replace(message.stream, kind=StreamKind.MONITOR_COUNTS), - value=message.value.hist({self._toa.dim: self._toa}), + def adapt(self, message: Message[sc.Variable]) -> tuple[Message[sc.DataArray], ...]: + return ( + replace( + message, + stream=replace(message.stream, kind=StreamKind.MONITOR_COUNTS), + value=message.value.hist({self._toa.dim: self._toa}), + ), ) diff --git a/tests/handlers/to_nxevent_data_test.py b/tests/handlers/to_nxevent_data_test.py index 0739ca2bb..865b74c88 100644 --- a/tests/handlers/to_nxevent_data_test.py +++ b/tests/handlers/to_nxevent_data_test.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: BSD-3-Clause # Copyright (c) 2025 Scipp contributors (https://github.com/scipp) +import numpy as np import pytest import scipp as sc from scipp.testing import assert_identical @@ -9,6 +10,7 @@ DetectorEvents, MonitorEvents, ToNXevent_data, + split_ev44_pulses, ) @@ -26,20 +28,57 @@ def test_MonitorEvents_from_ev44() -> None: assert monitor_events.unit == 'ns' -@pytest.mark.parametrize('events_cls', [MonitorEvents, DetectorEvents]) -def test_MonitorEvents_from_ev44_raises_with_multi_pulse_message( - events_cls: MonitorEvents, -) -> None: +def test_split_ev44_pulses_single_pulse_returns_input_as_is() -> None: ev44 = eventdata_ev44.EventData( - source_name='ignored', + source_name='src', message_id=0, - reference_time=[1, 2], - reference_time_index=[0, 1], - pixel_id=[1, 1, 1], - time_of_flight=[1, 2, 3], + reference_time=[100], + reference_time_index=[0], + time_of_flight=[10, 20, 30], + pixel_id=[1, 2, 3], + ) + result = split_ev44_pulses(ev44) + assert len(result) == 1 + assert result[0] is ev44 + + +def test_split_ev44_pulses_empty_reference_time() -> None: + ev44 = eventdata_ev44.EventData( + source_name='src', + message_id=0, + reference_time=[], + reference_time_index=[0], + time_of_flight=[10, 20], + pixel_id=[1, 2], ) - with pytest.raises(NotImplementedError): - events_cls.from_ev44(ev44) + result = split_ev44_pulses(ev44) + assert len(result) == 1 + assert result[0] is ev44 + + +def test_split_ev44_pulses_multi_pulse() -> None: + ev44 = eventdata_ev44.EventData( + source_name='src', + message_id=7, + reference_time=np.array([100, 200, 300]), + reference_time_index=np.array([0, 2, 5]), + time_of_flight=np.array([10, 20, 30, 40, 50, 60, 70]), + pixel_id=np.array([1, 2, 3, 4, 5, 6, 7]), + ) + pulses = split_ev44_pulses(ev44) + assert len(pulses) == 3 + + np.testing.assert_array_equal(pulses[0].reference_time, [100]) + np.testing.assert_array_equal(pulses[0].time_of_flight, [10, 20]) + np.testing.assert_array_equal(pulses[0].pixel_id, [1, 2]) + + np.testing.assert_array_equal(pulses[1].reference_time, [200]) + np.testing.assert_array_equal(pulses[1].time_of_flight, [30, 40, 50]) + np.testing.assert_array_equal(pulses[1].pixel_id, [3, 4, 5]) + + np.testing.assert_array_equal(pulses[2].reference_time, [300]) + np.testing.assert_array_equal(pulses[2].time_of_flight, [60, 70]) + np.testing.assert_array_equal(pulses[2].pixel_id, [6, 7]) def test_MonitorEvents_ToNXevent_data() -> None: diff --git a/tests/kafka/message_adapter_test.py b/tests/kafka/message_adapter_test.py index 59da99081..afe719a1d 100644 --- a/tests/kafka/message_adapter_test.py +++ b/tests/kafka/message_adapter_test.py @@ -122,7 +122,7 @@ def test_adapter(self) -> None: InputStreamKey(topic="monitors", source_name="monitor1"): "monitor_0" } ) - result = adapter.adapt(message) + [result] = adapter.adapt(message) assert result.stream.kind == StreamKind.MONITOR_EVENTS assert result.stream.name == "monitor_0" @@ -149,7 +149,7 @@ def test_no_reference_time_uses_message_timestamp(self) -> None: InputStreamKey(topic="monitors", source_name="monitor1"): "monitor_0" } ) - result = adapter.adapt(message) + [result] = adapter.adapt(message) assert result.timestamp == 9999 @@ -176,7 +176,7 @@ class TestKafkaToF144Adapter: def test_adapter(self) -> None: message = FakeKafkaMessage(value=make_serialized_f144(), topic="sensors") adapter = KafkaToF144Adapter() - result = adapter.adapt(message) + [result] = adapter.adapt(message) assert result.stream.kind == StreamKind.LOG assert result.stream.name == "temperature1" @@ -192,7 +192,7 @@ def test_adapter_with_stream_mapping(self) -> None: ): "mapped_temperature" } ) - result = adapter.adapt(message) + [result] = adapter.adapt(message) assert result.stream.kind == StreamKind.LOG assert result.stream.name == "mapped_temperature" @@ -202,10 +202,10 @@ class TestF144ToLogDataAdapter: def test_adapter(self) -> None: f144_adapter = KafkaToF144Adapter() message = FakeKafkaMessage(value=make_serialized_f144(), topic="sensors") - adapted_f144 = f144_adapter.adapt(message) + [adapted_f144] = f144_adapter.adapt(message) log_data_adapter = F144ToLogDataAdapter() - result = log_data_adapter.adapt(adapted_f144) + [result] = log_data_adapter.adapt(adapted_f144) assert result.stream.kind == StreamKind.LOG assert result.stream.name == "temperature1" @@ -218,7 +218,7 @@ class TestKafkaToDa00Adapter: def test_adapter(self) -> None: message = FakeKafkaMessage(value=make_serialized_da00(), topic="instrument") adapter = KafkaToDa00Adapter(stream_kind=StreamKind.MONITOR_COUNTS) - result = adapter.adapt(message) + [result] = adapter.adapt(message) assert result.stream.kind == StreamKind.MONITOR_COUNTS assert result.stream.name == "instrument" @@ -236,7 +236,7 @@ def test_adapter_with_stream_mapping(self) -> None: ): "mapped_instrument" }, ) - result = adapter.adapt(message) + [result] = adapter.adapt(message) assert result.stream.kind == StreamKind.MONITOR_COUNTS assert result.stream.name == "mapped_instrument" @@ -246,10 +246,10 @@ class TestDa00ToScippAdapter: def test_adapter(self) -> None: da00_adapter = KafkaToDa00Adapter(stream_kind=StreamKind.MONITOR_COUNTS) message = FakeKafkaMessage(value=make_serialized_da00(), topic="instrument") - adapted_da00 = da00_adapter.adapt(message) + [adapted_da00] = da00_adapter.adapt(message) scipp_adapter = Da00ToScippAdapter() - result = scipp_adapter.adapt(adapted_da00) + [result] = scipp_adapter.adapt(adapted_da00) assert result.stream.kind == StreamKind.MONITOR_COUNTS assert result.stream.name == "instrument" @@ -263,7 +263,7 @@ class TestKafkaToAd00Adapter: def test_adapter(self) -> None: message = FakeKafkaMessage(value=make_serialized_ad00(), topic="detector") adapter = KafkaToAd00Adapter(stream_kind=StreamKind.AREA_DETECTOR) - result = adapter.adapt(message) + [result] = adapter.adapt(message) assert result.stream.kind == StreamKind.AREA_DETECTOR assert result.stream.name == "area_detector" @@ -284,7 +284,7 @@ def test_adapter_with_stream_mapping(self) -> None: ): "mapped_detector" }, ) - result = adapter.adapt(message) + [result] = adapter.adapt(message) assert result.stream.kind == StreamKind.AREA_DETECTOR assert result.stream.name == "mapped_detector" @@ -294,10 +294,10 @@ class TestAd00ToScippAdapter: def test_adapter(self) -> None: ad00_adapter = KafkaToAd00Adapter(stream_kind=StreamKind.AREA_DETECTOR) message = FakeKafkaMessage(value=make_serialized_ad00(), topic="detector") - adapted_ad00 = ad00_adapter.adapt(message) + [adapted_ad00] = ad00_adapter.adapt(message) scipp_adapter = Ad00ToScippAdapter() - result = scipp_adapter.adapt(adapted_ad00) + [result] = scipp_adapter.adapt(adapted_ad00) assert result.stream.kind == StreamKind.AREA_DETECTOR assert result.stream.name == "area_detector" @@ -322,7 +322,7 @@ def test_adapter(self) -> None: ), ) adapter = Ev44ToDetectorEventsAdapter() - result = adapter.adapt(ev44_message) + [result] = adapter.adapt(ev44_message) assert result.timestamp == 1234 assert result.stream.kind == StreamKind.DETECTOR_EVENTS @@ -345,7 +345,7 @@ def test_adapter_merge_detectors(self) -> None: ), ) adapter = Ev44ToDetectorEventsAdapter(merge_detectors=True) - result = adapter.adapt(ev44_message) + [result] = adapter.adapt(ev44_message) assert result.stream.name == "unified_detector" assert isinstance(result.value, DetectorEvents) @@ -371,14 +371,14 @@ class TestAdapter: def __init__(self, value: str): self._value = value - def adapt(self, message: KafkaMessage) -> Message[str]: - return fake_message_with_value(message, self._value) + def adapt(self, message: KafkaMessage) -> tuple[Message[str]]: + return (fake_message_with_value(message, self._value),) adapter = RouteBySchemaAdapter( routes={"ev44": TestAdapter('adapter1'), "da00": TestAdapter('adapter2')} ) - assert adapter.adapt(message_with_schema('ev44')).value == "adapter1" - assert adapter.adapt(message_with_schema('da00')).value == "adapter2" + assert adapter.adapt(message_with_schema('ev44'))[0].value == "adapter1" + assert adapter.adapt(message_with_schema('da00'))[0].value == "adapter2" class TestRouteByTopicAdapter: @@ -389,10 +389,10 @@ def __init__(self, return_value: str): self.last_message = None self.return_value = return_value - def adapt(self, message: KafkaMessage) -> Message[str]: + def adapt(self, message: KafkaMessage) -> tuple[Message[str]]: self.adapt_called = True self.last_message = message - return fake_message_with_value(message, self.return_value) + return (fake_message_with_value(message, self.return_value),) adapter1 = TestAdapter("adapter1") adapter2 = TestAdapter("adapter2") @@ -402,13 +402,13 @@ def adapt(self, message: KafkaMessage) -> Message[str]: assert router.topics == ["topic1", "topic2"] msg1 = FakeKafkaMessage(value=b"dummy", topic="topic1") - result1 = router.adapt(msg1) + [result1] = router.adapt(msg1) assert adapter1.adapt_called is True assert adapter1.last_message == msg1 assert result1.value == "adapter1" msg2 = FakeKafkaMessage(value=b"dummy", topic="topic2") - result2 = router.adapt(msg2) + [result2] = router.adapt(msg2) assert adapter2.adapt_called is True assert adapter2.last_message == msg2 assert result2.value == "adapter2" @@ -432,7 +432,7 @@ def test_adapter_with_stream_mapping(self) -> None: ): "mapped_monitor1" }, ) - result = adapter.adapt(message) + [result] = adapter.adapt(message) assert result.stream.kind == StreamKind.MONITOR_EVENTS assert result.stream.name == "mapped_monitor1" @@ -455,10 +455,84 @@ def test_no_reference_time_uses_message_timestamp(self) -> None: ) adapter = KafkaToEv44Adapter(stream_kind=StreamKind.MONITOR_EVENTS) - result = adapter.adapt(message) + [result] = adapter.adapt(message) assert result.timestamp == 9999 + def test_multi_pulse_splits_into_multiple_messages(self) -> None: + multi_pulse_ev44 = eventdata_ev44.serialise_ev44( + source_name="monitor1", + message_id=0, + reference_time=np.array([100, 200]), + reference_time_index=np.array([0, 2]), + time_of_flight=np.array([10, 20, 30]), + pixel_id=np.array([1, 2, 3]), + ) + message = FakeKafkaMessage(value=multi_pulse_ev44, topic="monitors") + adapter = KafkaToEv44Adapter(stream_kind=StreamKind.MONITOR_EVENTS) + results = adapter.adapt(message) + + assert len(results) == 2 + assert results[0].timestamp == 100 + np.testing.assert_array_equal(results[0].value.time_of_flight, [10, 20]) + np.testing.assert_array_equal(results[0].value.pixel_id, [1, 2]) + + assert results[1].timestamp == 200 + np.testing.assert_array_equal(results[1].value.time_of_flight, [30]) + np.testing.assert_array_equal(results[1].value.pixel_id, [3]) + + +class TestKafkaToMonitorEventsAdapterMultiPulse: + def test_multi_pulse_splits_into_multiple_messages(self) -> None: + multi_pulse_ev44 = eventdata_ev44.serialise_ev44( + source_name="monitor1", + message_id=0, + reference_time=np.array([100, 200]), + reference_time_index=np.array([0, 2]), + time_of_flight=np.array([10, 20, 30]), + pixel_id=np.array([1, 2, 3]), + ) + message = FakeKafkaMessage(value=multi_pulse_ev44, topic="monitors") + adapter = KafkaToMonitorEventsAdapter( + stream_lut={ + InputStreamKey(topic="monitors", source_name="monitor1"): "monitor_0" + } + ) + results = adapter.adapt(message) + + assert len(results) == 2 + assert results[0].timestamp == 100 + np.testing.assert_array_equal(results[0].value.time_of_arrival, [10, 20]) + assert results[1].timestamp == 200 + np.testing.assert_array_equal(results[1].value.time_of_arrival, [30]) + + +class TestMultiPulseChainIntegration: + def test_multi_pulse_ev44_through_chain_produces_multiple_detector_events( + self, + ) -> None: + multi_pulse_ev44 = eventdata_ev44.serialise_ev44( + source_name="det1", + message_id=0, + reference_time=np.array([100, 200]), + reference_time_index=np.array([0, 1]), + time_of_flight=np.array([10, 20, 30]), + pixel_id=np.array([1, 2, 3]), + ) + message = FakeKafkaMessage(value=multi_pulse_ev44, topic="detectors") + adapter = ChainedAdapter( + first=KafkaToEv44Adapter(stream_kind=StreamKind.DETECTOR_EVENTS), + second=Ev44ToDetectorEventsAdapter(), + ) + results = adapter.adapt(message) + assert len(results) == 2 + assert results[0].timestamp == 100 + np.testing.assert_array_equal(results[0].value.time_of_arrival, [10]) + np.testing.assert_array_equal(results[0].value.pixel_id, [1]) + assert results[1].timestamp == 200 + np.testing.assert_array_equal(results[1].value.time_of_arrival, [20, 30]) + np.testing.assert_array_equal(results[1].value.pixel_id, [2, 3]) + class TestAdaptingMessageSource: def test_source(self) -> None: @@ -537,7 +611,7 @@ def test_adapter(self) -> None: key=key, value=encoded, topic="dummy_livedata_commands" ) adapter = CommandsAdapter() - adapted_message = adapter.adapt(message) + [adapted_message] = adapter.adapt(message) assert adapted_message.stream == COMMANDS_STREAM_ID assert adapted_message.value == RawConfigItem(key=key, value=encoded) @@ -559,7 +633,7 @@ def test_adapter(self) -> None: key=b'', value=encoded, topic="dummy_livedata_responses" ) adapter = ResponsesAdapter() - adapted_message = adapter.adapt(message) + [adapted_message] = adapter.adapt(message) assert adapted_message.stream == RESPONSES_STREAM_ID assert adapted_message.value == ack diff --git a/tests/kafka/sink_test.py b/tests/kafka/sink_test.py index b9f271c7e..7b9c431ae 100644 --- a/tests/kafka/sink_test.py +++ b/tests/kafka/sink_test.py @@ -48,7 +48,7 @@ def test_serialize_dataarray_to_da00_roundtrip_preserves_data(self) -> None: ) # Deserialize back to scipp - result_msg = roundtrip_adapter.adapt(kafka_msg) + [result_msg] = roundtrip_adapter.adapt(kafka_msg) # Verify the roundtrip preserved the data and payload timestamp assert sc.identical(result_msg.value, original_data) @@ -85,7 +85,7 @@ def test_serialize_dataarray_to_da00_with_multidimensional_data(self) -> None: timestamp=kafka_timestamp, # Should be ignored ) - result_msg = roundtrip_adapter.adapt(kafka_msg) + [result_msg] = roundtrip_adapter.adapt(kafka_msg) # Verify roundtrip uses payload timestamp, not Kafka timestamp assert sc.identical(result_msg.value, original_data) @@ -124,7 +124,7 @@ def test_roundtrip_preserves_value_and_payload_timestamp( ) # Deserialize back to f144 log data - result_msg = f144_adapter.adapt(kafka_msg) + [result_msg] = f144_adapter.adapt(kafka_msg) # Verify the roundtrip preserved the value and used time coordinate as timestamp assert result_msg.value.value == 42.5 # Value preserved @@ -156,7 +156,7 @@ def test_serialize_dataarray_to_f144_with_array_data(self) -> None: timestamp=2222222222, # Should be ignored ) - result_msg = f144_adapter.adapt(kafka_msg) + [result_msg] = f144_adapter.adapt(kafka_msg) # Verify array values are preserved and timestamp comes from time coordinate np.testing.assert_array_equal(result_msg.value.value, [1.0, 2.0, 3.0]) @@ -180,7 +180,7 @@ def test_serialize_dataarray_to_f144_different_time_units(self) -> None: value=serialized_bytes, topic='log_topic', timestamp=0 ) - result_msg = f144_adapter.adapt(kafka_msg) + [result_msg] = f144_adapter.adapt(kafka_msg) # Time should be converted to nanoseconds expected_time_ns = time_us * 1000 # us to ns conversion