From 435857a4d04eb3661b31e5d18c5d3b78e9b44566 Mon Sep 17 00:00:00 2001 From: Simon Heybrock Date: Tue, 14 Apr 2026 07:40:29 +0000 Subject: [PATCH 1/2] Add overlay unit validation at both intra-layer and cross-layer sites HoloViews silently renders overlaid elements against a single shared axis using the first element's units, producing incorrect plots when sources have mismatched units. This adds validation at two sites: Site 1 (Plotter.compute): validates that all DataArrays in an overlay share the same dimension names, coordinate units, and value units. Errors are caught by the existing try/except and displayed as text. Site 2 (_get_session_composed_plot): validates cross-layer unit compatibility before composing DynamicMaps. Each Plotter stores an OverlayUnitKey after compute(), which captures kdim and vdim unit metadata. ROI readback/request plotters provide kdim units from their coordinate metadata but skip vdim validation. Static plotters are naturally exempt (no Plotter instance). Co-Authored-By: Claude Sonnet 4.6 --- src/ess/livedata/dashboard/plots.py | 139 ++++++++++ .../livedata/dashboard/roi_readback_plots.py | 34 ++- .../livedata/dashboard/roi_request_plots.py | 12 +- .../dashboard/widgets/plot_grid_tabs.py | 15 ++ tests/dashboard/plots_test.py | 239 ++++++++++++++++++ tests/dashboard/roi_readback_plots_test.py | 32 +++ .../widgets/plot_grid_tabs_layout_test.py | 1 + 7 files changed, 470 insertions(+), 2 deletions(-) diff --git a/src/ess/livedata/dashboard/plots.py b/src/ess/livedata/dashboard/plots.py index 03350a41d..8462481be 100644 --- a/src/ess/livedata/dashboard/plots.py +++ b/src/ess/livedata/dashboard/plots.py @@ -215,6 +215,121 @@ def _compute_time_info(data: dict[str, sc.DataArray]) -> str | None: return f'{end_str} (Lag: {lag_s:.1f}s)' +@dataclass(frozen=True) +class OverlayUnitKey: + """Unit signature for cross-layer overlay validation. + + Captures the units of all plotting dimensions and (optionally) the value + dimension. Two keys are compatible when every shared kdim label has the + same unit string and, if both keys carry a vdim unit, those match too. + """ + + kdim_units: tuple[tuple[str, str | None], ...] = () + vdim_unit: str | None = None + + @classmethod + def from_data_array( + cls, + da: sc.DataArray, + *, + include_vdim: bool = True, + ) -> OverlayUnitKey: + """Build a key from a scipp DataArray using its dimension coordinates.""" + kdim_units = tuple( + sorted( + (dim, _unit_str(da.coords[dim].unit)) + for dim in da.dims + if dim in da.coords + ) + ) + vdim_unit = _unit_str(da.unit) if include_vdim else None + return cls(kdim_units=kdim_units, vdim_unit=vdim_unit) + + +def _unit_str(unit: sc.Unit | None) -> str | None: + """Convert a scipp unit to its string representation, or None.""" + return str(unit) if unit is not None else None + + +def _validate_overlay_units(data: dict[ResultKey, sc.DataArray]) -> None: + """Validate that all DataArrays in an overlay share compatible units. + + Raises + ------ + ValueError + If dimension names, coordinate units, or value units do not match. + """ + if len(data) < 2: + return + + items = iter(data.items()) + ref_key, ref = next(items) + ref_source = ref_key.job_id.source_name + + for key, da in items: + source = key.job_id.source_name + if ref.dims != da.dims: + raise ValueError( + f"Cannot overlay '{ref_source}' and '{source}': " + f"dimension mismatch ({ref.dims} vs {da.dims})" + ) + for dim in ref.dims: + if dim in ref.coords and dim in da.coords: + u_ref = ref.coords[dim].unit + u_da = da.coords[dim].unit + if u_ref != u_da: + raise ValueError( + f"Cannot overlay '{ref_source}' and '{source}': " + f"unit mismatch for coordinate '{dim}' " + f"({u_ref} vs {u_da})" + ) + if ref.unit != da.unit: + raise ValueError( + f"Cannot overlay '{ref_source}' and '{source}': " + f"value unit mismatch ({ref.unit} vs {da.unit})" + ) + + +def validate_cross_layer_units( + keys: list[tuple[str, OverlayUnitKey]], +) -> str | None: + """Check overlay unit keys from multiple layers for compatibility. + + Parameters + ---------- + keys: + Pairs of (layer label, unit key) for each data-carrying layer. + + Returns + ------- + : + An error message if units are incompatible, or None if compatible. + """ + if len(keys) < 2: + return None + + ref_label, ref_key = keys[0] + ref_kdims = dict(ref_key.kdim_units) + + for label, key in keys[1:]: + other_kdims = dict(key.kdim_units) + for dim_label, ref_unit in ref_kdims.items(): + if dim_label in other_kdims and ref_unit != other_kdims[dim_label]: + return ( + f"Cannot overlay layers '{ref_label}' and '{label}': " + f"unit mismatch for coordinate '{dim_label}' " + f"({ref_unit} vs {other_kdims[dim_label]})" + ) + if ref_key.vdim_unit is not None and key.vdim_unit is not None: + if ref_key.vdim_unit != key.vdim_unit: + return ( + f"Cannot overlay layers '{ref_label}' and '{label}': " + f"value unit mismatch " + f"({ref_key.vdim_unit} vs {key.vdim_unit})" + ) + return None + + class Plotter: """ Base class for plots that support autoscaling. @@ -246,6 +361,7 @@ def __init__( """ self._normalize_to_rate = normalize_to_rate self._cached_state: Any | None = None + self._overlay_unit_key: OverlayUnitKey | None = None self._presenters: weakref.WeakSet[PresenterBase] = weakref.WeakSet() self.autoscaler_kwargs = kwargs self.autoscalers: dict[ResultKey, Autoscaler] = {} @@ -258,6 +374,25 @@ def __init__( # with responsive mode in Panel containers (upstream bug). self._sizing_opts: dict[str, Any] = {'responsive': True} + @property + def overlay_unit_key(self) -> OverlayUnitKey | None: + """Unit signature from the last compute(), for cross-layer validation.""" + return self._overlay_unit_key + + def _extract_overlay_unit_key( + self, data: dict[ResultKey, sc.DataArray] + ) -> OverlayUnitKey | None: + """Extract an overlay unit key from input data. + + Default implementation uses dimension coordinates and value unit from the + first DataArray. Subclasses with non-standard coordinate structures (e.g., + ROI plotters) should override this method. + """ + if not data: + return None + da = next(iter(data.values())) + return OverlayUnitKey.from_data_array(da) + @staticmethod def _make_tick_opts(tick_params: TickParams | None) -> dict[str, Any]: """ @@ -431,9 +566,13 @@ def compute( if self._normalize_to_rate: data = {key: _normalize_to_rate(da) for key, da in data.items()} + self._overlay_unit_key = self._extract_overlay_unit_key(data) + resolver = title_resolver or TitleResolver() plots: list[hv.Element] = [] try: + if self.layout_params.combine_mode == 'overlay': + _validate_overlay_units(data) for data_key, da in data.items(): label = resolver.get_legend_label( data_key.job_id.source_name, data_key.output_name diff --git a/src/ess/livedata/dashboard/roi_readback_plots.py b/src/ess/livedata/dashboard/roi_readback_plots.py index aa0f1dc84..34b69ce55 100644 --- a/src/ess/livedata/dashboard/roi_readback_plots.py +++ b/src/ess/livedata/dashboard/roi_readback_plots.py @@ -21,7 +21,23 @@ ) from ess.livedata.config.workflow_spec import ResultKey -from .plots import Plotter +from .plots import OverlayUnitKey, Plotter, _unit_str + + +def _roi_overlay_unit_key(da: sc.DataArray) -> OverlayUnitKey: + """Build an overlay unit key from ROI coordinate metadata. + + ROI DataArrays use named coords ('x', 'y') rather than dimension + coordinates, and have no meaningful value unit. + """ + kdim_units = tuple( + sorted( + (name, _unit_str(da.coords[name].unit)) + for name in ('x', 'y') + if name in da.coords + ) + ) + return OverlayUnitKey(kdim_units=kdim_units, vdim_unit=None) class ROIReadbackStyle(pydantic.BaseModel): @@ -81,6 +97,14 @@ def from_params(cls, params: RectanglesReadbackParams) -> RectanglesReadbackPlot """Create plotter from params.""" return cls(params) + def _extract_overlay_unit_key( + self, data: dict[ResultKey, sc.DataArray] + ) -> OverlayUnitKey | None: + if not data: + return None + da = next(iter(data.values())) + return _roi_overlay_unit_key(da) + def plot( self, data: sc.DataArray, data_key: ResultKey, *, label: str = '', **kwargs ) -> hv.Rectangles: @@ -187,6 +211,14 @@ def from_params(cls, params: PolygonsReadbackParams) -> PolygonsReadbackPlotter: """Create plotter from params.""" return cls(params) + def _extract_overlay_unit_key( + self, data: dict[ResultKey, sc.DataArray] + ) -> OverlayUnitKey | None: + if not data: + return None + da = next(iter(data.values())) + return _roi_overlay_unit_key(da) + def plot( self, data: sc.DataArray, data_key: ResultKey, *, label: str = '', **kwargs ) -> hv.Polygons: diff --git a/src/ess/livedata/dashboard/roi_request_plots.py b/src/ess/livedata/dashboard/roi_request_plots.py index 28a0881fc..b45d31a9f 100644 --- a/src/ess/livedata/dashboard/roi_request_plots.py +++ b/src/ess/livedata/dashboard/roi_request_plots.py @@ -48,7 +48,7 @@ get_roi_mapper, ) -from .plots import Plotter, PresenterBase +from .plots import OverlayUnitKey, Plotter, PresenterBase from .static_plots import Color, LineDash, RectanglesCoordinates if TYPE_CHECKING: @@ -626,6 +626,16 @@ def compute( else None ) + # Build overlay unit key from ROI coordinate units + kdim_units = tuple( + sorted( + (name, unit) + for name, unit in [('x', self._x_unit), ('y', self._y_unit)] + if unit is not None + ) + ) + self._overlay_unit_key = OverlayUnitKey(kdim_units=kdim_units, vdim_unit=None) + # Forward data (presenter may use in future) self._set_cached_state(data) return data diff --git a/src/ess/livedata/dashboard/widgets/plot_grid_tabs.py b/src/ess/livedata/dashboard/widgets/plot_grid_tabs.py index c18c63d7a..1304b5824 100644 --- a/src/ess/livedata/dashboard/widgets/plot_grid_tabs.py +++ b/src/ess/livedata/dashboard/widgets/plot_grid_tabs.py @@ -35,6 +35,7 @@ SubscriptionId, ) from ..plot_params import PlotAspectType, StretchMode +from ..plots import OverlayUnitKey, validate_cross_layer_units from ..save_filename import build_save_filename_from_cell, make_save_filename_hook from ..session_layer import SessionLayer from ..session_updater import SessionUpdater @@ -940,6 +941,7 @@ def _get_session_composed_plot( Composed plot from session DMaps/elements, or None if none available. """ plots = [] + unit_keys: list[tuple[str, OverlayUnitKey]] = [] has_layout = False for layer in cell.layers: layer_id = layer.layer_id @@ -967,9 +969,22 @@ def _get_session_composed_plot( has_layout = True plots.append(dmap) + # Collect unit key for cross-layer validation + if state is not None and state.plotter is not None: + key = state.plotter.overlay_unit_key + if key is not None: + unit_keys.append((layer.config.plot_name, key)) + if not plots: return None + # Validate cross-layer unit compatibility before composing + error = validate_cross_layer_units(unit_keys) + if error is not None: + return hv.Text(0.5, 0.5, f"Error: {error}").opts( + text_align='center', text_baseline='middle' + ) + result: hv.DynamicMap | hv.Element if len(plots) == 1: result = plots[0] diff --git a/tests/dashboard/plots_test.py b/tests/dashboard/plots_test.py index 9fa9f6fa5..a0e927d77 100644 --- a/tests/dashboard/plots_test.py +++ b/tests/dashboard/plots_test.py @@ -2100,3 +2100,242 @@ def test_bars_plotter_normalizes_when_enabled(self, data_key): bars = next(iter(result.values())) assert isinstance(bars, hv.Bars) assert bars.vdims[0].unit == 'counts/s' + + +def _make_result_key(source_name: str) -> ResultKey: + return ResultKey( + workflow_id=WorkflowId(instrument='test', namespace='ns', name='wf', version=1), + job_id=JobId(source_name=source_name, job_number=uuid.uuid4()), + output_name='result', + ) + + +def _make_1d_da( + *, dim: str = 'tof', coord_unit: str = 'ms', value_unit: str = 'counts' +) -> sc.DataArray: + coord = sc.arange(dim, 5, dtype='float64', unit=coord_unit) + return sc.DataArray( + sc.ones(sizes={dim: 5}, unit=value_unit), + coords={dim: coord}, + ) + + +class TestValidateOverlayUnits: + """Tests for Site 1: intra-layer overlay unit validation.""" + + def test_single_entry_passes(self): + key = _make_result_key('src_a') + da = _make_1d_da() + plots._validate_overlay_units({key: da}) + + def test_matching_units_pass(self): + data = { + _make_result_key('src_a'): _make_1d_da(), + _make_result_key('src_b'): _make_1d_da(), + } + plots._validate_overlay_units(data) + + def test_mismatched_coord_units_raise(self): + data = { + _make_result_key('src_a'): _make_1d_da(coord_unit='ms'), + _make_result_key('src_b'): _make_1d_da(coord_unit='us'), + } + with pytest.raises(ValueError, match=r"coordinate 'tof'.*ms.*µs"): + plots._validate_overlay_units(data) + + def test_mismatched_value_units_raise(self): + data = { + _make_result_key('src_a'): _make_1d_da(value_unit='counts'), + _make_result_key('src_b'): _make_1d_da(value_unit='counts/us'), + } + with pytest.raises(ValueError, match="value unit mismatch"): + plots._validate_overlay_units(data) + + def test_mismatched_dims_raise(self): + data = { + _make_result_key('src_a'): _make_1d_da(dim='tof'), + _make_result_key('src_b'): _make_1d_da(dim='wavelength'), + } + with pytest.raises(ValueError, match="dimension mismatch"): + plots._validate_overlay_units(data) + + def test_error_message_includes_source_names(self): + data = { + _make_result_key('detector_1'): _make_1d_da(coord_unit='ms'), + _make_result_key('detector_2'): _make_1d_da(coord_unit='us'), + } + with pytest.raises(ValueError, match=r"detector_1.*detector_2"): + plots._validate_overlay_units(data) + + +class TestOverlayUnitKey: + """Tests for OverlayUnitKey construction and cross-layer validation.""" + + def test_from_data_array_1d(self): + da = _make_1d_da(dim='tof', coord_unit='ms', value_unit='counts') + key = plots.OverlayUnitKey.from_data_array(da) + assert key.kdim_units == (('tof', 'ms'),) + assert key.vdim_unit == 'counts' + + def test_from_data_array_2d(self): + da = sc.DataArray( + sc.ones(sizes={'y': 4, 'x': 5}, unit='counts'), + coords={ + 'x': sc.arange('x', 5, dtype='float64', unit='m'), + 'y': sc.arange('y', 4, dtype='float64', unit='m'), + }, + ) + key = plots.OverlayUnitKey.from_data_array(da) + assert key.kdim_units == (('x', 'm'), ('y', 'm')) + assert key.vdim_unit == 'counts' + + def test_from_data_array_no_vdim(self): + da = _make_1d_da() + key = plots.OverlayUnitKey.from_data_array(da, include_vdim=False) + assert key.vdim_unit is None + + def test_cross_layer_matching_units_pass(self): + keys = [ + ( + 'layer_a', + plots.OverlayUnitKey(kdim_units=(('tof', 'ms'),), vdim_unit='counts'), + ), + ( + 'layer_b', + plots.OverlayUnitKey(kdim_units=(('tof', 'ms'),), vdim_unit='counts'), + ), + ] + assert plots.validate_cross_layer_units(keys) is None + + def test_cross_layer_mismatched_kdim_units(self): + keys = [ + ( + 'layer_a', + plots.OverlayUnitKey(kdim_units=(('tof', 'ms'),), vdim_unit='counts'), + ), + ( + 'layer_b', + plots.OverlayUnitKey(kdim_units=(('tof', 'us'),), vdim_unit='counts'), + ), + ] + error = plots.validate_cross_layer_units(keys) + assert error is not None + assert 'tof' in error + assert 'ms' in error + assert 'us' in error + + def test_cross_layer_mismatched_vdim_units(self): + keys = [ + ( + 'layer_a', + plots.OverlayUnitKey(kdim_units=(('tof', 'ms'),), vdim_unit='counts'), + ), + ( + 'layer_b', + plots.OverlayUnitKey( + kdim_units=(('tof', 'ms'),), vdim_unit='counts/us' + ), + ), + ] + error = plots.validate_cross_layer_units(keys) + assert error is not None + assert 'value unit mismatch' in error + + def test_cross_layer_vdim_none_skips_check(self): + keys = [ + ( + 'image', + plots.OverlayUnitKey( + kdim_units=(('x', 'm'), ('y', 'm')), vdim_unit='counts' + ), + ), + ( + 'roi', + plots.OverlayUnitKey( + kdim_units=(('x', 'm'), ('y', 'm')), vdim_unit=None + ), + ), + ] + assert plots.validate_cross_layer_units(keys) is None + + def test_cross_layer_vdim_none_still_checks_kdims(self): + keys = [ + ( + 'image', + plots.OverlayUnitKey(kdim_units=(('x', 'm'),), vdim_unit='counts'), + ), + ('roi', plots.OverlayUnitKey(kdim_units=(('x', 'mm'),), vdim_unit=None)), + ] + error = plots.validate_cross_layer_units(keys) + assert error is not None + assert 'x' in error + + def test_cross_layer_single_key_passes(self): + keys = [ + ( + 'only_layer', + plots.OverlayUnitKey(kdim_units=(('tof', 'ms'),), vdim_unit='counts'), + ), + ] + assert plots.validate_cross_layer_units(keys) is None + + def test_cross_layer_non_overlapping_kdims_pass(self): + keys = [ + ( + 'curve', + plots.OverlayUnitKey(kdim_units=(('tof', 'ms'),), vdim_unit='counts'), + ), + ( + 'roi', + plots.OverlayUnitKey( + kdim_units=(('x', 'm'), ('y', 'm')), vdim_unit=None + ), + ), + ] + assert plots.validate_cross_layer_units(keys) is None + + +class TestPlotterOverlayUnitKey: + """Integration tests: plotter stores overlay unit key after compute().""" + + def test_line_plotter_stores_key(self): + params = PlotParams1d() + plotter = plots.LinePlotter.from_params(params) + key = _make_result_key('src') + da = _make_1d_da(dim='tof', coord_unit='ms', value_unit='counts') + plotter.compute({key: da}) + unit_key = plotter.overlay_unit_key + assert unit_key is not None + assert unit_key.kdim_units == (('tof', 'ms'),) + assert unit_key.vdim_unit == 'counts' + + def test_image_plotter_stores_key(self): + params = PlotParams2d() + plotter = plots.ImagePlotter.from_params(params) + key = _make_result_key('src') + da = sc.DataArray( + sc.ones(sizes={'y': 4, 'x': 5}, unit='counts'), + coords={ + 'x': sc.arange('x', 5, dtype='float64', unit='m'), + 'y': sc.arange('y', 4, dtype='float64', unit='m'), + }, + ) + plotter.compute({key: da}) + unit_key = plotter.overlay_unit_key + assert unit_key is not None + assert unit_key.kdim_units == (('x', 'm'), ('y', 'm')) + assert unit_key.vdim_unit == 'counts' + + def test_compute_with_mismatched_units_shows_error(self): + params = PlotParams1d() + plotter = plots.LinePlotter.from_params(params) + data = { + _make_result_key('src_a'): _make_1d_da(coord_unit='ms'), + _make_result_key('src_b'): _make_1d_da(coord_unit='us'), + } + plotter.compute(data) + result = plotter.get_cached_state() + # Validation error is caught and wrapped in an Overlay containing hv.Text + assert isinstance(result, hv.Overlay) + (element,) = result.values() + assert isinstance(element, hv.Text) diff --git a/tests/dashboard/roi_readback_plots_test.py b/tests/dashboard/roi_readback_plots_test.py index edcf72dce..9460f3578 100644 --- a/tests/dashboard/roi_readback_plots_test.py +++ b/tests/dashboard/roi_readback_plots_test.py @@ -422,3 +422,35 @@ def test_lines_plotter_compatible_with_regular_1d_data(self): ) compatible = plotter_registry.get_compatible_plotters({'key': data}) assert 'lines' in compatible + + +class TestROIReadbackOverlayUnitKey: + """Tests for ROI readback plotter overlay unit key extraction.""" + + def test_rectangles_readback_stores_kdim_units(self, result_key): + roi = RectangleROI( + x=Interval(min=0.0, max=1.0, unit='m'), + y=Interval(min=0.0, max=1.0, unit='m'), + ) + da = RectangleROI.to_concatenated_data_array({0: roi}) + plotter = RectanglesReadbackPlotter(RectanglesReadbackParams()) + plotter.compute({result_key: da}) + key = plotter.overlay_unit_key + assert key is not None + assert key.kdim_units == (('x', 'm'), ('y', 'm')) + assert key.vdim_unit is None + + def test_polygons_readback_stores_kdim_units(self, result_key): + roi = PolygonROI( + x=[0.0, 1.0, 1.0, 0.0], + y=[0.0, 0.0, 1.0, 1.0], + x_unit='mm', + y_unit='mm', + ) + da = PolygonROI.to_concatenated_data_array({0: roi}) + plotter = PolygonsReadbackPlotter(PolygonsReadbackParams()) + plotter.compute({result_key: da}) + key = plotter.overlay_unit_key + assert key is not None + assert key.kdim_units == (('x', 'mm'), ('y', 'mm')) + assert key.vdim_unit is None diff --git a/tests/dashboard/widgets/plot_grid_tabs_layout_test.py b/tests/dashboard/widgets/plot_grid_tabs_layout_test.py index 5bd7dbfa5..a2bd3887d 100644 --- a/tests/dashboard/widgets/plot_grid_tabs_layout_test.py +++ b/tests/dashboard/widgets/plot_grid_tabs_layout_test.py @@ -48,6 +48,7 @@ class FakePlotter: def __init__(self, cached_state=None): self._cached_state = cached_state self._presenters: list[FakePresenter] = [] + self.overlay_unit_key = None def get_cached_state(self): return self._cached_state From af4b5255dd5c18fe3a6b80290ed5483e9ad67571 Mon Sep 17 00:00:00 2001 From: Simon Heybrock Date: Wed, 15 Apr 2026 06:04:13 +0000 Subject: [PATCH 2/2] Simplify overlay unit validation with CanvasSpec Replace OverlayUnitKey (lossy string-based, weak cross-layer check) with CanvasSpec - a frozen dataclass holding coord_units and value_unit as native sc.Unit values. One validation function (validate_overlay_units) now serves both intra-layer and cross-layer sites with full consistency checks. ROI and annotation plotters opt out via participates_in_overlay_validation class flag, eliminating three _extract_overlay_unit_key overrides. Co-Authored-By: Claude Opus 4.6 --- src/ess/livedata/dashboard/plots.py | 165 ++++--------- .../livedata/dashboard/roi_readback_plots.py | 38 +-- .../livedata/dashboard/roi_request_plots.py | 14 +- .../dashboard/widgets/plot_grid_tabs.py | 15 +- tests/dashboard/plots_test.py | 223 +++++------------- tests/dashboard/roi_readback_plots_test.py | 18 +- .../widgets/plot_grid_tabs_layout_test.py | 2 +- 7 files changed, 133 insertions(+), 342 deletions(-) diff --git a/src/ess/livedata/dashboard/plots.py b/src/ess/livedata/dashboard/plots.py index 8462481be..eca38a32c 100644 --- a/src/ess/livedata/dashboard/plots.py +++ b/src/ess/livedata/dashboard/plots.py @@ -7,6 +7,7 @@ import weakref from collections.abc import Callable from dataclasses import dataclass +from dataclasses import field as dataclass_field from typing import Any, ClassVar, cast import holoviews as hv @@ -216,117 +217,66 @@ def _compute_time_info(data: dict[str, sc.DataArray]) -> str | None: @dataclass(frozen=True) -class OverlayUnitKey: - """Unit signature for cross-layer overlay validation. +class CanvasSpec: + """Unit metadata extracted from a scipp DataArray for overlay validation. - Captures the units of all plotting dimensions and (optionally) the value - dimension. Two keys are compatible when every shared kdim label has the - same unit string and, if both keys carry a vdim unit, those match too. + Captures coordinate units (keyed by dimension name) and the value unit. + Two overlays are compatible when they share the same coordinate dimensions + with matching units and matching value units. """ - kdim_units: tuple[tuple[str, str | None], ...] = () - vdim_unit: str | None = None + coord_units: dict[str, sc.Unit] = dataclass_field(default_factory=dict) + value_unit: sc.Unit = dataclass_field(default=sc.units.dimensionless) @classmethod - def from_data_array( - cls, - da: sc.DataArray, - *, - include_vdim: bool = True, - ) -> OverlayUnitKey: - """Build a key from a scipp DataArray using its dimension coordinates.""" - kdim_units = tuple( - sorted( - (dim, _unit_str(da.coords[dim].unit)) - for dim in da.dims - if dim in da.coords - ) + def from_data_array(cls, da: sc.DataArray) -> CanvasSpec: + """Build from a scipp DataArray using its dimension coordinates.""" + return cls( + coord_units={ + dim: da.coords[dim].unit for dim in da.dims if dim in da.coords + }, + value_unit=da.unit, ) - vdim_unit = _unit_str(da.unit) if include_vdim else None - return cls(kdim_units=kdim_units, vdim_unit=vdim_unit) - - -def _unit_str(unit: sc.Unit | None) -> str | None: - """Convert a scipp unit to its string representation, or None.""" - return str(unit) if unit is not None else None - - -def _validate_overlay_units(data: dict[ResultKey, sc.DataArray]) -> None: - """Validate that all DataArrays in an overlay share compatible units. - - Raises - ------ - ValueError - If dimension names, coordinate units, or value units do not match. - """ - if len(data) < 2: - return - - items = iter(data.items()) - ref_key, ref = next(items) - ref_source = ref_key.job_id.source_name - - for key, da in items: - source = key.job_id.source_name - if ref.dims != da.dims: - raise ValueError( - f"Cannot overlay '{ref_source}' and '{source}': " - f"dimension mismatch ({ref.dims} vs {da.dims})" - ) - for dim in ref.dims: - if dim in ref.coords and dim in da.coords: - u_ref = ref.coords[dim].unit - u_da = da.coords[dim].unit - if u_ref != u_da: - raise ValueError( - f"Cannot overlay '{ref_source}' and '{source}': " - f"unit mismatch for coordinate '{dim}' " - f"({u_ref} vs {u_da})" - ) - if ref.unit != da.unit: - raise ValueError( - f"Cannot overlay '{ref_source}' and '{source}': " - f"value unit mismatch ({ref.unit} vs {da.unit})" - ) -def validate_cross_layer_units( - keys: list[tuple[str, OverlayUnitKey]], +def validate_canvas_spec( + entries: list[tuple[str, CanvasSpec]], ) -> str | None: - """Check overlay unit keys from multiple layers for compatibility. + """Check that overlay entries share compatible units. Parameters ---------- - keys: - Pairs of (layer label, unit key) for each data-carrying layer. + entries: + Pairs of (label, units) for each data-carrying element or layer. Returns ------- : An error message if units are incompatible, or None if compatible. """ - if len(keys) < 2: + if len(entries) < 2: return None - ref_label, ref_key = keys[0] - ref_kdims = dict(ref_key.kdim_units) - - for label, key in keys[1:]: - other_kdims = dict(key.kdim_units) - for dim_label, ref_unit in ref_kdims.items(): - if dim_label in other_kdims and ref_unit != other_kdims[dim_label]: - return ( - f"Cannot overlay layers '{ref_label}' and '{label}': " - f"unit mismatch for coordinate '{dim_label}' " - f"({ref_unit} vs {other_kdims[dim_label]})" - ) - if ref_key.vdim_unit is not None and key.vdim_unit is not None: - if ref_key.vdim_unit != key.vdim_unit: + ref_label, ref = entries[0] + for label, other in entries[1:]: + if ref.coord_units.keys() != other.coord_units.keys(): + return ( + f"Cannot overlay '{ref_label}' and '{label}': " + f"dimension mismatch " + f"({set(ref.coord_units)} vs {set(other.coord_units)})" + ) + for dim, ref_unit in ref.coord_units.items(): + if ref_unit != other.coord_units[dim]: return ( - f"Cannot overlay layers '{ref_label}' and '{label}': " - f"value unit mismatch " - f"({ref_key.vdim_unit} vs {key.vdim_unit})" + f"Cannot overlay '{ref_label}' and '{label}': " + f"unit mismatch for coordinate '{dim}' " + f"({ref_unit} vs {other.coord_units[dim]})" ) + if ref.value_unit != other.value_unit: + return ( + f"Cannot overlay '{ref_label}' and '{label}': " + f"value unit mismatch ({ref.value_unit} vs {other.value_unit})" + ) return None @@ -338,6 +288,8 @@ class Plotter: This enables efficient polling-based update detection. """ + participates_in_overlay_validation: ClassVar[bool] = True + def __init__( self, *, @@ -361,7 +313,7 @@ def __init__( """ self._normalize_to_rate = normalize_to_rate self._cached_state: Any | None = None - self._overlay_unit_key: OverlayUnitKey | None = None + self.canvas_spec: CanvasSpec | None = None self._presenters: weakref.WeakSet[PresenterBase] = weakref.WeakSet() self.autoscaler_kwargs = kwargs self.autoscalers: dict[ResultKey, Autoscaler] = {} @@ -374,25 +326,6 @@ def __init__( # with responsive mode in Panel containers (upstream bug). self._sizing_opts: dict[str, Any] = {'responsive': True} - @property - def overlay_unit_key(self) -> OverlayUnitKey | None: - """Unit signature from the last compute(), for cross-layer validation.""" - return self._overlay_unit_key - - def _extract_overlay_unit_key( - self, data: dict[ResultKey, sc.DataArray] - ) -> OverlayUnitKey | None: - """Extract an overlay unit key from input data. - - Default implementation uses dimension coordinates and value unit from the - first DataArray. Subclasses with non-standard coordinate structures (e.g., - ROI plotters) should override this method. - """ - if not data: - return None - da = next(iter(data.values())) - return OverlayUnitKey.from_data_array(da) - @staticmethod def _make_tick_opts(tick_params: TickParams | None) -> dict[str, Any]: """ @@ -566,13 +499,19 @@ def compute( if self._normalize_to_rate: data = {key: _normalize_to_rate(da) for key, da in data.items()} - self._overlay_unit_key = self._extract_overlay_unit_key(data) - resolver = title_resolver or TitleResolver() plots: list[hv.Element] = [] try: - if self.layout_params.combine_mode == 'overlay': - _validate_overlay_units(data) + if self.participates_in_overlay_validation and data: + entries = [ + (key.job_id.source_name, CanvasSpec.from_data_array(da)) + for key, da in data.items() + ] + if self.layout_params.combine_mode == 'overlay': + error = validate_canvas_spec(entries) + if error is not None: + raise ValueError(error) + self.canvas_spec = entries[0][1] for data_key, da in data.items(): label = resolver.get_legend_label( data_key.job_id.source_name, data_key.output_name diff --git a/src/ess/livedata/dashboard/roi_readback_plots.py b/src/ess/livedata/dashboard/roi_readback_plots.py index 34b69ce55..1680abdaf 100644 --- a/src/ess/livedata/dashboard/roi_readback_plots.py +++ b/src/ess/livedata/dashboard/roi_readback_plots.py @@ -21,23 +21,7 @@ ) from ess.livedata.config.workflow_spec import ResultKey -from .plots import OverlayUnitKey, Plotter, _unit_str - - -def _roi_overlay_unit_key(da: sc.DataArray) -> OverlayUnitKey: - """Build an overlay unit key from ROI coordinate metadata. - - ROI DataArrays use named coords ('x', 'y') rather than dimension - coordinates, and have no meaningful value unit. - """ - kdim_units = tuple( - sorted( - (name, _unit_str(da.coords[name].unit)) - for name in ('x', 'y') - if name in da.coords - ) - ) - return OverlayUnitKey(kdim_units=kdim_units, vdim_unit=None) +from .plots import Plotter class ROIReadbackStyle(pydantic.BaseModel): @@ -87,6 +71,8 @@ class RectanglesReadbackPlotter(Plotter): rectangles with per-shape colors based on the ROI index. """ + participates_in_overlay_validation = False + def __init__(self, params: RectanglesReadbackParams) -> None: super().__init__() self._params = params @@ -97,14 +83,6 @@ def from_params(cls, params: RectanglesReadbackParams) -> RectanglesReadbackPlot """Create plotter from params.""" return cls(params) - def _extract_overlay_unit_key( - self, data: dict[ResultKey, sc.DataArray] - ) -> OverlayUnitKey | None: - if not data: - return None - da = next(iter(data.values())) - return _roi_overlay_unit_key(da) - def plot( self, data: sc.DataArray, data_key: ResultKey, *, label: str = '', **kwargs ) -> hv.Rectangles: @@ -201,6 +179,8 @@ class PolygonsReadbackPlotter(Plotter): polygons with per-shape colors based on the ROI index. """ + participates_in_overlay_validation = False + def __init__(self, params: PolygonsReadbackParams) -> None: super().__init__() self._params = params @@ -211,14 +191,6 @@ def from_params(cls, params: PolygonsReadbackParams) -> PolygonsReadbackPlotter: """Create plotter from params.""" return cls(params) - def _extract_overlay_unit_key( - self, data: dict[ResultKey, sc.DataArray] - ) -> OverlayUnitKey | None: - if not data: - return None - da = next(iter(data.values())) - return _roi_overlay_unit_key(da) - def plot( self, data: sc.DataArray, data_key: ResultKey, *, label: str = '', **kwargs ) -> hv.Polygons: diff --git a/src/ess/livedata/dashboard/roi_request_plots.py b/src/ess/livedata/dashboard/roi_request_plots.py index b45d31a9f..3f9a03902 100644 --- a/src/ess/livedata/dashboard/roi_request_plots.py +++ b/src/ess/livedata/dashboard/roi_request_plots.py @@ -48,7 +48,7 @@ get_roi_mapper, ) -from .plots import OverlayUnitKey, Plotter, PresenterBase +from .plots import Plotter, PresenterBase from .static_plots import Color, LineDash, RectanglesCoordinates if TYPE_CHECKING: @@ -533,6 +533,8 @@ class BaseROIRequestPlotter(Plotter, ABC, Generic[ROIType, ParamsType, Converter skip logic, publishing) is handled via a closure-based edit callback. """ + participates_in_overlay_validation = False + def __init__( self, params: ParamsType, @@ -626,16 +628,6 @@ def compute( else None ) - # Build overlay unit key from ROI coordinate units - kdim_units = tuple( - sorted( - (name, unit) - for name, unit in [('x', self._x_unit), ('y', self._y_unit)] - if unit is not None - ) - ) - self._overlay_unit_key = OverlayUnitKey(kdim_units=kdim_units, vdim_unit=None) - # Forward data (presenter may use in future) self._set_cached_state(data) return data diff --git a/src/ess/livedata/dashboard/widgets/plot_grid_tabs.py b/src/ess/livedata/dashboard/widgets/plot_grid_tabs.py index 1304b5824..49e3adbb6 100644 --- a/src/ess/livedata/dashboard/widgets/plot_grid_tabs.py +++ b/src/ess/livedata/dashboard/widgets/plot_grid_tabs.py @@ -35,7 +35,7 @@ SubscriptionId, ) from ..plot_params import PlotAspectType, StretchMode -from ..plots import OverlayUnitKey, validate_cross_layer_units +from ..plots import CanvasSpec, validate_canvas_spec from ..save_filename import build_save_filename_from_cell, make_save_filename_hook from ..session_layer import SessionLayer from ..session_updater import SessionUpdater @@ -941,7 +941,7 @@ def _get_session_composed_plot( Composed plot from session DMaps/elements, or None if none available. """ plots = [] - unit_keys: list[tuple[str, OverlayUnitKey]] = [] + overlay_entries: list[tuple[str, CanvasSpec]] = [] has_layout = False for layer in cell.layers: layer_id = layer.layer_id @@ -969,17 +969,18 @@ def _get_session_composed_plot( has_layout = True plots.append(dmap) - # Collect unit key for cross-layer validation + # Collect overlay units for cross-layer validation + # (only data-carrying plotters set canvas_spec) if state is not None and state.plotter is not None: - key = state.plotter.overlay_unit_key - if key is not None: - unit_keys.append((layer.config.plot_name, key)) + units = state.plotter.canvas_spec + if units is not None: + overlay_entries.append((layer.config.plot_name, units)) if not plots: return None # Validate cross-layer unit compatibility before composing - error = validate_cross_layer_units(unit_keys) + error = validate_canvas_spec(overlay_entries) if error is not None: return hv.Text(0.5, 0.5, f"Error: {error}").opts( text_align='center', text_baseline='middle' diff --git a/tests/dashboard/plots_test.py b/tests/dashboard/plots_test.py index a0e927d77..435dbc31e 100644 --- a/tests/dashboard/plots_test.py +++ b/tests/dashboard/plots_test.py @@ -2120,62 +2120,21 @@ def _make_1d_da( ) -class TestValidateOverlayUnits: - """Tests for Site 1: intra-layer overlay unit validation.""" - - def test_single_entry_passes(self): - key = _make_result_key('src_a') - da = _make_1d_da() - plots._validate_overlay_units({key: da}) - - def test_matching_units_pass(self): - data = { - _make_result_key('src_a'): _make_1d_da(), - _make_result_key('src_b'): _make_1d_da(), - } - plots._validate_overlay_units(data) - - def test_mismatched_coord_units_raise(self): - data = { - _make_result_key('src_a'): _make_1d_da(coord_unit='ms'), - _make_result_key('src_b'): _make_1d_da(coord_unit='us'), - } - with pytest.raises(ValueError, match=r"coordinate 'tof'.*ms.*µs"): - plots._validate_overlay_units(data) - - def test_mismatched_value_units_raise(self): - data = { - _make_result_key('src_a'): _make_1d_da(value_unit='counts'), - _make_result_key('src_b'): _make_1d_da(value_unit='counts/us'), - } - with pytest.raises(ValueError, match="value unit mismatch"): - plots._validate_overlay_units(data) - - def test_mismatched_dims_raise(self): - data = { - _make_result_key('src_a'): _make_1d_da(dim='tof'), - _make_result_key('src_b'): _make_1d_da(dim='wavelength'), - } - with pytest.raises(ValueError, match="dimension mismatch"): - plots._validate_overlay_units(data) - - def test_error_message_includes_source_names(self): - data = { - _make_result_key('detector_1'): _make_1d_da(coord_unit='ms'), - _make_result_key('detector_2'): _make_1d_da(coord_unit='us'), - } - with pytest.raises(ValueError, match=r"detector_1.*detector_2"): - plots._validate_overlay_units(data) +def _units(coord_units: dict[str, str], value_unit: str = 'counts') -> plots.CanvasSpec: + return plots.CanvasSpec( + coord_units={k: sc.Unit(v) for k, v in coord_units.items()}, + value_unit=sc.Unit(value_unit), + ) -class TestOverlayUnitKey: - """Tests for OverlayUnitKey construction and cross-layer validation.""" +class TestCanvasSpec: + """Tests for CanvasSpec construction.""" def test_from_data_array_1d(self): da = _make_1d_da(dim='tof', coord_unit='ms', value_unit='counts') - key = plots.OverlayUnitKey.from_data_array(da) - assert key.kdim_units == (('tof', 'ms'),) - assert key.vdim_unit == 'counts' + units = plots.CanvasSpec.from_data_array(da) + assert units.coord_units == {'tof': sc.Unit('ms')} + assert units.value_unit == sc.Unit('counts') def test_from_data_array_2d(self): da = sc.DataArray( @@ -2185,134 +2144,73 @@ def test_from_data_array_2d(self): 'y': sc.arange('y', 4, dtype='float64', unit='m'), }, ) - key = plots.OverlayUnitKey.from_data_array(da) - assert key.kdim_units == (('x', 'm'), ('y', 'm')) - assert key.vdim_unit == 'counts' + units = plots.CanvasSpec.from_data_array(da) + assert units.coord_units == {'x': sc.Unit('m'), 'y': sc.Unit('m')} + assert units.value_unit == sc.Unit('counts') - def test_from_data_array_no_vdim(self): - da = _make_1d_da() - key = plots.OverlayUnitKey.from_data_array(da, include_vdim=False) - assert key.vdim_unit is None - def test_cross_layer_matching_units_pass(self): - keys = [ - ( - 'layer_a', - plots.OverlayUnitKey(kdim_units=(('tof', 'ms'),), vdim_unit='counts'), - ), - ( - 'layer_b', - plots.OverlayUnitKey(kdim_units=(('tof', 'ms'),), vdim_unit='counts'), - ), +class TestValidateCanvasSpec: + """Tests for validate_canvas_spec (used at both sites).""" + + def test_single_entry_passes(self): + assert plots.validate_canvas_spec([('a', _units({'tof': 'ms'}))]) is None + + def test_matching_units_pass(self): + entries = [ + ('a', _units({'tof': 'ms'})), + ('b', _units({'tof': 'ms'})), ] - assert plots.validate_cross_layer_units(keys) is None + assert plots.validate_canvas_spec(entries) is None - def test_cross_layer_mismatched_kdim_units(self): - keys = [ - ( - 'layer_a', - plots.OverlayUnitKey(kdim_units=(('tof', 'ms'),), vdim_unit='counts'), - ), - ( - 'layer_b', - plots.OverlayUnitKey(kdim_units=(('tof', 'us'),), vdim_unit='counts'), - ), + def test_mismatched_coord_units(self): + entries = [ + ('src_a', _units({'tof': 'ms'})), + ('src_b', _units({'tof': 'us'})), ] - error = plots.validate_cross_layer_units(keys) + error = plots.validate_canvas_spec(entries) assert error is not None assert 'tof' in error assert 'ms' in error - assert 'us' in error - def test_cross_layer_mismatched_vdim_units(self): - keys = [ - ( - 'layer_a', - plots.OverlayUnitKey(kdim_units=(('tof', 'ms'),), vdim_unit='counts'), - ), - ( - 'layer_b', - plots.OverlayUnitKey( - kdim_units=(('tof', 'ms'),), vdim_unit='counts/us' - ), - ), + def test_mismatched_value_units(self): + entries = [ + ('a', _units({'tof': 'ms'}, value_unit='counts')), + ('b', _units({'tof': 'ms'}, value_unit='counts/us')), ] - error = plots.validate_cross_layer_units(keys) + error = plots.validate_canvas_spec(entries) assert error is not None assert 'value unit mismatch' in error - def test_cross_layer_vdim_none_skips_check(self): - keys = [ - ( - 'image', - plots.OverlayUnitKey( - kdim_units=(('x', 'm'), ('y', 'm')), vdim_unit='counts' - ), - ), - ( - 'roi', - plots.OverlayUnitKey( - kdim_units=(('x', 'm'), ('y', 'm')), vdim_unit=None - ), - ), - ] - assert plots.validate_cross_layer_units(keys) is None - - def test_cross_layer_vdim_none_still_checks_kdims(self): - keys = [ - ( - 'image', - plots.OverlayUnitKey(kdim_units=(('x', 'm'),), vdim_unit='counts'), - ), - ('roi', plots.OverlayUnitKey(kdim_units=(('x', 'mm'),), vdim_unit=None)), + def test_mismatched_dims(self): + entries = [ + ('a', _units({'tof': 'ms'})), + ('b', _units({'wavelength': 'angstrom'})), ] - error = plots.validate_cross_layer_units(keys) + error = plots.validate_canvas_spec(entries) assert error is not None - assert 'x' in error + assert 'dimension mismatch' in error - def test_cross_layer_single_key_passes(self): - keys = [ - ( - 'only_layer', - plots.OverlayUnitKey(kdim_units=(('tof', 'ms'),), vdim_unit='counts'), - ), + def test_error_includes_labels(self): + entries = [ + ('detector_1', _units({'tof': 'ms'})), + ('detector_2', _units({'tof': 'us'})), ] - assert plots.validate_cross_layer_units(keys) is None - - def test_cross_layer_non_overlapping_kdims_pass(self): - keys = [ - ( - 'curve', - plots.OverlayUnitKey(kdim_units=(('tof', 'ms'),), vdim_unit='counts'), - ), - ( - 'roi', - plots.OverlayUnitKey( - kdim_units=(('x', 'm'), ('y', 'm')), vdim_unit=None - ), - ), - ] - assert plots.validate_cross_layer_units(keys) is None + error = plots.validate_canvas_spec(entries) + assert 'detector_1' in error + assert 'detector_2' in error -class TestPlotterOverlayUnitKey: - """Integration tests: plotter stores overlay unit key after compute().""" +class TestPlotterCanvasSpec: + """Integration: plotter stores canvas_spec after compute().""" - def test_line_plotter_stores_key(self): - params = PlotParams1d() - plotter = plots.LinePlotter.from_params(params) - key = _make_result_key('src') + def test_line_plotter_stores_units(self): + plotter = plots.LinePlotter.from_params(PlotParams1d()) da = _make_1d_da(dim='tof', coord_unit='ms', value_unit='counts') - plotter.compute({key: da}) - unit_key = plotter.overlay_unit_key - assert unit_key is not None - assert unit_key.kdim_units == (('tof', 'ms'),) - assert unit_key.vdim_unit == 'counts' + plotter.compute({_make_result_key('src'): da}) + assert plotter.canvas_spec == _units({'tof': 'ms'}) - def test_image_plotter_stores_key(self): - params = PlotParams2d() - plotter = plots.ImagePlotter.from_params(params) - key = _make_result_key('src') + def test_image_plotter_stores_units(self): + plotter = plots.ImagePlotter.from_params(PlotParams2d()) da = sc.DataArray( sc.ones(sizes={'y': 4, 'x': 5}, unit='counts'), coords={ @@ -2320,22 +2218,17 @@ def test_image_plotter_stores_key(self): 'y': sc.arange('y', 4, dtype='float64', unit='m'), }, ) - plotter.compute({key: da}) - unit_key = plotter.overlay_unit_key - assert unit_key is not None - assert unit_key.kdim_units == (('x', 'm'), ('y', 'm')) - assert unit_key.vdim_unit == 'counts' + plotter.compute({_make_result_key('src'): da}) + assert plotter.canvas_spec == _units({'x': 'm', 'y': 'm'}) def test_compute_with_mismatched_units_shows_error(self): - params = PlotParams1d() - plotter = plots.LinePlotter.from_params(params) + plotter = plots.LinePlotter.from_params(PlotParams1d()) data = { _make_result_key('src_a'): _make_1d_da(coord_unit='ms'), _make_result_key('src_b'): _make_1d_da(coord_unit='us'), } plotter.compute(data) result = plotter.get_cached_state() - # Validation error is caught and wrapped in an Overlay containing hv.Text assert isinstance(result, hv.Overlay) (element,) = result.values() assert isinstance(element, hv.Text) diff --git a/tests/dashboard/roi_readback_plots_test.py b/tests/dashboard/roi_readback_plots_test.py index 9460f3578..b685f3273 100644 --- a/tests/dashboard/roi_readback_plots_test.py +++ b/tests/dashboard/roi_readback_plots_test.py @@ -424,10 +424,10 @@ def test_lines_plotter_compatible_with_regular_1d_data(self): assert 'lines' in compatible -class TestROIReadbackOverlayUnitKey: - """Tests for ROI readback plotter overlay unit key extraction.""" +class TestROIReadbackOverlayValidation: + """ROI readback plotters are exempt from overlay unit validation.""" - def test_rectangles_readback_stores_kdim_units(self, result_key): + def test_rectangles_readback_does_not_set_canvas_spec(self, result_key): roi = RectangleROI( x=Interval(min=0.0, max=1.0, unit='m'), y=Interval(min=0.0, max=1.0, unit='m'), @@ -435,12 +435,9 @@ def test_rectangles_readback_stores_kdim_units(self, result_key): da = RectangleROI.to_concatenated_data_array({0: roi}) plotter = RectanglesReadbackPlotter(RectanglesReadbackParams()) plotter.compute({result_key: da}) - key = plotter.overlay_unit_key - assert key is not None - assert key.kdim_units == (('x', 'm'), ('y', 'm')) - assert key.vdim_unit is None + assert plotter.canvas_spec is None - def test_polygons_readback_stores_kdim_units(self, result_key): + def test_polygons_readback_does_not_set_canvas_spec(self, result_key): roi = PolygonROI( x=[0.0, 1.0, 1.0, 0.0], y=[0.0, 0.0, 1.0, 1.0], @@ -450,7 +447,4 @@ def test_polygons_readback_stores_kdim_units(self, result_key): da = PolygonROI.to_concatenated_data_array({0: roi}) plotter = PolygonsReadbackPlotter(PolygonsReadbackParams()) plotter.compute({result_key: da}) - key = plotter.overlay_unit_key - assert key is not None - assert key.kdim_units == (('x', 'mm'), ('y', 'mm')) - assert key.vdim_unit is None + assert plotter.canvas_spec is None diff --git a/tests/dashboard/widgets/plot_grid_tabs_layout_test.py b/tests/dashboard/widgets/plot_grid_tabs_layout_test.py index a2bd3887d..952cf7b06 100644 --- a/tests/dashboard/widgets/plot_grid_tabs_layout_test.py +++ b/tests/dashboard/widgets/plot_grid_tabs_layout_test.py @@ -48,7 +48,7 @@ class FakePlotter: def __init__(self, cached_state=None): self._cached_state = cached_state self._presenters: list[FakePresenter] = [] - self.overlay_unit_key = None + self.canvas_spec = None def get_cached_state(self): return self._cached_state