diff --git a/src/ess/livedata/dashboard/plots.py b/src/ess/livedata/dashboard/plots.py index 03350a41d..eca38a32c 100644 --- a/src/ess/livedata/dashboard/plots.py +++ b/src/ess/livedata/dashboard/plots.py @@ -7,6 +7,7 @@ import weakref from collections.abc import Callable from dataclasses import dataclass +from dataclasses import field as dataclass_field from typing import Any, ClassVar, cast import holoviews as hv @@ -215,6 +216,70 @@ def _compute_time_info(data: dict[str, sc.DataArray]) -> str | None: return f'{end_str} (Lag: {lag_s:.1f}s)' +@dataclass(frozen=True) +class CanvasSpec: + """Unit metadata extracted from a scipp DataArray for overlay validation. + + Captures coordinate units (keyed by dimension name) and the value unit. + Two overlays are compatible when they share the same coordinate dimensions + with matching units and matching value units. + """ + + coord_units: dict[str, sc.Unit] = dataclass_field(default_factory=dict) + value_unit: sc.Unit = dataclass_field(default=sc.units.dimensionless) + + @classmethod + def from_data_array(cls, da: sc.DataArray) -> CanvasSpec: + """Build from a scipp DataArray using its dimension coordinates.""" + return cls( + coord_units={ + dim: da.coords[dim].unit for dim in da.dims if dim in da.coords + }, + value_unit=da.unit, + ) + + +def validate_canvas_spec( + entries: list[tuple[str, CanvasSpec]], +) -> str | None: + """Check that overlay entries share compatible units. + + Parameters + ---------- + entries: + Pairs of (label, units) for each data-carrying element or layer. + + Returns + ------- + : + An error message if units are incompatible, or None if compatible. + """ + if len(entries) < 2: + return None + + ref_label, ref = entries[0] + for label, other in entries[1:]: + if ref.coord_units.keys() != other.coord_units.keys(): + return ( + f"Cannot overlay '{ref_label}' and '{label}': " + f"dimension mismatch " + f"({set(ref.coord_units)} vs {set(other.coord_units)})" + ) + for dim, ref_unit in ref.coord_units.items(): + if ref_unit != other.coord_units[dim]: + return ( + f"Cannot overlay '{ref_label}' and '{label}': " + f"unit mismatch for coordinate '{dim}' " + f"({ref_unit} vs {other.coord_units[dim]})" + ) + if ref.value_unit != other.value_unit: + return ( + f"Cannot overlay '{ref_label}' and '{label}': " + f"value unit mismatch ({ref.value_unit} vs {other.value_unit})" + ) + return None + + class Plotter: """ Base class for plots that support autoscaling. @@ -223,6 +288,8 @@ class Plotter: This enables efficient polling-based update detection. """ + participates_in_overlay_validation: ClassVar[bool] = True + def __init__( self, *, @@ -246,6 +313,7 @@ def __init__( """ self._normalize_to_rate = normalize_to_rate self._cached_state: Any | None = None + self.canvas_spec: CanvasSpec | None = None self._presenters: weakref.WeakSet[PresenterBase] = weakref.WeakSet() self.autoscaler_kwargs = kwargs self.autoscalers: dict[ResultKey, Autoscaler] = {} @@ -434,6 +502,16 @@ def compute( resolver = title_resolver or TitleResolver() plots: list[hv.Element] = [] try: + if self.participates_in_overlay_validation and data: + entries = [ + (key.job_id.source_name, CanvasSpec.from_data_array(da)) + for key, da in data.items() + ] + if self.layout_params.combine_mode == 'overlay': + error = validate_canvas_spec(entries) + if error is not None: + raise ValueError(error) + self.canvas_spec = entries[0][1] for data_key, da in data.items(): label = resolver.get_legend_label( data_key.job_id.source_name, data_key.output_name diff --git a/src/ess/livedata/dashboard/roi_readback_plots.py b/src/ess/livedata/dashboard/roi_readback_plots.py index aa0f1dc84..1680abdaf 100644 --- a/src/ess/livedata/dashboard/roi_readback_plots.py +++ b/src/ess/livedata/dashboard/roi_readback_plots.py @@ -71,6 +71,8 @@ class RectanglesReadbackPlotter(Plotter): rectangles with per-shape colors based on the ROI index. """ + participates_in_overlay_validation = False + def __init__(self, params: RectanglesReadbackParams) -> None: super().__init__() self._params = params @@ -177,6 +179,8 @@ class PolygonsReadbackPlotter(Plotter): polygons with per-shape colors based on the ROI index. """ + participates_in_overlay_validation = False + def __init__(self, params: PolygonsReadbackParams) -> None: super().__init__() self._params = params diff --git a/src/ess/livedata/dashboard/roi_request_plots.py b/src/ess/livedata/dashboard/roi_request_plots.py index 28a0881fc..3f9a03902 100644 --- a/src/ess/livedata/dashboard/roi_request_plots.py +++ b/src/ess/livedata/dashboard/roi_request_plots.py @@ -533,6 +533,8 @@ class BaseROIRequestPlotter(Plotter, ABC, Generic[ROIType, ParamsType, Converter skip logic, publishing) is handled via a closure-based edit callback. """ + participates_in_overlay_validation = False + def __init__( self, params: ParamsType, diff --git a/src/ess/livedata/dashboard/widgets/plot_grid_tabs.py b/src/ess/livedata/dashboard/widgets/plot_grid_tabs.py index c18c63d7a..49e3adbb6 100644 --- a/src/ess/livedata/dashboard/widgets/plot_grid_tabs.py +++ b/src/ess/livedata/dashboard/widgets/plot_grid_tabs.py @@ -35,6 +35,7 @@ SubscriptionId, ) from ..plot_params import PlotAspectType, StretchMode +from ..plots import CanvasSpec, validate_canvas_spec from ..save_filename import build_save_filename_from_cell, make_save_filename_hook from ..session_layer import SessionLayer from ..session_updater import SessionUpdater @@ -940,6 +941,7 @@ def _get_session_composed_plot( Composed plot from session DMaps/elements, or None if none available. """ plots = [] + overlay_entries: list[tuple[str, CanvasSpec]] = [] has_layout = False for layer in cell.layers: layer_id = layer.layer_id @@ -967,9 +969,23 @@ def _get_session_composed_plot( has_layout = True plots.append(dmap) + # Collect overlay units for cross-layer validation + # (only data-carrying plotters set canvas_spec) + if state is not None and state.plotter is not None: + units = state.plotter.canvas_spec + if units is not None: + overlay_entries.append((layer.config.plot_name, units)) + if not plots: return None + # Validate cross-layer unit compatibility before composing + error = validate_canvas_spec(overlay_entries) + if error is not None: + return hv.Text(0.5, 0.5, f"Error: {error}").opts( + text_align='center', text_baseline='middle' + ) + result: hv.DynamicMap | hv.Element if len(plots) == 1: result = plots[0] diff --git a/tests/dashboard/plots_test.py b/tests/dashboard/plots_test.py index 9fa9f6fa5..435dbc31e 100644 --- a/tests/dashboard/plots_test.py +++ b/tests/dashboard/plots_test.py @@ -2100,3 +2100,135 @@ def test_bars_plotter_normalizes_when_enabled(self, data_key): bars = next(iter(result.values())) assert isinstance(bars, hv.Bars) assert bars.vdims[0].unit == 'counts/s' + + +def _make_result_key(source_name: str) -> ResultKey: + return ResultKey( + workflow_id=WorkflowId(instrument='test', namespace='ns', name='wf', version=1), + job_id=JobId(source_name=source_name, job_number=uuid.uuid4()), + output_name='result', + ) + + +def _make_1d_da( + *, dim: str = 'tof', coord_unit: str = 'ms', value_unit: str = 'counts' +) -> sc.DataArray: + coord = sc.arange(dim, 5, dtype='float64', unit=coord_unit) + return sc.DataArray( + sc.ones(sizes={dim: 5}, unit=value_unit), + coords={dim: coord}, + ) + + +def _units(coord_units: dict[str, str], value_unit: str = 'counts') -> plots.CanvasSpec: + return plots.CanvasSpec( + coord_units={k: sc.Unit(v) for k, v in coord_units.items()}, + value_unit=sc.Unit(value_unit), + ) + + +class TestCanvasSpec: + """Tests for CanvasSpec construction.""" + + def test_from_data_array_1d(self): + da = _make_1d_da(dim='tof', coord_unit='ms', value_unit='counts') + units = plots.CanvasSpec.from_data_array(da) + assert units.coord_units == {'tof': sc.Unit('ms')} + assert units.value_unit == sc.Unit('counts') + + def test_from_data_array_2d(self): + da = sc.DataArray( + sc.ones(sizes={'y': 4, 'x': 5}, unit='counts'), + coords={ + 'x': sc.arange('x', 5, dtype='float64', unit='m'), + 'y': sc.arange('y', 4, dtype='float64', unit='m'), + }, + ) + units = plots.CanvasSpec.from_data_array(da) + assert units.coord_units == {'x': sc.Unit('m'), 'y': sc.Unit('m')} + assert units.value_unit == sc.Unit('counts') + + +class TestValidateCanvasSpec: + """Tests for validate_canvas_spec (used at both sites).""" + + def test_single_entry_passes(self): + assert plots.validate_canvas_spec([('a', _units({'tof': 'ms'}))]) is None + + def test_matching_units_pass(self): + entries = [ + ('a', _units({'tof': 'ms'})), + ('b', _units({'tof': 'ms'})), + ] + assert plots.validate_canvas_spec(entries) is None + + def test_mismatched_coord_units(self): + entries = [ + ('src_a', _units({'tof': 'ms'})), + ('src_b', _units({'tof': 'us'})), + ] + error = plots.validate_canvas_spec(entries) + assert error is not None + assert 'tof' in error + assert 'ms' in error + + def test_mismatched_value_units(self): + entries = [ + ('a', _units({'tof': 'ms'}, value_unit='counts')), + ('b', _units({'tof': 'ms'}, value_unit='counts/us')), + ] + error = plots.validate_canvas_spec(entries) + assert error is not None + assert 'value unit mismatch' in error + + def test_mismatched_dims(self): + entries = [ + ('a', _units({'tof': 'ms'})), + ('b', _units({'wavelength': 'angstrom'})), + ] + error = plots.validate_canvas_spec(entries) + assert error is not None + assert 'dimension mismatch' in error + + def test_error_includes_labels(self): + entries = [ + ('detector_1', _units({'tof': 'ms'})), + ('detector_2', _units({'tof': 'us'})), + ] + error = plots.validate_canvas_spec(entries) + assert 'detector_1' in error + assert 'detector_2' in error + + +class TestPlotterCanvasSpec: + """Integration: plotter stores canvas_spec after compute().""" + + def test_line_plotter_stores_units(self): + plotter = plots.LinePlotter.from_params(PlotParams1d()) + da = _make_1d_da(dim='tof', coord_unit='ms', value_unit='counts') + plotter.compute({_make_result_key('src'): da}) + assert plotter.canvas_spec == _units({'tof': 'ms'}) + + def test_image_plotter_stores_units(self): + plotter = plots.ImagePlotter.from_params(PlotParams2d()) + da = sc.DataArray( + sc.ones(sizes={'y': 4, 'x': 5}, unit='counts'), + coords={ + 'x': sc.arange('x', 5, dtype='float64', unit='m'), + 'y': sc.arange('y', 4, dtype='float64', unit='m'), + }, + ) + plotter.compute({_make_result_key('src'): da}) + assert plotter.canvas_spec == _units({'x': 'm', 'y': 'm'}) + + def test_compute_with_mismatched_units_shows_error(self): + plotter = plots.LinePlotter.from_params(PlotParams1d()) + data = { + _make_result_key('src_a'): _make_1d_da(coord_unit='ms'), + _make_result_key('src_b'): _make_1d_da(coord_unit='us'), + } + plotter.compute(data) + result = plotter.get_cached_state() + assert isinstance(result, hv.Overlay) + (element,) = result.values() + assert isinstance(element, hv.Text) diff --git a/tests/dashboard/roi_readback_plots_test.py b/tests/dashboard/roi_readback_plots_test.py index edcf72dce..b685f3273 100644 --- a/tests/dashboard/roi_readback_plots_test.py +++ b/tests/dashboard/roi_readback_plots_test.py @@ -422,3 +422,29 @@ def test_lines_plotter_compatible_with_regular_1d_data(self): ) compatible = plotter_registry.get_compatible_plotters({'key': data}) assert 'lines' in compatible + + +class TestROIReadbackOverlayValidation: + """ROI readback plotters are exempt from overlay unit validation.""" + + def test_rectangles_readback_does_not_set_canvas_spec(self, result_key): + roi = RectangleROI( + x=Interval(min=0.0, max=1.0, unit='m'), + y=Interval(min=0.0, max=1.0, unit='m'), + ) + da = RectangleROI.to_concatenated_data_array({0: roi}) + plotter = RectanglesReadbackPlotter(RectanglesReadbackParams()) + plotter.compute({result_key: da}) + assert plotter.canvas_spec is None + + def test_polygons_readback_does_not_set_canvas_spec(self, result_key): + roi = PolygonROI( + x=[0.0, 1.0, 1.0, 0.0], + y=[0.0, 0.0, 1.0, 1.0], + x_unit='mm', + y_unit='mm', + ) + da = PolygonROI.to_concatenated_data_array({0: roi}) + plotter = PolygonsReadbackPlotter(PolygonsReadbackParams()) + plotter.compute({result_key: da}) + assert plotter.canvas_spec is None diff --git a/tests/dashboard/widgets/plot_grid_tabs_layout_test.py b/tests/dashboard/widgets/plot_grid_tabs_layout_test.py index 5bd7dbfa5..952cf7b06 100644 --- a/tests/dashboard/widgets/plot_grid_tabs_layout_test.py +++ b/tests/dashboard/widgets/plot_grid_tabs_layout_test.py @@ -48,6 +48,7 @@ class FakePlotter: def __init__(self, cached_state=None): self._cached_state = cached_state self._presenters: list[FakePresenter] = [] + self.canvas_spec = None def get_cached_state(self): return self._cached_state