Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 78 additions & 0 deletions src/ess/livedata/dashboard/plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import weakref
from collections.abc import Callable
from dataclasses import dataclass
from dataclasses import field as dataclass_field
from typing import Any, ClassVar, cast

import holoviews as hv
Expand Down Expand Up @@ -215,6 +216,70 @@ def _compute_time_info(data: dict[str, sc.DataArray]) -> str | None:
return f'{end_str} (Lag: {lag_s:.1f}s)'


@dataclass(frozen=True)
class CanvasSpec:
"""Unit metadata extracted from a scipp DataArray for overlay validation.

Captures coordinate units (keyed by dimension name) and the value unit.
Two overlays are compatible when they share the same coordinate dimensions
with matching units and matching value units.
"""

coord_units: dict[str, sc.Unit] = dataclass_field(default_factory=dict)
value_unit: sc.Unit = dataclass_field(default=sc.units.dimensionless)

@classmethod
def from_data_array(cls, da: sc.DataArray) -> CanvasSpec:
"""Build from a scipp DataArray using its dimension coordinates."""
return cls(
coord_units={
dim: da.coords[dim].unit for dim in da.dims if dim in da.coords
},
value_unit=da.unit,
)


def validate_canvas_spec(
entries: list[tuple[str, CanvasSpec]],
) -> str | None:
"""Check that overlay entries share compatible units.

Parameters
----------
entries:
Pairs of (label, units) for each data-carrying element or layer.

Returns
-------
:
An error message if units are incompatible, or None if compatible.
"""
if len(entries) < 2:
return None

ref_label, ref = entries[0]
for label, other in entries[1:]:
if ref.coord_units.keys() != other.coord_units.keys():
return (
f"Cannot overlay '{ref_label}' and '{label}': "
f"dimension mismatch "
f"({set(ref.coord_units)} vs {set(other.coord_units)})"
)
for dim, ref_unit in ref.coord_units.items():
if ref_unit != other.coord_units[dim]:
return (
f"Cannot overlay '{ref_label}' and '{label}': "
f"unit mismatch for coordinate '{dim}' "
f"({ref_unit} vs {other.coord_units[dim]})"
)
if ref.value_unit != other.value_unit:
return (
f"Cannot overlay '{ref_label}' and '{label}': "
f"value unit mismatch ({ref.value_unit} vs {other.value_unit})"
)
return None


class Plotter:
"""
Base class for plots that support autoscaling.
Expand All @@ -223,6 +288,8 @@ class Plotter:
This enables efficient polling-based update detection.
"""

participates_in_overlay_validation: ClassVar[bool] = True

