diff --git a/examples/pygfx_backend/01_basic_curves.py b/examples/pygfx_backend/01_basic_curves.py new file mode 100644 index 0000000000..ae50c5054c --- /dev/null +++ b/examples/pygfx_backend/01_basic_curves.py @@ -0,0 +1,51 @@ +"""Basic curve plotting with pygfx backend. + +Demonstrates: multiple curves, colors, line widths, symbols, fill, legend. +""" + +import numpy +from silx.gui import qt +from silx.gui.plot import Plot1D + + +def main(): + app = qt.QApplication([]) + + plot = Plot1D(backend="pygfx") + plot.setWindowTitle("pygfx - Basic Curves") + plot.setGraphTitle("Trigonometric Functions") + plot.setGraphXLabel("X") + plot.setGraphYLabel("Y") + + x = numpy.linspace(0, 4 * numpy.pi, 500) + + # Solid line + plot.addCurve(x, numpy.sin(x), legend="sin(x)", color="blue", linewidth=2) + # Dashed line with symbols + plot.addCurve( + x[::20], + numpy.cos(x[::20]), + legend="cos(x)", + color="red", + linewidth=1.5, + linestyle="--", + symbol="o", + ) + # Filled curve + plot.addCurve( + x, + 0.5 * numpy.sin(2 * x), + legend="0.5*sin(2x)", + color="green", + linewidth=1, + fill=True, + ) + + plot.setActiveCurveHandling(False) + plot.resetZoom() + plot.show() + app.exec() + + +if __name__ == "__main__": + main() diff --git a/examples/pygfx_backend/02_image_display.py b/examples/pygfx_backend/02_image_display.py new file mode 100644 index 0000000000..f60668c292 --- /dev/null +++ b/examples/pygfx_backend/02_image_display.py @@ -0,0 +1,52 @@ +"""Image display with pygfx backend. + +Demonstrates: 2D image with colormap, RGBA image, origin/scale, colorbar. +""" + +import numpy +from silx.gui import qt +from silx.gui.plot import Plot2D + + +def main(): + app = qt.QApplication([]) + + # --- Plot2D with colormap image --- + plot = Plot2D(backend="pygfx") + plot.setWindowTitle("pygfx - Image Display") + plot.setGraphTitle("2D Gaussian + Noise") + plot.setGraphXLabel("X") + plot.setGraphYLabel("Y") + + # Generate a 2D Gaussian + size = 256 + x = numpy.linspace(-3, 3, size) + y = numpy.linspace(-3, 3, size) + xx, yy = numpy.meshgrid(x, y) + image = numpy.exp(-(xx**2 + yy**2)) + 0.1 * numpy.random.random((size, size)) + + plot.getDefaultColormap().setName("viridis") + plot.addImage(image, origin=(-3, -3), scale=(6 / size, 6 / size)) + plot.setKeepDataAspectRatio(True) + plot.resetZoom() + plot.show() + + # --- RGBA image window --- + plot2 = Plot2D(backend="pygfx") + plot2.setWindowTitle("pygfx - RGBA Image") + plot2.setGraphTitle("RGBA Gradient") + + rgba = numpy.zeros((200, 300, 4), dtype=numpy.uint8) + rgba[:, :, 0] = numpy.linspace(0, 255, 300)[numpy.newaxis, :] # R gradient + rgba[:, :, 1] = numpy.linspace(0, 255, 200)[:, numpy.newaxis] # G gradient + rgba[:, :, 2] = 128 + rgba[:, :, 3] = 255 + plot2.addImage(rgba) + plot2.resetZoom() + plot2.show() + + app.exec() + + +if __name__ == "__main__": + main() diff --git a/examples/pygfx_backend/03_scatter_plot.py b/examples/pygfx_backend/03_scatter_plot.py new file mode 100644 index 0000000000..86886d3bb3 --- /dev/null +++ b/examples/pygfx_backend/03_scatter_plot.py @@ -0,0 +1,35 @@ +"""Scatter plot with pygfx backend. + +Demonstrates: scatter points with colormap values, symbol sizes. +""" + +import numpy +from silx.gui import qt +from silx.gui.plot import Plot1D + + +def main(): + app = qt.QApplication([]) + + plot = Plot1D(backend="pygfx") + plot.setWindowTitle("pygfx - Scatter Plot") + plot.setGraphTitle("Random Scatter with Colormap") + plot.setGraphXLabel("X") + plot.setGraphYLabel("Y") + plot.getDefaultColormap().setName("plasma") + + numpy.random.seed(42) + n = 200 + x = numpy.random.randn(n) + y = numpy.random.randn(n) + value = numpy.sqrt(x**2 + y**2) # distance from origin + + plot.addScatter(x, y, value, legend="distance", symbol="o") + plot.setKeepDataAspectRatio(True) + plot.resetZoom() + plot.show() + app.exec() + + +if __name__ == "__main__": + main() diff --git a/examples/pygfx_backend/04_line_styles.py b/examples/pygfx_backend/04_line_styles.py new file mode 100644 index 0000000000..1b9055117a --- /dev/null +++ b/examples/pygfx_backend/04_line_styles.py @@ -0,0 +1,67 @@ +"""Line styles and symbols with pygfx backend. + +Demonstrates: all line styles (solid, dashed, dash-dot, dotted), +various symbols, line widths, gap colors. +""" + +import numpy +from silx.gui import qt +from silx.gui.plot import PlotWidget + + +def main(): + app = qt.QApplication([]) + + plot = PlotWidget(backend="pygfx") + plot.setWindowTitle("pygfx - Line Styles & Symbols") + plot.setGraphTitle("Line Styles and Symbols") + plot.setGraphXLabel("X") + plot.setGraphYLabel("Y") + + x = numpy.linspace(0, 10, 100) + + # Line styles + styles = [ + ("-", "solid"), + ("--", "dashed"), + ("-.", "dash-dot"), + (":", "dotted"), + ] + for i, (style, name) in enumerate(styles): + y = numpy.sin(x) + i * 2.5 + plot.addCurve(x, y, legend=name, linestyle=style, linewidth=2, symbol="") + + # Gap color example + y = numpy.sin(x) + 10 + plot.addCurve( + x, + y, + legend="dashed+gapcolor", + linestyle="--", + linewidth=2, + symbol="", + color="blue", + ) + + # Symbols (only those supported by silx SymbolMixIn) + symbols = ["o", ".", "+", "x", "d", "s", ",", "|", "_"] + x_sym = numpy.linspace(0, 10, 30) + for i, sym in enumerate(symbols): + y_sym = numpy.cos(x_sym) + 15 + i * 1.5 + plot.addCurve( + x_sym, + y_sym, + legend=f"sym '{sym}'", + symbol=sym, + linestyle=" ", + color=f"C{i % 10}", + ) + + plot.setActiveCurveHandling(False) + plot.resetZoom() + plot.show() + app.exec() + + +if __name__ == "__main__": + main() diff --git a/examples/pygfx_backend/05_markers_shapes.py b/examples/pygfx_backend/05_markers_shapes.py new file mode 100644 index 0000000000..9890de47f3 --- /dev/null +++ b/examples/pygfx_backend/05_markers_shapes.py @@ -0,0 +1,57 @@ +"""Markers and shapes with pygfx backend. + +Demonstrates: point markers with text, x/y markers, shapes (rectangle, polyline). +""" + +import numpy +from silx.gui import qt +from silx.gui.plot import PlotWidget +from silx.gui.plot import items + + +def main(): + app = qt.QApplication([]) + + plot = PlotWidget(backend="pygfx") + plot.setWindowTitle("pygfx - Markers & Shapes") + plot.setGraphTitle("Markers, Text Labels and Shapes") + + # Background image for visual reference + size = 100 + xx, yy = numpy.meshgrid(numpy.linspace(0, 1, size), numpy.linspace(0, 1, size)) + image = numpy.sin(10 * xx) * numpy.cos(10 * yy) + plot.addImage(image, origin=(0, 0), scale=(100 / size, 100 / size)) + plot.getDefaultColormap().setName("gray") + + # Point markers with text + plot.addMarker(20, 80, legend="marker1", text="Point A", color="red", symbol="o") + plot.addMarker(50, 60, legend="marker2", text="Point B", color="blue", symbol="d") + plot.addMarker(80, 80, legend="marker3", text="Point C", color="green", symbol="s") + + # Horizontal and vertical markers + plot.addXMarker(30, legend="x_marker", text="X=30", color="yellow") + plot.addYMarker(40, legend="y_marker", text="Y=40", color="cyan") + + # Rectangle shape + rect = items.Shape("rectangle") + rect.setPoints(numpy.array([(10, 10), (45, 45)])) + rect.setColor("red") + rect.setLineWidth(2) + plot.addItem(rect) + + # Polyline shape + poly = items.Shape("polylines") + poly.setPoints(numpy.array([(55, 10), (70, 40), (85, 15), (95, 35)])) + poly.setColor("green") + poly.setLineWidth(2) + plot.addItem(poly) + + plot.setGraphXLimits(-5, 105) + plot.setGraphYLimits(-5, 105) + plot.resetZoom() + plot.show() + app.exec() + + +if __name__ == "__main__": + main() diff --git a/examples/pygfx_backend/06_error_bars.py b/examples/pygfx_backend/06_error_bars.py new file mode 100644 index 0000000000..aa3591457a --- /dev/null +++ b/examples/pygfx_backend/06_error_bars.py @@ -0,0 +1,63 @@ +"""Error bars with pygfx backend. + +Demonstrates: curves with x and y error bars. +""" + +import numpy +from silx.gui import qt +from silx.gui.plot import Plot1D + + +def main(): + app = qt.QApplication([]) + + plot = Plot1D(backend="pygfx") + plot.setWindowTitle("pygfx - Error Bars") + plot.setGraphTitle("Curves with Error Bars") + plot.setGraphXLabel("X") + plot.setGraphYLabel("Y") + + x = numpy.linspace(0, 10, 30) + + # Symmetric Y errors + y1 = numpy.sin(x) + yerr1 = 0.1 + 0.1 * numpy.abs(numpy.sin(x)) + plot.addCurve( + x, + y1, + legend="sym Y error", + color="blue", + symbol="o", + yerror=yerr1, + linewidth=1.5, + ) + + # Asymmetric Y errors + y2 = numpy.cos(x) + 3 + yerr_low = 0.2 * numpy.ones_like(x) + yerr_high = 0.5 * numpy.abs(numpy.cos(x)) + plot.addCurve( + x, + y2, + legend="asym Y error", + color="red", + symbol="s", + yerror=numpy.array([yerr_low, yerr_high]), + linewidth=1.5, + ) + + # X errors + y3 = 0.5 * x - 1.5 + xerr = 0.3 * numpy.ones_like(x) + plot.addCurve( + x, y3, legend="X error", color="green", symbol="d", xerror=xerr, linewidth=1.5 + ) + + plot.setActiveCurveHandling(False) + plot.resetZoom() + plot.show() + app.exec() + + +if __name__ == "__main__": + main() diff --git a/examples/pygfx_backend/07_log_axes.py b/examples/pygfx_backend/07_log_axes.py new file mode 100644 index 0000000000..6970cb904f --- /dev/null +++ b/examples/pygfx_backend/07_log_axes.py @@ -0,0 +1,40 @@ +"""Logarithmic axes with pygfx backend. + +Demonstrates: log scale on X and Y axes, grid. +""" + +import numpy +from silx.gui import qt +from silx.gui.plot import Plot1D + + +def main(): + app = qt.QApplication([]) + + plot = Plot1D(backend="pygfx") + plot.setWindowTitle("pygfx - Log Axes") + plot.setGraphTitle("Logarithmic Scale") + plot.setGraphXLabel("Frequency (Hz)") + plot.setGraphYLabel("Amplitude") + + x = numpy.logspace(0, 5, 200) + + # Power-law decay + y1 = 1e6 * x**-1.5 + plot.addCurve(x, y1, legend="f^-1.5", color="blue", linewidth=2) + + # Exponential decay + y2 = 1e4 * numpy.exp(-x / 1e4) + plot.addCurve(x, y2, legend="exp decay", color="red", linewidth=2) + + plot.getXAxis().setScale("log") + plot.getYAxis().setScale("log") + plot.setGraphGrid("both") + plot.setActiveCurveHandling(False) + plot.resetZoom() + plot.show() + app.exec() + + +if __name__ == "__main__": + main() diff --git a/examples/pygfx_backend/08_compare_three_backends.py b/examples/pygfx_backend/08_compare_three_backends.py new file mode 100644 index 0000000000..1dbe68453e --- /dev/null +++ b/examples/pygfx_backend/08_compare_three_backends.py @@ -0,0 +1,69 @@ +"""Compare all three backends: matplotlib, opengl, pygfx. + +Displays the same data in three side-by-side PlotWidgets. +""" + +import numpy +from silx.gui import qt +from silx.gui.plot import PlotWidget +from silx.gui.plot.utils.axis import SyncAxes + + +def populate(plot): + """Add curves, image, scatter, and markers to a plot.""" + x = numpy.linspace(0, 10, 200) + + # Curves + plot.addCurve(x, numpy.sin(x), legend="sin", color="blue", linewidth=2) + plot.addCurve( + x, + numpy.cos(x), + legend="cos", + color="red", + linewidth=1.5, + linestyle="--", + symbol="o", + ) + + # Markers + plot.addMarker(5, 0, legend="center", text="center", color="green", symbol="d") + plot.addXMarker(numpy.pi, legend="pi", text="pi", color="orange") + + plot.setActiveCurveHandling(False) + plot.resetZoom() + + +def main(): + app = qt.QApplication([]) + + window = qt.QWidget() + window.setWindowTitle("Backend Comparison: mpl vs opengl vs pygfx") + layout = qt.QHBoxLayout(window) + layout.setContentsMargins(0, 0, 0, 0) + + backends = ["mpl", "opengl", "pygfx"] + plots = [] + + for backend in backends: + try: + p = PlotWidget(backend=backend) + p.setGraphTitle(backend) + populate(p) + plots.append(p) + layout.addWidget(p) + except Exception as e: + label = qt.QLabel(f"{backend}: {e}") + layout.addWidget(label) + + # Sync axes across all plots + if len(plots) > 1: + SyncAxes([p.getXAxis() for p in plots]) + SyncAxes([p.getYAxis() for p in plots]) + + window.resize(1500, 500) + window.show() + app.exec() + + +if __name__ == "__main__": + main() diff --git a/examples/pygfx_backend/09_live_benchmark.py b/examples/pygfx_backend/09_live_benchmark.py new file mode 100644 index 0000000000..b84a6c270c --- /dev/null +++ b/examples/pygfx_backend/09_live_benchmark.py @@ -0,0 +1,192 @@ +"""Live update FPS benchmark: matplotlib vs opengl vs pygfx. + +Measures actual draw FPS for each backend with identical workloads. +""" + +import time +import numpy +from silx.gui import qt +from silx.gui.plot import Plot1D + + +class FPSCounter: + def __init__(self): + self.reset() + + def reset(self): + self._count = 0 + self._start = time.perf_counter() + + def tick(self): + self._count += 1 + + @property + def fps(self): + elapsed = time.perf_counter() - self._start + return self._count / elapsed if elapsed > 0 else 0 + + @property + def count(self): + return self._count + + +class BenchmarkWidget(qt.QWidget): + def __init__(self, n_points=1000, duration=5.0): + super().__init__() + self.setWindowTitle("Live Update FPS Benchmark") + + self._n_points = n_points + self._duration = duration + self._x = numpy.linspace(0, 4 * numpy.pi, n_points) + self._phase = 0.0 + + layout = qt.QVBoxLayout(self) + + # Info label + self._label = qt.QLabel( + f"Points: {n_points} | Duration: {duration}s per backend | Starting..." + ) + self._label.setAlignment(qt.Qt.AlignCenter) + font = self._label.font() + font.setPointSize(14) + self._label.setFont(font) + layout.addWidget(self._label) + + # Plot area + plot_layout = qt.QHBoxLayout() + layout.addLayout(plot_layout) + + self._backends = ["mpl", "opengl", "pygfx"] + self._plots = {} + self._fps_labels = {} + + for backend in self._backends: + container = qt.QVBoxLayout() + try: + plot = Plot1D(backend=backend) + plot.setGraphTitle(backend) + plot.setGraphYLimits(-1.5, 1.5) + plot.setActiveCurveHandling(False) + self._plots[backend] = plot + container.addWidget(plot) + except Exception as e: + err = qt.QLabel(f"{backend}: {e}") + container.addWidget(err) + + fps_label = qt.QLabel("waiting...") + fps_label.setAlignment(qt.Qt.AlignCenter) + fps_font = fps_label.font() + fps_font.setPointSize(12) + fps_font.setBold(True) + fps_label.setFont(fps_font) + self._fps_labels[backend] = fps_label + container.addWidget(fps_label) + + plot_layout.addLayout(container) + + # Results label + self._result_label = qt.QLabel("") + self._result_label.setAlignment(qt.Qt.AlignCenter) + font2 = self._result_label.font() + font2.setPointSize(13) + self._result_label.setFont(font2) + layout.addWidget(self._result_label) + + # State + self._current_backend_idx = 0 + self._counter = FPSCounter() + self._results = {} + + # Timer for updates + self._timer = qt.QTimer(self) + self._timer.timeout.connect(self._update) + + # Start after a short delay + qt.QTimer.singleShot(500, self._startNextBackend) + + def _startNextBackend(self): + if self._current_backend_idx >= len(self._backends): + self._showResults() + return + + backend = self._backends[self._current_backend_idx] + if backend not in self._plots: + self._current_backend_idx += 1 + self._startNextBackend() + return + + self._label.setText( + f"Benchmarking: {backend} | {self._n_points} points | " f"{self._duration}s" + ) + self._fps_labels[backend].setText("running...") + self._fps_labels[backend].setStyleSheet("color: blue;") + self._phase = 0.0 + self._counter.reset() + self._timer.start(1) # as fast as possible + + def _update(self): + backend = self._backends[self._current_backend_idx] + plot = self._plots.get(backend) + if plot is None: + return + + self._phase += 0.1 + y = numpy.sin(self._x + self._phase) + plot.addCurve( + self._x, y, legend="bench", color="blue", linewidth=2, resetzoom=False + ) + self._counter.tick() + + fps = self._counter.fps + self._fps_labels[backend].setText(f"{fps:.1f} FPS") + + if time.perf_counter() - self._counter._start >= self._duration: + self._timer.stop() + final_fps = self._counter.fps + self._results[backend] = final_fps + self._fps_labels[backend].setText(f"{final_fps:.1f} FPS") + self._fps_labels[backend].setStyleSheet("color: green;") + self._current_backend_idx += 1 + qt.QTimer.singleShot(300, self._startNextBackend) + + def _showResults(self): + lines = ["Results:"] + for backend, fps in self._results.items(): + lines.append(f" {backend}: {fps:.1f} FPS") + self._label.setText( + " | ".join(f"{b}: {f:.1f} FPS" for b, f in self._results.items()) + ) + self._result_label.setText( + f"Points: {self._n_points} | Duration: {self._duration}s each" + ) + + +def main(): + import argparse + + parser = argparse.ArgumentParser(description="Live update FPS benchmark") + parser.add_argument( + "-n", + "--points", + type=int, + default=1000, + help="Number of curve points (default: 1000)", + ) + parser.add_argument( + "-d", + "--duration", + type=float, + default=5.0, + help="Seconds per backend (default: 5)", + ) + args = parser.parse_args() + + app = qt.QApplication([]) + w = BenchmarkWidget(n_points=args.points, duration=args.duration) + w.resize(1500, 500) + w.show() + app.exec() + + +if __name__ == "__main__": + main() diff --git a/examples/pygfx_backend/09_live_update.py b/examples/pygfx_backend/09_live_update.py new file mode 100644 index 0000000000..8ec19d752e --- /dev/null +++ b/examples/pygfx_backend/09_live_update.py @@ -0,0 +1,82 @@ +"""Live data update with pygfx backend. + +Demonstrates: real-time curve and image updates from a timer. +""" + +import numpy +from silx.gui import qt +from silx.gui.plot import Plot1D, Plot2D + + +class LiveCurveWindow(qt.QMainWindow): + def __init__(self): + super().__init__() + self.setWindowTitle("pygfx - Live Curve Update") + + self._plot = Plot1D(backend="pygfx") + self._plot.setGraphTitle("Live Sine Wave") + self._plot.setGraphXLabel("X") + self._plot.setGraphYLabel("Y") + self._plot.setGraphYLimits(-1.5, 1.5) + self.setCentralWidget(self._plot) + + self._phase = 0.0 + self._x = numpy.linspace(0, 4 * numpy.pi, 500) + + self._timer = qt.QTimer(self) + self._timer.timeout.connect(self._update) + self._timer.start(30) # ~33 fps + + def _update(self): + self._phase += 0.05 + y = numpy.sin(self._x + self._phase) + self._plot.addCurve( + self._x, y, legend="live", color="blue", linewidth=2, resetzoom=False + ) + + +class LiveImageWindow(qt.QMainWindow): + def __init__(self): + super().__init__() + self.setWindowTitle("pygfx - Live Image Update") + + self._plot = Plot2D(backend="pygfx") + self._plot.setGraphTitle("Live 2D Gaussian") + self._plot.getDefaultColormap().setName("inferno") + self.setCentralWidget(self._plot) + + self._size = 128 + self._x0 = 0.0 + self._y0 = 0.0 + x = numpy.linspace(-3, 3, self._size) + self._xx, self._yy = numpy.meshgrid(x, x) + + self._timer = qt.QTimer(self) + self._timer.timeout.connect(self._update) + self._timer.start(50) # ~20 fps + + def _update(self): + self._x0 += 0.05 * (numpy.random.random() - 0.5) + self._y0 += 0.05 * (numpy.random.random() - 0.5) + image = numpy.exp(-((self._xx - self._x0) ** 2 + (self._yy - self._y0) ** 2)) + image += 0.1 * numpy.random.random(image.shape) + self._plot.addImage(image, resetzoom=False) + + +def main(): + app = qt.QApplication([]) + + w1 = LiveCurveWindow() + w1.resize(700, 400) + w1.show() + + w2 = LiveImageWindow() + w2.resize(600, 500) + w2.move(720, 0) + w2.show() + + app.exec() + + +if __name__ == "__main__": + main() diff --git a/examples/pygfx_backend/10_dual_yaxis.py b/examples/pygfx_backend/10_dual_yaxis.py new file mode 100644 index 0000000000..74a88ad03b --- /dev/null +++ b/examples/pygfx_backend/10_dual_yaxis.py @@ -0,0 +1,40 @@ +"""Dual Y-axis with pygfx backend. + +Demonstrates: left and right Y axes with different scales. +""" + +import numpy +from silx.gui import qt +from silx.gui.plot import Plot1D + + +def main(): + app = qt.QApplication([]) + + plot = Plot1D(backend="pygfx") + plot.setWindowTitle("pygfx - Dual Y Axis") + plot.setGraphTitle("Temperature and Pressure") + plot.setGraphXLabel("Time (s)") + plot.getYAxis().setLabel("Temperature (K)") + plot.getYAxis(axis="right").setLabel("Pressure (mbar)") + + x = numpy.linspace(0, 100, 300) + + # Left Y axis: temperature + temp = 300 + 50 * numpy.sin(x / 10) + 5 * numpy.random.randn(len(x)) + plot.addCurve(x, temp, legend="Temperature", color="red", linewidth=2, yaxis="left") + + # Right Y axis: pressure + pressure = 1e-6 + 5e-7 * numpy.cos(x / 15) + 1e-7 * numpy.random.randn(len(x)) + plot.addCurve( + x, pressure, legend="Pressure", color="blue", linewidth=2, yaxis="right" + ) + + plot.setActiveCurveHandling(False) + plot.resetZoom() + plot.show() + app.exec() + + +if __name__ == "__main__": + main() diff --git a/examples/pygfx_backend/11_imageview.py b/examples/pygfx_backend/11_imageview.py new file mode 100644 index 0000000000..4bbe6c9baf --- /dev/null +++ b/examples/pygfx_backend/11_imageview.py @@ -0,0 +1,38 @@ +"""ImageView widget with pygfx backend. + +Demonstrates: ImageView with side histograms, colormap, aspect ratio. +""" + +import numpy +from silx.gui import qt +from silx.gui.plot.ImageView import ImageView + + +def main(): + app = qt.QApplication([]) + + view = ImageView(backend="pygfx") + view.setWindowTitle("pygfx - ImageView") + view.setKeepDataAspectRatio(True) + + # Generate a multi-peak image + size = 256 + x = numpy.linspace(-5, 5, size) + y = numpy.linspace(-5, 5, size) + xx, yy = numpy.meshgrid(x, y) + + image = ( + numpy.exp(-((xx - 1) ** 2 + (yy - 1) ** 2)) + + 0.7 * numpy.exp(-((xx + 2) ** 2 + (yy + 1) ** 2) / 0.5) + + 0.3 * numpy.exp(-((xx - 2) ** 2 + (yy + 2) ** 2) / 2) + + 0.05 * numpy.random.random((size, size)) + ) + + view.setImage(image, origin=(-5, -5), scale=(10 / size, 10 / size)) + view.setColormap("viridis") + view.show() + app.exec() + + +if __name__ == "__main__": + main() diff --git a/examples/pygfx_backend/12_large_data.py b/examples/pygfx_backend/12_large_data.py new file mode 100644 index 0000000000..89267637d6 --- /dev/null +++ b/examples/pygfx_backend/12_large_data.py @@ -0,0 +1,49 @@ +"""Large dataset performance test with pygfx backend. + +Demonstrates: rendering performance with large curve and image data. +""" + +import numpy +from silx.gui import qt +from silx.gui.plot import Plot1D, Plot2D + + +def main(): + app = qt.QApplication([]) + + # --- Large curve: 1M points --- + plot1 = Plot1D(backend="pygfx") + plot1.setWindowTitle("pygfx - 1M Points Curve") + plot1.setGraphTitle("1,000,000 Points") + + n = 1_000_000 + x = numpy.linspace(0, 100, n) + y = numpy.sin(x * 10) * numpy.exp(-x / 30) + 0.1 * numpy.random.randn(n) + plot1.addCurve(x, y, legend="1M pts", color="blue", linewidth=1) + plot1.resetZoom() + plot1.resize(800, 400) + plot1.show() + + # --- Large image: 2048x2048 --- + plot2 = Plot2D(backend="pygfx") + plot2.setWindowTitle("pygfx - 2048x2048 Image") + plot2.setGraphTitle("2048 x 2048 Image") + plot2.getDefaultColormap().setName("magma") + + size = 2048 + xx, yy = numpy.meshgrid( + numpy.linspace(-10, 10, size), numpy.linspace(-10, 10, size) + ) + image = numpy.sin(xx) * numpy.cos(yy) + 0.05 * numpy.random.random((size, size)) + plot2.addImage(image) + plot2.setKeepDataAspectRatio(True) + plot2.resetZoom() + plot2.resize(600, 600) + plot2.move(820, 0) + plot2.show() + + app.exec() + + +if __name__ == "__main__": + main() diff --git a/examples/pygfx_backend/13_3d_scatter.py b/examples/pygfx_backend/13_3d_scatter.py new file mode 100644 index 0000000000..34cd8e9a76 --- /dev/null +++ b/examples/pygfx_backend/13_3d_scatter.py @@ -0,0 +1,53 @@ +"""3D scatter plot with SceneWidget. + +Demonstrates: 3D scatter points with colormap, symbol customization, +and group transforms. +""" + +import numpy +from silx.gui import qt +from silx.gui.plot3d.SceneWindow import SceneWindow, items + + +def main(): + app = qt.QApplication([]) + + window = SceneWindow(backend="pygfx") + window.setWindowTitle("3D Scatter Plot") + + scene = window.getSceneWidget() + scene.setBackgroundColor((0.15, 0.15, 0.2, 1.0)) + scene.setForegroundColor((0.9, 0.9, 0.9, 1.0)) + scene.setTextColor((0.9, 0.9, 0.9, 1.0)) + + # Generate clustered 3D scatter data + n_per_cluster = 500 + clusters = [] + centers = [(-2, -2, -2), (2, 2, 2), (-2, 2, 0), (2, -2, 0)] + for cx, cy, cz in centers: + x = numpy.random.normal(cx, 0.8, n_per_cluster) + y = numpy.random.normal(cy, 0.8, n_per_cluster) + z = numpy.random.normal(cz, 0.8, n_per_cluster) + clusters.append((x, y, z)) + + x = numpy.concatenate([c[0] for c in clusters]) + y = numpy.concatenate([c[1] for c in clusters]) + z = numpy.concatenate([c[2] for c in clusters]) + values = numpy.sqrt(x**2 + y**2 + z**2) # distance from origin + + scatter = items.Scatter3D() + scatter.setData(x, y, z, values) + scatter.getColormap().setName("viridis") + scatter.setSymbol("o") + scatter.setSymbolSize(6) + scatter.setLabel("Clustered scatter") + + scene.addItem(scatter) + + window.resize(800, 600) + window.show() + app.exec() + + +if __name__ == "__main__": + main() diff --git a/examples/pygfx_backend/14_3d_volume.py b/examples/pygfx_backend/14_3d_volume.py new file mode 100644 index 0000000000..1b0e35f16c --- /dev/null +++ b/examples/pygfx_backend/14_3d_volume.py @@ -0,0 +1,57 @@ +"""3D scalar field volume with isosurfaces and cut plane. + +Demonstrates: ScalarField3D, isosurfaces at multiple levels, +interactive cut plane with colormap. +""" + +import numpy +from silx.gui import qt +from silx.gui.plot3d.SceneWindow import SceneWindow, items + + +def main(): + app = qt.QApplication([]) + + window = SceneWindow(backend="pygfx") + window.setWindowTitle("3D Volume - Isosurfaces & Cut Plane") + + scene = window.getSceneWidget() + scene.setBackgroundColor((0.1, 0.1, 0.15, 1.0)) + scene.setForegroundColor((0.9, 0.9, 0.9, 1.0)) + scene.setTextColor((0.9, 0.9, 0.9, 1.0)) + + # Generate 3D Gaussian + noise volume + size = 64 + lin = numpy.linspace(-3, 3, size) + x, y, z = numpy.meshgrid(lin, lin, lin) + data = ( + numpy.exp(-(x**2 + y**2 + z**2)) + + 0.5 * numpy.exp(-((x - 1.5) ** 2 + (y - 1) ** 2 + (z + 1) ** 2) / 0.5) + + 0.02 * numpy.random.random((size, size, size)) + ) + + volume = items.ScalarField3D() + volume.setData(data.astype(numpy.float32)) + volume.setLabel("Gaussian blobs") + + # Add isosurfaces at different levels + volume.addIsosurface(0.3, "#FF660080") # orange, semi-transparent + volume.addIsosurface(0.6, "#3399FF80") # blue, semi-transparent + volume.addIsosurface(0.9, "#FF3366CC") # red, more opaque + + # Set up cut plane + cutPlane = volume.getCutPlanes()[0] + cutPlane.setVisible(True) + cutPlane.getColormap().setName("magma") + cutPlane.setNormal((0.0, 0.0, 1.0)) + cutPlane.moveToCenter() + + scene.addItem(volume) + + window.resize(800, 600) + window.show() + app.exec() + + +if __name__ == "__main__": + main() diff --git a/examples/pygfx_backend/15_3d_surface.py b/examples/pygfx_backend/15_3d_surface.py new file mode 100644 index 0000000000..6f381456f6 --- /dev/null +++ b/examples/pygfx_backend/15_3d_surface.py @@ -0,0 +1,51 @@ +"""3D surface (height map) visualization. + +Demonstrates: 2D scatter as solid surface with height map, +wireframe, and points modes side by side. +""" + +import numpy +from silx.gui import qt +from silx.gui.plot3d.SceneWindow import SceneWindow + + +def main(): + app = qt.QApplication([]) + + window = SceneWindow(backend="pygfx") + window.setWindowTitle("3D Surface - Height Map") + + scene = window.getSceneWidget() + scene.setBackgroundColor((0.12, 0.12, 0.18, 1.0)) + scene.setForegroundColor((0.9, 0.9, 0.9, 1.0)) + scene.setTextColor((0.9, 0.9, 0.9, 1.0)) + + # Generate surface data on a grid + n = 50 + x = numpy.linspace(-3, 3, n) + y = numpy.linspace(-3, 3, n) + xx, yy = numpy.meshgrid(x, y) + xx = xx.ravel() + yy = yy.ravel() + values = numpy.sin(xx) * numpy.cos(yy) * numpy.exp(-(xx**2 + yy**2) / 8) + + modes = ["solid", "lines", "points"] + for i, mode in enumerate(modes): + scatter2d = scene.add2DScatter(xx, yy, values) + scatter2d.setTranslation(i * 8.0, 0.0, 0.0) + scatter2d.setHeightMap(True) + scatter2d.setVisualization(mode) + scatter2d.getColormap().setName("coolwarm") + scatter2d.setLabel(f"Surface ({mode})") + if mode == "points": + scatter2d.setSymbolSize(4) + if mode == "lines": + scatter2d.setLineWidth(1.5) + + window.resize(1000, 600) + window.show() + app.exec() + + +if __name__ == "__main__": + main() diff --git a/examples/pygfx_backend/16_3d_mesh.py b/examples/pygfx_backend/16_3d_mesh.py new file mode 100644 index 0000000000..e92a5a6f8c --- /dev/null +++ b/examples/pygfx_backend/16_3d_mesh.py @@ -0,0 +1,113 @@ +"""3D mesh and primitive shapes. + +Demonstrates: custom Mesh with triangle data, Box, Cylinder, Hexagon +primitives with transforms and grouping. +""" + +import numpy +from silx.gui import qt +from silx.gui.plot3d.SceneWindow import SceneWindow, items + + +def make_sphere_mesh(radius=1.0, n_lat=20, n_lon=20): + """Generate sphere triangle mesh vertices and per-vertex colors.""" + positions = [] + colorvals = [] + + for i in range(n_lat): + theta0 = numpy.pi * i / n_lat + theta1 = numpy.pi * (i + 1) / n_lat + for j in range(n_lon): + phi0 = 2 * numpy.pi * j / n_lon + phi1 = 2 * numpy.pi * (j + 1) / n_lon + + # Two triangles per quad + p00 = [ + radius * numpy.sin(theta0) * numpy.cos(phi0), + radius * numpy.sin(theta0) * numpy.sin(phi0), + radius * numpy.cos(theta0), + ] + p10 = [ + radius * numpy.sin(theta1) * numpy.cos(phi0), + radius * numpy.sin(theta1) * numpy.sin(phi0), + radius * numpy.cos(theta1), + ] + p01 = [ + radius * numpy.sin(theta0) * numpy.cos(phi1), + radius * numpy.sin(theta0) * numpy.sin(phi1), + radius * numpy.cos(theta0), + ] + p11 = [ + radius * numpy.sin(theta1) * numpy.cos(phi1), + radius * numpy.sin(theta1) * numpy.sin(phi1), + radius * numpy.cos(theta1), + ] + + positions.extend([p00, p10, p11, p00, p11, p01]) + # Color based on latitude + c = float(i) / n_lat + color = [c, 0.3, 1.0 - c, 1.0] + colorvals.extend([color] * 6) + + return ( + numpy.array(positions, dtype=numpy.float32), + numpy.array(colorvals, dtype=numpy.float32), + ) + + +def main(): + app = qt.QApplication([]) + + window = SceneWindow(backend="pygfx") + window.setWindowTitle("3D Mesh & Primitives") + + scene = window.getSceneWidget() + scene.setBackgroundColor((0.1, 0.1, 0.15, 1.0)) + scene.setForegroundColor((0.9, 0.9, 0.9, 1.0)) + scene.setTextColor((0.9, 0.9, 0.9, 1.0)) + + # Custom sphere mesh + positions, vertex_colors = make_sphere_mesh(radius=2.0) + normals = positions / numpy.linalg.norm(positions, axis=1, keepdims=True) + + mesh = items.Mesh() + mesh.setData( + position=positions, + color=vertex_colors, + normal=normals, + mode="triangles", + ) + mesh.setLabel("Sphere mesh") + scene.addItem(mesh) + + # Box primitive + box = items.Box() + box.setData(size=(2, 2, 2)) + box.color = (0.2, 0.8, 0.3, 0.8) + box.setTranslation(5, 0, 0) + box.setLabel("Box") + scene.addItem(box) + + # Cylinder primitive + cylinder = items.Cylinder() + cylinder.setData(radius=1.0, height=3.0) + cylinder.color = (0.8, 0.3, 0.2, 0.8) + cylinder.setTranslation(10, 0, 0) + cylinder.setLabel("Cylinder") + scene.addItem(cylinder) + + # Hexagon primitive + hexagon = items.Hexagon() + hexagon.setData(radius=1.5, height=2.0) + hexagon.color = (0.3, 0.3, 0.9, 0.8) + hexagon.setTranslation(15, 0, 0) + hexagon.setLabel("Hexagon") + scene.addItem(hexagon) + + window.resize(900, 600) + window.show() + app.exec() + + +if __name__ == "__main__": + main() diff --git a/examples/pygfx_backend/17_3d_image_heightmap.py b/examples/pygfx_backend/17_3d_image_heightmap.py new file mode 100644 index 0000000000..aba5603d82 --- /dev/null +++ b/examples/pygfx_backend/17_3d_image_heightmap.py @@ -0,0 +1,61 @@ +"""3D image and height map display. + +Demonstrates: ImageData with colormap, ImageRgba, and HeightMapData +displayed in a 3D scene with transforms. +""" + +import numpy +from silx.gui import qt +from silx.gui.plot3d.SceneWindow import SceneWindow, items + + +def main(): + app = qt.QApplication([]) + + window = SceneWindow(backend="pygfx") + window.setWindowTitle("3D Images & Height Maps") + + scene = window.getSceneWidget() + scene.setBackgroundColor((0.12, 0.12, 0.18, 1.0)) + scene.setForegroundColor((0.9, 0.9, 0.9, 1.0)) + scene.setTextColor((0.9, 0.9, 0.9, 1.0)) + + size = 256 + + # 1. Grayscale image with colormap + xx, yy = numpy.meshgrid(numpy.linspace(-5, 5, size), numpy.linspace(-5, 5, size)) + data = numpy.sin(xx) * numpy.cos(yy) + + imageData = scene.addImage(data.astype(numpy.float32)) + imageData.setLabel("Grayscale (magma)") + imageData.getColormap().setName("magma") + imageData.setInterpolation("linear") + + # 2. RGBA image + rgba = numpy.zeros((size, size, 3), dtype=numpy.float32) + rgba[:, :, 0] = numpy.clip((xx + 5) / 10, 0, 1) # R: left-right gradient + rgba[:, :, 1] = numpy.clip(numpy.exp(-(xx**2 + yy**2) / 8), 0, 1) # G: center blob + rgba[:, :, 2] = numpy.clip((yy + 5) / 10, 0, 1) # B: bottom-top gradient + + imageRgba = scene.addImage(rgba) + imageRgba.setLabel("RGB image") + imageRgba.setTranslation(size + 20, 0, 0) + + # 3. Height map + heightData = numpy.exp(-(xx**2 + yy**2) / 4).astype(numpy.float32) + + heightMap = items.HeightMapData() + heightMap.setData(heightData) + heightMap.getColormap().setName("viridis") + heightMap.setTranslation(0, size + 20, 0) + heightMap.setScale(1, 1, 50) # exaggerate height + heightMap.setLabel("Height map") + scene.addItem(heightMap) + + window.resize(900, 600) + window.show() + app.exec() + + +if __name__ == "__main__": + main() diff --git a/examples/pygfx_backend/18_3d_clip_group.py b/examples/pygfx_backend/18_3d_clip_group.py new file mode 100644 index 0000000000..f9dedada14 --- /dev/null +++ b/examples/pygfx_backend/18_3d_clip_group.py @@ -0,0 +1,77 @@ +"""3D clipping plane and group transforms. + +Demonstrates: ClipPlane to slice through 3D objects, +GroupItem for shared transforms, multiple items in a scene. +""" + +import numpy +from silx.gui import qt +from silx.gui.plot3d.SceneWindow import SceneWindow, items + + +def main(): + app = qt.QApplication([]) + + window = SceneWindow(backend="pygfx") + window.setWindowTitle("3D Clipping & Groups") + + scene = window.getSceneWidget() + scene.setBackgroundColor((0.1, 0.1, 0.15, 1.0)) + scene.setForegroundColor((0.9, 0.9, 0.9, 1.0)) + scene.setTextColor((0.9, 0.9, 0.9, 1.0)) + + # Create a group for clipped items + group = items.GroupItem() + group.setLabel("Clipped group") + + # Add a clipping plane + clip = items.ClipPlane() + clip.setNormal((1.0, 0.3, 0.0)) + clip.setPoint((32, 32, 32)) + group.addItem(clip) + + # Add a 3D volume to the group + size = 64 + lin = numpy.linspace(-3, 3, size) + x, y, z = numpy.meshgrid(lin, lin, lin) + data = numpy.exp(-(x**2 + y**2 + z**2)).astype(numpy.float32) + + volume = items.ScalarField3D() + volume.setData(data) + volume.setLabel("Volume") + volume.addIsosurface(0.3, "#FF8800AA") + volume.addIsosurface(0.7, "#0088FFCC") + group.addItem(volume) + + # Add a 3D scatter to the same group (also gets clipped) + n = 2000 + sx = numpy.random.normal(32, 15, n).astype(numpy.float32) + sy = numpy.random.normal(32, 15, n).astype(numpy.float32) + sz = numpy.random.normal(32, 15, n).astype(numpy.float32) + sv = numpy.sqrt((sx - 32) ** 2 + (sy - 32) ** 2 + (sz - 32) ** 2) + + scatter = items.Scatter3D() + scatter.setData(sx, sy, sz, sv) + scatter.getColormap().setName("plasma") + scatter.setSymbol("o") + scatter.setSymbolSize(4) + scatter.setLabel("Scatter (clipped)") + group.addItem(scatter) + + scene.addItem(group) + + # Add an unclipped reference box outside the group + box = items.Box() + box.setData(size=(10, 10, 10)) + box.color = (0.5, 0.9, 0.5, 0.5) + box.setTranslation(80, 32, 32) + box.setLabel("Unclipped box") + scene.addItem(box) + + window.resize(900, 600) + window.show() + app.exec() + + +if __name__ == "__main__": + main() diff --git a/examples/pygfx_backend/19_gpu_colormap_benchmark.py b/examples/pygfx_backend/19_gpu_colormap_benchmark.py new file mode 100644 index 0000000000..742c440b7d --- /dev/null +++ b/examples/pygfx_backend/19_gpu_colormap_benchmark.py @@ -0,0 +1,258 @@ +"""GPU colormap performance benchmark for pygfx backend. + +Pre-generates image frames, then streams them at maximum rate +to measure pure rendering throughput (no data generation overhead). + + plot_ms - addImage + _draw() (GPU pipeline) + other - Qt processEvents overhead + +Usage: + python 19_gpu_colormap_benchmark.py + python 19_gpu_colormap_benchmark.py --size 2048 --duration 10 +""" + +import time +import numpy as np +import argparse +from silx.gui import qt +from silx.gui.plot.PlotWindow import PlotWindow +from silx.gui.colors import Colormap + +NUM_PREGEN_FRAMES = 20 + + +def _pregenerate_frames(size, n=NUM_PREGEN_FRAMES): + """Pre-generate a pool of test frames.""" + frames = [] + for i in range(n): + t = i * 0.3 + cx = (np.sin(t) * 0.5 + 0.5) * size + cy = (np.cos(t) * 0.5 + 0.5) * size + sigma = size / 8 + y, x = np.ogrid[:size, :size] + img = np.exp(-((x - cx) ** 2 + (y - cy) ** 2) / (2 * sigma**2)) + img += 0.05 * np.random.random((size, size)) + frames.append(img.astype(np.float32)) + return frames + + +class StreamingBenchmark(qt.QWidget): + def __init__(self, image_size=1024, duration=5.0): + super().__init__() + self.setWindowTitle("pygfx GPU Colormap Benchmark") + self._duration = duration + self._image_size = image_size + + layout = qt.QVBoxLayout(self) + + # Controls + ctrl = qt.QHBoxLayout() + ctrl.addWidget(qt.QLabel("Size:")) + self._size_combo = qt.QComboBox() + self._size_combo.addItems(["256", "512", "1024", "2048", "4096"]) + self._size_combo.setCurrentText(str(image_size)) + ctrl.addWidget(self._size_combo) + + ctrl.addWidget(qt.QLabel("Norm:")) + self._norm_combo = qt.QComboBox() + self._norm_combo.addItems(["linear", "log", "sqrt", "gamma", "arcsinh"]) + ctrl.addWidget(self._norm_combo) + + ctrl.addStretch() + self._status = qt.QLabel("Ready") + self._status.setMinimumWidth(400) + font = self._status.font() + font.setPointSize(13) + font.setBold(True) + self._status.setFont(font) + ctrl.addWidget(self._status) + + self._start_btn = qt.QPushButton("Start") + self._stop_btn = qt.QPushButton("Stop") + self._stop_btn.setEnabled(False) + ctrl.addWidget(self._start_btn) + ctrl.addWidget(self._stop_btn) + layout.addLayout(ctrl) + + # Plot with toolbar (includes colormap dialog button) + self._plot = PlotWindow( + backend="pygfx", colormap=True, mask=False, roi=False, fit=False + ) + self._plot.setGraphTitle("pygfx streaming") + self._plot.setKeepDataAspectRatio(True) + layout.addWidget(self._plot) + + # Results table + self._results_text = qt.QTextEdit() + self._results_text.setReadOnly(True) + self._results_text.setMaximumHeight(200) + self._results_text.setFontFamily("monospace") + self._results_text.setText( + "Results will appear here after each run.\n" + "Try different sizes and normalizations to compare.\n\n" + "plot_ms = addImage + _draw() (GPU pipeline)\n" + "other = Qt processEvents overhead\n" + "total = plot + other (should ~ 1000/FPS)" + ) + layout.addWidget(self._results_text) + + # State + self._timer = qt.QTimer(self) + self._timer.timeout.connect(self._tick) + self._frame_count = 0 + self._t_start = 0.0 + self._frame_plot = [] + self._frame_other = [] + self._frames = [] + self._results = [] + + self._start_btn.clicked.connect(self._start) + self._stop_btn.clicked.connect(self._stop) + + def _start(self): + size = int(self._size_combo.currentText()) + norm = self._norm_combo.currentText() + + # Set colormap + if norm == "log": + cm = Colormap("viridis", normalization="log", vmin=0.01, vmax=1.5) + elif norm == "gamma": + cm = Colormap("viridis", normalization="gamma", vmin=0.0, vmax=1.5) + cm.setGammaNormalizationParameter(2.2) + elif norm == "arcsinh": + cm = Colormap("viridis", normalization="arcsinh", vmin=-0.5, vmax=1.5) + else: + cm = Colormap("viridis", normalization=norm, vmin=0.0, vmax=1.5) + + self._plot.setDefaultColormap(cm) + self._image_size = size + self._frame_count = 0 + self._frame_plot = [] + self._frame_other = [] + + # Pre-generate frames + self._status.setText( + f"Generating {NUM_PREGEN_FRAMES} frames ({size}x{size})..." + ) + qt.QApplication.processEvents() + self._frames = _pregenerate_frames(size) + + # Warm-up frame + self._plot.addImage(self._frames[0], legend="bench", resetzoom=True) + qt.QApplication.processEvents() + + self._t_start = time.perf_counter() + self._last_fps_time = self._t_start + + self._start_btn.setEnabled(False) + self._stop_btn.setEnabled(True) + self._size_combo.setEnabled(False) + self._norm_combo.setEnabled(False) + self._status.setText(f"Running: {size}x{size} {norm}...") + self._timer.start(0) + + def _stop(self): + self._timer.stop() + elapsed = time.perf_counter() - self._t_start + n = max(len(self._frame_plot), 1) + avg_fps = n / elapsed if elapsed > 0 else 0 + + plot = np.array(self._frame_plot) * 1000 + other = np.array(self._frame_other) * 1000 + + avg_plot = float(np.mean(plot)) if len(plot) else 0 + avg_other = float(np.mean(other)) if len(other) else 0 + + size = self._image_size + norm = self._norm_combo.currentText() + + self._results.append((size, norm, avg_fps, avg_plot, avg_other, n, elapsed)) + + # Update results table + lines = [ + f"{'Size':>6} {'Norm':>8} {'FPS':>7} " + f"{'plot_ms':>8} {'other':>7} {'total':>7} " + f"{'Frames':>7} {'Time':>5}" + ] + lines.append("-" * 62) + for s, no, fps, pm, om, fr, t in self._results: + lines.append( + f"{s:>6} {no:>8} {fps:>7.1f} " + f"{pm:>8.2f} {om:>7.2f} {pm + om:>7.2f} " + f"{fr:>7} {t:>4.1f}s" + ) + self._results_text.setText("\n".join(lines)) + + # Also print to console + print("\n".join(lines[-1:])) + + self._status.setText( + f"Done: {avg_fps:.1f} FPS | " + f"plot {avg_plot:.1f} + other {avg_other:.1f}ms" + ) + self._start_btn.setEnabled(True) + self._stop_btn.setEnabled(False) + self._size_combo.setEnabled(True) + self._norm_combo.setEnabled(True) + + def _tick(self): + img = self._frames[self._frame_count % len(self._frames)] + + # --- Plot update --- + t0 = time.perf_counter() + self._plot.addImage(img, legend="bench", resetzoom=False) + t1 = time.perf_counter() + + # --- Qt event processing --- + qt.QApplication.processEvents() + t2 = time.perf_counter() + + self._frame_plot.append(t1 - t0) + self._frame_other.append(t2 - t1) + self._frame_count += 1 + + # Update status every 0.5s + if t2 - self._last_fps_time >= 0.5: + n = self._frame_count + elapsed = t2 - self._t_start + fps = n / elapsed if elapsed > 0 else 0 + avg_plot = np.mean(self._frame_plot) * 1000 + avg_other = np.mean(self._frame_other) * 1000 + self._status.setText( + f"{self._image_size}x{self._image_size} | FPS: {fps:.1f} | " + f"plot {avg_plot:.1f} + other {avg_other:.1f}ms" + ) + self._last_fps_time = t2 + + # Auto-stop after duration + if t2 - self._t_start >= self._duration: + self._stop() + + +def main(): + parser = argparse.ArgumentParser(description="pygfx GPU colormap benchmark") + parser.add_argument( + "-s", + "--size", + type=int, + default=1024, + help="Initial image size (default: 1024)", + ) + parser.add_argument( + "-d", + "--duration", + type=float, + default=5.0, + help="Seconds per run (default: 5)", + ) + args = parser.parse_args() + + app = qt.QApplication([]) + w = StreamingBenchmark(image_size=args.size, duration=args.duration) + w.resize(900, 700) + w.show() + app.exec() + + +if __name__ == "__main__": + main() diff --git a/pyproject.toml b/pyproject.toml index 1c7e6a180e..26ed0dff9b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -51,8 +51,12 @@ opencl = [ 'pyopencl', # For silx.opencl 'Mako', # For pyopencl reduction ] +pygfx = [ + 'pygfx >=0.16.0,<0.17.0', # For silx.gui.plot (pygfx backend) + 'rendercanvas >=2.6.1', # For silx.gui.plot (pygfx backend) +] full = [ - 'silx[opencl]', + 'silx[opencl,pygfx]', 'qtconsole', # For silx.gui.console 'matplotlib >= 3.6', # For silx.gui.plot 'PyOpenGL', # For silx.gui.plot3d diff --git a/src/silx/_config.py b/src/silx/_config.py index ec1d05bd5f..01ea984243 100644 --- a/src/silx/_config.py +++ b/src/silx/_config.py @@ -45,6 +45,7 @@ class Config: - 'matplotlib' (default) or 'mpl' - 'opengl', 'gl' + - 'pygfx', 'wgpu' (requires pygfx and rendercanvas packages) - 'none' - A :class:`silx.gui.plot.backend.BackendBase.BackendBase` class - A callable returning backend class or binding name diff --git a/src/silx/conftest.py b/src/silx/conftest.py index e2162d0ce8..e37838312a 100644 --- a/src/silx/conftest.py +++ b/src/silx/conftest.py @@ -46,6 +46,13 @@ def pytest_addoption(parser): action="store_false", help="Disable tests using OpenGL", ) + parser.addoption( + "--no-pygfx", + dest="pygfx", + default=True, + action="store_false", + help="Disable tests using pygfx", + ) parser.addoption( "--no-opencl", dest="opencl", @@ -120,6 +127,16 @@ def use_opengl(test_options): pytest.skip(test_options.WITH_GL_TEST_REASON, allow_module_level=True) +@pytest.fixture(scope="session") +def use_pygfx(test_options): + """Fixture to flag test using pygfx. + + This can be skipped with `--no-pygfx`. + """ + if not test_options.WITH_PYGFX_TEST: + pytest.skip(test_options.WITH_PYGFX_TEST_REASON, allow_module_level=True) + + @pytest.fixture(scope="session") def use_opencl(test_options): """Fixture to flag test using a OpenCL. diff --git a/src/silx/gui/plot/PlotWidget.py b/src/silx/gui/plot/PlotWidget.py index 6fd34b6159..3093b34e06 100755 --- a/src/silx/gui/plot/PlotWidget.py +++ b/src/silx/gui/plot/PlotWidget.py @@ -534,6 +534,28 @@ def __getBackendClass(self, backend: BackendType) -> BackendBase: _logger.debug("Backtrace", exc_info=True) raise RuntimeError("OpenGL backend is not available") + elif backend in ("pygfx", "wgpu"): + import os + import sys + + if sys.platform.startswith("linux"): + if not os.environ.get("DISPLAY", ""): + raise RuntimeError( + "pygfx backend is not available: " + "DISPLAY environment variable not set" + ) + if os.environ.get("XDG_SESSION_TYPE", "") == "wayland": + raise RuntimeError( + "pygfx backend is not available: " + "Wayland sessions are not supported" + ) + + try: + from .backends.BackendPygfx import BackendPygfx as backendClass + except ImportError: + _logger.debug("Backtrace", exc_info=True) + raise RuntimeError("pygfx backend is not available") + elif backend == "none": from .backends.BackendBase import BackendBase as backendClass @@ -566,11 +588,12 @@ def setBackend(self, backend: BackendType): - 'matplotlib' and 'mpl': Matplotlib with Qt. - 'opengl' and 'gl': OpenGL backend (requires PyOpenGL and OpenGL >= 2.1) + - 'pygfx' and 'wgpu': pygfx/WGPU backend (requires pygfx and rendercanvas) - 'none': No backend, to run headless for testing purpose. :param backend: The backend to use, in: - 'matplotlib' (default), 'mpl', 'opengl', 'gl', 'none', + 'matplotlib' (default), 'mpl', 'opengl', 'gl', 'pygfx', 'wgpu', 'none', a :class:`BackendBase.BackendBase` class. If multiple backends are provided, the first available one is used. :raises ValueError: Unsupported backend descriptor diff --git a/src/silx/gui/plot/backends/BackendPygfx.py b/src/silx/gui/plot/backends/BackendPygfx.py new file mode 100644 index 0000000000..a8ec3916bc --- /dev/null +++ b/src/silx/gui/plot/backends/BackendPygfx.py @@ -0,0 +1,2802 @@ +# /*########################################################################## +# +# Copyright (c) 2024 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ############################################################################*/ +"""pygfx (WGPU) Plot backend.""" + +from __future__ import annotations + +__authors__ = ["S. Kim"] +__license__ = "MIT" + +import logging +import math +import re +import threading + +import numpy +import wgpu + +from rendercanvas.qt import QRenderWidget +import pygfx as gfx + +from .. import items +from .._utils import FLOAT32_MINPOS +from . import BackendBase +from ... import colors +from ... import qt +from ._PlotFrameCore import PlotFrame2DCore +from silx.gui.colors import RGBAColorType + +_logger = logging.getLogger(__name__) + +_MATHDEFAULT_RE = re.compile(r"\$\\mathdefault\{([^}]*)\}\$") + + +def _stripMathDefault(text): + """Strip matplotlib's $\\mathdefault{...}$ LaTeX wrapping from tick labels.""" + if text is None: + return text + return _MATHDEFAULT_RE.sub(r"\1", text) + + +# Dash pattern mapping: silx linestyle -> pygfx dash_pattern +# pygfx dash_pattern is a tuple of (dash, gap, ...) relative to line thickness +_DASH_PATTERNS = { + "": None, + " ": None, + "-": None, # solid + "--": (3.7, 1.6, 3.7, 1.6), + "-.": (6.4, 1.6, 1, 1.6), + ":": (1, 1.65, 1, 1.65), +} + + +def _lineStyleToDashPattern(linestyle): + """Convert silx linestyle to pygfx dash_pattern tuple.""" + if linestyle is None or linestyle in ("", " "): + return None + if isinstance(linestyle, tuple) and len(linestyle) == 2: + # Custom (offset, (on, off, on, off, ...)) + return linestyle[1] + return _DASH_PATTERNS.get(linestyle) + + +# silx symbol -> pygfx marker shape mapping +_SYMBOL_MAP = { + "o": "circle", + ".": "circle", # smaller via size + ",": "square", # pixel + "+": "plus", + "x": "cross", + "d": "diamond", + "s": "square", + "^": "triangle_up", + "v": "triangle_down", + "<": "triangle_left", + ">": "triangle_right", + "*": "asterisk6", +} + + +def _rgbaToGfxColor(color): + """Convert silx RGBA color (4-tuple of 0..1 floats) to pygfx Color.""" + if color is None: + return gfx.Color(1, 1, 1, 1) + if isinstance(color, str): + color = colors.rgba(color) + if len(color) == 3: + return gfx.Color(*color, 1.0) + return gfx.Color(*color) + + +# Item classes ################################################################ + + +class _PygfxCurveItem: + """Manages pygfx scene objects for a single curve.""" + + def __init__( + self, + x, + y, + color, + gapcolor, + symbol, + linewidth, + linestyle, + yaxis, + xerror, + yerror, + fill, + alpha, + symbolsize, + baseline, + ): + self.yaxis = yaxis + self.group = gfx.Group() + self._lineObj = None + self._gapLineObj = None + self._pointsObj = None + self._errorGroup = None + self._fillObj = None + + x = numpy.asarray(x, dtype=numpy.float32) + y = numpy.asarray(y, dtype=numpy.float32) + + # Per-vertex color handling + if isinstance(color, numpy.ndarray) and color.ndim == 2: + perVertexColor = True + vertexColors = numpy.asarray(color, dtype=numpy.float32) + if vertexColors.shape[1] == 3: + vertexColors = numpy.column_stack( + [ + vertexColors, + numpy.full(len(vertexColors), alpha, dtype=numpy.float32), + ] + ) + uniformColor = gfx.Color(1, 1, 1, 1) + else: + perVertexColor = False + vertexColors = None + rgba = colors.rgba(color) + uniformColor = gfx.Color(rgba[0], rgba[1], rgba[2], rgba[3] * alpha) + + # Line + dashPattern = _lineStyleToDashPattern(linestyle) + hasLine = linestyle not in (None, "", " ") + if hasLine and len(x) > 1: + positions = numpy.zeros((len(x), 3), dtype=numpy.float32) + positions[:, 0] = x + positions[:, 1] = y + + lineKwargs = {} + if perVertexColor: + lineKwargs["colors"] = vertexColors + + geom = gfx.Geometry(positions=positions, **lineKwargs) + mat = gfx.LineMaterial( + thickness=max(linewidth, 1.0), + color=uniformColor, + color_mode="vertex" if perVertexColor else "uniform", + dash_pattern=dashPattern if dashPattern else (), + ) + self._lineObj = gfx.Line(geom, mat) + self.group.add(self._lineObj) + + # Gap color line (behind the dashed line via z-offset) + if gapcolor is not None and dashPattern: + gapPositions = positions.copy() + gapPositions[:, 2] = -0.1 # slightly behind + gapRgba = colors.rgba(gapcolor) + gapMat = gfx.LineMaterial( + thickness=max(linewidth, 1.0), + color=gfx.Color(*gapRgba), + ) + self._gapLineObj = gfx.Line( + gfx.Geometry(positions=gapPositions), gapMat + ) + self.group.add(self._gapLineObj) + + # Symbol / Points + hasSymbol = symbol not in (None, "", " ") + if hasSymbol: + positions = numpy.zeros((len(x), 3), dtype=numpy.float32) + positions[:, 0] = x + positions[:, 1] = y + + markerShape = _SYMBOL_MAP.get(symbol, "circle") + pointSize = symbolsize if symbol != "," else 1.0 + if symbol == ".": + pointSize = max(pointSize * 0.5, 1.0) + + pointKwargs = {} + if perVertexColor: + pointKwargs["colors"] = vertexColors + + geom = gfx.Geometry(positions=positions, **pointKwargs) + mat = gfx.PointsMarkerMaterial( + marker=markerShape, + size=pointSize, + color=uniformColor, + color_mode="vertex" if perVertexColor else "uniform", + edge_width=0.5, + edge_color=uniformColor, + ) + self._pointsObj = gfx.Points(geom, mat) + self.group.add(self._pointsObj) + + # Error bars + if xerror is not None or yerror is not None: + self._errorGroup = gfx.Group() + errSegments = self._buildErrorBarSegments(x, y, xerror, yerror) + if len(errSegments) > 0: + errGeom = gfx.Geometry(positions=errSegments.astype(numpy.float32)) + errMat = gfx.LineSegmentMaterial( + thickness=1.0, + color=uniformColor, + ) + errLine = gfx.Line(errGeom, errMat) + self._errorGroup.add(errLine) + self.group.add(self._errorGroup) + + # Fill between curve and baseline + if fill and len(x) >= 2: + self._fillObj = self._buildFill(x, y, baseline, uniformColor, alpha) + if self._fillObj is not None: + self._fillObj.local.z = -0.2 # behind curve line + self.group.add(self._fillObj) + + @staticmethod + def _buildErrorBarSegments(x, y, xerror, yerror): + """Build line segments for error bars.""" + parts = [] + + if yerror is not None: + yerror = numpy.asarray(yerror, dtype=numpy.float64) + if yerror.ndim == 2 and yerror.shape[1] == 1: + yerror = numpy.ravel(yerror) + if yerror.ndim == 0: + yErrMinus = numpy.full_like(y, yerror) + yErrPlus = yErrMinus + elif yerror.ndim == 1: + yErrMinus = yerror + yErrPlus = yerror + else: + yErrMinus = yerror[0] + yErrPlus = yerror[1] + n = len(x) + seg = numpy.empty((n * 2, 3), dtype=numpy.float64) + seg[0::2, 0] = x + seg[0::2, 1] = y - yErrMinus + seg[0::2, 2] = 0 + seg[1::2, 0] = x + seg[1::2, 1] = y + yErrPlus + seg[1::2, 2] = 0 + parts.append(seg) + + if xerror is not None: + xerror = numpy.asarray(xerror, dtype=numpy.float64) + if xerror.ndim == 2 and xerror.shape[1] == 1: + xerror = numpy.ravel(xerror) + if xerror.ndim == 0: + xErrMinus = numpy.full_like(x, xerror) + xErrPlus = xErrMinus + elif xerror.ndim == 1: + xErrMinus = xerror + xErrPlus = xerror + else: + xErrMinus = xerror[0] + xErrPlus = xerror[1] + n = len(x) + seg = numpy.empty((n * 2, 3), dtype=numpy.float64) + seg[0::2, 0] = x - xErrMinus + seg[0::2, 1] = y + seg[0::2, 2] = 0 + seg[1::2, 0] = x + xErrPlus + seg[1::2, 1] = y + seg[1::2, 2] = 0 + parts.append(seg) + + if parts: + return numpy.concatenate(parts) + return numpy.empty((0, 3), dtype=numpy.float64) + + @staticmethod + def _buildFill(x, y, baseline, color, alpha): + """Build a filled mesh between curve and baseline.""" + if baseline is None: + baseY = numpy.zeros_like(y) + elif isinstance(baseline, numpy.ndarray): + baseY = baseline + else: + baseY = numpy.full_like(y, float(baseline)) + + n = len(x) + # Create triangle strip: for each segment, two triangles + vertices = [] + indices = [] + for i in range(n): + vertices.append([x[i], y[i], 0]) + vertices.append([x[i], baseY[i], 0]) + + for i in range(n - 1): + idx = i * 2 + indices.append([idx, idx + 1, idx + 2]) + indices.append([idx + 1, idx + 3, idx + 2]) + + if not indices: + return None + + vertices = numpy.array(vertices, dtype=numpy.float32) + indices = numpy.array(indices, dtype=numpy.int32) + + fillColor = gfx.Color(color.r, color.g, color.b, alpha * 0.5) + geom = gfx.Geometry(positions=vertices, indices=indices) + mat = gfx.MeshBasicMaterial(color=fillColor, side="both") + return gfx.Mesh(geom, mat) + + +def _fastColormapRange(data, colormap): + """Fast colormap range for common cases (avoids slow normalizer pipeline).""" + vmin = colormap.getVMin() + vmax = colormap.getVMax() + if vmin is not None and vmax is not None: + return float(vmin), float(vmax) + + # Fast path for linear + minmax (most common streaming case) + norm = colormap.getNormalization() + mode = colormap.getAutoscaleMode() + if norm == "linear" and mode == "minmax": + if vmin is None: + vmin = float(numpy.nanmin(data)) + if vmax is None: + vmax = float(numpy.nanmax(data)) + if vmin >= vmax: + vmax = vmin + 1.0 + return vmin, vmax + + # Fallback to full pipeline (log, sqrt, percentile, etc.) + return colormap.getColormapRange(data) + + +# GPU colormap helpers ######################################################## + + +def _colormapToLUT(colormap): + """Extract a 256x4 float32 LUT from a silx Colormap. + + :param colormap: silx Colormap object + :returns: (lut, nanColor) where lut is (256, 4) float32 in [0, 1] + and nanColor is (4,) float32 RGBA + """ + lut_u8 = colormap.getNColors(nbColors=256) # (256, 4) uint8 + lut = lut_u8.astype(numpy.float32) / 255.0 + + qNanColor = colormap.getNaNColor() + nanColor = numpy.array( + [qNanColor.redF(), qNanColor.greenF(), qNanColor.blueF(), qNanColor.alphaF()], + dtype=numpy.float32, + ) + return lut, nanColor + + +def _prepareScalarForGPU(data, normalization, vmin, vmax, gamma): + """Apply normalization pre-processing on CPU for GPU colormap path. + + For linear and gamma, no pre-processing is needed — pygfx handles them + natively via clim and gamma material parameters. + For log/sqrt/arcsinh, apply the transform to both data and clim bounds + so that pygfx's linear clim mapping produces the correct result. + + :param data: 2D numpy array (scalar image data) + :param normalization: one of "linear", "log", "sqrt", "gamma", "arcsinh" + :param vmin: colormap lower bound + :param vmax: colormap upper bound + :param gamma: gamma parameter (used only for "gamma" normalization) + :returns: (scalar_data, clim, use_gamma) ready for GPU + scalar_data: float32 2D array + clim: (float, float) for ImageBasicMaterial + use_gamma: float for ImageBasicMaterial.gamma + """ + scalar = numpy.asarray(data, dtype=numpy.float32) + + if normalization == "linear": + return scalar, (float(vmin), float(vmax)), 1.0 + + elif normalization == "gamma": + # pygfx gamma: pow(normalized, 1/gamma), but silx gamma means + # pow(normalized, gamma). So pass 1/gamma to pygfx. + return scalar, (float(vmin), float(vmax)), 1.0 / gamma + + elif normalization == "log": + minPos = max(vmin, FLOAT32_MINPOS) if vmin > 0 else FLOAT32_MINPOS + scalar = numpy.log10(numpy.clip(scalar, minPos, None)) + clim = ( + float(numpy.log10(max(vmin, minPos))), + float(numpy.log10(max(vmax, minPos))), + ) + return scalar, clim, 1.0 + + elif normalization == "sqrt": + scalar = numpy.sqrt(numpy.clip(scalar, 0, None)) + clim = (float(numpy.sqrt(max(vmin, 0))), float(numpy.sqrt(max(vmax, 0)))) + return scalar, clim, 1.0 + + elif normalization == "arcsinh": + scalar = numpy.arcsinh(scalar) + clim = (float(numpy.arcsinh(vmin)), float(numpy.arcsinh(vmax))) + return scalar, clim, 1.0 + + else: + _logger.warning("Unknown normalization %r, using linear", normalization) + return scalar, (float(vmin), float(vmax)), 1.0 + + +def _handleNaN(scalar_data, clim, lut, nanColor): + """Replace NaN pixels with a sentinel value and set LUT[0] to nanColor. + + The sentinel is placed below clim[0] so it maps to LUT index 0. + With wrap='clamp', values below clim[0] also map to LUT[0]. + + :param scalar_data: float32 2D array (may be modified in-place) + :param clim: (vmin, vmax) tuple + :param lut: (256, 4) float32 array (modified in-place) + :param nanColor: (4,) float32 RGBA + :returns: (scalar_data, clim) with sentinel applied + """ + hasNan = numpy.any(numpy.isnan(scalar_data)) + if not hasNan: + return scalar_data, clim + + scalar_data = scalar_data.copy() + vmin, vmax = clim + + # Sentinel: well below vmin so it maps to LUT[0] + rng = vmax - vmin if vmax != vmin else 1.0 + sentinel = vmin - rng * 0.01 + + # Replace NaN with sentinel + nanMask = numpy.isnan(scalar_data) + scalar_data[nanMask] = sentinel + + # Adjust clim so sentinel maps to ~index 0 and vmin maps to ~index 1+ + # LUT[0] = nanColor, LUT[1..255] = original LUT[0..254] + newLut = numpy.empty_like(lut) + newLut[0] = nanColor + # Remap: compress original 256 entries into indices 1..255 + indices = numpy.linspace(0, 255, 255).astype(numpy.int32) + newLut[1:] = lut[indices] + + lut[:] = newLut + + # Expand clim so sentinel→0, vmin→~1/256, vmax→255/256 + newVmin = sentinel + # vmin should map to index ~1 out of 256 + # index = (val - newVmin) / (newVmax - newVmin) * 255 + # For val=vmin, index=1: 1 = (vmin - sentinel) / (newVmax - sentinel) * 255 + # newVmax = sentinel + (vmin - sentinel) * 255 + newVmax = sentinel + (vmax - sentinel) * 256.0 / 255.0 + + return scalar_data, (float(newVmin), float(newVmax)) + + +# WGSL Compute shaders ######################################################## + +_MINMAX_SHADER = """ +struct Params { + num_elements: u32, +} + +@group(0) @binding(0) var input_data: array; +@group(0) @binding(1) var output_data: array; +@group(0) @binding(2) var params: Params; + +var s_min: array; +var s_max: array; +var s_min_pos: array; + +@compute @workgroup_size(256) +fn main(@builtin(local_invocation_id) lid: vec3, + @builtin(workgroup_id) wgid: vec3) { + let tid = lid.x; + let gid = wgid.x * 256u + tid; + let stride = 256u * ((params.num_elements + 255u) / 256u); + + // Initialize with identity values + var local_min: f32 = 3.402823e+38; + var local_max: f32 = -3.402823e+38; + var local_min_pos: f32 = 3.402823e+38; + + // Grid-stride loop: each thread processes multiple elements + var idx = gid; + while (idx < params.num_elements) { + let val = input_data[idx]; + // Skip NaN and Inf + if (!isNan(val) && !isInf(val)) { + local_min = min(local_min, val); + local_max = max(local_max, val); + if (val > 0.0) { + local_min_pos = min(local_min_pos, val); + } + } + idx += stride; + } + + s_min[tid] = local_min; + s_max[tid] = local_max; + s_min_pos[tid] = local_min_pos; + workgroupBarrier(); + + // Tree reduction + var step = 128u; + while (step > 0u) { + if (tid < step) { + s_min[tid] = min(s_min[tid], s_min[tid + step]); + s_max[tid] = max(s_max[tid], s_max[tid + step]); + s_min_pos[tid] = min(s_min_pos[tid], s_min_pos[tid + step]); + } + workgroupBarrier(); + step = step >> 1u; + } + + // Write workgroup result + if (tid == 0u) { + let out_idx = wgid.x * 3u; + output_data[out_idx] = s_min[0]; + output_data[out_idx + 1u] = s_max[0]; + output_data[out_idx + 2u] = s_min_pos[0]; + } +} + +fn isNan(v: f32) -> bool { + return !(v == v); +} + +fn isInf(v: f32) -> bool { + return (v == 3.402823e+38) || (v == -3.402823e+38) || abs(v) > 3.4e+38; +} +""" + +_HISTOGRAM_SHADER = """ +struct Params { + num_elements: u32, + data_min: f32, + data_max: f32, + num_bins: u32, + norm_mode: u32, // 0=linear, 1=log10, 2=sqrt, 3=arcsinh + _pad1: u32, + _pad2: u32, + _pad3: u32, +} + +@group(0) @binding(0) var input_data: array; +@group(0) @binding(1) var histogram: array>; +@group(0) @binding(2) var params: Params; + +fn apply_norm(val: f32, mode: u32) -> f32 { + switch (mode) { + case 1u: { // log10 + if (val <= 0.0) { return -1e30; } + return log2(val) * 0.30102999566; // log2(x) / log2(10) + } + case 2u: { return sqrt(max(val, 0.0)); } // sqrt + case 3u: { return asinh(val); } // arcsinh + default: { return val; } // linear/gamma + } +} + +@compute @workgroup_size(256) +fn main(@builtin(global_invocation_id) gid: vec3) { + let range = params.data_max - params.data_min; + if (range <= 0.0) { + return; + } + + // Grid-stride loop for large data (>65535 workgroups) + let total_threads = 256u * ((params.num_elements + 255u) / 256u); + var idx = gid.x; + while (idx < params.num_elements) { + let val = input_data[idx]; + + // Skip NaN + if (val == val) { + let transformed = apply_norm(val, params.norm_mode); + var normalized = (transformed - params.data_min) / range; + normalized = clamp(normalized, 0.0, 0.999999); + let bin = u32(normalized * f32(params.num_bins)); + atomicAdd(&histogram[bin], 1u); + } + + idx += total_threads; + } +} +""" + +# Normalization mode constants for histogram shader +_HIST_NORM_LINEAR = 0 +_HIST_NORM_LOG = 1 +_HIST_NORM_SQRT = 2 +_HIST_NORM_ARCSINH = 3 + + +class _WgpuComputeHelper: + """GPU compute helper for min/max reduction and histogram computation.""" + + _instance = None + + @classmethod + def get(cls): + """Get or create the singleton compute helper.""" + if cls._instance is None: + try: + cls._instance = cls() + except Exception: + _logger.debug("Failed to create GPU compute helper", exc_info=True) + cls._instance = False # Sentinel: tried and failed + if cls._instance is False: + return None + return cls._instance + + def __init__(self): + adapter = wgpu.gpu.request_adapter_sync(power_preference="high-performance") + self._device = adapter.request_device_sync() + + # Create minmax pipeline + minmax_module = self._device.create_shader_module(code=_MINMAX_SHADER) + self._minmax_bgl = self._device.create_bind_group_layout( + entries=[ + { + "binding": 0, + "visibility": wgpu.ShaderStage.COMPUTE, + "buffer": {"type": "read-only-storage"}, + }, + { + "binding": 1, + "visibility": wgpu.ShaderStage.COMPUTE, + "buffer": {"type": "storage"}, + }, + { + "binding": 2, + "visibility": wgpu.ShaderStage.COMPUTE, + "buffer": {"type": "uniform"}, + }, + ] + ) + self._minmax_pipeline = self._device.create_compute_pipeline( + layout=self._device.create_pipeline_layout( + bind_group_layouts=[self._minmax_bgl] + ), + compute={"module": minmax_module, "entry_point": "main"}, + ) + + # Create histogram pipeline + hist_module = self._device.create_shader_module(code=_HISTOGRAM_SHADER) + self._hist_bgl = self._device.create_bind_group_layout( + entries=[ + { + "binding": 0, + "visibility": wgpu.ShaderStage.COMPUTE, + "buffer": {"type": "read-only-storage"}, + }, + { + "binding": 1, + "visibility": wgpu.ShaderStage.COMPUTE, + "buffer": {"type": "storage"}, + }, + { + "binding": 2, + "visibility": wgpu.ShaderStage.COMPUTE, + "buffer": {"type": "uniform"}, + }, + ] + ) + self._hist_pipeline = self._device.create_compute_pipeline( + layout=self._device.create_pipeline_layout( + bind_group_layouts=[self._hist_bgl] + ), + compute={"module": hist_module, "entry_point": "main"}, + ) + + def compute_minmax(self, data): + """Compute (min, minPositive, max) using GPU reduction. + + :param data: numpy array (will be flattened to float32) + :returns: (min, minPositive, max) tuple of floats, or None on failure + """ + flat = numpy.ascontiguousarray(data.ravel(), dtype=numpy.float32) + num_elements = len(flat) + if num_elements == 0: + return None + + workgroup_size = 256 + num_workgroups = min( + (num_elements + workgroup_size - 1) // workgroup_size, 65535 + ) + + # Input buffer + input_buf = self._device.create_buffer_with_data( + data=flat.tobytes(), + usage=wgpu.BufferUsage.STORAGE, + ) + + # Output buffer: 3 floats per workgroup (min, max, minPos) + output_size = num_workgroups * 3 * 4 # float32 + output_buf = self._device.create_buffer( + size=output_size, + usage=wgpu.BufferUsage.STORAGE | wgpu.BufferUsage.COPY_SRC, + ) + + # Params uniform + params = numpy.array([num_elements], dtype=numpy.uint32) + params_buf = self._device.create_buffer_with_data( + data=params.tobytes(), + usage=wgpu.BufferUsage.UNIFORM, + ) + + # Bind group + bind_group = self._device.create_bind_group( + layout=self._minmax_bgl, + entries=[ + {"binding": 0, "resource": {"buffer": input_buf}}, + {"binding": 1, "resource": {"buffer": output_buf}}, + {"binding": 2, "resource": {"buffer": params_buf}}, + ], + ) + + # Dispatch + encoder = self._device.create_command_encoder() + compute_pass = encoder.begin_compute_pass() + compute_pass.set_pipeline(self._minmax_pipeline) + compute_pass.set_bind_group(0, bind_group) + compute_pass.dispatch_workgroups(num_workgroups) + compute_pass.end() + + # Readback + readback_buf = self._device.create_buffer( + size=output_size, + usage=wgpu.BufferUsage.COPY_DST | wgpu.BufferUsage.MAP_READ, + ) + encoder.copy_buffer_to_buffer(output_buf, 0, readback_buf, 0, output_size) + self._device.queue.submit([encoder.finish()]) + + readback_buf.map_sync(wgpu.MapMode.READ) + result_bytes = readback_buf.read_mapped() + result = numpy.frombuffer(result_bytes, dtype=numpy.float32).copy() + readback_buf.unmap() + + # CPU final reduction of per-workgroup results + result = result.reshape(-1, 3) + final_min = float(numpy.min(result[:, 0])) + final_max = float(numpy.max(result[:, 1])) + min_pos_vals = result[:, 2] + valid_pos = min_pos_vals[min_pos_vals < 3.4e38] + final_min_pos = ( + float(numpy.min(valid_pos)) if len(valid_pos) > 0 else float("inf") + ) + + # Clean up + input_buf.destroy() + output_buf.destroy() + params_buf.destroy() + readback_buf.destroy() + + return (final_min, final_min_pos, final_max) + + def compute_histogram(self, data, data_min, data_max, num_bins=256, norm_mode=0): + """Compute histogram using GPU atomic operations. + + :param data: numpy array (will be flattened to float32) + :param data_min: histogram lower bound (in normalized space) + :param data_max: histogram upper bound (in normalized space) + :param num_bins: number of bins (default 256) + :param norm_mode: 0=linear, 1=log10, 2=sqrt, 3=arcsinh + :returns: (counts, bin_edges) or None on failure + """ + flat = numpy.ascontiguousarray(data.ravel(), dtype=numpy.float32) + num_elements = len(flat) + if num_elements == 0: + return None + + workgroup_size = 256 + num_workgroups = min( + (num_elements + workgroup_size - 1) // workgroup_size, 65535 + ) + + # Input buffer + input_buf = self._device.create_buffer_with_data( + data=flat.tobytes(), + usage=wgpu.BufferUsage.STORAGE, + ) + + # Histogram buffer (zero-initialized) + hist_size = num_bins * 4 # uint32 + hist_buf = self._device.create_buffer( + size=hist_size, + usage=wgpu.BufferUsage.STORAGE | wgpu.BufferUsage.COPY_SRC, + ) + + # Params uniform (8 x u32/f32 = 32 bytes, matches Params struct) + params = numpy.zeros(8, dtype=numpy.float32) + params_view = params.view(numpy.uint32) + params_view[0] = num_elements + params[1] = numpy.float32(data_min) + params[2] = numpy.float32(data_max) + params_view[3] = num_bins + params_view[4] = norm_mode + # [5], [6], [7] = padding + params_buf = self._device.create_buffer_with_data( + data=params.tobytes(), + usage=wgpu.BufferUsage.UNIFORM, + ) + + # Bind group + bind_group = self._device.create_bind_group( + layout=self._hist_bgl, + entries=[ + {"binding": 0, "resource": {"buffer": input_buf}}, + {"binding": 1, "resource": {"buffer": hist_buf}}, + {"binding": 2, "resource": {"buffer": params_buf}}, + ], + ) + + # Dispatch + encoder = self._device.create_command_encoder() + compute_pass = encoder.begin_compute_pass() + compute_pass.set_pipeline(self._hist_pipeline) + compute_pass.set_bind_group(0, bind_group) + compute_pass.dispatch_workgroups(num_workgroups) + compute_pass.end() + + # Readback + readback_buf = self._device.create_buffer( + size=hist_size, + usage=wgpu.BufferUsage.COPY_DST | wgpu.BufferUsage.MAP_READ, + ) + encoder.copy_buffer_to_buffer(hist_buf, 0, readback_buf, 0, hist_size) + self._device.queue.submit([encoder.finish()]) + + readback_buf.map_sync(wgpu.MapMode.READ) + result_bytes = readback_buf.read_mapped() + counts = numpy.frombuffer(result_bytes, dtype=numpy.uint32).copy() + readback_buf.unmap() + + bin_edges = numpy.linspace(data_min, data_max, num_bins + 1) + + # Clean up + input_buf.destroy() + hist_buf.destroy() + params_buf.destroy() + readback_buf.destroy() + + return (counts, bin_edges) + + +# Async compute for streaming ################################################## + + +class _AsyncCompute: + """Non-blocking async computation for streaming image data. + + The render thread never blocks. It submits data and reads the latest + completed result. A single worker thread processes requests, always + skipping to the newest data (stale requests are dropped). + + Stats and histogram are computed on full data (no subsampling needed + since computation runs off the render thread). GPU histogram at + 4096x4096 sustains ~50Hz; stats use optimized CPU (~20Hz for 4K). + """ + + def __init__(self, gpu_compute=None): + self._gpu_compute = gpu_compute + + # Latest results (read by render thread) + self._stats_result = None + self._hist_result = None + + # Pending requests (written by render thread, read by worker) + self._pending_stats_data = None + self._pending_hist_request = None # (data, data_min, data_max, num_bins) + + # Lock protects pending slots and results + self._lock = threading.Lock() + + # Worker thread + self._running = True + self._event = threading.Event() # Signals new work available + self._thread = threading.Thread(target=self._worker, daemon=True) + self._thread.start() + + def shutdown(self): + """Stop the worker thread.""" + self._running = False + self._event.set() + + def submit_stats(self, data): + """Submit data for async stats computation. Non-blocking. + + :param data: numpy array (a reference is kept until processed) + """ + with self._lock: + self._pending_stats_data = data + self._event.set() + + def submit_histogram(self, data, data_min, data_max, num_bins=256, norm_mode=0): + """Submit data for async histogram computation. Non-blocking. + + :param data: numpy array + :param data_min: histogram lower bound (in normalized space) + :param data_max: histogram upper bound (in normalized space) + :param num_bins: number of bins + :param norm_mode: 0=linear, 1=log10, 2=sqrt, 3=arcsinh + """ + with self._lock: + self._pending_hist_request = (data, data_min, data_max, num_bins, norm_mode) + self._event.set() + + def get_stats(self): + """Return the latest computed stats, or None. Non-blocking.""" + return self._stats_result + + def get_histogram(self): + """Return the latest computed histogram, or None. Non-blocking.""" + return self._hist_result + + def invalidate(self): + """Clear cached results (e.g., when colormap changes).""" + with self._lock: + self._stats_result = None + self._hist_result = None + + def _worker(self): + """Worker thread: processes the latest pending request.""" + while self._running: + self._event.wait() + self._event.clear() + + if not self._running: + break + + # Grab latest pending work (drop stale) + with self._lock: + stats_data = self._pending_stats_data + self._pending_stats_data = None + hist_req = self._pending_hist_request + self._pending_hist_request = None + + # Process stats + if stats_data is not None: + result = self._compute_stats(stats_data) + if result is not None: + self._stats_result = result + + # Process histogram + if hist_req is not None: + data, dmin, dmax, nbins, nmode = hist_req + result = self._compute_histogram(data, dmin, dmax, nbins, nmode) + if result is not None: + self._hist_result = result + + def _compute_stats(self, data): + """Compute (min, minPositive, max) on full data.""" + try: + data = numpy.asarray(data) + if data.size == 0: + return None + min_ = float(numpy.nanmin(data)) + max_ = float(numpy.nanmax(data)) + if not (numpy.isfinite(min_) and numpy.isfinite(max_)): + return None + if min_ > 0: + minPositive = min_ + else: + pos = data[data > 0] + minPositive = float(numpy.min(pos)) if len(pos) > 0 else float("inf") + return (min_, minPositive, max_) + except Exception: + return None + + def _compute_histogram(self, data, data_min, data_max, num_bins, norm_mode=0): + """Compute histogram, preferring GPU.""" + try: + # GPU path + if self._gpu_compute is not None: + result = self._gpu_compute.compute_histogram( + data, data_min, data_max, num_bins, norm_mode=norm_mode + ) + if result is not None: + return result + # CPU fallback + data = numpy.asarray(data) + flat = data.ravel() + finite = flat[numpy.isfinite(flat)] + counts, edges = numpy.histogram( + finite, bins=num_bins, range=(data_min, data_max) + ) + return (counts, edges) + except Exception: + return None + + +# Image item ################################################################## + + +class _PygfxImageItem: + """Manages pygfx scene objects for a single image.""" + + def __init__(self, data, origin, scale, colormap, alpha): + self.group = gfx.Group() + self.yaxis = "left" + self._imageObj = None + self._scalarShape = None + self._cmapName = None + self._cmapTexture = None + self._gpuColormapInfo = None # Set when using GPU colormap path + self._origin = origin + self._scale = scale + self._dataShape = numpy.asarray(data).shape[:2] + + self._build(data, origin, scale, colormap, alpha) + + def _build(self, data, origin, scale, colormap, alpha): + data = numpy.asarray(data) + self._origin = origin + self._scale = scale + self._dataShape = data.shape[:2] + + if data.ndim == 2: + self._buildScalar(data, origin, scale, colormap, alpha) + elif data.ndim == 3 and data.shape[2] in (3, 4): + self._buildRGBA(data, origin, scale, alpha) + else: + _logger.warning("Unsupported image data shape: %s", data.shape) + + def _buildScalar(self, data, origin, scale, colormap, alpha): + self._scalarShape = data.shape + + # Data: upload scalar float32 directly (no CPU colormap) + if data.dtype == numpy.float32 and data.flags["C_CONTIGUOUS"]: + scalarData = data + else: + scalarData = numpy.ascontiguousarray(data, dtype=numpy.float32) + + # Range: fast path for linear+minmax + if colormap is not None: + vmin, vmax = _fastColormapRange(data, colormap) + cmapTex = self._getOrCreateCmapTexture(colormap, alpha) + else: + vmin = float(numpy.nanmin(data)) + vmax = float(numpy.nanmax(data)) + if vmin == vmax: + vmax = vmin + 1.0 + cmapTex = None + + if self._imageObj is None: + # First time: create GPU objects + tex = gfx.Texture(scalarData, dim=2) + geom = gfx.Geometry(grid=tex) + mat = gfx.ImageBasicMaterial( + clim=(vmin, vmax), + map=cmapTex, + interpolation="nearest", + ) + self._imageObj = gfx.Image(geom, mat) + self.group.add(self._imageObj) + else: + # Reuse: update texture data + clim (no GPU object creation) + self._imageObj.geometry.grid.set_data(scalarData) + self._imageObj.material.clim = (vmin, vmax) + if cmapTex is not None: + self._imageObj.material.map = cmapTex + + ox, oy = origin + sx, sy = scale + self._imageObj.local.position = (ox, oy, 0) + self._imageObj.local.scale = (sx, sy, 1) + + def updateData(self, data, clim=None): + """Fast path: update only the texture data (no item system overhead). + + Requires the image object to already exist and data shape to match. + + :param data: New image data (2D array) + :param clim: (vmin, vmax) tuple for color limits, or None to compute from data + """ + if self._imageObj is None: + return + if data.dtype == numpy.float32 and data.flags["C_CONTIGUOUS"]: + scalarData = data + else: + scalarData = numpy.ascontiguousarray(data, dtype=numpy.float32) + self._imageObj.geometry.grid.set_data(scalarData) + if clim is None: + dmin = float(numpy.nanmin(data)) + dmax = float(numpy.nanmax(data)) + if dmin >= dmax: + dmax = dmin + 1.0 + clim = (dmin, dmax) + + self._imageObj.material.clim = clim + + def _buildRGBA(self, data, origin, scale, alpha): + self._scalarShape = None + + if data.dtype == numpy.float64: + data = data.astype(numpy.float32) + if data.dtype in (numpy.float32, numpy.float64): + rgbaData = (numpy.clip(data, 0, 1) * 255).astype(numpy.uint8) + else: + rgbaData = numpy.asarray(data, dtype=numpy.uint8) + if rgbaData.shape[2] == 3: + alphaChannel = numpy.full(rgbaData.shape[:2] + (1,), 255, dtype=numpy.uint8) + rgbaData = numpy.concatenate([rgbaData, alphaChannel], axis=-1) + + rgbaFloat = rgbaData.astype(numpy.float32) / 255.0 + if alpha < 1.0: + rgbaFloat = rgbaFloat.copy() + rgbaFloat[:, :, 3] *= alpha + rgbaFloat = numpy.ascontiguousarray(rgbaFloat) + + geom = gfx.Geometry(grid=gfx.Texture(rgbaFloat, dim=2)) + mat = gfx.ImageBasicMaterial(interpolation="nearest") + self._imageObj = gfx.Image(geom, mat) + self.group.add(self._imageObj) + + ox, oy = origin + sx, sy = scale + self._imageObj.local.position = (ox, oy, 0) + self._imageObj.local.scale = (sx, sy, 1) + + def _getOrCreateCmapTexture(self, colormap, alpha): + """Cache colormap LUT texture, recreate only when colormap changes.""" + name = colormap.getName() + if name == self._cmapName and self._cmapTexture is not None: + return self._cmapTexture + + lut = colormap.getNColors() # (256, 4) uint8 RGBA + lutFloat = lut.astype(numpy.float32) / 255.0 + if alpha < 1.0: + lutFloat = lutFloat.copy() + lutFloat[:, 3] *= alpha + self._cmapTexture = gfx.Texture(lutFloat, dim=1) + self._cmapName = name + return self._cmapTexture + + def _initGPUColormap(self, data, origin, scale, colormap, alpha): + """Initialize image using GPU-native colormap rendering. + + Uploads scalar data as a 1-channel texture and uses pygfx's + ImageBasicMaterial.map for GPU-side colormap application. + """ + normalization = colormap.getNormalization() + cmapRange = colormap.getColormapRange(data) + vmin, vmax = cmapRange + gamma = colormap.getGammaNormalizationParameter() + + # 1. Normalization pre-processing + scalar_data, clim, use_gamma = _prepareScalarForGPU( + data, normalization, vmin, vmax, gamma + ) + + # 2. Build LUT and handle NaN + lut, nanColor = _colormapToLUT(colormap) + scalar_data, clim = _handleNaN(scalar_data, clim, lut, nanColor) + + # 3. Apply alpha to LUT + if alpha < 1.0: + lut = lut.copy() + lut[:, 3] *= alpha + + # 4. Create GPU objects + scalar_data = numpy.ascontiguousarray(scalar_data) + lut_tex = gfx.Texture(lut, dim=1) + cmap_map = gfx.TextureMap(lut_tex, filter="nearest", wrap="clamp") + + geom = gfx.Geometry(grid=gfx.Texture(scalar_data, dim=2)) + mat = gfx.ImageBasicMaterial( + map=cmap_map, + clim=clim, + gamma=use_gamma, + interpolation="nearest", + ) + self._imageObj = gfx.Image(geom, mat) + + # Position and scale + ox, oy = origin + sx, sy = scale + self._imageObj.local.position = (ox, oy, 0) + self._imageObj.local.scale = (sx, sy, 1) + + self.group.add(self._imageObj) + + # Store info for dynamic updates (clim/LUT changes without re-upload) + self._gpuColormapInfo = { + "material": mat, + "lut_texture": lut_tex, + "normalization": normalization, + "vmin": vmin, + "vmax": vmax, + } + + +class _PygfxTrianglesItem: + """Manages pygfx scene objects for triangles.""" + + def __init__(self, x, y, triangles, color, alpha): + self.group = gfx.Group() + self.yaxis = "left" + + x = numpy.asarray(x, dtype=numpy.float32) + y = numpy.asarray(y, dtype=numpy.float32) + triangles = numpy.asarray(triangles, dtype=numpy.int32) + + self._x = x + self._y = y + self._triangles = triangles + + positions = numpy.zeros((len(x), 3), dtype=numpy.float32) + positions[:, 0] = x + positions[:, 1] = y + + color = numpy.asarray(color, dtype=numpy.float32) + if color.ndim == 2: + if color.shape[1] == 3: + color = numpy.column_stack( + [color, numpy.full(len(color), alpha, dtype=numpy.float32)] + ) + geom = gfx.Geometry(positions=positions, indices=triangles, colors=color) + mat = gfx.MeshBasicMaterial(color_mode="vertex", side="both") + else: + rgba = colors.rgba(color) + geom = gfx.Geometry(positions=positions, indices=triangles) + mat = gfx.MeshBasicMaterial( + color=gfx.Color(rgba[0], rgba[1], rgba[2], rgba[3] * alpha), + side="both", + ) + + self._meshObj = gfx.Mesh(geom, mat) + self.group.add(self._meshObj) + + +class _PygfxShapeItem(dict): + """Manages pygfx scene objects for shapes.""" + + def __init__( + self, + x, + y, + shape, + color, + fill, + overlay, + linewidth, + linestyle, + gapcolor, + ): + super().__init__() + + if shape not in ("polygon", "rectangle", "line", "vline", "hline", "polylines"): + raise NotImplementedError(f"Unsupported shape {shape}") + + x = numpy.asarray(x, dtype=numpy.float32) + y = numpy.asarray(y, dtype=numpy.float32) + + if shape == "rectangle": + xMin, xMax = x + x = numpy.array((xMin, xMin, xMax, xMax), dtype=numpy.float32) + yMin, yMax = y + y = numpy.array((yMin, yMax, yMax, yMin), dtype=numpy.float32) + + fill = fill if shape != "polylines" else False + + rgba = colors.rgba(color) + dashPattern = _lineStyleToDashPattern(linestyle) + + self.update( + { + "shape": shape, + "color": rgba, + "fill": fill, + "x": x, + "y": y, + "linewidth": linewidth, + "overlay": overlay, + } + ) + + self.group = gfx.Group() + + gfxColor = gfx.Color(*rgba) + + # Build outline + if shape in ("polygon", "rectangle"): + positions = numpy.zeros((len(x) + 1, 3), dtype=numpy.float32) + positions[:-1, 0] = x + positions[:-1, 1] = y + positions[-1, 0] = x[0] + positions[-1, 1] = y[0] + elif shape == "polylines": + positions = numpy.zeros((len(x), 3), dtype=numpy.float32) + positions[:, 0] = x + positions[:, 1] = y + elif shape in ("line", "hline", "vline"): + positions = numpy.zeros((len(x), 3), dtype=numpy.float32) + positions[:, 0] = x + positions[:, 1] = y + else: + positions = numpy.zeros((len(x), 3), dtype=numpy.float32) + positions[:, 0] = x + positions[:, 1] = y + + if len(positions) >= 2: + # Gap color line: solid line behind the dashed foreground line. + # Must be at a lower z to pass the strict '<' depth test. + if gapcolor is not None and dashPattern: + gapPositions = positions.copy() + gapPositions[:, 2] = -0.1 # slightly behind + gapRgba = colors.rgba(gapcolor) + gapMat = gfx.LineMaterial( + thickness=max(linewidth, 1.0), + color=gfx.Color(*gapRgba), + ) + gapLineObj = gfx.Line(gfx.Geometry(positions=gapPositions), gapMat) + self.group.add(gapLineObj) + + # Foreground line (dashed or solid) at z=0 (in front of gap line) + geom = gfx.Geometry(positions=positions) + mat = gfx.LineMaterial( + thickness=max(linewidth, 1.0), + color=gfxColor, + dash_pattern=dashPattern if dashPattern else (), + ) + lineObj = gfx.Line(geom, mat) + self.group.add(lineObj) + + # Build fill for closed shapes + if fill and shape in ("polygon", "rectangle") and len(x) >= 3: + fillObj = self._buildPolygonFill(x, y, rgba) + if fillObj is not None: + fillObj.local.z = -0.2 # behind lines + self.group.add(fillObj) + + @staticmethod + def _buildPolygonFill(x, y, rgba): + """Create a semi-transparent polygon fill using a triangle fan mesh.""" + n = len(x) + if n < 3: + return None + + # Sort vertices by angle from centroid to avoid bowtie patterns + cx, cy = numpy.nanmean(x), numpy.nanmean(y) + angles = numpy.arctan2(y - cy, x - cx) + order = numpy.argsort(angles) + x = x[order] + y = y[order] + + # Triangle fan from vertex 0 + positions = numpy.zeros((n, 3), dtype=numpy.float32) + positions[:, 0] = x + positions[:, 1] = y + + indices = numpy.zeros(((n - 2), 3), dtype=numpy.uint32) + for i in range(n - 2): + indices[i] = [0, i + 1, i + 2] + + fillColor = gfx.Color(rgba[0], rgba[1], rgba[2], 0.3) + geom = gfx.Geometry(indices=indices, positions=positions) + mat = gfx.MeshBasicMaterial( + color=fillColor, + side="both", + depth_write=False, + ) + return gfx.Mesh(geom, mat) + + +class _PygfxMarkerItem(dict): + """Manages pygfx scene objects for markers.""" + + def __init__( + self, + x, + y, + text, + color, + symbol, + symbolsize, + linewidth, + linestyle, + constraint, + yaxis, + font, + bgcolor, + ): + super().__init__() + + if symbol is None: + symbol = "+" + + # Apply constraint + isConstraint = constraint is not None and x is not None and y is not None + if isConstraint: + x, y = constraint(x, y) + + dashPattern = _lineStyleToDashPattern(linestyle) + + self.update( + { + "x": x, + "y": y, + "text": text, + "color": colors.rgba(color), + "constraint": constraint if isConstraint else None, + "symbol": symbol, + "symbolsize": symbolsize, + "linewidth": linewidth, + "linestyle": linestyle, + "dashpattern": dashPattern, + "yaxis": yaxis, + "font": font, + "bgcolor": bgcolor, + } + ) + + self.group = gfx.Group() + self._lineObj = None + self._textObj = None + rgba = colors.rgba(color) + gfxColor = gfx.Color(*rgba) + + if x is not None and y is not None: + # Point marker + positions = numpy.array([[x, y, 0]], dtype=numpy.float32) + markerShape = _SYMBOL_MAP.get(symbol, "plus") + geom = gfx.Geometry(positions=positions) + mat = gfx.PointsMarkerMaterial( + marker=markerShape, + size=symbolsize, + color=gfxColor, + edge_width=1.0, + edge_color=gfxColor, + ) + self._pointsObj = gfx.Points(geom, mat) + self.group.add(self._pointsObj) + + +# BackendPygfx ################################################################ + + +class BackendPygfx(BackendBase.BackendBase, QRenderWidget): + """pygfx/WGPU-based Plot backend. + + Uses pygfx for GPU-accelerated rendering via WGPU (Vulkan/Metal/DX12). + """ + + _TEXT_MARKER_PADDING = 4 + VSYNC = True + """Enable VSync (default True). Set to False before creating the plot + to unlock frame rates beyond the monitor refresh rate.""" + + PRESENT_METHOD = "screen" + """Present method for rendering. "screen" uses direct GPU rendering + (~3x faster), "image" uses CPU readback (works with remote desktops). + Set before creating the plot.""" + + def __init__(self, plot, parent=None): + QRenderWidget.__init__( + self, + parent=parent, + present_method=self.PRESENT_METHOD, + vsync=self.VSYNC, + ) + BackendBase.BackendBase.__init__(self, plot, parent) + + # Match OpenGLWidget: a layout is needed for Qt to respect sizeHint + layout = qt.QHBoxLayout(self) + layout.setContentsMargins(0, 0, 0, 0) + self.setLayout(layout) + + # Accept mouse events without requiring focus first (match OpenGL backend) + self.setFocusPolicy(qt.Qt.NoFocus) + + # Raise max FPS for responsive interaction (zoom, pan, drag) + self.set_update_mode("ondemand", max_fps=240) + + self._defaultFont = None + + self._backgroundColor = (1.0, 1.0, 1.0, 1.0) + self._dataBackgroundColor = (1.0, 1.0, 1.0, 1.0) + + self._keepDataAspectRatio = False + self._crosshairCursor = None + self._mousePosInPixels = None + + # pygfx rendering objects + self._renderer = gfx.WgpuRenderer(self, pixel_ratio=4) + self._scene = gfx.Scene() + + # Camera: orthographic for 2D plotting + self._camera = gfx.OrthographicCamera(640, 480, maintain_aspect=False) + + # Scene hierarchy + self._bgGroup = gfx.Group() + self._dataGroup = gfx.Group() + self._overlayGroup = gfx.Group() + self._frameGroup = gfx.Group() + + # Shift overlays forward in z so they always render in front of data. + # Camera z-range is wide (near=-100..far=100), so z=10 is safe. + self._overlayGroup.local.z = 10 + + self._scene.add(self._bgGroup) + self._scene.add(self._dataGroup) + self._scene.add(self._overlayGroup) + self._scene.add(self._frameGroup) + + # PlotFrame for coordinate transforms + self._plotFrame = PlotFrame2DCore( + foregroundColor=(0.0, 0.0, 0.0, 1.0), + gridColor=(0.7, 0.7, 0.7, 1.0), + marginRatios=(0.15, 0.1, 0.1, 0.15), + font=self._getDefaultFont(), + ) + self._plotFrame.size = ( + int(self.getDevicePixelRatio() * 640), + int(self.getDevicePixelRatio() * 480), + ) + + # Screen-space scene for frame/axes rendering (PR 9) + self._screenScene = gfx.Scene() + self._screenBg = gfx.Background( + None, gfx.BackgroundMaterial(gfx.Color(1, 1, 1, 1)) + ) + self._screenScene.add(self._screenBg) + self._screenFrameGroup = gfx.Group() + self._screenScene.add(self._screenFrameGroup) + self._screenCamera = gfx.OrthographicCamera(maintain_aspect=False) + self._cachedBgColor = (1.0, 1.0, 1.0, 1.0) + + # Frame rendering objects (populated by _updateFrame) + self._frameLines = None + self._gridLines = None + self._frameTexts = [] + self._titleText = None + + # Crosshair cursor lines + self._crosshairHLine = None + self._crosshairVLine = None + + self._reusableImageItem = None # Pool for image item reuse + + # GPU compute helper (lazy singleton) + self._gpuCompute = None + + # Async compute for streaming (non-blocking stats/histogram) + self._asyncCompute = None # Lazy init + + self.request_draw(self._draw) + self.setAutoFillBackground(False) + self.setMouseTracking(True) + + def _getGpuCompute(self): + """Get or create the GPU compute helper (lazy initialization).""" + if self._gpuCompute is None: + self._gpuCompute = _WgpuComputeHelper.get() + return self._gpuCompute + + def _getAsyncCompute(self): + """Get or create the async compute helper.""" + if self._asyncCompute is None: + self._asyncCompute = _AsyncCompute(gpu_compute=self._getGpuCompute()) + return self._asyncCompute + + def _computeGpuDataStats(self, data): + """Submit data for async stats computation and return latest result. + + Called from ColormapMixIn._setColormappedData() to pre-fill the + autoscale range cache. Non-blocking: submits work to background + thread and returns the most recent completed result. + + :param data: numpy array + :returns: (min, minPositive, max) or None + """ + if data is None: + return None + ac = self._getAsyncCompute() + ac.submit_stats(data) + return ac.get_stats() + + def _computeGpuHistogram(self, data, data_min, data_max, num_bins=256, norm_mode=0): + """Submit data for async histogram computation and return latest result. + + Non-blocking: submits work to background thread and returns + the most recent completed histogram. + + :param data: numpy array + :param data_min: histogram lower bound (in normalized space) + :param data_max: histogram upper bound (in normalized space) + :param num_bins: number of bins + :param norm_mode: 0=linear, 1=log10, 2=sqrt, 3=arcsinh + :returns: (counts, bin_edges) or None + """ + ac = self._getAsyncCompute() + ac.submit_histogram(data, data_min, data_max, num_bins, norm_mode) + return ac.get_histogram() + + def _getDefaultFont(self): + if self._defaultFont is None: + app = qt.QApplication.instance() + if app is not None: + self._defaultFont = app.font() + else: + self._defaultFont = qt.QFont() + return self._defaultFont + + def getDevicePixelRatio(self): + return self.devicePixelRatioF() + + def getDotsPerInch(self): + screen = self.screen() + if screen is not None: + return screen.logicalDotsPerInch() * self.getDevicePixelRatio() + return 92 + + # Drawing ############################################################### + + def _draw(self): + plot = self._plotRef() + if plot is None: + return + + with plot._paintContext(): + self._syncPlotFrame() + self._syncCamera() + self._updateFrame() + self._updateMarkers() + self._updateCrosshair() + + # First pass: render frame (background + axes) in full widget + self._renderer.render(self._screenScene, self._screenCamera, flush=False) + + # Second pass: render data scene in plot area viewport only + dpr = self.getDevicePixelRatio() + left, top = self._plotFrame.plotOrigin + pw, ph = self._plotFrame.plotSize + # Viewport rect is in logical pixels + plotRect = (left / dpr, top / dpr, pw / dpr, ph / dpr) + self._renderer.render( + self._scene, self._camera, rect=plotRect, flush=True, clear=False + ) + + def _syncPlotFrame(self): + """Sync plot frame size with widget size.""" + dpr = self.getDevicePixelRatio() + w = int(self.width() * dpr) + h = int(self.height() * dpr) + if (w, h) != self._plotFrame.size: + self._plotFrame.size = (w, h) + self._plotFrame.devicePixelRatio = dpr + self._plotFrame.dotsPerInch = self.getDotsPerInch() + + def _syncCamera(self): + """Update camera to match the current data ranges.""" + trRanges = self._plotFrame.transformedDataRanges + xMin, xMax = trRanges.x + yMin, yMax = trRanges.y + + if self._plotFrame.isXAxisInverted: + xMin, xMax = xMax, xMin + if self._plotFrame.isYAxisInverted: + yMin, yMax = yMax, yMin + + # Ensure non-zero extent to avoid camera errors + if xMin == xMax: + xMin -= 0.5 + xMax += 0.5 + if yMin == yMax: + yMin -= 0.5 + yMax += 0.5 + + # show_rect(left, right, top, bottom) + # height = bottom - top; positive height means Y increases upward + # top=yMin, bottom=yMax → yMax at top of viewport, yMin at bottom + extent = max(abs(xMax - xMin), abs(yMax - yMin), 1.0) + self._camera.show_rect(xMin, xMax, yMin, yMax, depth=extent) + + # Populate projection matrix caches so isDirty returns False + # (pygfx doesn't use OpenGL projection matrices, but isDirty checks them) + _ = self._plotFrame.transformedDataProjMat + _ = self._plotFrame.transformedDataY2ProjMat + + def _updateFrame(self): + """Update axes, ticks, grid, labels in screen space.""" + # Update background color only when changed + bgColor = self._backgroundColor + if self._cachedBgColor != bgColor: + self._screenBg.material = gfx.BackgroundMaterial(gfx.Color(*bgColor)) + self._cachedBgColor = bgColor + + if not self._plotFrame.isDirty: + return # Frame unchanged, keep cached objects + + # Clear previous frame objects (frame group only, not markers/crosshair) + for child in list(self._screenFrameGroup.children): + self._screenFrameGroup.remove(child) + + if self._plotFrame.margins == self._plotFrame._NoDisplayMargins: + return + + w, h = self._plotFrame.size + if w <= 2 or h <= 2: + return + + # Set screen camera to pixel coordinates (Y=0 at top, Y=h at bottom) + # show_rect(left, right, top, bottom): + # PlotFrameCore uses Y=0=top, Y=h=bottom (pixel convention) + # In pygfx: height = bottom - top, so top=h, bottom=0 flips Y axis + extent = max(w, h, 1.0) + self._screenCamera.show_rect(0, w, h, 0, depth=extent) + + # Build vertices and labels from the core + vertices, gridVertices, labelDicts = self._plotFrame._buildVerticesAndLabels() + self._plotFrame._clearDirty() + + # Render grid lines + if len(gridVertices) >= 2: + gridColor = gfx.Color(*self._plotFrame.gridColor) + geom = gfx.Geometry( + positions=numpy.column_stack( + [gridVertices, numpy.zeros(len(gridVertices), dtype=numpy.float32)] + ) + ) + mat = gfx.LineSegmentMaterial(thickness=1.0, color=gridColor) + gridLine = gfx.Line(geom, mat) + self._screenFrameGroup.add(gridLine) + + # Render frame lines (axes) + if len(vertices) >= 2: + fgColor = gfx.Color(*self._plotFrame.foregroundColor) + geom = gfx.Geometry( + positions=numpy.column_stack( + [vertices, numpy.zeros(len(vertices), dtype=numpy.float32)] + ) + ) + mat = gfx.LineSegmentMaterial(thickness=1.0, color=fgColor) + frameLine = gfx.Line(geom, mat) + self._screenFrameGroup.add(frameLine) + + # Render text labels (tick labels, axis titles, main title) + for labelDict in labelDicts: + text = labelDict.get("text", "") + if not text: + continue + # Strip matplotlib LaTeX formatting + text = _stripMathDefault(text) + + lx = labelDict["x"] + ly = labelDict["y"] + labelColor = labelDict.get("color", (0, 0, 0, 1)) + rotate = labelDict.get("rotate", 0) + + # Map alignment strings to pygfx anchor + align = labelDict.get("align", "center") + valign = labelDict.get("valign", "center") + anchor = self._mapAnchor(align, valign) + + fontSize = 12.0 + font = labelDict.get("font") + if font is not None: + ps = font.pointSizeF() + if ps > 0: + fontSize = ps + else: + px = font.pixelSize() + if px > 0: + fontSize = px * 72.0 / self.getDotsPerInch() + + textObj = gfx.Text( + text=text, + material=gfx.TextMaterial(color=gfx.Color(*labelColor)), + font_size=fontSize, + anchor=anchor, + screen_space=True, + ) + textObj.local.position = (lx, ly, 0) + + if rotate: + import pylinalg as la + + # Negate angle because screen camera flips Y + textObj.local.rotation = la.quat_from_axis_angle( + (0, 0, 1), math.radians(-rotate) + ) + + self._screenFrameGroup.add(textObj) + + def _updateMarkers(self): + """Update marker lines and text labels in screen space.""" + plot = self._plotRef() + if plot is None: + return + + pixelOffset = 3 + + for plotItem in self.getItemsFromBackToFront(condition=lambda i: i.isVisible()): + if plotItem._backendRenderer is None: + continue + item = plotItem._backendRenderer + if not isinstance(item, _PygfxMarkerItem): + continue + + xCoord = item["x"] + yCoord = item["y"] + yAxis = item.get("yaxis", "left") + color = item["color"] + linewidth = item["linewidth"] + dashPattern = item["dashpattern"] + + # Remove old line and text from the screen scene + if item._lineObj is not None: + if item._lineObj.parent is not None: + item._lineObj.parent.remove(item._lineObj) + item._lineObj = None + if item._textObj is not None: + if item._textObj.parent is not None: + item._textObj.parent.remove(item._textObj) + item._textObj = None + + gfxColor = gfx.Color(*color) + + if xCoord is None or yCoord is None: + # hline or vline marker — render in screen space + if xCoord is None: + # Horizontal line at y + pixelPos = self._plotFrame.dataToPixel( + 0.5 * sum(self._plotFrame.dataRanges[0]), + yCoord, + axis=yAxis, + ) + if pixelPos is None: + continue + left = self._plotFrame.margins.left + right = self._plotFrame.size[0] - self._plotFrame.margins.right + positions = numpy.array( + [ + [left, pixelPos[1], 0], + [right, pixelPos[1], 0], + ], + dtype=numpy.float32, + ) + + if item["text"] is not None: + tx = right - pixelOffset + ty = pixelPos[1] - pixelOffset + textObj = gfx.Text( + material=gfx.TextMaterial(color=gfxColor), + text=item["text"], + font_size=self._getDefaultFont().pointSizeF() or 10, + anchor="bottom-right", + screen_space=True, + ) + textObj.local.position = (tx, ty, 0) + item._textObj = textObj + self._screenScene.add(textObj) + else: + # Vertical line at x + yRange = self._plotFrame.dataRanges[1 if yAxis == "left" else 2] + pixelPos = self._plotFrame.dataToPixel( + xCoord, + 0.5 * sum(yRange), + axis=yAxis, + ) + if pixelPos is None: + continue + top = self._plotFrame.margins.top + bottom = self._plotFrame.size[1] - self._plotFrame.margins.bottom + positions = numpy.array( + [ + [pixelPos[0], top, 0], + [pixelPos[0], bottom, 0], + ], + dtype=numpy.float32, + ) + + if item["text"] is not None: + tx = pixelPos[0] + pixelOffset + ty = top + pixelOffset + textObj = gfx.Text( + material=gfx.TextMaterial(color=gfxColor), + text=item["text"], + font_size=self._getDefaultFont().pointSizeF() or 10, + anchor="top-left", + screen_space=True, + ) + textObj.local.position = (tx, ty, 0) + item._textObj = textObj + self._screenScene.add(textObj) + + geom = gfx.Geometry(positions=positions) + mat = gfx.LineMaterial( + thickness=max(linewidth, 1.0), + color=gfxColor, + dash_pattern=dashPattern if dashPattern else (), + ) + item._lineObj = gfx.Line(geom, mat) + self._screenScene.add(item._lineObj) + + else: + # Point marker — text label in screen space + if item["text"] is not None: + pixelPos = self._plotFrame.dataToPixel( + xCoord, + yCoord, + axis=yAxis, + ) + if pixelPos is None: + continue + tx = pixelPos[0] + pixelOffset + ty = pixelPos[1] + pixelOffset + textObj = gfx.Text( + material=gfx.TextMaterial(color=gfxColor), + text=item["text"], + font_size=self._getDefaultFont().pointSizeF() or 10, + anchor="top-left", + screen_space=True, + ) + textObj.local.position = (tx, ty, 0) + item._textObj = textObj + self._screenScene.add(textObj) + + def _updateCrosshair(self): + """Update crosshair cursor lines.""" + # Remove old crosshair + if self._crosshairHLine is not None: + if self._crosshairHLine in self._screenScene.children: + self._screenScene.remove(self._crosshairHLine) + self._crosshairHLine = None + if self._crosshairVLine is not None: + if self._crosshairVLine in self._screenScene.children: + self._screenScene.remove(self._crosshairVLine) + self._crosshairVLine = None + + if self._crosshairCursor is None or self._mousePosInPixels is None: + return + + color, linewidth = self._crosshairCursor + gfxColor = gfx.Color(*color) + mx, my = self._mousePosInPixels + + w, h = self._plotFrame.size + left, top = self._plotFrame.plotOrigin + pw, ph = self._plotFrame.plotSize + + # Horizontal line + hPositions = numpy.array( + [ + [left, my, 0], + [left + pw, my, 0], + ], + dtype=numpy.float32, + ) + hGeom = gfx.Geometry(positions=hPositions) + hMat = gfx.LineMaterial(thickness=linewidth, color=gfxColor) + self._crosshairHLine = gfx.Line(hGeom, hMat) + self._screenScene.add(self._crosshairHLine) + + # Vertical line + vPositions = numpy.array( + [ + [mx, top, 0], + [mx, top + ph, 0], + ], + dtype=numpy.float32, + ) + vGeom = gfx.Geometry(positions=vPositions) + vMat = gfx.LineMaterial(thickness=linewidth, color=gfxColor) + self._crosshairVLine = gfx.Line(vGeom, vMat) + self._screenScene.add(self._crosshairVLine) + + @staticmethod + def _mapAnchor(align, valign): + """Map silx align/valign strings to pygfx anchor string.""" + vmap = {"top": "top", "bottom": "bottom", "center": "middle"} + hmap = {"left": "left", "right": "right", "center": "center"} + v = vmap.get(str(valign), "middle") + h = hmap.get(str(align), "center") + return f"{v}-{h}" + + # QWidget events ######################################################## + + _MOUSE_BTNS = { + qt.Qt.LeftButton: "left", + qt.Qt.RightButton: "right", + qt.Qt.MiddleButton: "middle", + } + + def sizeHint(self): + return qt.QSize(8 * 80, 6 * 80) + + def minimumSizeHint(self): + return qt.QSize(0, 0) + + def enterEvent(self, event): + # WA_NativeWindow (from screen present mode) requires OS-level focus. + # Activate the top-level window when the mouse enters so that + # mouse events and cursor changes work without an extra click. + topLevel = self.window() + if topLevel is not None: + topLevel.activateWindow() + super().enterEvent(event) + + def mousePressEvent(self, event): + if event.button() not in self._MOUSE_BTNS: + return super().mousePressEvent(event) + x, y = qt.getMouseEventPosition(event) + self._plot.onMousePress(x, y, self._MOUSE_BTNS[event.button()]) + event.accept() + + def mouseMoveEvent(self, event): + qtPos = qt.getMouseEventPosition(event) + + previousMousePosInPixels = self._mousePosInPixels + if qtPos == self._mouseInPlotArea(*qtPos): + dpr = self.getDevicePixelRatio() + devicePos = qtPos[0] * dpr, qtPos[1] * dpr + self._mousePosInPixels = devicePos + else: + self._mousePosInPixels = None + + if ( + self._crosshairCursor is not None + and previousMousePosInPixels != self._mousePosInPixels + ): + self._plot._setDirtyPlot(overlayOnly=True) + + self._plot.onMouseMove(*qtPos) + event.accept() + + def mouseReleaseEvent(self, event): + if event.button() not in self._MOUSE_BTNS: + return super().mouseReleaseEvent(event) + x, y = qt.getMouseEventPosition(event) + self._plot.onMouseRelease(x, y, self._MOUSE_BTNS[event.button()]) + event.accept() + + def wheelEvent(self, event): + delta = event.angleDelta().y() + angleInDegrees = delta / 8.0 + x, y = qt.getMouseEventPosition(event) + self._plot.onMouseWheel(x, y, angleInDegrees) + event.accept() + + def leaveEvent(self, _): + self._plot.onMouseLeaveWidget() + + def resizeEvent(self, event): + super().resizeEvent(event) + w, h = self.width(), self.height() + if w == 0 or h == 0: + return + dpr = self.getDevicePixelRatio() + self._plotFrame.size = (int(w * dpr), int(h * dpr)) + + # Store current ranges + previousXRange = self.getGraphXLimits() + previousYRange = self.getGraphYLimits(axis="left") + previousYRightRange = self.getGraphYLimits(axis="right") + + # Re-apply current data ranges to the new size (same as OpenGL backend) + (xMin, xMax), (yMin, yMax), (y2Min, y2Max) = self._plotFrame.dataRanges + self.setLimits(xMin, xMax, yMin, yMax, y2Min, y2Max) + + # If plot range has changed, then emit signal + if previousXRange != self.getGraphXLimits(): + self._plot.getXAxis()._emitLimitsChanged() + if previousYRange != self.getGraphYLimits(axis="left"): + self._plot.getYAxis(axis="left")._emitLimitsChanged() + if previousYRightRange != self.getGraphYLimits(axis="right"): + self._plot.getYAxis(axis="right")._emitLimitsChanged() + + # Backend API: Log transform helpers ##################################### + + def _logTransformX(self, x): + """Apply log10 if X axis is log scale.""" + if not self._plotFrame.xAxis.isLog: + return x + x = numpy.array(x, copy=True, dtype=numpy.float64) + mask = x < FLOAT32_MINPOS + x[mask] = numpy.nan + with numpy.errstate(divide="ignore"): + return numpy.log10(x).astype(numpy.float32) + + def _logTransformY(self, y, yaxis="left"): + """Apply log10 if Y axis is log scale.""" + isLog = ( + self._plotFrame.yAxis.isLog + if yaxis == "left" + else self._plotFrame.y2Axis.isLog + ) + if not isLog: + return y + y = numpy.array(y, copy=True, dtype=numpy.float64) + mask = y < FLOAT32_MINPOS + y[mask] = numpy.nan + with numpy.errstate(divide="ignore"): + return numpy.log10(y).astype(numpy.float32) + + # Backend API: Add methods ############################################## + + def addCurve( + self, + x, + y, + color, + gapcolor, + symbol, + linewidth, + linestyle, + yaxis, + xerror, + yerror, + fill, + alpha, + symbolsize, + baseline, + ): + x = numpy.asarray(x, dtype=numpy.float64) + y = numpy.asarray(y, dtype=numpy.float64) + + # Log transform errors before coordinates + if self._plotFrame.xAxis.isLog and xerror is not None: + xerror = numpy.asarray(xerror, dtype=numpy.float32) + logX = numpy.log10(x) + if xerror.ndim == 2: + xErrMinus, xErrPlus = xerror[0], xerror[1] + else: + xErrMinus, xErrPlus = xerror, xerror + with numpy.errstate(divide="ignore", invalid="ignore"): + xErrMinus = logX - numpy.log10(x - xErrMinus) + xErrPlus = numpy.log10(x + xErrPlus) - logX + xerror = numpy.array((xErrMinus, xErrPlus), dtype=numpy.float32) + + isYLog = (yaxis == "left" and self._plotFrame.yAxis.isLog) or ( + yaxis == "right" and self._plotFrame.y2Axis.isLog + ) + if isYLog and yerror is not None: + yerror = numpy.asarray(yerror, dtype=numpy.float32) + logY = numpy.log10(y) + if yerror.ndim == 2: + yErrMinus, yErrPlus = yerror[0], yerror[1] + else: + yErrMinus, yErrPlus = yerror, yerror + with numpy.errstate(divide="ignore", invalid="ignore"): + yErrMinus = logY - numpy.log10(y - yErrMinus) + yErrPlus = numpy.log10(y + yErrPlus) - logY + yerror = numpy.array((yErrMinus, yErrPlus), dtype=numpy.float32) + + x = self._logTransformX(x) + y = self._logTransformY(y, yaxis) + + if baseline is not None and isYLog: + if isinstance(baseline, numpy.ndarray): + baseline = self._logTransformY(baseline, yaxis) + else: + bl = float(baseline) + if bl > 0: + baseline = math.log10(bl) + else: + baseline = numpy.nan + + item = _PygfxCurveItem( + x, + y, + color, + gapcolor, + symbol, + linewidth, + linestyle, + yaxis, + xerror, + yerror, + fill, + alpha, + symbolsize, + baseline, + ) + self._dataGroup.add(item.group) + return item + + def addImage(self, data, origin, scale, colormap, alpha): + data = numpy.asarray(data) + ox, oy = origin + sx, sy = scale + h, w = data.shape[:2] + + if self._plotFrame.xAxis.isLog: + xMin = ox + xMax = ox + w * sx + if xMin > 0 and xMax > 0: + logXMin = math.log10(xMin) + logXMax = math.log10(xMax) + ox = logXMin + sx = (logXMax - logXMin) / w + + if self._plotFrame.yAxis.isLog: + yMin = oy + yMax = oy + h * sy + if yMin > 0 and yMax > 0: + logYMin = math.log10(yMin) + logYMax = math.log10(yMax) + oy = logYMin + sy = (logYMax - logYMin) / h + + # Reuse pooled item if shape matches (avoids GPU object recreation) + reuse = self._reusableImageItem + if reuse is not None and data.ndim == 2 and reuse._scalarShape == data.shape: + self._reusableImageItem = None + reuse._build(data, (ox, oy), (sx, sy), colormap, alpha) + self._dataGroup.add(reuse.group) + return reuse + + self._reusableImageItem = None + item = _PygfxImageItem(data, (ox, oy), (sx, sy), colormap, alpha) + self._dataGroup.add(item.group) + return item + + def addTriangles(self, x, y, triangles, color, alpha): + x = self._logTransformX(numpy.asarray(x, dtype=numpy.float64)) + y = self._logTransformY(numpy.asarray(y, dtype=numpy.float64)) + item = _PygfxTrianglesItem(x, y, triangles, color, alpha) + self._dataGroup.add(item.group) + return item + + def addShape( + self, + x, + y, + shape, + color, + fill, + overlay, + linestyle, + linewidth, + gapcolor, + ): + x = self._logTransformX(numpy.asarray(x, dtype=numpy.float64)) + y = self._logTransformY(numpy.asarray(y, dtype=numpy.float64)) + # Ensure overlay outlines (e.g. zoom selection) are clearly visible + if overlay and linewidth < 2.0: + linewidth = 2.0 + item = _PygfxShapeItem( + x, + y, + shape, + color, + fill, + overlay, + linewidth, + linestyle, + gapcolor, + ) + if overlay: + self._overlayGroup.add(item.group) + else: + self._dataGroup.add(item.group) + return item + + def addMarker( + self, + x: float | None, + y: float | None, + text: str | None, + color: str, + symbol: str | None, + symbolsize: float, + linestyle: str | tuple[float, tuple[float, ...] | None], + linewidth: float, + constraint, + yaxis: str, + font: qt.QFont, + bgcolor: RGBAColorType | None, + ) -> object: + # Log transform marker coordinates + if x is not None and self._plotFrame.xAxis.isLog: + x = math.log10(x) if x > 0 else numpy.nan + if y is not None: + isYLog = ( + self._plotFrame.yAxis.isLog + if yaxis == "left" + else self._plotFrame.y2Axis.isLog + ) + if isYLog: + y = math.log10(y) if y > 0 else numpy.nan + + item = _PygfxMarkerItem( + x, + y, + text, + color, + symbol, + symbolsize, + linewidth, + linestyle, + constraint, + yaxis, + font, + bgcolor, + ) + + self._overlayGroup.add(item.group) + return item + + # Backend API: Remove #################################################### + + def remove(self, item): + if hasattr(item, "group"): + # Check Y2 axis visibility + if hasattr(item, "yaxis") and item.yaxis == "right": + y2AxisItems = ( + i + for i in self._plot.getItems() + if isinstance(i, items.YAxisMixIn) and i.getYAxis() == "right" + ) + self._plotFrame.isY2Axis = next(y2AxisItems, None) is not None + + # Pool scalar image items for reuse (avoids GPU object recreation) + if isinstance(item, _PygfxImageItem) and item._scalarShape is not None: + self._reusableImageItem = item + + group = item.group + if group.parent is not None: + group.parent.remove(group) + + # Backend API: Interaction ############################################### + + _QT_CURSORS = { + BackendBase.CURSOR_DEFAULT: qt.Qt.ArrowCursor, + BackendBase.CURSOR_POINTING: qt.Qt.PointingHandCursor, + BackendBase.CURSOR_SIZE_HOR: qt.Qt.SizeHorCursor, + BackendBase.CURSOR_SIZE_VER: qt.Qt.SizeVerCursor, + BackendBase.CURSOR_SIZE_ALL: qt.Qt.SizeAllCursor, + } + + def setGraphCursorShape(self, cursor): + if cursor is None: + super().unsetCursor() + else: + cursor = self._QT_CURSORS[cursor] + super().setCursor(qt.QCursor(cursor)) + + def setGraphCursor(self, flag, color, linewidth, linestyle): + if flag: + color = colors.rgba(color) + crosshairCursor = color, linewidth + else: + crosshairCursor = None + + if crosshairCursor != self._crosshairCursor: + self._crosshairCursor = crosshairCursor + + _PICK_OFFSET = 3 + + def _mouseInPlotArea(self, x, y): + """Returns closest visible position in the plot.""" + left, top, width, height = self.getPlotBoundsInPixels() + return ( + numpy.clip(x, left, left + width - 1), + numpy.clip(y, top, top + height - 1), + ) + + def pickItem(self, x, y, item): + dataPos = self._plot.pixelToData(x, y, axis="left", check=True) + if dataPos is None: + return None + + if item is None: + _logger.error("No item provided for picking") + return None + + # Pick markers + if isinstance(item, _PygfxMarkerItem): + yaxis = item["yaxis"] + pixelPos = self._plot.dataToPixel( + item["x"], item["y"], axis=yaxis, check=False + ) + if pixelPos is None: + return None + + if item["x"] is None: # Horizontal line + pt1 = self._plot.pixelToData( + x, y - self._PICK_OFFSET, axis=yaxis, check=False + ) + pt2 = self._plot.pixelToData( + x, y + self._PICK_OFFSET, axis=yaxis, check=False + ) + isPicked = min(pt1[1], pt2[1]) <= item["y"] <= max(pt1[1], pt2[1]) + + elif item["y"] is None: # Vertical line + pt1 = self._plot.pixelToData( + x - self._PICK_OFFSET, y, axis=yaxis, check=False + ) + pt2 = self._plot.pixelToData( + x + self._PICK_OFFSET, y, axis=yaxis, check=False + ) + isPicked = min(pt1[0], pt2[0]) <= item["x"] <= max(pt1[0], pt2[0]) + + else: + isPicked = ( + numpy.fabs(x - pixelPos[0]) <= self._PICK_OFFSET + and numpy.fabs(y - pixelPos[1]) <= self._PICK_OFFSET + ) + + return (0,) if isPicked else None + + # Pick curves + if isinstance(item, _PygfxCurveItem): + return self._pickCurve(item, x, y) + + # Pick images + if isinstance(item, _PygfxImageItem): + return self._pickImage(item, dataPos) + + # Pick triangles + if isinstance(item, _PygfxTrianglesItem): + return self._pickTriangles(item, dataPos) + + return None + + def _pickCurve(self, item, x, y): + """Pick a curve item.""" + offset = self._PICK_OFFSET + + inAreaPos = self._mouseInPlotArea(x - offset, y - offset) + dataPos = self._plot.pixelToData( + inAreaPos[0], inAreaPos[1], axis=item.yaxis, check=True + ) + if dataPos is None: + return None + xPick0, yPick0 = dataPos + + inAreaPos = self._mouseInPlotArea(x + offset, y + offset) + dataPos = self._plot.pixelToData( + inAreaPos[0], inAreaPos[1], axis=item.yaxis, check=True + ) + if dataPos is None: + return None + xPick1, yPick1 = dataPos + + xPickMin = min(xPick0, xPick1) + xPickMax = max(xPick0, xPick1) + yPickMin = min(yPick0, yPick1) + yPickMax = max(yPick0, yPick1) + + # Get curve data from the line geometry + if item._lineObj is not None: + positions = item._lineObj.geometry.positions.data + xData = positions[:, 0] + yData = positions[:, 1] + elif item._pointsObj is not None: + positions = item._pointsObj.geometry.positions.data + xData = positions[:, 0] + yData = positions[:, 1] + else: + return None + + # Find points within the pick area + indices = numpy.where( + (xData >= xPickMin) + & (xData <= xPickMax) + & (yData >= yPickMin) + & (yData <= yPickMax) + )[0] + + if len(indices) > 0: + return indices + return None + + def _pickImage(self, item, dataPos): + """Pick an image item.""" + ox, oy = item._origin + sx, sy = item._scale + h, w = item._dataShape + + xMin = ox if sx >= 0 else ox + sx * w + xMax = ox + sx * w if sx >= 0 else ox + yMin = oy if sy >= 0 else oy + sy * h + yMax = oy + sy * h if sy >= 0 else oy + + x, y = dataPos + if x < xMin or x > xMax or y < yMin or y > yMax: + return None + + col = int((x - ox) / sx) if sx != 0 else 0 + row = int((y - oy) / sy) if sy != 0 else 0 + + col = numpy.clip(col, 0, w - 1) + row = numpy.clip(row, 0, h - 1) + + return (row,), (col,) + + def _pickTriangles(self, item, dataPos): + """Pick a triangles item.""" + x, y = dataPos + xPts = item._x + yPts = item._y + triangles = item._triangles + + if len(xPts) == 0 or len(triangles) == 0: + return None + + # Bounding box check + if x < xPts.min() or x > xPts.max() or y < yPts.min() or y > yPts.max(): + return None + + # Build triangle coordinates array (N, 3, 3) for intersection test + triCoords = numpy.zeros((len(triangles), 3, 3), dtype=numpy.float32) + triCoords[:, :, 0] = xPts[triangles] + triCoords[:, :, 1] = yPts[triangles] + + # Create vertical segment through clicked point + segment = numpy.array(((x, y, -1.0), (x, y, 1.0)), dtype=numpy.float32) + + from silx.gui._glutils.utils import segmentTrianglesIntersection + + indices = segmentTrianglesIntersection(segment, triCoords)[0] + if len(indices) == 0: + return None + + # Convert triangle indices to vertex indices + indices = numpy.unique(numpy.ravel(triangles[indices])) + + # Sort from furthest to closest + dists = (xPts[indices] - x) ** 2 + (yPts[indices] - y) ** 2 + indices = indices[numpy.flip(numpy.argsort(dists), axis=0)] + + return tuple(indices) + + # Backend API: Update curve ############################################## + + def setCurveColor(self, curve, color): + pass # TODO + + # Backend API: Widget #################################################### + + def getWidgetHandle(self): + return self + + def paintEvent(self, event): + # Flush dirty items inside the paint event, where GPU operations are + # safe (same pattern as OpenGL's paintGL). This ensures _backendRenderer + # is up-to-date before pick() is called. Qt's update() coalesces + # multiple calls, naturally batching mutations. + plot = self._plotRef() + if plot is not None and plot._getDirtyPlot(): + with plot._paintContext(): + pass + super().paintEvent(event) + + def postRedisplay(self): + self.request_draw(self._draw) + # Schedule a Qt paint event so processEvents() flushes dirty items. + # rendercanvas's request_draw() uses an async scheduler that may not + # fire during processEvents(). Qt's update() coalesces multiple calls, + # naturally batching mutations before the paint event fires. + qt.QWidget.update(self) + + def replot(self): + self.request_draw(self._draw) + qt.QWidget.update(self) + + def saveGraph(self, fileName, fileFormat, dpi): + if dpi is not None: + _logger.warning("saveGraph ignores dpi parameter") + + if fileFormat not in ["png", "ppm", "svg", "tif", "tiff"]: + raise NotImplementedError("Unsupported format: %s" % fileFormat) + + # Force a synchronous render + self._draw() + snapshot = self._renderer.snapshot() + + # snapshot is (H, W, 4) RGBA uint8 + from PIL import Image as PILImage + + img = PILImage.fromarray(snapshot) + if fileFormat in ("tif", "tiff"): + img.save(fileName, format="TIFF") + elif fileFormat == "ppm": + img.convert("RGB").save(fileName, format="PPM") + elif fileFormat == "svg": + raise NotImplementedError("SVG export not supported by pygfx backend") + else: + img.save(fileName, format=fileFormat.upper()) + + # Backend API: Labels #################################################### + + def setGraphTitle(self, title): + self._plotFrame.title = title + + def setGraphXLabel(self, label): + self._plotFrame.xAxis.title = label + + def setGraphYLabel(self, label, axis): + if axis == "left": + self._plotFrame.yAxis.title = label + else: + self._plotFrame.y2Axis.title = label + + # Backend API: Limits #################################################### + + def _setDataRanges(self, xlim=None, ylim=None, y2lim=None): + self._plotFrame.setDataRanges(xlim, ylim, y2lim) + + def _ensureAspectRatio(self, keepDim=None): + plotWidth, plotHeight = self._plotFrame.plotSize + if plotWidth <= 2 or plotHeight <= 2: + return + + if keepDim is None: + ranges = self._plot.getDataRange() + if ( + ranges.y is not None + and ranges.x is not None + and (ranges.y[1] - ranges.y[0]) != 0.0 + ): + dataRatio = (ranges.x[1] - ranges.x[0]) / float( + ranges.y[1] - ranges.y[0] + ) + plotRatio = plotWidth / float(plotHeight) + keepDim = "x" if dataRatio > plotRatio else "y" + else: + keepDim = "x" + + (xMin, xMax), (yMin, yMax), (y2Min, y2Max) = self._plotFrame.dataRanges + if keepDim == "y": + dataW = (yMax - yMin) * plotWidth / float(plotHeight) + xCenter = 0.5 * (xMin + xMax) + xMin = xCenter - 0.5 * dataW + xMax = xCenter + 0.5 * dataW + elif keepDim == "x": + dataH = (xMax - xMin) * plotHeight / float(plotWidth) + yCenter = 0.5 * (yMin + yMax) + yMin = yCenter - 0.5 * dataH + yMax = yCenter + 0.5 * dataH + y2Center = 0.5 * (y2Min + y2Max) + y2Min = y2Center - 0.5 * dataH + y2Max = y2Center + 0.5 * dataH + else: + raise RuntimeError("Unsupported dimension to keep: %s" % keepDim) + + self._setDataRanges(xlim=(xMin, xMax), ylim=(yMin, yMax), y2lim=(y2Min, y2Max)) + + def _setPlotBounds(self, xRange=None, yRange=None, y2Range=None, keepDim=None): + self._setDataRanges(xlim=xRange, ylim=yRange, y2lim=y2Range) + if self.isKeepDataAspectRatio(): + self._ensureAspectRatio(keepDim) + + def setLimits(self, xmin, xmax, ymin, ymax, y2min=None, y2max=None): + if y2min is None or y2max is None: + y2Range = None + else: + y2Range = y2min, y2max + self._setPlotBounds((xmin, xmax), (ymin, ymax), y2Range) + + def getGraphXLimits(self): + return self._plotFrame.dataRanges.x + + def setGraphXLimits(self, xmin, xmax): + self._setPlotBounds(xRange=(xmin, xmax), keepDim="x") + + def getGraphYLimits(self, axis): + assert axis in ("left", "right") + if axis == "left": + return self._plotFrame.dataRanges.y + else: + return self._plotFrame.dataRanges.y2 + + def setGraphYLimits(self, ymin, ymax, axis): + assert axis in ("left", "right") + if axis == "left": + self._setPlotBounds(yRange=(ymin, ymax), keepDim="y") + else: + self._setPlotBounds(y2Range=(ymin, ymax), keepDim="y") + + # Backend API: Axes ###################################################### + + def getXAxisTimeZone(self): + return self._plotFrame.xAxis.timeZone + + def setXAxisTimeZone(self, tz): + self._plotFrame.xAxis.timeZone = tz + + def isXAxisTimeSeries(self): + return self._plotFrame.xAxis.isTimeSeries + + def setXAxisTimeSeries(self, isTimeSeries): + self._plotFrame.xAxis.isTimeSeries = isTimeSeries + + def setXAxisLogarithmic(self, flag): + if flag != self._plotFrame.xAxis.isLog: + if flag and self._keepDataAspectRatio: + _logger.warning("KeepDataAspectRatio is ignored with log axes") + self._plotFrame.xAxis.isLog = flag + + def setYAxisLogarithmic(self, flag): + if flag != self._plotFrame.yAxis.isLog or flag != self._plotFrame.y2Axis.isLog: + if flag and self._keepDataAspectRatio: + _logger.warning("KeepDataAspectRatio is ignored with log axes") + self._plotFrame.yAxis.isLog = flag + self._plotFrame.y2Axis.isLog = flag + + def setYAxisInverted(self, flag: bool): + self._plotFrame.isYAxisInverted = flag + + def isYAxisInverted(self) -> bool: + return self._plotFrame.isYAxisInverted + + def setXAxisInverted(self, flag: bool): + self._plotFrame.isXAxisInverted = flag + + def isXAxisInverted(self) -> bool: + return self._plotFrame.isXAxisInverted + + def isYRightAxisVisible(self): + return self._plotFrame.isY2Axis + + def isKeepDataAspectRatio(self): + if self._plotFrame.xAxis.isLog or self._plotFrame.yAxis.isLog: + return False + return self._keepDataAspectRatio + + def setKeepDataAspectRatio(self, flag): + if flag and (self._plotFrame.xAxis.isLog or self._plotFrame.yAxis.isLog): + _logger.warning("KeepDataAspectRatio is ignored with log axes") + self._keepDataAspectRatio = flag + + def setGraphGrid(self, which): + assert which in (None, "major", "both") + self._plotFrame.grid = which is not None + + # Backend API: Data <-> Pixel ############################################ + + def dataToPixel(self, x, y, axis): + result = self._plotFrame.dataToPixel(x, y, axis) + if result is None: + return None + dpr = self.getDevicePixelRatio() + return tuple(value / dpr for value in result) + + def pixelToData(self, x, y, axis): + dpr = self.getDevicePixelRatio() + return self._plotFrame.pixelToData(x * dpr, y * dpr, axis) + + def getPlotBoundsInPixels(self): + dpr = self.getDevicePixelRatio() + return tuple( + int(value / dpr) + for value in self._plotFrame.plotOrigin + self._plotFrame.plotSize + ) + + # Backend API: Margins & Colors ########################################## + + def setAxesMargins(self, left: float, top: float, right: float, bottom: float): + self._plotFrame.marginRatios = left, top, right, bottom + + def setForegroundColors(self, foregroundColor, gridColor): + self._plotFrame.foregroundColor = foregroundColor + self._plotFrame.gridColor = gridColor + + def setBackgroundColors(self, backgroundColor, dataBackgroundColor): + self._backgroundColor = backgroundColor + self._dataBackgroundColor = dataBackgroundColor + + # Remove old background + if hasattr(self, "_bgObj") and self._bgObj is not None: + if self._bgObj in self._scene.children: + self._scene.remove(self._bgObj) + + # Update data scene background (plot area uses dataBackgroundColor) + if dataBackgroundColor is not None: + bgColor = gfx.Color(*dataBackgroundColor) + self._bgObj = gfx.Background(None, gfx.BackgroundMaterial(bgColor)) + self._scene.add(self._bgObj) + else: + self._bgObj = None diff --git a/src/silx/gui/plot/backends/_PlotFrameCore.py b/src/silx/gui/plot/backends/_PlotFrameCore.py new file mode 100644 index 0000000000..950e508508 --- /dev/null +++ b/src/silx/gui/plot/backends/_PlotFrameCore.py @@ -0,0 +1,1393 @@ +# /*########################################################################## +# +# Copyright (c) 2014-2023 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ############################################################################*/ +""" +Rendering-independent plot frame math and layout. + +Extracted from GLPlotFrame to be shared by OpenGL and pygfx backends. +Provides coordinate transforms, tick generation, margin/layout calculation, +and grid vertex computation without any OpenGL or rendering dependencies. +""" + +from __future__ import annotations + +__authors__ = ["T. Vincent"] +__license__ = "MIT" + +import datetime as dt +import math +import weakref +import logging +import numbers +from collections import namedtuple + +import numpy + +from ... import qt +from ...utils.matplotlib import DefaultTickFormatter +from .._utils import checkAxisLimits, FLOAT32_MINPOS +from .._utils.ticklayout import niceNumbersAdaptative, niceNumbersForLog10 +from .._utils.dtime_ticklayout import ( + DtUnit, + bestUnit, + calcTicksAdaptive, + formatDatetimes, +) +from .._utils.dtime_ticklayout import timestamp + +_logger = logging.getLogger(__name__) + + +# PlotAxisCore ################################################################ + + +class PlotAxisCore: + """Represents a 1D axis of the plot (rendering-independent). + + Provides tick computation, data range management, and label generation + without any GL or rendering dependencies. + """ + + def __init__( + self, + plotFrame, + tickLength=(0.0, 0.0), + foregroundColor=(0.0, 0.0, 0.0, 1.0), + labelAlign="center", + labelVAlign="center", + titleAlign="center", + titleVAlign="center", + orderOffsetAlign="center", + orderOffsetVAlign="center", + titleRotate=0, + titleOffset=(0.0, 0.0), + font: qt.QFont | None = None, + ): + self._tickFormatter = DefaultTickFormatter() + self._ticks = None + self._orderAndOffsetText = "" + + self._plotFrameRef = weakref.ref(plotFrame) + + self._isDateTime = False + self._timeZone = None + self._isLog = False + self._dataRange = 1.0, 100.0 + self._displayCoords = (0.0, 0.0), (1.0, 0.0) + self._title = "" + + self._tickLength = tickLength + self._foregroundColor = foregroundColor + self._labelAlign = labelAlign + self._labelVAlign = labelVAlign + self._orderOffsetAnchor = (1.0, 0.0) + self._orderOffsetAlign = orderOffsetAlign + self._orderOffsetVAlign = orderOffsetVAlign + self._titleAlign = titleAlign + self._titleVAlign = titleVAlign + self._titleRotate = titleRotate + self._titleOffset = titleOffset + self._font = font + + @property + def dataRange(self): + """The range of the data represented on the axis as a tuple + of 2 floats: (min, max).""" + return self._dataRange + + @property + def font(self) -> qt.QFont: + if self._font is None: + return qt.QApplication.instance().font() + return self._font + + @dataRange.setter + def dataRange(self, dataRange): + assert len(dataRange) == 2 + assert dataRange[0] <= dataRange[1] + dataRange = float(dataRange[0]), float(dataRange[1]) + + if dataRange != self._dataRange: + self._dataRange = dataRange + self._dirtyTicks() + + @property + def isLog(self): + """Whether the axis is using a log10 scale or not as a bool.""" + return self._isLog + + @isLog.setter + def isLog(self, isLog): + isLog = bool(isLog) + if isLog != self._isLog: + self._isLog = isLog + self._dirtyTicks() + + @property + def timeZone(self): + """Returns datetime.tzinfo that is used if this axis plots date times.""" + return self._timeZone + + @timeZone.setter + def timeZone(self, tz): + """Sets datetime.tzinfo that is used if this axis plots date times.""" + self._timeZone = tz + self._dirtyTicks() + + @property + def isTimeSeries(self): + """Whether the axis is showing floats as datetime objects""" + return self._isDateTime + + @isTimeSeries.setter + def isTimeSeries(self, isTimeSeries): + isTimeSeries = bool(isTimeSeries) + if isTimeSeries != self._isDateTime: + self._isDateTime = isTimeSeries + self._dirtyTicks() + + @property + def displayCoords(self): + """The coordinates of the start and end points of the axis + in display space (i.e., in pixels) as a tuple of 2 tuples of + 2 floats: ((x0, y0), (x1, y1)). + """ + return self._displayCoords + + @displayCoords.setter + def displayCoords(self, displayCoords): + assert len(displayCoords) == 2 + assert len(displayCoords[0]) == 2 + assert len(displayCoords[1]) == 2 + displayCoords = tuple(displayCoords[0]), tuple(displayCoords[1]) + if displayCoords != self._displayCoords: + self._displayCoords = displayCoords + self._dirtyTicks() + + @property + def devicePixelRatio(self): + """Returns the ratio between qt pixels and device pixels.""" + plotFrame = self._plotFrameRef() + return plotFrame.devicePixelRatio if plotFrame is not None else 1.0 + + @property + def dotsPerInch(self): + """Returns the screen DPI""" + plotFrame = self._plotFrameRef() + return plotFrame.dotsPerInch if plotFrame is not None else 92 + + @property + def title(self): + """The text label associated with this axis as a str.""" + return self._title + + @title.setter + def title(self, title): + if title != self._title: + self._title = title + self._dirtyPlotFrame() + + @property + def orderOffsetAnchor(self) -> tuple[float, float]: + """Anchor position for the tick order&offset text""" + return self._orderOffsetAnchor + + @orderOffsetAnchor.setter + def orderOffsetAnchor(self, position: tuple[float, float]): + if position != self._orderOffsetAnchor: + self._orderOffsetAnchor = position + self._dirtyTicks() + + @property + def titleOffset(self): + """Title offset in pixels (x: int, y: int)""" + return self._titleOffset + + @titleOffset.setter + def titleOffset(self, offset): + if offset != self._titleOffset: + self._titleOffset = offset + self._dirtyTicks() + + @property + def foregroundColor(self): + """Color used for frame and labels""" + return self._foregroundColor + + @foregroundColor.setter + def foregroundColor(self, color): + """Color used for frame and labels""" + assert ( + len(color) == 4 + ), f"foregroundColor must have length 4, got {len(self._foregroundColor)}" + if self._foregroundColor != color: + self._foregroundColor = color + self._dirtyTicks() + + @property + def ticks(self): + """Ticks as tuples: ((x, y) in display, dataPos, textLabel).""" + if self._ticks is None: + self._ticks = tuple(self._ticksGenerator()) + return self._ticks + + def getVerticesAndLabels(self): + """Create the list of vertices and associated text label descriptors. + + Returns plain dicts for labels instead of GL Text2D objects. + + :returns: A tuple: (list of 2D line vertices, list of label dicts). + Each label dict has keys: text, font, color, x, y, align, valign, + rotate, devicePixelRatio. + """ + vertices = list(self.displayCoords) # Add start and end points + labels = [] + + xTickLength, yTickLength = self._tickLength + xTickLength *= self.devicePixelRatio + yTickLength *= self.devicePixelRatio + for (xPixel, yPixel), dataPos, text in self.ticks: + if text is None: + tickScale = 0.5 + else: + tickScale = 1.0 + + label = { + "text": text, + "font": self.font, + "color": self._foregroundColor, + "x": xPixel - xTickLength, + "y": yPixel - yTickLength, + "align": self._labelAlign, + "valign": self._labelVAlign, + "rotate": 0, + "devicePixelRatio": self.devicePixelRatio, + } + labels.append(label) + + vertices.append((xPixel, yPixel)) + vertices.append( + (xPixel + tickScale * xTickLength, yPixel + tickScale * yTickLength) + ) + + (x0, y0), (x1, y1) = self.displayCoords + xAxisCenter = 0.5 * (x0 + x1) + yAxisCenter = 0.5 * (y0 + y1) + + xOffset, yOffset = self.titleOffset + + axisTitle = { + "text": self.title, + "font": self.font, + "color": self._foregroundColor, + "x": xAxisCenter + xOffset, + "y": yAxisCenter + yOffset, + "align": self._titleAlign, + "valign": self._titleVAlign, + "rotate": self._titleRotate, + "devicePixelRatio": self.devicePixelRatio, + } + labels.append(axisTitle) + + if self._orderAndOffsetText: + orderAndOffsetFont = self._orderAndOffsetFont(self.font) + + xOrderOffset, yOrderOffset = self.orderOffsetAnchor + labels.append( + { + "text": self._orderAndOffsetText, + "font": orderAndOffsetFont, + "color": self._foregroundColor, + "x": xOrderOffset, + "y": yOrderOffset, + "align": self._orderOffsetAlign, + "valign": self._orderOffsetVAlign, + "rotate": 0, + "devicePixelRatio": self.devicePixelRatio, + } + ) + return vertices, labels + + @staticmethod + def _orderAndOffsetFont(font: qt.QFont) -> qt.QFont: + """Returns a larger bold font""" + boldBiggerFont = qt.QFont(font) + boldBiggerFont.setWeight(qt.QFont.ExtraBold) + pointSize = boldBiggerFont.pointSizeF() + if pointSize > 0: + boldBiggerFont.setPointSizeF(1.1 * pointSize) + pixelSize = boldBiggerFont.pixelSize() + if pixelSize > 0: + boldBiggerFont.setPixelSize(int(1.1 * pixelSize)) + return boldBiggerFont + + def _dirtyPlotFrame(self): + """Dirty parent PlotFrame""" + plotFrame = self._plotFrameRef() + if plotFrame is not None: + plotFrame._dirty() + + def _dirtyTicks(self): + """Mark ticks as dirty and notify listener.""" + self._ticks = None + self._dirtyPlotFrame() + + @staticmethod + def _frange(start, stop, step): + """range for float (including stop).""" + while start <= stop: + yield start + start += step + + def _ticksGenerator(self): + """Generator of ticks as tuples: + ((x, y) in display, dataPos, textLabel). + """ + self._orderAndOffsetText = "" + + dataMin, dataMax = self.dataRange + if self.isLog and dataMin <= 0.0: + _logger.warning("Getting ticks while isLog=True and dataRange[0]<=0.") + dataMin = 1.0 + if dataMax < dataMin: + dataMax = 1.0 + + if dataMin != dataMax: # data range is not null + (x0, y0), (x1, y1) = self.displayCoords + + if self.isLog: + if self.isTimeSeries: + _logger.warning("Time series not implemented for log-scale") + + logMin, logMax = math.log10(dataMin), math.log10(dataMax) + tickMin, tickMax, step, _ = niceNumbersForLog10(logMin, logMax) + + xScale = (x1 - x0) / (logMax - logMin) + yScale = (y1 - y0) / (logMax - logMin) + + for logPos in self._frange(tickMin, tickMax, step): + if logMin <= logPos <= logMax: + dataPos = 10**logPos + xPixel = x0 + (logPos - logMin) * xScale + yPixel = y0 + (logPos - logMin) * yScale + text = "1e%+03d" % logPos + yield ((xPixel, yPixel), dataPos, text) + + if step == 1: + ticks = list(self._frange(tickMin, tickMax, step))[:-1] + for logPos in ticks: + dataOrigPos = 10**logPos + for index in range(2, 10): + dataPos = dataOrigPos * index + if dataMin <= dataPos <= dataMax: + logSubPos = math.log10(dataPos) + xPixel = x0 + (logSubPos - logMin) * xScale + yPixel = y0 + (logSubPos - logMin) * yScale + yield ((xPixel, yPixel), dataPos, None) + + else: + xScale = (x1 - x0) / (dataMax - dataMin) + yScale = (y1 - y0) / (dataMax - dataMin) + + nbPixels = ( + math.sqrt(pow(x1 - x0, 2) + pow(y1 - y0, 2)) / self.devicePixelRatio + ) + + # Density of 1.3 label per 92 pixels + # i.e., 1.3 label per inch on a 92 dpi screen + tickDensity = 1.3 * self.devicePixelRatio / self.dotsPerInch + + if not self.isTimeSeries: + tickMin, tickMax, step, _ = niceNumbersAdaptative( + dataMin, dataMax, nbPixels, tickDensity + ) + + visibleTickPositions = [ + pos + for pos in self._frange(tickMin, tickMax, step) + if dataMin <= pos <= dataMax + ] + self._tickFormatter.axis.set_view_interval(dataMin, dataMax) + self._tickFormatter.axis.set_data_interval(dataMin, dataMax) + texts = self._tickFormatter.format_ticks(visibleTickPositions) + self._orderAndOffsetText = self._tickFormatter.get_offset() + + for dataPos, text in zip(visibleTickPositions, texts): + xPixel = x0 + (dataPos - dataMin) * xScale + yPixel = y0 + (dataPos - dataMin) * yScale + yield ((xPixel, yPixel), dataPos, text) + + else: + # Time series + try: + dtMin = dt.datetime.fromtimestamp(dataMin, tz=self.timeZone) + dtMax = dt.datetime.fromtimestamp(dataMax, tz=self.timeZone) + except ValueError: + _logger.warning("Data range cannot be displayed with time axis") + return # Range is out of bound of the datetime + + if bestUnit( + (dtMax - dtMin).total_seconds() == DtUnit.MICRO_SECONDS + ): + # Special case for micro seconds: Reduce tick density + tickDensity = 1.0 * self.devicePixelRatio / self.dotsPerInch + + tickDateTimes, spacing, unit = calcTicksAdaptive( + dtMin, dtMax, nbPixels, tickDensity + ) + visibleDatetimes = tuple( + dt for dt in tickDateTimes if dtMin <= dt <= dtMax + ) + ticks = formatDatetimes(visibleDatetimes, spacing, unit) + + for tickDateTime, text in ticks.items(): + dataPos = timestamp(tickDateTime) + xPixel = x0 + (dataPos - dataMin) * xScale + yPixel = y0 + (dataPos - dataMin) * yScale + yield ((xPixel, yPixel), dataPos, text) + + +# PlotFrameCore ############################################################### + + +class PlotFrameCore: + """Base class for rendering-independent 2D frame layout. + + Provides margin computation, axis management, grid vertex generation, + and label generation without any GL dependencies. + """ + + _TICK_LENGTH_IN_PIXELS = 5 + _LINE_WIDTH = 1 + + _Margins = namedtuple("Margins", ("left", "right", "top", "bottom")) + + # Margins used when plot frame is not displayed + _NoDisplayMargins = _Margins(0, 0, 0, 0) + + def __init__(self, marginRatios, foregroundColor, gridColor, font: qt.QFont): + """ + :param marginRatios: + The ratios of margins around plot area for axis and labels. + (left, top, right, bottom) as float in [0., 1.] + :param foregroundColor: color used for the frame and labels. + :param gridColor: color used for grid lines. + :param font: Font used by the axes label + """ + self.__dirty = True + + self.__marginRatios = marginRatios + self.__marginsCache = None + + self._foregroundColor = foregroundColor + self._gridColor = gridColor + + self.axes = [] # List of PlotAxisCore to be updated by subclasses + + self._grid = False + self._size = 0.0, 0.0 + self._title = "" + self._font: qt.QFont = font + + self._devicePixelRatio = 1.0 + self._dpi = 92 + + @property + def isDirty(self): + """True if it needs to refresh, False otherwise.""" + return self.__dirty + + GRID_NONE = 0 + GRID_MAIN_TICKS = 1 + GRID_SUB_TICKS = 2 + GRID_ALL_TICKS = GRID_MAIN_TICKS + GRID_SUB_TICKS + + @property + def foregroundColor(self): + """Color used for frame and labels""" + return self._foregroundColor + + @foregroundColor.setter + def foregroundColor(self, color): + """Color used for frame and labels""" + assert ( + len(color) == 4 + ), f"foregroundColor must have length 4, got {len(self._foregroundColor)}" + if self._foregroundColor != color: + self._foregroundColor = color + for axis in self.axes: + axis.foregroundColor = color + self._dirty() + + @property + def gridColor(self): + """Color used for grid""" + return self._gridColor + + @gridColor.setter + def gridColor(self, color): + """Color used for grid""" + assert ( + len(color) == 4 + ), f"gridColor must have length 4, got {len(self._gridColor)}" + if self._gridColor != color: + self._gridColor = color + self._dirty() + + @property + def marginRatios(self): + """Plot margin ratios: (left, top, right, bottom) as 4 float in [0, 1].""" + return self.__marginRatios + + @marginRatios.setter + def marginRatios(self, ratios): + ratios = tuple(float(v) for v in ratios) + assert len(ratios) == 4 + for value in ratios: + assert 0.0 <= value <= 1.0 + assert ratios[0] + ratios[2] < 1.0 + assert ratios[1] + ratios[3] < 1.0 + + if self.__marginRatios != ratios: + self.__marginRatios = ratios + self.__marginsCache = None + self._dirty() + + @property + def margins(self): + """Margins in pixels around the plot.""" + if self.__marginsCache is None: + width, height = self.size + left, top, right, bottom = self.marginRatios + self.__marginsCache = self._Margins( + left=int(left * width), + right=int(right * width), + top=int(top * height), + bottom=int(bottom * height), + ) + return self.__marginsCache + + @property + def devicePixelRatio(self): + return self._devicePixelRatio + + @devicePixelRatio.setter + def devicePixelRatio(self, ratio): + if ratio != self._devicePixelRatio: + self._devicePixelRatio = ratio + self._dirty() + + @property + def dotsPerInch(self): + return self._dpi + + @dotsPerInch.setter + def dotsPerInch(self, dpi): + if dpi != self._dpi: + self._dpi = dpi + self._dirty() + + @property + def grid(self): + """Grid display mode: + - 0: No grid. + - 1: Grid on main ticks. + - 2: Grid on sub-ticks for log scale axes. + - 3: Grid on main and sub ticks.""" + return self._grid + + @grid.setter + def grid(self, grid): + assert grid in ( + self.GRID_NONE, + self.GRID_MAIN_TICKS, + self.GRID_SUB_TICKS, + self.GRID_ALL_TICKS, + ) + if grid != self._grid: + self._grid = grid + self._dirty() + + @property + def size(self): + """Size in device pixels of the plot area including margins.""" + return self._size + + @size.setter + def size(self, size): + assert len(size) == 2 + size = tuple(size) + if size != self._size: + self._size = size + self.__marginsCache = None + self._dirty() + + @property + def plotOrigin(self): + """Plot area origin (left, top) in widget coordinates in pixels.""" + return self.margins.left, self.margins.top + + @property + def plotSize(self): + """Plot area size (width, height) in pixels.""" + w, h = self.size + w -= self.margins.left + self.margins.right + h -= self.margins.top + self.margins.bottom + return w, h + + @property + def title(self): + """Main title as a str.""" + return self._title + + @title.setter + def title(self, title): + if title != self._title: + self._title = title + self._dirty() + + def _dirty(self): + self.__dirty = True + + def _clearDirty(self): + self.__dirty = False + + def _buildGridVertices(self): + if self._grid == self.GRID_NONE: + return [] + + elif self._grid == self.GRID_MAIN_TICKS: + + def test(text): + return text is not None + + elif self._grid == self.GRID_SUB_TICKS: + + def test(text): + return text is None + + elif self._grid == self.GRID_ALL_TICKS: + + def test(_): + return True + + else: + logging.warning("Wrong grid mode: %d" % self._grid) + return [] + + return self._buildGridVerticesWithTest(test) + + def _buildGridVerticesWithTest(self, test): + """Override in subclass to generate grid vertices""" + return [] + + def _buildVerticesAndLabels(self): + """Build vertices and labels for the frame. + + Returns (vertices as numpy array, gridVertices as numpy array, + labels as list of dicts). + """ + vertices = [] + labels = [] + + for axis in self.axes: + axisVertices, axisLabels = axis.getVerticesAndLabels() + vertices += axisVertices + labels += axisLabels + + vertices = numpy.array(vertices, dtype=numpy.float32) + + # Add main title + xTitle = (self.size[0] + self.margins.left - self.margins.right) // 2 + yTitle = self.margins.top - self._TICK_LENGTH_IN_PIXELS + labels.append( + { + "text": self.title, + "font": self._font, + "color": self._foregroundColor, + "x": xTitle, + "y": yTitle, + "align": "center", + "valign": "bottom", + "rotate": 0, + "devicePixelRatio": self.devicePixelRatio, + } + ) + + # grid + gridVertices = numpy.array(self._buildGridVertices(), dtype=numpy.float32) + + return vertices, gridVertices, labels + + +# PlotFrame2DCore ############################################################# + + +class PlotFrame2DCore(PlotFrameCore): + """Rendering-independent 2D plot frame with coordinate transforms. + + Provides data-to-pixel / pixel-to-data conversions, axis inversion, + log scale, base vectors, and grid vertex computation. + """ + + _DataRanges = namedtuple("dataRanges", ("x", "y", "y2")) + + # Align constants as strings (rendering layer maps these to its own constants) + _ALIGN_CENTER = "center" + _ALIGN_LEFT = "left" + _ALIGN_RIGHT = "right" + _VALIGN_TOP = "top" + _VALIGN_BOTTOM = "bottom" + _VALIGN_CENTER = "center" + _ROTATE_270 = 270 + + def __init__(self, marginRatios, foregroundColor, gridColor, font: qt.QFont): + super().__init__(marginRatios, foregroundColor, gridColor, font) + self._font = font + + self.axes.append( + PlotAxisCore( + self, + tickLength=(0.0, -5.0), + foregroundColor=self._foregroundColor, + labelAlign=self._ALIGN_CENTER, + labelVAlign=self._VALIGN_TOP, + orderOffsetAlign=self._ALIGN_RIGHT, + orderOffsetVAlign=self._VALIGN_TOP, + titleAlign=self._ALIGN_CENTER, + titleVAlign=self._VALIGN_TOP, + titleRotate=0, + font=self._font, + ) + ) + + self._x2AxisCoords = () + + self.axes.append( + PlotAxisCore( + self, + tickLength=(5.0, 0.0), + foregroundColor=self._foregroundColor, + labelAlign=self._ALIGN_RIGHT, + labelVAlign=self._VALIGN_CENTER, + orderOffsetAlign=self._ALIGN_RIGHT, + orderOffsetVAlign=self._VALIGN_BOTTOM, + titleAlign=self._ALIGN_CENTER, + titleVAlign=self._VALIGN_BOTTOM, + titleRotate=self._ROTATE_270, + font=self._font, + ) + ) + + self._y2Axis = PlotAxisCore( + self, + tickLength=(-5.0, 0.0), + foregroundColor=self._foregroundColor, + labelAlign=self._ALIGN_LEFT, + labelVAlign=self._VALIGN_CENTER, + orderOffsetAlign=self._ALIGN_LEFT, + orderOffsetVAlign=self._VALIGN_BOTTOM, + titleAlign=self._ALIGN_CENTER, + titleVAlign=self._VALIGN_TOP, + titleRotate=self._ROTATE_270, + font=self._font, + ) + + self._isXAxisInverted = False + self._isYAxisInverted = False + + self._dataRanges = {"x": (1.0, 100.0), "y": (1.0, 100.0), "y2": (1.0, 100.0)} + + self._baseVectors = (1.0, 0.0), (0.0, 1.0) + + self._transformedDataRanges = None + self._transformedDataProjMat = None + self._transformedDataY2ProjMat = None + + def _dirty(self): + super()._dirty() + self._transformedDataRanges = None + self._transformedDataProjMat = None + self._transformedDataY2ProjMat = None + + @property + def isDirty(self): + """True if it needs to refresh, False otherwise.""" + return ( + super().isDirty + or self._transformedDataRanges is None + or self._transformedDataProjMat is None + or self._transformedDataY2ProjMat is None + ) + + @property + def xAxis(self): + return self.axes[0] + + @property + def yAxis(self): + return self.axes[1] + + @property + def y2Axis(self): + return self._y2Axis + + @property + def isY2Axis(self): + """Whether to display the right Y axis or not.""" + return len(self.axes) == 3 + + @isY2Axis.setter + def isY2Axis(self, isY2Axis): + if isY2Axis != self.isY2Axis: + if isY2Axis: + self.axes.append(self._y2Axis) + else: + self.axes = self.axes[:2] + + self._dirty() + + @property + def isYAxisInverted(self) -> bool: + """Whether Y axes are inverted or not as a bool.""" + return self._isYAxisInverted + + @isYAxisInverted.setter + def isYAxisInverted(self, value: bool): + value = bool(value) + if value != self._isYAxisInverted: + self._isYAxisInverted = value + self._dirty() + + @property + def isXAxisInverted(self) -> bool: + return self._isXAxisInverted + + @isXAxisInverted.setter + def isXAxisInverted(self, value: bool): + value = bool(value) + if value != self._isXAxisInverted: + self._isXAxisInverted = value + self._dirty() + + DEFAULT_BASE_VECTORS = (1.0, 0.0), (0.0, 1.0) + """Values of baseVectors for orthogonal axes.""" + + @property + def baseVectors(self): + """Coordinates of the X and Y axes in the orthogonal plot coords. + + Raises ValueError if corresponding matrix is singular. + + 2 tuples of 2 floats: (xx, xy), (yx, yy) + """ + return self._baseVectors + + @baseVectors.setter + def baseVectors(self, baseVectors): + self._dirty() + + (xx, xy), (yx, yy) = baseVectors + vectors = (float(xx), float(xy)), (float(yx), float(yy)) + + det = vectors[0][0] * vectors[1][1] - vectors[1][0] * vectors[0][1] + if det == 0.0: + raise ValueError("Singular matrix for base vectors: " + str(vectors)) + + if vectors != self._baseVectors: + self._baseVectors = vectors + self._dirty() + + def _updateTitleOffset(self): + """Update axes title offset according to margins""" + margins = self.margins + self.xAxis.titleOffset = 0, margins.bottom // 2 + self.yAxis.titleOffset = -3 * margins.left // 4, 0 + self.y2Axis.titleOffset = 3 * margins.right // 4, 0 + + # Override size and marginRatios setters to update titleOffsets + @PlotFrameCore.size.setter + def size(self, size): + PlotFrameCore.size.fset(self, size) + self._updateTitleOffset() + + @PlotFrameCore.marginRatios.setter + def marginRatios(self, ratios): + PlotFrameCore.marginRatios.fset(self, ratios) + self._updateTitleOffset() + + @property + def dataRanges(self): + """Ranges of data visible in the plot on x, y and y2 axes. + + This is different to the axes range when axes are not orthogonal. + + Type: ((xMin, xMax), (yMin, yMax), (y2Min, y2Max)) + """ + return self._DataRanges( + self._dataRanges["x"], self._dataRanges["y"], self._dataRanges["y2"] + ) + + def setDataRanges(self, x=None, y=None, y2=None): + """Set data range over each axes. + + The provided ranges are clipped to possible values + (i.e., 32 float range + positive range for log scale). + + :param x: (min, max) data range over X axis + :param y: (min, max) data range over Y axis + :param y2: (min, max) data range over Y2 axis + """ + if x is not None: + self._dataRanges["x"] = checkAxisLimits( + x[0], x[1], self.xAxis.isLog, name="x" + ) + + if y is not None: + self._dataRanges["y"] = checkAxisLimits( + y[0], y[1], self.yAxis.isLog, name="y" + ) + + if y2 is not None: + self._dataRanges["y2"] = checkAxisLimits( + y2[0], y2[1], self.y2Axis.isLog, name="y2" + ) + + self.xAxis.dataRange = self._dataRanges["x"] + self.yAxis.dataRange = self._dataRanges["y"] + self.y2Axis.dataRange = self._dataRanges["y2"] + + @property + def transformedDataRanges(self): + """Bounds of the displayed area in transformed data coordinates + (i.e., log scale applied if any as well as skew) + + 3-tuple of 2-tuple (min, max) for each axis: x, y, y2. + """ + if self._transformedDataRanges is None: + (xMin, xMax), (yMin, yMax), (y2Min, y2Max) = self.dataRanges + + if self.xAxis.isLog: + try: + xMin = math.log10(xMin) + except ValueError: + _logger.info("xMin: warning log10(%f)", xMin) + xMin = 0.0 + try: + xMax = math.log10(xMax) + except ValueError: + _logger.info("xMax: warning log10(%f)", xMax) + xMax = 0.0 + + if self.yAxis.isLog: + try: + yMin = math.log10(yMin) + except ValueError: + _logger.info("yMin: warning log10(%f)", yMin) + yMin = 0.0 + try: + yMax = math.log10(yMax) + except ValueError: + _logger.info("yMax: warning log10(%f)", yMax) + yMax = 0.0 + + try: + y2Min = math.log10(y2Min) + except ValueError: + _logger.info("yMin: warning log10(%f)", y2Min) + y2Min = 0.0 + try: + y2Max = math.log10(y2Max) + except ValueError: + _logger.info("yMax: warning log10(%f)", y2Max) + y2Max = 0.0 + + self._transformedDataRanges = self._DataRanges( + (xMin, xMax), (yMin, yMax), (y2Min, y2Max) + ) + + return self._transformedDataRanges + + @property + def transformedDataProjMat(self): + """Orthographic projection matrix for rendering transformed data + + :type: numpy.ndarray (4x4) + """ + if self._transformedDataProjMat is None: + xMin, xMax = self.transformedDataRanges.x + yMin, yMax = self.transformedDataRanges.y + + if self.isYAxisInverted: + yMax, yMin = yMin, yMax + + if self.isXAxisInverted: + xMax, xMin = xMin, xMax + + self._transformedDataProjMat = self._mat4Ortho( + xMin, xMax, yMin, yMax, 1, -1 + ) + + return self._transformedDataProjMat + + @property + def transformedDataY2ProjMat(self): + """Orthographic projection matrix for rendering transformed data + for the 2nd Y axis + + :type: numpy.ndarray (4x4) + """ + if self._transformedDataY2ProjMat is None: + xMin, xMax = self.transformedDataRanges.x + y2Min, y2Max = self.transformedDataRanges.y2 + + if self.isYAxisInverted: + y2Max, y2Min = y2Min, y2Max + + if self.isXAxisInverted: + xMax, xMin = xMin, xMax + + self._transformedDataY2ProjMat = self._mat4Ortho( + xMin, xMax, y2Min, y2Max, 1, -1 + ) + + return self._transformedDataY2ProjMat + + @staticmethod + def _mat4Ortho(left, right, bottom, top, near, far): + """Orthographic projection matrix (row-major). + + Equivalent to glutils.GLSupport.mat4Ortho but without GL dependency. + """ + if left == right or bottom == top or near == far: + return numpy.identity(4, dtype=numpy.float64) + + sx = 2.0 / (right - left) + sy = 2.0 / (top - bottom) + sz = -2.0 / (far - near) + tx = -(right + left) / (right - left) + ty = -(top + bottom) / (top - bottom) + tz = -(far + near) / (far - near) + + return numpy.array( + ( + (sx, 0.0, 0.0, tx), + (0.0, sy, 0.0, ty), + (0.0, 0.0, sz, tz), + (0.0, 0.0, 0.0, 1.0), + ), + dtype=numpy.float64, + ) + + @staticmethod + def _applyLog( + data: float | numpy.ndarray, isLog: bool + ) -> float | numpy.ndarray | None: + """Apply log to data filtering out""" + if not isLog: + return data + + if isinstance(data, numbers.Real): + return None if data < FLOAT32_MINPOS else math.log10(data) + + isBelowMin = data < FLOAT32_MINPOS + if numpy.any(isBelowMin): + data = numpy.array(data, copy=True, dtype=numpy.float64) + data[isBelowMin] = numpy.nan + + with numpy.errstate(divide="ignore"): + return numpy.log10(data) + + def dataToPixel(self, x, y, axis="left"): + """Convert data coordinate to widget pixel coordinate.""" + assert axis in ("left", "right") + + trBounds = self.transformedDataRanges + + xDataTr = self._applyLog(x, self.xAxis.isLog) + if xDataTr is None: + return None + + yDataTr = self._applyLog(y, self.yAxis.isLog) + if yDataTr is None: + return None + + # Non-orthogonal axes + if self.baseVectors != self.DEFAULT_BASE_VECTORS: + (xx, xy), (yx, yy) = self.baseVectors + skew_mat = numpy.array(((xx, yx), (xy, yy))) + + coords = numpy.dot(skew_mat, numpy.array((xDataTr, yDataTr))) + xDataTr, yDataTr = coords + + plotWidth, plotHeight = self.plotSize + + xOffset = ( + plotWidth * (xDataTr - trBounds.x[0]) / (trBounds.x[1] - trBounds.x[0]) + ) + if self.isXAxisInverted: + xPixel = self.size[0] - self.margins.right - xOffset + else: + xPixel = self.margins.left + xOffset + + usedAxis = trBounds.y if axis == "left" else trBounds.y2 + yOffset = plotHeight * (yDataTr - usedAxis[0]) / (usedAxis[1] - usedAxis[0]) + + if self.isYAxisInverted: + yPixel = self.margins.top + yOffset + else: + yPixel = self.size[1] - self.margins.bottom - yOffset + + return ( + ( + int(xPixel) + if isinstance(xPixel, numbers.Real) + else xPixel.astype(numpy.int64) + ), + ( + int(yPixel) + if isinstance(yPixel, numbers.Real) + else yPixel.astype(numpy.int64) + ), + ) + + def pixelToData(self, x, y, axis="left"): + """Convert pixel position to data coordinates. + + :param float x: X coord + :param float y: Y coord + :param str axis: Y axis to use in ('left', 'right') + :return: (x, y) position in data coords + """ + assert axis in ("left", "right") + + plotWidth, plotHeight = self.plotSize + + trBounds = self.transformedDataRanges + + if self.isXAxisInverted: + unscaledXData = self.size[0] - self.margins.right - x - 0.5 + else: + unscaledXData = x - self.margins.left + 0.5 + xData = trBounds.x[0] + unscaledXData / float(plotWidth) * ( + trBounds.x[1] - trBounds.x[0] + ) + + if self.isYAxisInverted: + unscaledYData = y - self.margins.top + 0.5 + else: + unscaledYData = self.size[1] - self.margins.bottom - y - 0.5 + usedAxis = trBounds.y if axis == "left" else trBounds.y2 + yData = usedAxis[0] + unscaledYData / float(plotHeight) * ( + usedAxis[1] - usedAxis[0] + ) + + # non-orthogonal axis + if self.baseVectors != self.DEFAULT_BASE_VECTORS: + (xx, xy), (yx, yy) = self.baseVectors + skew_mat = numpy.array(((xx, yx), (xy, yy))) + skew_mat = numpy.linalg.inv(skew_mat) + + coords = numpy.dot(skew_mat, numpy.array((xData, yData))) + xData, yData = coords + + if self.xAxis.isLog: + xData = pow(10, xData) + if self.yAxis.isLog: + yData = pow(10, yData) + + return xData, yData + + def _buildGridVerticesWithTest(self, test): + vertices = [] + + if self.baseVectors == self.DEFAULT_BASE_VECTORS: + for axis in self.axes: + for (xPixel, yPixel), data, text in axis.ticks: + if test(text): + vertices.append((xPixel, yPixel)) + if axis == self.xAxis: + vertices.append((xPixel, self.margins.top)) + elif axis == self.yAxis: + vertices.append((self.size[0] - self.margins.right, yPixel)) + else: # axis == self.y2Axis + vertices.append((self.margins.left, yPixel)) + + else: + # Get plot corners in data coords + plotLeft, plotTop = self.plotOrigin + plotWidth, plotHeight = self.plotSize + + corners = [ + (plotLeft, plotTop), + (plotLeft, plotTop + plotHeight), + (plotLeft + plotWidth, plotTop + plotHeight), + (plotLeft + plotWidth, plotTop), + ] + + for axis in self.axes: + if axis == self.xAxis: + cornersInData = numpy.array( + [self.pixelToData(x, y) for (x, y) in corners] + ) + borders = ( + (cornersInData[0], cornersInData[3]), # top + (cornersInData[1], cornersInData[0]), # left + (cornersInData[3], cornersInData[2]), + ) # right + + for (xPixel, yPixel), data, text in axis.ticks: + if test(text): + for (x0, y0), (x1, y1) in borders: + if min(x0, x1) <= data < max(x0, x1): + yIntersect = (data - x0) * (y1 - y0) / ( + x1 - x0 + ) + y0 + + pixelPos = self.dataToPixel(data, yIntersect) + if pixelPos is not None: + vertices.append((xPixel, yPixel)) + vertices.append(pixelPos) + break # Stop at first intersection + + else: # y or y2 axes + if axis == self.yAxis: + axis_name = "left" + cornersInData = numpy.array( + [self.pixelToData(x, y) for (x, y) in corners] + ) + borders = ( + (cornersInData[3], cornersInData[2]), # right + (cornersInData[0], cornersInData[3]), # top + (cornersInData[2], cornersInData[1]), + ) # bottom + + else: # axis == self.y2Axis + axis_name = "right" + corners = numpy.array( + [self.pixelToData(x, y, axis="right") for (x, y) in corners] + ) + borders = ( + (cornersInData[1], cornersInData[0]), # left + (cornersInData[0], cornersInData[3]), # top + (cornersInData[2], cornersInData[1]), + ) # bottom + + for (xPixel, yPixel), data, text in axis.ticks: + if test(text): + for (x0, y0), (x1, y1) in borders: + if min(y0, y1) <= data < max(y0, y1): + xIntersect = (data - y0) * (x1 - x0) / ( + y1 - y0 + ) + x0 + + pixelPos = self.dataToPixel( + xIntersect, data, axis=axis_name + ) + if pixelPos is not None: + vertices.append((xPixel, yPixel)) + vertices.append(pixelPos) + break # Stop at first intersection + + return vertices + + def _buildVerticesAndLabels(self): + width, height = self.size + + xLeft = self.margins.left - 0.5 + xRight = width - self.margins.right + 0.5 + yBottom = height - self.margins.bottom + 0.5 + yTop = self.margins.top - 0.5 + + self._x2AxisCoords = ((xLeft, yTop), (xRight, yTop)) + + # Set order&offset anchor **before** handling axis inversion + fontPixelSize = self._font.pixelSize() + if fontPixelSize == -1: + fontPixelSize = self._font.pointSizeF() / 72.0 * self.dotsPerInch + + self.axes[0].orderOffsetAnchor = ( + xRight, + yBottom + fontPixelSize * 1.2, + ) + self.axes[1].orderOffsetAnchor = ( + xLeft, + yTop - 4 * self.devicePixelRatio - fontPixelSize / 2.0, + ) + self._y2Axis.orderOffsetAnchor = ( + xRight, + yTop - 4 * self.devicePixelRatio - fontPixelSize / 2.0, + ) + + if self.isYAxisInverted: + yCoords = yTop, yBottom + else: + yCoords = yBottom, yTop + + if self.isXAxisInverted: + xCoords = xRight, xLeft + else: + xCoords = xLeft, xRight + + self.axes[0].displayCoords = ( + (xCoords[0], yBottom), + (xCoords[1], yBottom), + ) + + self.axes[1].displayCoords = ( + (xLeft, yCoords[0]), + (xLeft, yCoords[1]), + ) + + self._y2Axis.displayCoords = ( + (xRight, yCoords[0]), + (xRight, yCoords[1]), + ) + + vertices, gridVertices, labels = super()._buildVerticesAndLabels() + + # Adds vertices for borders without axis + extraVertices = [] + extraVertices += list(self._x2AxisCoords) + if not self.isY2Axis: + extraVertices += list(self._y2Axis.displayCoords) + + extraVertices = numpy.asarray(extraVertices, dtype=numpy.float32) + vertices = numpy.append(vertices, extraVertices, axis=0) + + return vertices, gridVertices, labels + + @property + def foregroundColor(self): + """Color used for frame and labels""" + return self._foregroundColor + + @foregroundColor.setter + def foregroundColor(self, color): + """Color used for frame and labels""" + assert ( + len(color) == 4 + ), f"foregroundColor must have length 4, got {len(self._foregroundColor)}" + if self._foregroundColor != color: + self._y2Axis.foregroundColor = color + PlotFrameCore.foregroundColor.fset(self, color) diff --git a/src/silx/gui/plot/test/conftest.py b/src/silx/gui/plot/test/conftest.py index 78475fb4c4..4995b5e069 100644 --- a/src/silx/gui/plot/test/conftest.py +++ b/src/silx/gui/plot/test/conftest.py @@ -40,4 +40,6 @@ def plotWidget(qWidgetFactory, request): backend = "mpl" # Backend was not defined if backend == "gl": request.getfixturevalue("use_opengl") # Skip test if OpenGL test disabled + elif backend == "pygfx": + request.getfixturevalue("use_pygfx") # Skip test if pygfx test disabled yield qWidgetFactory(PlotWidget, backend=backend) diff --git a/src/silx/gui/plot/test/test_item.py b/src/silx/gui/plot/test/test_item.py index 8a6db40289..a97c9360de 100644 --- a/src/silx/gui/plot/test/test_item.py +++ b/src/silx/gui/plot/test/test_item.py @@ -531,7 +531,7 @@ def testPlotWidgetAddShape(plotWidget): (4.0, (8.0, 4.0, 4.0, 4.0)), ), ) -@pytest.mark.parametrize("plotWidget", ("mpl", "gl"), indirect=True) +@pytest.mark.parametrize("plotWidget", ("mpl", "gl", "pygfx"), indirect=True) def testLineStyle(qapp_utils, plotWidget, linestyle): """Test different line styles for LineMixIn items""" plotWidget.setGraphTitle(f"Line style: {linestyle}") diff --git a/src/silx/gui/plot/test/test_plotwidget.py b/src/silx/gui/plot/test/test_plotwidget.py index 2db603a918..dfb806b16f 100755 --- a/src/silx/gui/plot/test/test_plotwidget.py +++ b/src/silx/gui/plot/test/test_plotwidget.py @@ -2054,6 +2054,66 @@ class TestPlotMarkerLog_Gl(TestPlotMarkerLog): backend = "gl" +@pytest.mark.usefixtures("use_pygfx") +class TestPlotWidget_Pygfx(TestPlotWidget): + backend = "pygfx" + + +@pytest.mark.usefixtures("use_pygfx") +class TestPlotImage_Pygfx(TestPlotImage): + backend = "pygfx" + + +@pytest.mark.usefixtures("use_pygfx") +class TestPlotCurve_Pygfx(TestPlotCurve): + backend = "pygfx" + + +@pytest.mark.usefixtures("use_pygfx") +class TestPlotHistogram_Pygfx(TestPlotHistogram): + backend = "pygfx" + + +@pytest.mark.usefixtures("use_pygfx") +class TestPlotScatter_Pygfx(TestPlotScatter): + backend = "pygfx" + + +@pytest.mark.usefixtures("use_pygfx") +class TestPlotMarker_Pygfx(TestPlotMarker): + backend = "pygfx" + + +@pytest.mark.usefixtures("use_pygfx") +class TestPlotItem_Pygfx(TestPlotItem): + backend = "pygfx" + + +@pytest.mark.usefixtures("use_pygfx") +class TestPlotAxes_Pygfx(TestPlotAxes): + backend = "pygfx" + + +@pytest.mark.usefixtures("use_pygfx") +class TestPlotEmptyLog_Pygfx(TestPlotEmptyLog): + backend = "pygfx" + + +@pytest.mark.usefixtures("use_pygfx") +class TestPlotCurveLog_Pygfx(TestPlotCurveLog): + backend = "pygfx" + + +@pytest.mark.usefixtures("use_pygfx") +class TestPlotImageLog_Pygfx(TestPlotImageLog): + backend = "pygfx" + + +@pytest.mark.usefixtures("use_pygfx") +class TestPlotMarkerLog_Pygfx(TestPlotMarkerLog): + backend = "pygfx" + + class TestSpecial_ExplicitMplBackend(TestSpecialBackend): backend = "mpl" @@ -2061,7 +2121,7 @@ class TestSpecial_ExplicitMplBackend(TestSpecialBackend): @pytest.mark.filterwarnings("ignore:All-NaN slice encountered:RuntimeWarning") @pytest.mark.filterwarnings("ignore:.* converting a masked element to nan.:UserWarning") @pytest.mark.filterwarnings("ignore:All-NaN axis encountered:RuntimeWarning") -@pytest.mark.parametrize("plotWidget", ("mpl", "gl"), indirect=True) +@pytest.mark.parametrize("plotWidget", ("mpl", "gl", "pygfx"), indirect=True) @pytest.mark.parametrize( "xerror,yerror", [ diff --git a/src/silx/gui/plot/test/test_plotwidgetactiveitem.py b/src/silx/gui/plot/test/test_plotwidgetactiveitem.py index 99285a8035..5af3f38983 100755 --- a/src/silx/gui/plot/test/test_plotwidgetactiveitem.py +++ b/src/silx/gui/plot/test/test_plotwidgetactiveitem.py @@ -35,7 +35,7 @@ from silx.gui.plot.items.curve import CurveStyle -@pytest.mark.parametrize("plotWidget", ("mpl", "gl"), indirect=True) +@pytest.mark.parametrize("plotWidget", ("mpl", "gl", "pygfx"), indirect=True) def testActiveCurveAndLabels(plotWidget): # Active curve handling off, no label change plotWidget.setActiveCurveHandling(False) @@ -85,7 +85,7 @@ def testActiveCurveAndLabels(plotWidget): plotWidget.setActiveCurveHandling(False) -@pytest.mark.parametrize("plotWidget", ("mpl", "gl"), indirect=True) +@pytest.mark.parametrize("plotWidget", ("mpl", "gl", "pygfx"), indirect=True) def testPlotActiveCurveSelectionMode(plotWidget): xData = numpy.arange(1000) yData = -500 + 100 * numpy.sin(xData) @@ -140,7 +140,7 @@ def testPlotActiveCurveSelectionMode(plotWidget): plotWidget.setActiveCurveHandling(False) -@pytest.mark.parametrize("plotWidget", ("mpl", "gl"), indirect=True) +@pytest.mark.parametrize("plotWidget", ("mpl", "gl", "pygfx"), indirect=True) def testActiveCurveStyle(plotWidget): """Test change of active curve style""" plotWidget.setActiveCurveHandling(True) @@ -192,7 +192,7 @@ def testActiveCurveStyle(plotWidget): plotWidget.setActiveCurveHandling(False) -@pytest.mark.parametrize("plotWidget", ("mpl", "gl"), indirect=True) +@pytest.mark.parametrize("plotWidget", ("mpl", "gl", "pygfx"), indirect=True) def testActiveImageAndLabels(plotWidget): # Active image handling always on, no API for toggling it plotWidget.getXAxis().setLabel("XLabel") @@ -232,7 +232,7 @@ def _checkSelection(selection, current=None, selected=()): assert selection.getSelectedItems() == selected -@pytest.mark.parametrize("plotWidget", ("mpl", "gl"), indirect=True) +@pytest.mark.parametrize("plotWidget", ("mpl", "gl", "pygfx"), indirect=True) def testSelectionSyncWithActiveItems(plotWidget): """Test update of PlotWidgetSelection according to active items""" listener = SignalListener() @@ -314,7 +314,7 @@ def testSelectionSyncWithActiveItems(plotWidget): _checkSelection(selection) -@pytest.mark.parametrize("plotWidget", ("mpl", "gl"), indirect=True) +@pytest.mark.parametrize("plotWidget", ("mpl", "gl", "pygfx"), indirect=True) def testSelectionWithItems(plotWidget): """Test init of selection on a plot with items""" plotWidget.addImage(((0, 1), (2, 3)), legend="image") @@ -331,7 +331,7 @@ def testSelectionWithItems(plotWidget): assert plotWidget.getActiveScatter() in selected -@pytest.mark.parametrize("plotWidget", ("mpl", "gl"), indirect=True) +@pytest.mark.parametrize("plotWidget", ("mpl", "gl", "pygfx"), indirect=True) def testSelectionSetCurrentItem(plotWidget): """Test setCurrentItem""" # Add items to the plot diff --git a/src/silx/gui/plot/test/test_plotwindow.py b/src/silx/gui/plot/test/test_plotwindow.py index 43d465192b..398e87d1da 100644 --- a/src/silx/gui/plot/test/test_plotwindow.py +++ b/src/silx/gui/plot/test/test_plotwindow.py @@ -166,9 +166,12 @@ def testSwitchBackend(self): ylimits = self.plot.getYAxis().getLimits() isKeepAspectRatio = self.plot.isKeepDataAspectRatio() - for backend in ("gl", "mpl"): - with self.subTest(): - self.plot.setBackend(backend) + for backend in ("gl", "mpl", "pygfx"): + with self.subTest(backend=backend): + try: + self.plot.setBackend(backend) + except Exception: + continue self.plot.replot() self.assertEqual(self.plot.getXAxis().getLimits(), xlimits) self.assertEqual(self.plot.getYAxis().getLimits(), ylimits) diff --git a/src/silx/gui/plot3d/Plot3DWidgetPygfx.py b/src/silx/gui/plot3d/Plot3DWidgetPygfx.py new file mode 100644 index 0000000000..a581d1dbe6 --- /dev/null +++ b/src/silx/gui/plot3d/Plot3DWidgetPygfx.py @@ -0,0 +1,427 @@ +"""pygfx-based 3D rendering widget, replacement for Plot3DWidget.""" + +import logging +import math + +import numpy + +from .. import qt +from ..colors import rgba + +_logger = logging.getLogger(__name__) + + +def _look_at_quaternion(eye, target, up=(0, 1, 0)): + """Compute quaternion (x, y, z, w) for camera at eye looking at target.""" + eye = numpy.asarray(eye, dtype=numpy.float64) + target = numpy.asarray(target, dtype=numpy.float64) + up = numpy.asarray(up, dtype=numpy.float64) + + forward = target - eye + fwd_len = numpy.linalg.norm(forward) + if fwd_len < 1e-10: + return (0.0, 0.0, 0.0, 1.0) + forward = forward / fwd_len + + right = numpy.cross(forward, up) + right_len = numpy.linalg.norm(right) + if right_len < 1e-6: + alt_up = ( + numpy.array([1.0, 0, 0]) + if abs(forward[1]) > 0.9 + else numpy.array([0, 1.0, 0]) + ) + right = numpy.cross(forward, alt_up) + right = right / numpy.linalg.norm(right) + else: + right = right / right_len + + up_actual = numpy.cross(right, forward) + + # Rotation matrix: camera local X=right, Y=up, -Z=forward + R = numpy.zeros((3, 3)) + R[:, 0] = right + R[:, 1] = up_actual + R[:, 2] = -forward + + return _mat3_to_quat(R) + + +def _mat3_to_quat(m): + """Convert 3x3 rotation matrix to quaternion (x, y, z, w).""" + tr = m[0, 0] + m[1, 1] + m[2, 2] + + if tr > 0: + s = math.sqrt(tr + 1.0) * 2 + w = 0.25 * s + x = (m[2, 1] - m[1, 2]) / s + y = (m[0, 2] - m[2, 0]) / s + z = (m[1, 0] - m[0, 1]) / s + elif m[0, 0] > m[1, 1] and m[0, 0] > m[2, 2]: + s = math.sqrt(1.0 + m[0, 0] - m[1, 1] - m[2, 2]) * 2 + w = (m[2, 1] - m[1, 2]) / s + x = 0.25 * s + y = (m[0, 1] + m[1, 0]) / s + z = (m[0, 2] + m[2, 0]) / s + elif m[1, 1] > m[2, 2]: + s = math.sqrt(1.0 + m[1, 1] - m[0, 0] - m[2, 2]) * 2 + w = (m[0, 2] - m[2, 0]) / s + x = (m[0, 1] + m[1, 0]) / s + y = 0.25 * s + z = (m[1, 2] + m[2, 1]) / s + else: + s = math.sqrt(1.0 + m[2, 2] - m[0, 0] - m[1, 1]) * 2 + w = (m[1, 0] - m[0, 1]) / s + x = (m[0, 2] + m[2, 0]) / s + y = (m[1, 2] + m[2, 1]) / s + z = 0.25 * s + + return (float(x), float(y), float(z), float(w)) + + +class _StubLight: + """Stub light for pygfx backend, compatible with _DirectionalLightProxy.""" + + direction = (0, -1, -1) + + def addListener(self, callback): + pass + + +class _ExtrinsicProxy: + """Proxy for camera extrinsic with reset(face=) API.""" + + _FACE_DIRECTIONS = { + "front": numpy.array([0.0, 0.0, 1.0]), + "back": numpy.array([0.0, 0.0, -1.0]), + "right": numpy.array([1.0, 0.0, 0.0]), + "left": numpy.array([-1.0, 0.0, 0.0]), + "top": numpy.array([0.0, 1.0, 0.001]), + "bottom": numpy.array([0.0, -1.0, 0.001]), + "side": numpy.array([1.0, 1.0, 1.0]), + } + + def __init__(self, widget): + self._widget = widget + + def reset(self, face="front"): + """Reset camera to a predefined viewpoint.""" + camera = self._widget._camera + scene = self._widget._scene + + direction = self._FACE_DIRECTIONS.get( + face, self._FACE_DIRECTIONS["front"] + ).copy() + direction = direction / numpy.linalg.norm(direction) + + center = self._widget._getSceneCenter() + + # First show_object to get proper distance + camera.show_object(scene) + pos = numpy.array(camera.local.position, dtype=numpy.float64) + distance = max(numpy.linalg.norm(pos - center), 1.0) + + # Reposition camera + new_pos = center + direction * distance + camera.local.position = tuple(new_pos) + + # Set rotation to look at center + up = ( + (0, 0, -1) + if face == "top" + else (0, 0, 1) if face == "bottom" else (0, 1, 0) + ) + quat = _look_at_quaternion(new_pos, center, up) + camera.local.rotation = quat + + # Recreate controller to pick up new camera state + gfx = self._widget._gfx + self._widget._controller = gfx.OrbitController( + camera, register_events=self._widget._renderer + ) + + +class _CameraProxy: + """Proxy for camera with extrinsic.reset(face=) API.""" + + def __init__(self, widget): + self.extrinsic = _ExtrinsicProxy(widget) + + +class _ViewportProxy: + """Proxy providing viewport-like API for pygfx widget.""" + + def __init__(self, widget): + self._widget = widget + self.camera = _CameraProxy(widget) + self.light = _StubLight() + + def orbitCamera(self, direction, angle=1.0): + """Rotate camera around scene center. + + :param str direction: 'up', 'down', 'left', 'right' + :param float angle: Rotation angle in degrees + """ + camera = self._widget._camera + pos = numpy.array(camera.local.position, dtype=numpy.float64) + center = self._widget._getSceneCenter() + rel = pos - center + distance = numpy.linalg.norm(rel) + if distance < 1e-6: + return + + rad = math.radians(angle) + + if direction in ("left", "right"): + sign = 1.0 if direction == "left" else -1.0 + c, s = math.cos(sign * rad), math.sin(sign * rad) + new_rel = numpy.array( + [ + rel[0] * c + rel[2] * s, + rel[1], + -rel[0] * s + rel[2] * c, + ] + ) + elif direction in ("up", "down"): + sign = 1.0 if direction == "up" else -1.0 + c, s = math.cos(sign * rad), math.sin(sign * rad) + new_rel = numpy.array( + [ + rel[0], + rel[1] * c - rel[2] * s, + rel[1] * s + rel[2] * c, + ] + ) + else: + return + + new_pos = center + new_rel + camera.local.position = tuple(new_pos) + + # Update rotation to look at center + quat = _look_at_quaternion(new_pos, center) + camera.local.rotation = quat + + +class Plot3DWidgetPygfx(qt.QWidget): + """3D scene widget using pygfx/wgpu for rendering. + + Drop-in replacement for Plot3DWidget with the same public API. + """ + + sigStyleChanged = qt.Signal(str) + sigInteractiveModeChanged = qt.Signal() + sigSceneClicked = qt.Signal(float, float) + + def __init__(self, parent=None): + super().__init__(parent) + import pygfx as gfx + from rendercanvas.qt import QRenderWidget + + self._gfx = gfx + + # Layout + layout = qt.QHBoxLayout(self) + layout.setContentsMargins(0, 0, 0, 0) + + # Render widget + self._renderWidget = QRenderWidget(self) + self._renderWidget.set_update_mode("continuous", max_fps=60) + layout.addWidget(self._renderWidget) + + # Renderer + self._renderer = gfx.WgpuRenderer(self._renderWidget) + + # Scene + self._scene = gfx.Scene() + + # Camera + self._camera = gfx.PerspectiveCamera(fov=50) + self._projection = "perspective" + + # Lights: ambient + directional (attached to camera) + ambient = gfx.AmbientLight(intensity=0.4) + self._scene.add(ambient) + + directional = gfx.DirectionalLight(intensity=0.8) + self._camera.add(directional) + self._scene.add(self._camera) + + # Controller + self._controller = gfx.OrbitController( + self._camera, register_events=self._renderer + ) + self._interactiveMode = "rotate" + + # Background + self._backgroundColor = (0.2, 0.2, 0.25, 1.0) + bg = gfx.BackgroundMaterial(gfx.Color(*self._backgroundColor)) + self._background = gfx.Background(None, bg) + self._scene.add(self._background) + + # Data group (items added here) + self._dataGroup = gfx.Group() + self._scene.add(self._dataGroup) + + # Viewport proxy for toolbar/action compatibility + self.viewport = _ViewportProxy(self) + + # Connect render loop + self._renderWidget.request_draw(self._animate) + + def _animate(self): + """Render callback.""" + self._renderer.render(self._scene, self._camera) + self._renderWidget.request_draw(self._animate) + + def _getSceneCenter(self): + """Estimate the center of the scene data.""" + try: + bbox = self._dataGroup.get_world_bounding_box() + if bbox is not None: + mn = numpy.array(bbox[0]) + mx = numpy.array(bbox[1]) + if numpy.all(numpy.isfinite(mn)) and numpy.all(numpy.isfinite(mx)): + return (mn + mx) / 2 + except (AttributeError, Exception): + pass + return numpy.array([0.0, 0.0, 0.0]) + + # --- Background color --- + + def setBackgroundColor(self, color): + """Set the background color. + + :param color: RGBA color + """ + color = rgba(color) + self._backgroundColor = color + + gfx = self._gfx + self._background.material = gfx.BackgroundMaterial(gfx.Color(*color)) + + def getBackgroundColor(self): + """Return the background color. + + :rtype: QColor + """ + return qt.QColor.fromRgbF(*self._backgroundColor) + + # --- Projection --- + + def setProjection(self, projection): + """Set the projection mode. + + :param str projection: 'perspective' or 'orthographic' + """ + gfx = self._gfx + if projection == "orthographic" and self._projection != "orthographic": + self._projection = "orthographic" + self._camera = gfx.OrthographicCamera() + self._controller = gfx.OrbitController( + self._camera, register_events=self._renderer + ) + self._scene.add(self._camera) + elif projection == "perspective" and self._projection != "perspective": + self._projection = "perspective" + self._camera = gfx.PerspectiveCamera(fov=50) + self._controller = gfx.OrbitController( + self._camera, register_events=self._renderer + ) + self._scene.add(self._camera) + + def getProjection(self): + """Return the current projection mode. + + :rtype: str + """ + return self._projection + + # --- Interactive mode --- + + def setInteractiveMode(self, mode): + """Set the interactive mode. + + :param str mode: 'rotate', 'pan', or None + """ + mode = mode if mode else "rotate" + if mode != self._interactiveMode: + self._interactiveMode = mode + self.sigInteractiveModeChanged.emit() + + def getInteractiveMode(self): + """Return the current interactive mode. + + :rtype: str + """ + return self._interactiveMode + + # --- View control --- + + def centerScene(self): + """Center the camera on the scene.""" + self._camera.show_object(self._scene) + + def resetZoom(self, face="front"): + """Reset camera to a preset view. + + :param str face: The face to show ('front', 'back', etc.) + """ + self._camera.show_object(self._scene) + + # --- Screenshot --- + + def grabGL(self): + """Render the scene and return a QImage. + + :returns: RGBA image as QImage + :rtype: QImage + """ + try: + snapshot = self._renderer.snapshot() + arr = numpy.ascontiguousarray(numpy.asarray(snapshot)) + h, w = arr.shape[:2] + image = qt.QImage(arr.data, w, h, w * 4, qt.QImage.Format_RGBA8888) + # copy() to own the data (detach from numpy buffer) + return image.copy() + except (AttributeError, Exception): + # No render has occurred yet + return qt.QImage() + + # --- Device pixel ratio --- + + def getDevicePixelRatio(self): + """Return the device pixel ratio. + + :rtype: float + """ + return self.devicePixelRatioF() + + # --- Fog (no-op for pygfx) --- + + def setFogMode(self, mode): + pass + + def getFogMode(self): + return None + + # --- Light mode (no-op, always has lights) --- + + def setLightMode(self, mode): + pass + + def getLightMode(self): + return "directional" + + # --- Orientation indicator (no-op for now) --- + + def setOrientationIndicatorVisible(self, visible): + pass + + def isOrientationIndicatorVisible(self): + return False + + # --- Valid check --- + + def isValid(self): + return True diff --git a/src/silx/gui/plot3d/SceneWidgetPygfx.py b/src/silx/gui/plot3d/SceneWidgetPygfx.py new file mode 100644 index 0000000000..3c2868033b --- /dev/null +++ b/src/silx/gui/plot3d/SceneWidgetPygfx.py @@ -0,0 +1,405 @@ +"""pygfx-based SceneWidget, replacement for SceneWidget.""" + +import logging + +import numpy + +from .. import qt +from ..colors import rgba + +from .Plot3DWidgetPygfx import Plot3DWidgetPygfx +from . import items +from .items._pygfx_sync import sync_item + +_logger = logging.getLogger(__name__) + + +class _StubSelection(qt.QObject): + """Stub selection for pygfx backend.""" + + sigCurrentChanged = qt.Signal(object, object) + + def getCurrentItem(self): + return None + + def setCurrentItem(self, item): + pass + + def _setSyncSelectionModel(self, model): + pass + + +class SceneWidgetPygfx(Plot3DWidgetPygfx): + """Widget displaying data sets in 3D using pygfx backend. + + Provides the same public API as SceneWidget for item management. + """ + + def __init__(self, parent=None): + super().__init__(parent) + self._pygfxObjects = {} # item -> pygfx WorldObject mapping + + self._textColor = (1.0, 1.0, 1.0, 1.0) + self._foregroundColor = (1.0, 1.0, 1.0, 1.0) + self._highlightColor = (0.7, 0.7, 0.0, 1.0) + + self._selection = _StubSelection(self) + + # Item management via GroupItem (reuses SceneModel infrastructure) + self._sceneGroup = items.GroupItem() + self._sceneGroup.setLabel("Data") + self._sceneGroup.sigItemAdded.connect(self._onGroupItemAdded) + self._sceneGroup.sigItemRemoved.connect(self._onGroupItemRemoved) + self._model = None + + # Axes and bounding box + gfx = self._gfx + self._axesGroup = gfx.Group() + self._scene.add(self._axesGroup) + self._rulers = None # (ruler_x, ruler_y, ruler_z) + self._bboxLine = None + + # --- Item management --- + + def addItem(self, item, index=None): + """Add an Item3D to the scene. + + :param Item3D item: The item to add + :param int index: Index at which to place the item (default: end) + """ + self._sceneGroup.addItem(item, index) + # Auto-fit camera to scene after adding items + self._camera.show_object(self._scene) + + def removeItem(self, item): + """Remove an Item3D from the scene. + + :param Item3D item: The item to remove + """ + self._sceneGroup.removeItem(item) + + def getItems(self): + """Return the list of items in the scene. + + :rtype: tuple + """ + return self._sceneGroup.getItems() + + def clearItems(self): + """Remove all items from the scene.""" + self._sceneGroup.clearItems() + + # --- visit (for GroupPropertiesWidget compatibility) --- + + def visit(self, included=True): + """Generator visiting scene items recursively. + + :param bool included: Whether to include self + """ + if included: + yield self + for item in self._sceneGroup.getItems(): + if hasattr(item, "visit"): + yield from item.visit(included=True) + else: + yield item + + # --- Convenience add methods --- + + def addVolume(self, data, copy=True, index=None): + """Add 3D data volume to the scene. + + :param data: 3D array (zyx order) + :param bool copy: Whether to copy the data + :param int index: Position index + :returns: ScalarField3D or ComplexField3D + """ + if data is not None: + data = numpy.asarray(data) + + if numpy.iscomplexobj(data): + volume = items.ComplexField3D() + else: + volume = items.ScalarField3D() + volume.setData(data, copy=copy) + self.addItem(volume, index) + return volume + + def add3DScatter(self, x, y, z, value, copy=True, index=None): + """Add 3D scatter data to the scene. + + :returns: Scatter3D item + """ + scatter3d = items.Scatter3D() + scatter3d.setData(x=x, y=y, z=z, value=value, copy=copy) + self.addItem(scatter3d, index) + return scatter3d + + def add2DScatter(self, x, y, value, copy=True, index=None): + """Add 2D scatter data to the scene. + + :returns: Scatter2D item + """ + scatter2d = items.Scatter2D() + scatter2d.setData(x=x, y=y, value=value, copy=copy) + self.addItem(scatter2d, index) + return scatter2d + + def addImage(self, data, copy=True, index=None): + """Add 2D image or RGBA image to the scene. + + :returns: ImageData or ImageRgba + """ + data = numpy.asarray(data) + if data.ndim == 2: + image = items.ImageData() + elif data.ndim == 3: + image = items.ImageRgba() + else: + raise ValueError("Unsupported array dimensions: %d" % data.ndim) + image.setData(data, copy=copy) + self.addItem(image, index) + return image + + # --- Colors --- + + def getTextColor(self): + return qt.QColor.fromRgbF(*self._textColor) + + def setTextColor(self, color): + color = rgba(color) + if color != self._textColor: + self._textColor = color + self.sigStyleChanged.emit("textColor") + + def getForegroundColor(self): + return qt.QColor.fromRgbF(*self._foregroundColor) + + def setForegroundColor(self, color): + color = rgba(color) + if color != self._foregroundColor: + self._foregroundColor = color + self.sigStyleChanged.emit("foregroundColor") + + def getHighlightColor(self): + return qt.QColor.fromRgbF(*self._highlightColor) + + def setHighlightColor(self, color): + color = rgba(color) + if color != self._highlightColor: + self._highlightColor = color + self.sigStyleChanged.emit("highlightColor") + + # --- Scene group --- + + def getSceneGroup(self): + """Return the GroupItem managing scene items. + + :rtype: GroupItem + """ + return self._sceneGroup + + # --- Picking (stub) --- + + def pickItems(self, x, y, condition=None): + """Stub for picking - not yet implemented for pygfx backend.""" + return iter([]) + + # --- Selection (stub) --- + + def selection(self): + """Return a stub selection object.""" + return self._selection + + def model(self): + """Return the SceneModel for the parameter tree. + + :rtype: SceneModel + """ + if self._model is None: + from ._model.model import SceneModel + + self._model = SceneModel(parent=self) + return self._model + + # --- Internal sync --- + + def _onGroupItemAdded(self, item): + """Handle item added to GroupItem.""" + item.sigItemChanged.connect(self._onItemChanged) + self._syncItem(item) + self._updateAxesAndBBox() + + def _onGroupItemRemoved(self, item): + """Handle item removed from GroupItem.""" + item.sigItemChanged.disconnect(self._onItemChanged) + self._unsyncItem(item) + self._updateAxesAndBBox() + + def _syncItem(self, item): + """Synchronize an Item3D to pygfx scene objects.""" + self._unsyncItem(item) + + obj = sync_item(item) + if obj is not None: + self._pygfxObjects[id(item)] = obj + self._dataGroup.add(obj) + + def _unsyncItem(self, item): + """Remove pygfx objects for an item.""" + key = id(item) + if key in self._pygfxObjects: + obj = self._pygfxObjects.pop(key) + try: + self._dataGroup.remove(obj) + except ValueError: + pass + + def _onItemChanged(self, event): + """Handle item property changes by re-syncing.""" + item = self.sender() + if item in self._sceneGroup.getItems(): + self._syncItem(item) + self._updateAxesAndBBox() + + def _resyncAll(self): + """Re-synchronize all items.""" + for obj in list(self._pygfxObjects.values()): + try: + self._dataGroup.remove(obj) + except ValueError: + pass + self._pygfxObjects.clear() + + for item in self._sceneGroup.getItems(): + obj = sync_item(item) + if obj is not None: + self._pygfxObjects[id(item)] = obj + self._dataGroup.add(obj) + + # --- 3D Axes and Bounding Box --- + + def _getDataBounds(self): + """Get bounding box of all data in the scene. + + :returns: (min, max) arrays or None if no data + """ + try: + bbox = self._dataGroup.get_world_bounding_box() + if bbox is not None: + mn = numpy.array(bbox[0], dtype=numpy.float64) + mx = numpy.array(bbox[1], dtype=numpy.float64) + if numpy.all(numpy.isfinite(mn)) and numpy.all(numpy.isfinite(mx)): + return mn, mx + except Exception: + pass + return None + + def _updateAxesAndBBox(self): + """Update 3D axes rulers and bounding box wireframe.""" + gfx = self._gfx + bounds = self._getDataBounds() + + if bounds is None: + # Remove existing axes/bbox + if self._rulers is not None: + for ruler in self._rulers: + self._axesGroup.remove(ruler) + self._rulers = None + if self._bboxLine is not None: + self._axesGroup.remove(self._bboxLine) + self._bboxLine = None + return + + mn, mx = bounds + + # Rulers: X (red), Y (green), Z (blue) at bbox edges, labels facing outward + if self._rulers is None: + self._rulers = ( + gfx.Ruler( + start_pos=tuple(mn), + end_pos=(mx[0], mn[1], mn[2]), + start_value=mn[0], + tick_side="right", + color=(1, 0, 0, 1), + line_width=2, + ), + gfx.Ruler( + start_pos=tuple(mn), + end_pos=(mn[0], mx[1], mn[2]), + start_value=mn[1], + tick_side="left", + color=(0, 1, 0, 1), + line_width=2, + ), + gfx.Ruler( + start_pos=tuple(mn), + end_pos=(mn[0], mn[1], mx[2]), + start_value=mn[2], + tick_side="right", + color=(0, 0, 1, 1), + line_width=2, + ), + ) + for ruler in self._rulers: + self._axesGroup.add(ruler) + else: + self._rulers[0].start_pos = tuple(mn) + self._rulers[0].end_pos = (mx[0], mn[1], mn[2]) + self._rulers[0].start_value = mn[0] + self._rulers[1].start_pos = tuple(mn) + self._rulers[1].end_pos = (mn[0], mx[1], mn[2]) + self._rulers[1].start_value = mn[1] + self._rulers[2].start_pos = tuple(mn) + self._rulers[2].end_pos = (mn[0], mn[1], mx[2]) + self._rulers[2].start_value = mn[2] + + # Bounding box wireframe (12 edges) + corners = numpy.array( + [ + [mn[0], mn[1], mn[2]], + [mx[0], mn[1], mn[2]], + [mx[0], mx[1], mn[2]], + [mn[0], mx[1], mn[2]], + [mn[0], mn[1], mx[2]], + [mx[0], mn[1], mx[2]], + [mx[0], mx[1], mx[2]], + [mn[0], mx[1], mx[2]], + ], + dtype=numpy.float32, + ) + edges = [ + (0, 1), + (1, 2), + (2, 3), + (3, 0), # bottom face + (4, 5), + (5, 6), + (6, 7), + (7, 4), # top face + (0, 4), + (1, 5), + (2, 6), + (3, 7), # vertical edges + ] + positions = numpy.array( + [corners[i] for edge in edges for i in edge], + dtype=numpy.float32, + ) + + if self._bboxLine is not None: + self._axesGroup.remove(self._bboxLine) + + self._bboxLine = gfx.Line( + gfx.Geometry(positions=positions), + gfx.LineSegmentMaterial(color=(0.6, 0.6, 0.6, 0.5), thickness=1), + ) + self._axesGroup.add(self._bboxLine) + + def _animate(self): + """Render callback with ruler updates.""" + if self._rulers: + size = self._renderWidget.get_logical_size() + for ruler in self._rulers: + ruler.update(self._camera, size) + super()._animate() diff --git a/src/silx/gui/plot3d/SceneWindow.py b/src/silx/gui/plot3d/SceneWindow.py index 05c4b31d8b..7fa8b6f6ae 100644 --- a/src/silx/gui/plot3d/SceneWindow.py +++ b/src/silx/gui/plot3d/SceneWindow.py @@ -98,15 +98,27 @@ def setPlot3DWidget(self, widget): class SceneWindow(qt.QMainWindow): - """OpenGL 3D scene widget with toolbars.""" + """OpenGL 3D scene widget with toolbars. - def __init__(self, parent=None): + :param parent: Parent widget + :param str backend: 'opengl' (default) or 'pygfx' + """ + + def __init__(self, parent=None, backend=None): super().__init__(parent) if parent is not None: # behave as a widget self.setWindowFlags(qt.Qt.Widget) - self._sceneWidget = SceneWidget() + self._backend = backend + + if backend == "pygfx": + from .SceneWidgetPygfx import SceneWidgetPygfx + + self._sceneWidget = SceneWidgetPygfx() + else: + self._sceneWidget = SceneWidget() + self.setCentralWidget(self._sceneWidget) # Add PositionInfoWidget to display picking info @@ -139,7 +151,8 @@ def __init__(self, parent=None): self._paramTreeView.setModel(self._sceneWidget.model()) selectionModel = self._paramTreeView.selectionModel() - self._sceneWidget.selection()._setSyncSelectionModel(selectionModel) + if selectionModel is not None: + self._sceneWidget.selection()._setSyncSelectionModel(selectionModel) paramDock = qt.QDockWidget() paramDock.setWindowTitle("Object parameters") diff --git a/src/silx/gui/plot3d/_pygfx_utils.py b/src/silx/gui/plot3d/_pygfx_utils.py new file mode 100644 index 0000000000..0509654959 --- /dev/null +++ b/src/silx/gui/plot3d/_pygfx_utils.py @@ -0,0 +1,142 @@ +"""Utility functions for pygfx 3D backend.""" + +import logging +import numpy + +_logger = logging.getLogger(__name__) + +# silx symbol -> pygfx marker shape mapping (same as 2D backend) +SYMBOL_MAP = { + "o": "circle", + ".": "circle", + ",": "square", + "+": "plus", + "x": "cross", + "d": "diamond", + "s": "square", + "^": "triangle_up", + "v": "triangle_down", + "<": "triangle_left", + ">": "triangle_right", + "*": "asterisk6", +} + + +def apply_colormap(colormap, data): + """Apply a silx Colormap to data, returning (N, 4) float32 RGBA array. + + :param colormap: silx Colormap object + :param data: 1D or 2D numpy array of scalar values + :returns: RGBA array with shape (*data.shape, 4), float32 in [0, 1] + """ + original_shape = data.shape + flat = numpy.asarray(data, dtype=numpy.float32).ravel() + + # Get colormap range + vmin, vmax = colormap.getColormapRange(flat) + + # Normalize data to [0, 1] + if vmax == vmin: + normalized = numpy.zeros_like(flat) + else: + normalized = numpy.clip((flat - vmin) / (vmax - vmin), 0, 1) + + # Get LUT (256 colors) + lut = colormap.getNColors(nbColors=256) # (256, 4) uint8 + lut_f = lut.astype(numpy.float32) / 255.0 + + # Map normalized values to LUT indices + indices = numpy.clip((normalized * 255).astype(int), 0, 255) + colors = lut_f[indices] + + # Reshape to match original data shape + return colors.reshape(*original_shape, 4) + + +def grid_to_triangles(H, W): + """Convert (H, W) grid to triangle index array. + + Creates two triangles per grid cell for a total of (H-1)*(W-1)*2 triangles. + + :param int H: Number of rows + :param int W: Number of columns + :returns: (N, 3) uint32 index array + """ + rows = numpy.arange(H - 1) + cols = numpy.arange(W - 1) + r, c = numpy.meshgrid(rows, cols, indexing="ij") + r = r.ravel() + c = c.ravel() + + # Vertex indices for each quad + v00 = r * W + c + v01 = r * W + (c + 1) + v10 = (r + 1) * W + c + v11 = (r + 1) * W + (c + 1) + + # Two triangles per quad + tri1 = numpy.column_stack([v00, v10, v11]) + tri2 = numpy.column_stack([v00, v11, v01]) + indices = numpy.vstack([tri1, tri2]).astype(numpy.uint32) + return indices + + +def compute_normals(positions, indices): + """Compute per-vertex normals from positions and triangle indices. + + :param positions: (N, 3) float32 array of vertex positions + :param indices: (M, 3) uint32 array of triangle indices + :returns: (N, 3) float32 array of normalized per-vertex normals + """ + positions = numpy.asarray(positions, dtype=numpy.float32) + indices = numpy.asarray(indices, dtype=numpy.uint32) + + normals = numpy.zeros_like(positions) + + v0 = positions[indices[:, 0]] + v1 = positions[indices[:, 1]] + v2 = positions[indices[:, 2]] + + # Face normals + face_normals = numpy.cross(v1 - v0, v2 - v0) + + # Accumulate face normals to vertices + for i in range(3): + numpy.add.at(normals, indices[:, i], face_normals) + + # Normalize + lengths = numpy.linalg.norm(normals, axis=1, keepdims=True) + lengths = numpy.maximum(lengths, 1e-10) + normals /= lengths + + return normals + + +def apply_transform(item, world_object): + """Apply Item3D's transforms to a pygfx WorldObject. + + Handles translation, scale, and rotation from DataItem3D. + + :param item: silx Item3D (DataItem3D) with transform methods + :param world_object: pygfx WorldObject to apply transforms to + """ + if not hasattr(item, "getTranslation"): + return + + tx, ty, tz = item.getTranslation() + world_object.local.position = (float(tx), float(ty), float(tz)) + + sx, sy, sz = item.getScale() + world_object.local.scale = (float(sx), float(sy), float(sz)) + + angle, axis = item.getRotation() + if angle != 0 and numpy.any(axis != 0): + import pylinalg as la + + axis_f = numpy.asarray(axis, dtype=numpy.float64) + norm = numpy.linalg.norm(axis_f) + if norm > 0: + axis_f /= norm + angle_rad = numpy.radians(float(angle)) + quat = la.quat_from_axis_angle(axis_f, angle_rad) + world_object.local.rotation = quat diff --git a/src/silx/gui/plot3d/items/_pygfx_sync.py b/src/silx/gui/plot3d/items/_pygfx_sync.py new file mode 100644 index 0000000000..7d364c8d4b --- /dev/null +++ b/src/silx/gui/plot3d/items/_pygfx_sync.py @@ -0,0 +1,754 @@ +"""Item3D -> pygfx WorldObject conversion functions.""" + +import logging + +import numpy + +import pygfx as gfx + +from .._pygfx_utils import ( + SYMBOL_MAP, + apply_colormap, + apply_transform, + compute_normals, + grid_to_triangles, +) + +_logger = logging.getLogger(__name__) + + +def sync_item(item, clip_planes=None): + """Convert an Item3D to a pygfx WorldObject based on its type. + + :param item: silx Item3D instance + :param clip_planes: list of (a, b, c, d) plane equations for clipping + :returns: pygfx WorldObject or None + """ + from . import ( + Scatter3D, + Scatter2D, + Mesh, + ColormapMesh, + Box, + Cylinder, + Hexagon, + ImageData, + ImageRgba, + HeightMapData, + HeightMapRGBA, + ScalarField3D, + GroupItem, + ClipPlane, + ) + + obj = None + + if isinstance(item, ClipPlane): + return None # Handled by sync_group + + elif isinstance(item, GroupItem): + obj = sync_group(item, clip_planes) + + elif isinstance(item, Scatter3D): + obj = sync_scatter3d(item) + + elif isinstance(item, Scatter2D): + obj = sync_scatter2d(item) + + elif isinstance(item, Box): + obj = sync_box(item) + + elif isinstance(item, Cylinder): + obj = sync_cylinder(item) + + elif isinstance(item, Hexagon): + obj = sync_hexagon(item) + + elif isinstance(item, ColormapMesh): + obj = sync_colormap_mesh(item) + + elif isinstance(item, Mesh): + obj = sync_mesh(item) + + elif isinstance(item, HeightMapData): + obj = sync_heightmap_data(item) + + elif isinstance(item, HeightMapRGBA): + obj = sync_heightmap_rgba(item) + + elif isinstance(item, ImageData): + obj = sync_image_data(item) + + elif isinstance(item, ImageRgba): + obj = sync_image_rgba(item) + + elif isinstance(item, ScalarField3D): + obj = sync_scalar_field_3d(item) + + else: + _logger.warning("Unsupported item type for pygfx sync: %s", type(item).__name__) + return None + + if obj is not None: + obj.visible = item.isVisible() + apply_transform(item, obj) + + # Apply clipping planes to materials + if clip_planes: + _apply_clip_planes(obj, clip_planes) + + return obj + + +def _apply_clip_planes(obj, clip_planes): + """Recursively apply clipping planes to all materials in an object tree.""" + if hasattr(obj, "material") and obj.material is not None: + if hasattr(obj.material, "clipping_planes"): + obj.material.clipping_planes = [tuple(p) for p in clip_planes] + obj.material.clipping_mode = "any" + if hasattr(obj, "children"): + for child in obj.children: + _apply_clip_planes(child, clip_planes) + + +# --- Mesh items --- + + +def sync_mesh(item): + """Convert a Mesh item to pygfx.Mesh. + + :param item: silx Mesh item + :returns: pygfx.Mesh + """ + positions = item.getPositionData(copy=False) + if positions is None or len(positions) == 0: + return None + + colors = item.getColorData(copy=False) + normals = item.getNormalData(copy=False) + indices = item.getIndices(copy=False) + mode = item.getDrawMode() + + positions = numpy.ascontiguousarray(positions, dtype=numpy.float32) + + # Handle color: can be single color or per-vertex + if colors is not None: + colors = numpy.asarray(colors, dtype=numpy.float32) + if colors.ndim == 1: + # Single color for all vertices + if len(colors) == 3: + color = (*colors, 1.0) + else: + color = tuple(colors) + geo = gfx.Geometry(positions=positions) + if normals is not None: + normals = numpy.ascontiguousarray(normals, dtype=numpy.float32) + if normals.ndim == 1: + # Broadcast single normal + normals = numpy.tile(normals, (len(positions), 1)) + geo = gfx.Geometry(positions=positions, normals=normals) + mat = gfx.MeshPhongMaterial(color=color) + else: + # Per-vertex colors + if colors.shape[1] == 3: + alpha = numpy.ones((len(colors), 1), dtype=numpy.float32) + colors = numpy.hstack([colors, alpha]) + colors = numpy.ascontiguousarray(colors, dtype=numpy.float32) + kwargs = {"positions": positions, "colors": colors} + if normals is not None: + normals = numpy.ascontiguousarray(normals, dtype=numpy.float32) + if normals.ndim == 1: + normals = numpy.tile(normals, (len(positions), 1)) + kwargs["normals"] = normals + geo = gfx.Geometry(**kwargs) + mat = gfx.MeshPhongMaterial(color_mode="vertex") + else: + geo = gfx.Geometry(positions=positions) + mat = gfx.MeshPhongMaterial(color=(0.8, 0.8, 0.8, 1.0)) + + # Handle triangle strip/fan by expanding to triangles + if mode == "triangle_strip" and indices is None: + indices = _strip_to_triangles(len(positions)) + elif mode == "fan" and indices is None: + indices = _fan_to_triangles(len(positions)) + elif mode == "triangles" and indices is None: + # pygfx requires explicit indices + indices = numpy.arange(len(positions), dtype=numpy.uint32).reshape(-1, 3) + + if indices is not None: + indices = numpy.ascontiguousarray(indices, dtype=numpy.uint32) + if indices.ndim == 1: + indices = indices.reshape(-1, 3) + geo.indices = gfx.Buffer(indices) + + return gfx.Mesh(geo, mat) + + +def sync_colormap_mesh(item): + """Convert a ColormapMesh item to pygfx.Mesh with colormapped vertex colors.""" + positions = item.getPositionData(copy=False) + if positions is None or len(positions) == 0: + return None + + values = item.getValueData(copy=False) + normals = item.getNormalData(copy=False) + indices = item.getIndices(copy=False) + + positions = numpy.ascontiguousarray(positions, dtype=numpy.float32) + + # Apply colormap to get per-vertex colors + colors = apply_colormap(item.getColormap(), values.ravel()) + colors = numpy.ascontiguousarray(colors, dtype=numpy.float32) + + kwargs = {"positions": positions, "colors": colors} + if normals is not None: + normals = numpy.ascontiguousarray(normals, dtype=numpy.float32) + if normals.ndim == 1: + normals = numpy.tile(normals, (len(positions), 1)) + kwargs["normals"] = normals + + geo = gfx.Geometry(**kwargs) + + if indices is not None: + indices = numpy.ascontiguousarray(indices, dtype=numpy.uint32) + if indices.ndim == 1: + indices = indices.reshape(-1, 3) + geo.indices = gfx.Buffer(indices) + + mat = gfx.MeshPhongMaterial(color_mode="vertex") + return gfx.Mesh(geo, mat) + + +def _strip_to_triangles(n): + """Convert triangle strip vertex count to triangle indices.""" + indices = [] + for i in range(n - 2): + if i % 2 == 0: + indices.append([i, i + 1, i + 2]) + else: + indices.append([i, i + 2, i + 1]) + return numpy.array(indices, dtype=numpy.uint32) + + +def _fan_to_triangles(n): + """Convert triangle fan vertex count to triangle indices.""" + indices = [] + for i in range(1, n - 1): + indices.append([0, i, i + 1]) + return numpy.array(indices, dtype=numpy.uint32) + + +def sync_box(item): + """Convert a Box item to pygfx.Mesh using box_geometry.""" + size = item.getSize() + color = item.getColor(copy=False) + positions = item.getPosition(copy=False) + + if len(color) == 3: + color_rgba = (*color, 1.0) + else: + color_rgba = tuple(color[:4]) + + group = gfx.Group() + + for pos in positions: + geo = gfx.box_geometry(float(size[0]), float(size[1]), float(size[2])) + mat = gfx.MeshPhongMaterial(color=color_rgba) + if color_rgba[3] < 1.0: + mat.opacity = color_rgba[3] + mat.transparent = True + mesh = gfx.Mesh(geo, mat) + mesh.local.position = (float(pos[0]), float(pos[1]), float(pos[2])) + group.add(mesh) + + if len(positions) == 1: + # For single box, return mesh directly (simpler transform) + mesh = group.children[0] + group.remove(mesh) + return mesh + + return group + + +def sync_cylinder(item): + """Convert a Cylinder item to pygfx.Mesh using cylinder_geometry.""" + radius = item.getRadius() + height = item.getHeight() + color = item.getColor(copy=False) + positions = item.getPosition(copy=False) + + if len(color) == 3: + color_rgba = (*color, 1.0) + else: + color_rgba = tuple(color[:4]) + + group = gfx.Group() + + for pos in positions: + geo = gfx.cylinder_geometry( + radius_bottom=float(radius), + radius_top=float(radius), + height=float(height), + radial_segments=20, + ) + mat = gfx.MeshPhongMaterial(color=color_rgba) + if color_rgba[3] < 1.0: + mat.opacity = color_rgba[3] + mat.transparent = True + mesh = gfx.Mesh(geo, mat) + mesh.local.position = (float(pos[0]), float(pos[1]), float(pos[2])) + group.add(mesh) + + if len(positions) == 1: + mesh = group.children[0] + group.remove(mesh) + return mesh + + return group + + +def sync_hexagon(item): + """Convert a Hexagon item to pygfx.Mesh using cylinder_geometry with 6 segments.""" + radius = item.getRadius() + height = item.getHeight() + color = item.getColor(copy=False) + positions = item.getPosition(copy=False) + + if len(color) == 3: + color_rgba = (*color, 1.0) + else: + color_rgba = tuple(color[:4]) + + group = gfx.Group() + + for pos in positions: + geo = gfx.cylinder_geometry( + radius_bottom=float(radius), + radius_top=float(radius), + height=float(height), + radial_segments=6, + ) + mat = gfx.MeshPhongMaterial(color=color_rgba) + if color_rgba[3] < 1.0: + mat.opacity = color_rgba[3] + mat.transparent = True + mesh = gfx.Mesh(geo, mat) + mesh.local.position = (float(pos[0]), float(pos[1]), float(pos[2])) + group.add(mesh) + + if len(positions) == 1: + mesh = group.children[0] + group.remove(mesh) + return mesh + + return group + + +# --- Scatter items --- + + +def sync_scatter3d(item): + """Convert a Scatter3D item to pygfx.Points.""" + x, y, z, value = item.getData(copy=False) + if x is None or len(x) == 0: + return None + + positions = numpy.column_stack( + [ + numpy.asarray(x, dtype=numpy.float32), + numpy.asarray(y, dtype=numpy.float32), + numpy.asarray(z, dtype=numpy.float32), + ] + ) + positions = numpy.ascontiguousarray(positions) + + # Apply colormap + colors = apply_colormap( + item.getColormap(), numpy.asarray(value, dtype=numpy.float32) + ) + colors = numpy.ascontiguousarray(colors, dtype=numpy.float32) + + geo = gfx.Geometry(positions=positions, colors=colors) + + symbol = item.getSymbol() + marker = SYMBOL_MAP.get(symbol, "circle") + size = float(item.getSymbolSize()) + + mat = gfx.PointsMarkerMaterial( + marker=marker, + size=size, + color_mode="vertex", + size_space="screen", + ) + + return gfx.Points(geo, mat) + + +def sync_scatter2d(item): + """Convert a Scatter2D item to pygfx WorldObject. + + Supports solid, lines, and points visualization modes. + """ + x = numpy.asarray(item.getXData(copy=False), dtype=numpy.float32) + y = numpy.asarray(item.getYData(copy=False), dtype=numpy.float32) + value = numpy.asarray(item.getValueData(copy=False), dtype=numpy.float32) + + if len(x) == 0: + return None + + height_map = item.isHeightMap() + z = value if height_map else numpy.zeros_like(x) + + positions = numpy.column_stack([x, y, z]) + positions = numpy.ascontiguousarray(positions, dtype=numpy.float32) + + colors = apply_colormap(item.getColormap(), value) + colors = numpy.ascontiguousarray(colors, dtype=numpy.float32) + + vis = item.getVisualization() + vis_name = vis.value if hasattr(vis, "value") else str(vis) + + if vis_name == "solid": + return _scatter2d_solid(positions, colors) + elif vis_name == "lines": + return _scatter2d_lines(item, positions, colors, x, y) + else: # points + return _scatter2d_points(item, positions, colors) + + +def _scatter2d_solid(positions, colors): + """Create solid surface from 2D scatter data using Delaunay triangulation.""" + try: + from scipy.spatial import Delaunay + + points_2d = positions[:, :2] + tri = Delaunay(points_2d) + indices = numpy.ascontiguousarray(tri.simplices.astype(numpy.uint32)) + except ImportError: + _logger.warning("scipy not available, falling back to grid triangulation") + # Try grid-based triangulation if data is on a grid + n = int(numpy.sqrt(len(positions))) + if n * n == len(positions): + indices = grid_to_triangles(n, n) + else: + return None + except Exception: + _logger.warning("Delaunay triangulation failed") + return None + + normals = compute_normals(positions, indices) + + geo = gfx.Geometry( + positions=positions, + normals=normals, + colors=colors, + indices=gfx.Buffer(indices), + ) + mat = gfx.MeshPhongMaterial(color_mode="vertex") + return gfx.Mesh(geo, mat) + + +def _scatter2d_lines(item, positions, colors, x, y): + """Create wireframe from 2D scatter data.""" + # Try to detect grid structure + unique_x = numpy.unique(x) + unique_y = numpy.unique(y) + nx, ny = len(unique_x), len(unique_y) + + if nx * ny == len(x): + # Grid data - create line segments along rows and columns + group = gfx.Group() + + # Reshape to grid + pos_grid = positions.reshape(ny, nx, 3) + col_grid = colors.reshape(ny, nx, 4) + + # Row lines + for j in range(ny): + row_pos = numpy.ascontiguousarray(pos_grid[j], dtype=numpy.float32) + row_col = numpy.ascontiguousarray(col_grid[j], dtype=numpy.float32) + geo = gfx.Geometry(positions=row_pos, colors=row_col) + mat = gfx.LineMaterial(thickness=item.getLineWidth(), color_mode="vertex") + group.add(gfx.Line(geo, mat)) + + # Column lines + for i in range(nx): + col_pos = numpy.ascontiguousarray(pos_grid[:, i], dtype=numpy.float32) + col_col = numpy.ascontiguousarray(col_grid[:, i], dtype=numpy.float32) + geo = gfx.Geometry(positions=col_pos, colors=col_col) + mat = gfx.LineMaterial(thickness=item.getLineWidth(), color_mode="vertex") + group.add(gfx.Line(geo, mat)) + + return group + else: + # Non-grid data - just connect points in order + geo = gfx.Geometry(positions=positions, colors=colors) + mat = gfx.LineMaterial(thickness=item.getLineWidth(), color_mode="vertex") + return gfx.Line(geo, mat) + + +def _scatter2d_points(item, positions, colors): + """Create point cloud from 2D scatter data.""" + geo = gfx.Geometry(positions=positions, colors=colors) + + symbol = item.getSymbol() if hasattr(item, "getSymbol") else "o" + marker = SYMBOL_MAP.get(symbol, "circle") + size = float(item.getSymbolSize()) if hasattr(item, "getSymbolSize") else 6.0 + + mat = gfx.PointsMarkerMaterial( + marker=marker, + size=size, + color_mode="vertex", + size_space="screen", + ) + return gfx.Points(geo, mat) + + +# --- Image items --- + + +def sync_image_data(item): + """Convert ImageData item to pygfx.Image.""" + data = item.getData(copy=False) + if data is None: + return None + + colors = apply_colormap(item.getColormap(), data) + colors = numpy.ascontiguousarray(colors, dtype=numpy.float32) + + tex = gfx.Texture(colors, dim=2) + geo = gfx.Geometry(grid=tex) + mat = gfx.ImageBasicMaterial(clim=(0, 1)) + return gfx.Image(geo, mat) + + +def sync_image_rgba(item): + """Convert ImageRgba item to pygfx.Image.""" + data = item.getData(copy=False) + if data is None: + return None + + data = numpy.asarray(data) + if data.dtype == numpy.uint8: + data = data.astype(numpy.float32) / 255.0 + + if data.ndim == 3 and data.shape[2] == 3: + # Add alpha channel + alpha = numpy.ones((*data.shape[:2], 1), dtype=numpy.float32) + data = numpy.concatenate([data, alpha], axis=2) + + data = numpy.ascontiguousarray(data, dtype=numpy.float32) + tex = gfx.Texture(data, dim=2) + geo = gfx.Geometry(grid=tex) + mat = gfx.ImageBasicMaterial(clim=(0, 1)) + return gfx.Image(geo, mat) + + +def sync_heightmap_data(item): + """Convert HeightMapData item to pygfx.Mesh (height field as triangle mesh).""" + height_data = item.getData(copy=False) + if height_data is None: + return None + + H, W = height_data.shape + y_idx, x_idx = numpy.mgrid[0:H, 0:W] + + positions = numpy.column_stack( + [ + x_idx.ravel().astype(numpy.float32), + y_idx.ravel().astype(numpy.float32), + height_data.ravel().astype(numpy.float32), + ] + ) + positions = numpy.ascontiguousarray(positions) + + indices = grid_to_triangles(H, W) + + # Use colormapped data if available, otherwise use height data + colormap_data = item.getColormappedData(copy=False) + if colormap_data is None or colormap_data.size == 0: + colormap_data = height_data + colors = apply_colormap(item.getColormap(), colormap_data.ravel()) + colors = numpy.ascontiguousarray(colors, dtype=numpy.float32) + + normals = compute_normals(positions, indices) + + geo = gfx.Geometry( + positions=positions, + normals=normals, + colors=colors, + indices=gfx.Buffer(indices), + ) + mat = gfx.MeshPhongMaterial(color_mode="vertex") + return gfx.Mesh(geo, mat) + + +def sync_heightmap_rgba(item): + """Convert HeightMapRGBA item to pygfx.Mesh.""" + height_data = item.getData(copy=False) + if height_data is None: + return None + + H, W = height_data.shape + y_idx, x_idx = numpy.mgrid[0:H, 0:W] + + positions = numpy.column_stack( + [ + x_idx.ravel().astype(numpy.float32), + y_idx.ravel().astype(numpy.float32), + height_data.ravel().astype(numpy.float32), + ] + ) + positions = numpy.ascontiguousarray(positions) + + indices = grid_to_triangles(H, W) + + color_data = item.getColorData(copy=False) + if color_data is not None: + color_data = numpy.asarray(color_data, dtype=numpy.float32) + if color_data.dtype == numpy.uint8: + color_data = color_data.astype(numpy.float32) / 255.0 + if color_data.ndim == 3 and color_data.shape[2] == 3: + alpha = numpy.ones((*color_data.shape[:2], 1), dtype=numpy.float32) + color_data = numpy.concatenate([color_data, alpha], axis=2) + colors = color_data.reshape(-1, 4) + else: + colors = numpy.ones((H * W, 4), dtype=numpy.float32) + + colors = numpy.ascontiguousarray(colors, dtype=numpy.float32) + normals = compute_normals(positions, indices) + + geo = gfx.Geometry( + positions=positions, + normals=normals, + colors=colors, + indices=gfx.Buffer(indices), + ) + mat = gfx.MeshPhongMaterial(color_mode="vertex") + return gfx.Mesh(geo, mat) + + +# --- Volume items --- + + +def sync_scalar_field_3d(item): + """Convert ScalarField3D item to pygfx.Group. + + Handles isosurfaces (marching cubes -> mesh) and cut planes (volume slice). + """ + data = item.getData(copy=False) + if data is None: + return None + + group = gfx.Group() + + # Isosurfaces -> marching cubes -> gfx.Mesh + for isosurface in item.getIsosurfaces(): + if not isosurface.isVisible(): + continue + + level = isosurface.getLevel() + color = isosurface.getColor() + + try: + from skimage.measure import marching_cubes + + verts, faces, _, _ = marching_cubes(data, level=level) + # marching_cubes returns (z, y, x) order; swap to (x, y, z) + offset + verts = verts[:, ::-1].copy() + 0.5 # z,y,x -> x,y,z and offset + verts = numpy.ascontiguousarray(verts.astype(numpy.float32)) + faces = numpy.ascontiguousarray(faces.astype(numpy.uint32)) + + normals = compute_normals(verts, faces) + + geo = gfx.Geometry( + positions=verts, + normals=normals, + indices=gfx.Buffer(faces), + ) + + # Parse color + r = color.redF() if hasattr(color, "redF") else color[0] + g = color.greenF() if hasattr(color, "greenF") else color[1] + b = color.blueF() if hasattr(color, "blueF") else color[2] + a = ( + color.alphaF() + if hasattr(color, "alphaF") + else (color[3] if len(color) > 3 else 1.0) + ) + + mat = gfx.MeshPhongMaterial(color=(r, g, b, a)) + mat.opacity = a + group.add(gfx.Mesh(geo, mat)) + + except ImportError: + _logger.warning("scikit-image not available for marching cubes") + except Exception as e: + _logger.warning("Marching cubes failed for level %s: %s", level, e) + + # Cut planes -> volume slice + for cut_plane in item.getCutPlanes(): + if not cut_plane.isVisible(): + continue + + try: + data_f32 = numpy.ascontiguousarray(data.astype(numpy.float32)) + tex = gfx.Texture(data_f32, dim=3) + geo = gfx.Geometry(grid=tex) + + normal = numpy.asarray(cut_plane.getNormal(), dtype=numpy.float64) + point = numpy.asarray(cut_plane.getPoint(), dtype=numpy.float64) + d = -numpy.dot(normal, point) + + # Build colormap texture for the slice + cmap = cut_plane.getColormap() + lut = cmap.getNColors(nbColors=256) # (256, 4) uint8 + lut_f = lut.astype(numpy.float32) / 255.0 + cmap_tex = gfx.Texture(lut_f, dim=1) + + vmin, vmax = cmap.getColormapRange(data_f32) + + mat = gfx.VolumeSliceMaterial( + plane=(float(normal[0]), float(normal[1]), float(normal[2]), float(d)), + map=cmap_tex, + clim=(float(vmin), float(vmax)), + ) + group.add(gfx.Volume(geo, mat)) + + except Exception as e: + _logger.warning("Cut plane rendering failed: %s", e) + + if len(group.children) == 0: + return None + + return group + + +# --- Group and clipping --- + + +def sync_group(item, clip_planes=None): + """Convert a GroupItem to pygfx.Group with recursive child sync. + + ClipPlane items in the group add clipping planes for subsequent siblings. + """ + from . import ClipPlane + + group = gfx.Group() + current_clips = list(clip_planes or []) + + for child in item.getItems(): + if isinstance(child, ClipPlane): + if child.isVisible(): + normal = numpy.asarray(child.getNormal(), dtype=numpy.float64) + point = numpy.asarray(child.getPoint(), dtype=numpy.float64) + d = -numpy.dot(normal, point) + current_clips.append( + (float(normal[0]), float(normal[1]), float(normal[2]), float(d)) + ) + else: + obj = sync_item(child, clip_planes=current_clips) + if obj is not None: + group.add(obj) + + return group diff --git a/src/silx/gui/plot3d/test/test_scenewindow.py b/src/silx/gui/plot3d/test/test_scenewindow.py index bbe24129ed..a7cf6a0feb 100644 --- a/src/silx/gui/plot3d/test/test_scenewindow.py +++ b/src/silx/gui/plot3d/test/test_scenewindow.py @@ -20,7 +20,7 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN # THE SOFTWARE. # ###########################################################################*/ -"""Test SceneWindow""" +"""Test SceneWindow with OpenGL and pygfx backends""" __authors__ = ["T. Vincent"] __license__ = "MIT" @@ -38,27 +38,41 @@ from silx.gui.plot3d.SceneWindow import SceneWindow from silx.gui.plot3d.items import HeightMapData, HeightMapRGBA +# --- Parametrized fixture for both backends --- -@pytest.mark.usefixtures("use_opengl") -class TestSceneWindow(TestCaseQt, ParametricTestCase): - """Tests SceneWidget picking feature""" - def setUp(self): - super().setUp() - self.window = SceneWindow() - self.window.show() - self.qWaitForWindowExposed(self.window) +@pytest.fixture( + params=[ + pytest.param(None, id="opengl"), + pytest.param("pygfx", id="pygfx"), + ] +) +def scene_window(request, qapp, test_options): + """SceneWindow fixture parametrized by backend.""" + backend = request.param + if backend is None and not test_options.WITH_GL_TEST: + pytest.skip(test_options.WITH_GL_TEST_REASON) + if backend == "pygfx" and not test_options.WITH_PYGFX_TEST: + pytest.skip(test_options.WITH_PYGFX_TEST_REASON) - def tearDown(self): - self.qapp.processEvents() - self.window.setAttribute(qt.Qt.WA_DeleteOnClose) - self.window.close() - del self.window - super().tearDown() + window = SceneWindow(backend=backend) + window.show() + qapp.processEvents() + yield window + window.setAttribute(qt.Qt.WA_DeleteOnClose) + window.close() + qapp.processEvents() - def testAdd(self): - """Test add basic scene primitive""" - sceneWidget = self.window.getSceneWidget() + +# --- Tests for both backends --- + + +class TestSceneWindow: + """Tests SceneWindow features shared across backends""" + + def test_add(self, scene_window, qapp): + """Test add basic scene primitives""" + sceneWidget = scene_window.getSceneWidget() items = [] # RGB image @@ -67,7 +81,7 @@ def testAdd(self): ) image.setLabel("RGB image") items.append(image) - self.assertEqual(sceneWidget.getItems(), tuple(items)) + assert sceneWidget.getItems() == tuple(items) # Data image image = sceneWidget.addImage( @@ -75,7 +89,7 @@ def testAdd(self): ) image.setTranslation(10.0) items.append(image) - self.assertEqual(sceneWidget.getItems(), tuple(items)) + assert sceneWidget.getItems() == tuple(items) # 2D scatter scatter = sceneWidget.add2DScatter( @@ -84,7 +98,7 @@ def testAdd(self): scatter.setTranslation(0, 10) scatter.setScale(10, 10, 10) items.insert(0, scatter) - self.assertEqual(sceneWidget.getItems(), tuple(items)) + assert sceneWidget.getItems() == tuple(items) # 3D scatter scatter = sceneWidget.add3DScatter( @@ -93,7 +107,7 @@ def testAdd(self): scatter.setTranslation(10, 10) scatter.setScale(10, 10, 10) items.append(scatter) - self.assertEqual(sceneWidget.getItems(), tuple(items)) + assert sceneWidget.getItems() == tuple(items) # 3D array of float volume = sceneWidget.addVolume( @@ -102,9 +116,8 @@ def testAdd(self): volume.setTranslation(0, 0, 10) volume.setRotation(45, (0, 0, 1)) volume.addIsosurface(500, "red") - volume.getCutPlanes()[0].getColormap().setName("viridis") items.append(volume) - self.assertEqual(sceneWidget.getItems(), tuple(items)) + assert sceneWidget.getItems() == tuple(items) # 3D array of complex volume = sceneWidget.addVolume( @@ -115,10 +128,106 @@ def testAdd(self): volume.setComplexMode(volume.ComplexMode.REAL) volume.addIsosurface(500, (1.0, 0.0, 0.0, 0.5)) items.append(volume) - self.assertEqual(sceneWidget.getItems(), tuple(items)) + assert sceneWidget.getItems() == tuple(items) sceneWidget.resetZoom("front") + qapp.processEvents() + + def test_change_content(self, scene_window, qapp): + """Test add/remove/clear items""" + sceneWidget = scene_window.getSceneWidget() + items = [] + + # Add 2 images + image = numpy.arange(100, dtype=numpy.float32).reshape(10, 10) + items.append(sceneWidget.addImage(image)) + items.append(sceneWidget.addImage(image)) + qapp.processEvents() + assert sceneWidget.getItems() == tuple(items) + + # Clear + sceneWidget.clearItems() + qapp.processEvents() + assert sceneWidget.getItems() == () + + # Add 2 images and remove first one + image = numpy.arange(100, dtype=numpy.float32).reshape(10, 10) + sceneWidget.addImage(image) + items = (sceneWidget.addImage(image),) + qapp.processEvents() + + sceneWidget.removeItem(sceneWidget.getItems()[0]) + qapp.processEvents() + assert sceneWidget.getItems() == items + + def test_colors(self, scene_window, qapp): + """Test setting scene colors""" + sceneWidget = scene_window.getSceneWidget() + + color = qt.QColor(128, 128, 128) + sceneWidget.setBackgroundColor(color) + assert sceneWidget.getBackgroundColor() == color + + color = qt.QColor(0, 0, 0) + sceneWidget.setForegroundColor(color) + assert sceneWidget.getForegroundColor() == color + + color = qt.QColor(255, 0, 0) + sceneWidget.setTextColor(color) + assert sceneWidget.getTextColor() == color + + color = qt.QColor(0, 255, 0) + sceneWidget.setHighlightColor(color) + assert sceneWidget.getHighlightColor() == color + + qapp.processEvents() + + def test_interactive_mode(self, scene_window, qapp): + """Test changing interactive mode""" + sceneWidget = scene_window.getSceneWidget() + + for mode in ("rotate", "pan"): + sceneWidget.setInteractiveMode(mode) + qapp.processEvents() + assert sceneWidget.getInteractiveMode() == mode + + def test_model(self, scene_window, qapp): + """Test that model is properly set up""" + sceneWidget = scene_window.getSceneWidget() + model = sceneWidget.model() + assert model is not None + assert model.rowCount() == 2 # Settings + Data + + # Add item and check model updates + scatter = sceneWidget.add3DScatter( + *numpy.random.random(4000).astype(numpy.float32).reshape(4, -1) + ) + scatter.setLabel("Test scatter") + + # Data group should have children now + data_index = model.index(1, 0) + assert model.rowCount(data_index) > 0 + + +# --- OpenGL-only tests --- + + +@pytest.mark.usefixtures("use_opengl") +class TestSceneWindowOpenGL(TestCaseQt, ParametricTestCase): + """Tests specific to OpenGL backend""" + + def setUp(self): + super().setUp() + self.window = SceneWindow() + self.window.show() + self.qWaitForWindowExposed(self.window) + + def tearDown(self): self.qapp.processEvents() + self.window.setAttribute(qt.Qt.WA_DeleteOnClose) + self.window.close() + del self.window + super().tearDown() def testHeightMap(self): """Test height map items""" @@ -158,57 +267,8 @@ def testHeightMap(self): self.qapp.processEvents() sceneWidget.clearItems() - def testChangeContent(self): - """Test add/remove/clear items""" - sceneWidget = self.window.getSceneWidget() - items = [] - - # Add 2 images - image = numpy.arange(100, dtype=numpy.float32).reshape(10, 10) - items.append(sceneWidget.addImage(image)) - items.append(sceneWidget.addImage(image)) - self.qapp.processEvents() - self.assertEqual(sceneWidget.getItems(), tuple(items)) - - # Clear - sceneWidget.clearItems() - self.qapp.processEvents() - self.assertEqual(sceneWidget.getItems(), ()) - - # Add 2 images and remove first one - image = numpy.arange(100, dtype=numpy.float32).reshape(10, 10) - sceneWidget.addImage(image) - items = (sceneWidget.addImage(image),) - self.qapp.processEvents() - - sceneWidget.removeItem(sceneWidget.getItems()[0]) - self.qapp.processEvents() - self.assertEqual(sceneWidget.getItems(), items) - - def testColors(self): - """Test setting scene colors""" - sceneWidget = self.window.getSceneWidget() - - color = qt.QColor(128, 128, 128) - sceneWidget.setBackgroundColor(color) - self.assertEqual(sceneWidget.getBackgroundColor(), color) - - color = qt.QColor(0, 0, 0) - sceneWidget.setForegroundColor(color) - self.assertEqual(sceneWidget.getForegroundColor(), color) - - color = qt.QColor(255, 0, 0) - sceneWidget.setTextColor(color) - self.assertEqual(sceneWidget.getTextColor(), color) - - color = qt.QColor(0, 255, 0) - sceneWidget.setHighlightColor(color) - self.assertEqual(sceneWidget.getHighlightColor(), color) - - self.qapp.processEvents() - def testInteractiveMode(self): - """Test changing interactive mode""" + """Test changing interactive mode with mouse events""" sceneWidget = self.window.getSceneWidget() center = numpy.array((sceneWidget.width() // 2, sceneWidget.height() // 2)) diff --git a/src/silx/gui/plot3d/tools/PositionInfoWidget.py b/src/silx/gui/plot3d/tools/PositionInfoWidget.py index b1fd9c9199..6ba2a10434 100644 --- a/src/silx/gui/plot3d/tools/PositionInfoWidget.py +++ b/src/silx/gui/plot3d/tools/PositionInfoWidget.py @@ -37,6 +37,13 @@ from ..items import volume from ..SceneWidget import SceneWidget +try: + from ..SceneWidgetPygfx import SceneWidgetPygfx + + _SCENE_WIDGET_TYPES = (SceneWidget, SceneWidgetPygfx) +except ImportError: + _SCENE_WIDGET_TYPES = (SceneWidget,) + _logger = logging.getLogger(__name__) @@ -129,7 +136,7 @@ def setSceneWidget(self, widget): :param ~silx.gui.plot3d.SceneWidget.SceneWidget widget: 3D scene for which to display information """ - if widget is not None and not isinstance(widget, SceneWidget): + if widget is not None and not isinstance(widget, _SCENE_WIDGET_TYPES): raise ValueError("widget must be a SceneWidget or None") self._sceneWidgetRef = None if widget is None else weakref.ref(widget) diff --git a/src/silx/test/utils.py b/src/silx/test/utils.py index 5a16540ff3..dddb1da67c 100644 --- a/src/silx/test/utils.py +++ b/src/silx/test/utils.py @@ -69,6 +69,12 @@ def __init__(self): self.WITH_GL_TEST_REASON = "" """Reason for OpenGL tests are disabled if any""" + self.WITH_PYGFX_TEST = True + """pygfx tests are included""" + + self.WITH_PYGFX_TEST_REASON = "" + """Reason for pygfx tests are disabled if any""" + self.WITH_HIGH_MEM_TEST = False """Skip tests using too much memory""" @@ -117,6 +123,32 @@ def configure(self, parsed_options=None): self.WITH_GL_TEST = False self.WITH_GL_TEST_REASON = "OpenGL package not available" + if parsed_options is not None and not parsed_options.pygfx: + self.WITH_PYGFX_TEST = False + self.WITH_PYGFX_TEST_REASON = "Skipped by command line" + elif os.environ.get("WITH_PYGFX_TEST", "True") == "False": + self.WITH_PYGFX_TEST = False + self.WITH_PYGFX_TEST_REASON = "Skipped by WITH_PYGFX_TEST env var" + elif sys.platform.startswith("linux") and not os.environ.get("DISPLAY", ""): + self.WITH_PYGFX_TEST = False + self.WITH_PYGFX_TEST_REASON = "DISPLAY env variable not set" + else: + try: + import pygfx # noqa: F401 + except ImportError: + self.WITH_PYGFX_TEST = False + self.WITH_PYGFX_TEST_REASON = "pygfx package not available" + else: + try: + import pygfx as gfx + + gfx.renderers.wgpu.get_shared().device + except Exception: + self.WITH_PYGFX_TEST = False + self.WITH_PYGFX_TEST_REASON = ( + "pygfx wgpu device not available (no GPU)" + ) + if parsed_options is not None and parsed_options.high_mem: self.WITH_HIGH_MEM_TEST = True self.WITH_HIGH_MEM_TEST_REASON = ""