def __init__(
self,
*,
Expand All @@ -246,6 +313,7 @@ def __init__(
"""
self._normalize_to_rate = normalize_to_rate
self._cached_state: Any | None = None
self.canvas_spec: CanvasSpec | None = None
self._presenters: weakref.WeakSet[PresenterBase] = weakref.WeakSet()
self.autoscaler_kwargs = kwargs
self.autoscalers: dict[ResultKey, Autoscaler] = {}
Expand Down Expand Up @@ -434,6 +502,16 @@ def compute(
resolver = title_resolver or TitleResolver()
plots: list[hv.Element] = []
try:
if self.participates_in_overlay_validation and data:
entries = [
(key.job_id.source_name, CanvasSpec.from_data_array(da))
for key, da in data.items()
]
if self.layout_params.combine_mode == 'overlay':
error = validate_canvas_spec(entries)
if error is not None:
raise ValueError(error)
self.canvas_spec = entries[0][1]
for data_key, da in data.items():
label = resolver.get_legend_label(
data_key.job_id.source_name, data_key.output_name
Expand Down
4 changes: 4 additions & 0 deletions src/ess/livedata/dashboard/roi_readback_plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,8 @@ class RectanglesReadbackPlotter(Plotter):
rectangles with per-shape colors based on the ROI index.
"""

participates_in_overlay_validation = False

def __init__(self, params: RectanglesReadbackParams) -> None:
super().__init__()
self._params = params
Expand Down Expand Up @@ -177,6 +179,8 @@ class PolygonsReadbackPlotter(Plotter):
polygons with per-shape colors based on the ROI index.
"""

participates_in_overlay_validation = False

def __init__(self, params: PolygonsReadbackParams) -> None:
super().__init__()
self._params = params
Expand Down
2 changes: 2 additions & 0 deletions src/ess/livedata/dashboard/roi_request_plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -533,6 +533,8 @@ class BaseROIRequestPlotter(Plotter, ABC, Generic[ROIType, ParamsType, Converter
skip logic, publishing) is handled via a closure-based edit callback.
"""

participates_in_overlay_validation = False

def __init__(
self,
params: ParamsType,
Expand Down
16 changes: 16 additions & 0 deletions src/ess/livedata/dashboard/widgets/plot_grid_tabs.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
SubscriptionId,
)
from ..plot_params import PlotAspectType, StretchMode
from ..plots import CanvasSpec, validate_canvas_spec
from ..save_filename import build_save_filename_from_cell, make_save_filename_hook
from ..session_layer import SessionLayer
from ..session_updater import SessionUpdater
Expand Down Expand Up @@ -940,6 +941,7 @@ def _get_session_composed_plot(
Composed plot from session DMaps/elements, or None if none available.
"""
plots = []
overlay_entries: list[tuple[str, CanvasSpec]] = []
has_layout = False
for layer in cell.layers:
layer_id = layer.layer_id
Expand Down Expand Up @@ -967,9 +969,23 @@ def _get_session_composed_plot(
has_layout = True
plots.append(dmap)

# Collect overlay units for cross-layer validation
# (only data-carrying plotters set canvas_spec)
if state is not None and state.plotter is not None:
units = state.plotter.canvas_spec
if units is not None:
overlay_entries.append((layer.config.plot_name, units))

if not plots:
return None

# Validate cross-layer unit compatibility before composing
error = validate_canvas_spec(overlay_entries)
if error is not None:
return hv.Text(0.5, 0.5, f"Error: {error}").opts(
text_align='center', text_baseline='middle'
)

result: hv.DynamicMap | hv.Element
if len(plots) == 1:
result = plots[0]
Expand Down
132 changes: 132 additions & 0 deletions tests/dashboard/plots_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2100,3 +2100,135 @@ def test_bars_plotter_normalizes_when_enabled(self, data_key):
bars = next(iter(result.values()))
assert isinstance(bars, hv.Bars)
assert bars.vdims[0].unit == 'counts/s'


def _make_result_key(source_name: str) -> ResultKey:
return ResultKey(
workflow_id=WorkflowId(instrument='test', namespace='ns', name='wf', version=1),
job_id=JobId(source_name=source_name, job_number=uuid.uuid4()),
output_name='result',
)


def _make_1d_da(
*, dim: str = 'tof', coord_unit: str = 'ms', value_unit: str = 'counts'
) -> sc.DataArray:
coord = sc.arange(dim, 5, dtype='float64', unit=coord_unit)
return sc.DataArray(
sc.ones(sizes={dim: 5}, unit=value_unit),
coords={dim: coord},
)


def _units(coord_units: dict[str, str], value_unit: str = 'counts') -> plots.CanvasSpec:
return plots.CanvasSpec(
coord_units={k: sc.Unit(v) for k, v in coord_units.items()},
value_unit=sc.Unit(value_unit),
)


class TestCanvasSpec:
"""Tests for CanvasSpec construction."""

def test_from_data_array_1d(self):
da = _make_1d_da(dim='tof', coord_unit='ms', value_unit='counts')
units = plots.CanvasSpec.from_data_array(da)
assert units.coord_units == {'tof': sc.Unit('ms')}
assert units.value_unit == sc.Unit('counts')

def test_from_data_array_2d(self):
da = sc.DataArray(
sc.ones(sizes={'y': 4, 'x': 5}, unit='counts'),
coords={
'x': sc.arange('x', 5, dtype='float64', unit='m'),
'y': sc.arange('y', 4, dtype='float64', unit='m'),
},
)
units = plots.CanvasSpec.from_data_array(da)
assert units.coord_units == {'x': sc.Unit('m'), 'y': sc.Unit('m')}
assert units.value_unit == sc.Unit('counts')


class TestValidateCanvasSpec:
"""Tests for validate_canvas_spec (used at both sites)."""

def test_single_entry_passes(self):
assert plots.validate_canvas_spec([('a', _units({'tof': 'ms'}))]) is None

def test_matching_units_pass(self):
entries = [
('a', _units({'tof': 'ms'})),
('b', _units({'tof': 'ms'})),
]
assert plots.validate_canvas_spec(entries) is None

def test_mismatched_coord_units(self):
entries = [
('src_a', _units({'tof': 'ms'})),
('src_b', _units({'tof': 'us'})),
]
error = plots.validate_canvas_spec(entries)
assert error is not None
assert 'tof' in error
assert 'ms' in error

def test_mismatched_value_units(self):
entries = [
('a', _units({'tof': 'ms'}, value_unit='counts')),
('b', _units({'tof': 'ms'}, value_unit='counts/us')),
]
error = plots.validate_canvas_spec(entries)
assert error is not None
assert 'value unit mismatch' in error

def test_mismatched_dims(self):
entries = [
('a', _units({'tof': 'ms'})),
('b', _units({'wavelength': 'angstrom'})),
]
error = plots.validate_canvas_spec(entries)
assert error is not None
assert 'dimension mismatch' in error

def test_error_includes_labels(self):
entries = [
('detector_1', _units({'tof': 'ms'})),
('detector_2', _units({'tof': 'us'})),
]
error = plots.validate_canvas_spec(entries)
assert 'detector_1' in error
assert 'detector_2' in error


class TestPlotterCanvasSpec:
"""Integration: plotter stores canvas_spec after compute()."""

def test_line_plotter_stores_units(self):
plotter = plots.LinePlotter.from_params(PlotParams1d())
da = _make_1d_da(dim='tof', coord_unit='ms', value_unit='counts')
plotter.compute({_make_result_key('src'): da})
assert plotter.canvas_spec == _units({'tof': 'ms'})

def test_image_plotter_stores_units(self):
plotter = plots.ImagePlotter.from_params(PlotParams2d())
da = sc.DataArray(
sc.ones(sizes={'y': 4, 'x': 5}, unit='counts'),
coords={
'x': sc.arange('x', 5, dtype='float64', unit='m'),
'y': sc.arange('y', 4, dtype='float64', unit='m'),
},
)
plotter.compute({_make_result_key('src'): da})
assert plotter.canvas_spec == _units({'x': 'm', 'y': 'm'})

def test_compute_with_mismatched_units_shows_error(self):
plotter = plots.LinePlotter.from_params(PlotParams1d())
data = {
_make_result_key('src_a'): _make_1d_da(coord_unit='ms'),
_make_result_key('src_b'): _make_1d_da(coord_unit='us'),
}
plotter.compute(data)
result = plotter.get_cached_state()
assert isinstance(result, hv.Overlay)
(element,) = result.values()
assert isinstance(element, hv.Text)
26 changes: 26 additions & 0 deletions tests/dashboard/roi_readback_plots_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,3 +422,29 @@ def test_lines_plotter_compatible_with_regular_1d_data(self):
)
compatible = plotter_registry.get_compatible_plotters({'key': data})
assert 'lines' in compatible


class TestROIReadbackOverlayValidation:
"""ROI readback plotters are exempt from overlay unit validation."""

def test_rectangles_readback_does_not_set_canvas_spec(self, result_key):
roi = RectangleROI(
x=Interval(min=0.0, max=1.0, unit='m'),
y=Interval(min=0.0, max=1.0, unit='m'),
)
da = RectangleROI.to_concatenated_data_array({0: roi})
plotter = RectanglesReadbackPlotter(RectanglesReadbackParams())
plotter.compute({result_key: da})
assert plotter.canvas_spec is None

def test_polygons_readback_does_not_set_canvas_spec(self, result_key):
roi = PolygonROI(
x=[0.0, 1.0, 1.0, 0.0],
y=[0.0, 0.0, 1.0, 1.0],
x_unit='mm',
y_unit='mm',
)
da = PolygonROI.to_concatenated_data_array({0: roi})
plotter = PolygonsReadbackPlotter(PolygonsReadbackParams())
plotter.compute({result_key: da})
assert plotter.canvas_spec is None
1 change: 1 addition & 0 deletions tests/dashboard/widgets/plot_grid_tabs_layout_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ class FakePlotter:
def __init__(self, cached_state=None):
self._cached_state = cached_state
self._presenters: list[FakePresenter] = []
self.canvas_spec = None

def get_cached_state(self):
return self._cached_state
Expand Down
Loading