diff --git a/.gitignore b/.gitignore index a48eb4435..d6fd59c29 100644 --- a/.gitignore +++ b/.gitignore @@ -86,6 +86,9 @@ ipython_config.py # For a library or package, you might want to ignore these files since the code is # intended to run in multiple environments; otherwise, check them in: # .python-version +.conda +bootstrap_requirements.txt +environment.yml # pipenv # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. diff --git a/docs/examples/Plotting_Examples.ipynb b/docs/examples/Plotting_Examples.ipynb new file mode 100644 index 000000000..2e80677fd --- /dev/null +++ b/docs/examples/Plotting_Examples.ipynb @@ -0,0 +1,428 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "60e3ac1d", + "metadata": {}, + "source": [ + "# Plotting Examples with Matplotlib and Plotly\n", + "\n", + "This notebook demonstrates how to visualize broadbean pulse sequences using both **matplotlib** and **plotly** backends.\n", + "\n", + "## Table of Contents\n", + "\n", + "* [Setup](#Setup)\n", + "* [Creating Sample Pulses](#Creating-Sample-Pulses)\n", + "* [Matplotlib Backend](#Matplotlib-Backend)\n", + " * [Plotting Blueprints](#Plotting-Blueprints-with-Matplotlib)\n", + " * [Plotting Elements](#Plotting-Elements-with-Matplotlib)\n", + " * [Plotting Sequences](#Plotting-Sequences-with-Matplotlib)\n", + "* [Plotly Backend](#Plotly-Backend)\n", + " * [Plotting Blueprints](#Plotting-Blueprints-with-Plotly)\n", + " * [Plotting Elements](#Plotting-Elements-with-Plotly)\n", + " * [Plotting Sequences](#Plotting-Sequences-with-Plotly)\n", + "* [Comparing Backends](#Comparing-Backends)" + ] + }, + { + "cell_type": "markdown", + "id": "e7beb180", + "metadata": {}, + "source": [ + "## Setup\n", + "\n", + "First, let's import the necessary modules." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "39adffe2", + "metadata": {}, + "outputs": [], + "source": [ + "import matplotlib as mpl\n", + "\n", + "import broadbean as bb\n", + "from broadbean.plotting import plotter\n", + "\n", + "# Configure matplotlib for better display\n", + "mpl.rcParams[\"figure.figsize\"] = (10, 4)\n", + "mpl.rcParams[\"figure.subplot.bottom\"] = 0.15" + ] + }, + { + "cell_type": "markdown", + "id": "0adf4265", + "metadata": {}, + "source": [ + "## Creating Sample Pulses\n", + "\n", + "Let's create some sample blueprints, elements, and sequences that we'll use for plotting demonstrations." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7bd68572", + "metadata": {}, + "outputs": [], + "source": [ + "# Get the built-in pulse atoms\n", + "ramp = bb.PulseAtoms.ramp # args: start, stop\n", + "sine = bb.PulseAtoms.sine # args: freq, ampl, off, phase\n", + "\n", + "# Create a simple blueprint with multiple segments\n", + "bp = bb.BluePrint()\n", + "bp.insertSegment(0, ramp, (0, 0.5e-3), name=\"ramp_up\", dur=2e-6)\n", + "bp.insertSegment(1, sine, (2e6, 0.5e-3, 0.5e-3, 0), name=\"oscillation\", dur=3e-6)\n", + "bp.insertSegment(2, ramp, (0.5e-3, 0), name=\"ramp_down\", dur=2e-6)\n", + "bp.setSR(1e9) # 1 GS/s sample rate\n", + "\n", + "# Add some markers for demonstration\n", + "bp.setSegmentMarker(\"ramp_up\", (0, 1e-6), 1) # Marker 1 during first half of ramp_up\n", + "bp.setSegmentMarker(\"oscillation\", (0.5e-6, 2e-6), 2) # Marker 2 during oscillation\n", + "\n", + "print(\"Blueprint created:\")\n", + "bp.showPrint()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "79077cf1", + "metadata": {}, + "outputs": [], + "source": [ + "# Create a second blueprint for multi-channel demonstration\n", + "bp2 = bb.BluePrint()\n", + "bp2.insertSegment(0, ramp, (0, -0.3e-3), name=\"init\", dur=2e-6)\n", + "bp2.insertSegment(1, ramp, (-0.3e-3, -0.3e-3), name=\"hold\", dur=3e-6)\n", + "bp2.insertSegment(2, ramp, (-0.3e-3, 0), name=\"release\", dur=2e-6)\n", + "bp2.setSR(1e9)\n", + "\n", + "print(\"Second blueprint created:\")\n", + "bp2.showPrint()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3d1d73c6", + "metadata": {}, + "outputs": [], + "source": [ + "# Create an element with two channels\n", + "elem = bb.Element()\n", + "elem.addBluePrint(1, bp)\n", + "elem.addBluePrint(2, bp2)\n", + "\n", + "print(f\"Element created with {len(elem._data)} channels\")\n", + "print(f\"Total duration: {bp.duration * 1e6:.1f} µs\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e26ba32c", + "metadata": {}, + "outputs": [], + "source": [ + "# Create a sequence with multiple positions\n", + "seq = bb.Sequence()\n", + "seq.setSR(1e9)\n", + "\n", + "# Add the element at position 1\n", + "seq.addElement(1, elem)\n", + "\n", + "# Create a variation of the element for position 2\n", + "bp3 = bp.copy()\n", + "bp3.changeArg(\"oscillation\", \"freq\", 4e6) # Double the frequency\n", + "\n", + "elem2 = bb.Element()\n", + "elem2.addBluePrint(1, bp3)\n", + "elem2.addBluePrint(2, bp2)\n", + "\n", + "seq.addElement(2, elem2)\n", + "\n", + "# Set AWG specs\n", + "seq.setChannelAmplitude(1, 2.5) # 2.5 V amplitude\n", + "seq.setChannelAmplitude(2, 2.5)\n", + "seq.setChannelOffset(1, 0)\n", + "seq.setChannelOffset(2, 0)\n", + "\n", + "print(f\"Sequence created with {seq.length_sequenceelements} positions\")" + ] + }, + { + "cell_type": "markdown", + "id": "767e4e32", + "metadata": {}, + "source": [ + "## Matplotlib Backend\n", + "\n", + "The matplotlib backend is the default plotting backend. It creates static plots that are great for documentation and publications." + ] + }, + { + "cell_type": "markdown", + "id": "45bab40f", + "metadata": {}, + "source": [ + "### Plotting Blueprints with Matplotlib" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4e740341", + "metadata": {}, + "outputs": [], + "source": [ + "# Plot a blueprint using matplotlib (default backend)\n", + "fig = plotter(bp)\n", + "\n", + "# The plotter returns the matplotlib figure, which can be customized further\n", + "fig.suptitle(\"Blueprint with Matplotlib\", fontsize=12, y=1.02)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a7899c1c", + "metadata": {}, + "outputs": [], + "source": [ + "# Explicitly specify the matplotlib backend\n", + "fig = plotter(bp, backend=\"matplotlib\")\n", + "fig.suptitle(\"Explicit Matplotlib Backend\", fontsize=12, y=1.02)" + ] + }, + { + "cell_type": "markdown", + "id": "3b0e0285", + "metadata": {}, + "source": [ + "### Plotting Elements with Matplotlib\n", + "\n", + "Elements show all channels side by side." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7e39c96e", + "metadata": {}, + "outputs": [], + "source": [ + "# Plot the multi-channel element\n", + "fig = plotter(elem, backend=\"matplotlib\")\n", + "fig.suptitle(\"Two-Channel Element with Matplotlib\", fontsize=12, y=1.02)" + ] + }, + { + "cell_type": "markdown", + "id": "f71b7559", + "metadata": {}, + "source": [ + "### Plotting Sequences with Matplotlib\n", + "\n", + "Sequences show all positions and channels in a grid layout." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8f801217", + "metadata": {}, + "outputs": [], + "source": [ + "# Plot the full sequence\n", + "fig = plotter(seq, backend=\"matplotlib\")\n", + "fig.suptitle(\"Sequence with Matplotlib (2 positions × 2 channels)\", fontsize=12, y=1.02)" + ] + }, + { + "cell_type": "markdown", + "id": "460faec5", + "metadata": {}, + "source": [ + "## Plotly Backend\n", + "\n", + "The plotly backend creates interactive plots that allow zooming, panning, and hovering over data points.\n", + "\n", + "**Note:** To use the plotly backend, you need to have plotly installed:\n", + "```bash\n", + "pip install plotly\n", + "```\n", + "or\n", + "```bash\n", + "pip install broadbean[plotly]\n", + "```" + ] + }, + { + "cell_type": "markdown", + "id": "ccca8d17", + "metadata": {}, + "source": [ + "### Plotting Blueprints with Plotly" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ad26a427", + "metadata": {}, + "outputs": [], + "source": [ + "# Plot a blueprint using the plotly backend\n", + "fig = plotter(bp, backend=\"plotly\")\n", + "\n", + "# Update the layout title\n", + "fig.update_layout(title=\"Blueprint with Plotly - Interactive!\")\n", + "fig.show()" + ] + }, + { + "cell_type": "markdown", + "id": "2b717e77", + "metadata": {}, + "source": [ + "### Plotting Elements with Plotly" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0f644f74", + "metadata": {}, + "outputs": [], + "source": [ + "# Plot the multi-channel element with plotly\n", + "fig = plotter(elem, backend=\"plotly\")\n", + "fig.update_layout(\n", + " title=\"Two-Channel Element with Plotly\",\n", + " height=500, # Adjust height for better visibility\n", + ")\n", + "fig.show()" + ] + }, + { + "cell_type": "markdown", + "id": "d11a495c", + "metadata": {}, + "source": [ + "### Plotting Sequences with Plotly" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2b1b7453", + "metadata": {}, + "outputs": [], + "source": [ + "# Plot the full sequence with plotly\n", + "fig = plotter(seq, backend=\"plotly\")\n", + "fig.update_layout(title=\"Sequence with Plotly (2 positions × 2 channels)\", height=600)\n", + "fig.show()" + ] + }, + { + "cell_type": "markdown", + "id": "1fb15e19", + "metadata": {}, + "source": [ + "## Comparing Backends\n", + "\n", + "Here's a quick summary of when to use each backend:\n", + "\n", + "| Feature | Matplotlib | Plotly |\n", + "|---------|------------|--------|\n", + "| Interactive zoom/pan | ❌ | ✅ |\n", + "| Hover information | ❌ | ✅ |\n", + "| Static export (PNG, PDF) | ✅ Excellent | ✅ Good |\n", + "| Publication quality | ✅ | ✅ |\n", + "| Jupyter notebook | ✅ | ✅ |\n", + "| No extra dependencies | ✅ | ❌ (requires plotly) |\n", + "\n", + "**Recommendations:**\n", + "- Use **matplotlib** for publications, reports, and when you need fine control over the figure styling.\n", + "- Use **plotly** for interactive exploration, debugging pulse shapes, and presentations." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "677aeb46", + "metadata": {}, + "outputs": [], + "source": [ + "# Create a more complex blueprint to showcase the backends\n", + "bp_complex = bb.BluePrint()\n", + "bp_complex.insertSegment(0, ramp, (0, 1e-3), name=\"charge\", dur=1e-6)\n", + "bp_complex.insertSegment(1, sine, (5e6, 0.2e-3, 1e-3, 0), name=\"pulse\", dur=2e-6)\n", + "bp_complex.insertSegment(2, ramp, (1e-3, 0.5e-3), name=\"measure\", dur=1e-6)\n", + "bp_complex.insertSegment(3, ramp, (0.5e-3, 0.5e-3), name=\"hold\", dur=2e-6)\n", + "bp_complex.insertSegment(4, ramp, (0.5e-3, 0), name=\"reset\", dur=1e-6)\n", + "bp_complex.setSR(1e9)\n", + "\n", + "# Add markers\n", + "bp_complex.setSegmentMarker(\"pulse\", (0, 2e-6), 1)\n", + "bp_complex.setSegmentMarker(\"measure\", (0, 1e-6), 2)\n", + "\n", + "print(\"Complex blueprint structure:\")\n", + "bp_complex.showPrint()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1a7edcb9", + "metadata": {}, + "outputs": [], + "source": [ + "# Matplotlib version\n", + "print(\"Matplotlib output:\")\n", + "fig_mpl = plotter(bp_complex, backend=\"matplotlib\")\n", + "fig_mpl.suptitle(\"Complex Pulse - Matplotlib\", y=1.02)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "bf46b55b", + "metadata": {}, + "outputs": [], + "source": [ + "# Plotly version\n", + "print(\"Plotly output (interactive - try zooming!):\")\n", + "fig_plotly = plotter(bp_complex, backend=\"plotly\")\n", + "fig_plotly.update_layout(title=\"Complex Pulse - Plotly (Interactive)\")\n", + "fig_plotly.show()" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".conda", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.12" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/docs/examples/index.rst b/docs/examples/index.rst index 691f64a9a..415e5529b 100644 --- a/docs/examples/index.rst +++ b/docs/examples/index.rst @@ -10,3 +10,4 @@ Broadbean Examples Example_Write_Read_JSON.ipynb Filter_compensation.ipynb Subsequences.ipynb + Plotting_Examples.ipynb diff --git a/pyproject.toml b/pyproject.toml index 642b24f24..fc1b94e6d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -52,7 +52,9 @@ docs = [ "sphinx-jsonschema", "ipykernel", ] - +plotly = [ + "plotly>=5.0.0", +] [tool.pytest] minversion = "9.0" @@ -68,7 +70,8 @@ warn_redundant_casts = true [[tool.mypy.overrides]] module = [ "matplotlib.*", - "schema" + "schema", + "plotly.*", ] ignore_missing_imports = true diff --git a/requirements.txt b/requirements.txt index fe01ae0e1..cd6026be5 100644 --- a/requirements.txt +++ b/requirements.txt @@ -152,6 +152,8 @@ pillow==12.1.1 # via matplotlib platformdirs==4.5.1 # via jupyter-core +plotly==6.5.0 + # via broadbean (pyproject.toml) pluggy==1.6.0 # via # pytest diff --git a/src/broadbean/plotting.py b/src/broadbean/plotting.py index 9c5b9ebc4..0b9ce4051 100644 --- a/src/broadbean/plotting.py +++ b/src/broadbean/plotting.py @@ -127,51 +127,26 @@ def _plot_summariser(seq: dict[int, dict]) -> dict[int, dict[str, np.ndarray]]: return output -# the Grand Unified Plotter -def plotter(obj_to_plot: BBObject, **forger_kwargs) -> None: - """ - The one plot function to be called. Turns whatever it gets - into a sequence, forges it, and plots that. +def _plot_matplotlib( + obj_to_plot: BBObject, + seq: dict[int, dict], + chans: list, + seqlen: int, + chanminmax: list[tuple[float, float]], +): """ + Create a matplotlib plot of the forged sequence. - # TODO: Take axes as input - - # strategy: - # * Validate - # * Forge - # * Plot - - _plot_object_validator(obj_to_plot) - - seq = _plot_object_forger(obj_to_plot, **forger_kwargs) - - # Get the dimensions. - chans = seq[1]["content"][1]["data"].keys() - seqlen = len(seq.keys()) - - def update_minmax(chanminmax, wfmdata, chanind): - (thismin, thismax) = (wfmdata.min(), wfmdata.max()) - if thismin < chanminmax[chanind][0]: - chanminmax[chanind] = [thismin, chanminmax[chanind][1]] - if thismax > chanminmax[chanind][1]: - chanminmax[chanind] = [chanminmax[chanind][0], thismax] - return chanminmax - - # Then figure out the figure scalings - minf: float = -np.inf - inf: float = np.inf - chanminmax: list[tuple[float, float]] = [(inf, minf)] * len(chans) - for chanind, chan in enumerate(chans): - for pos in range(1, seqlen + 1): - if seq[pos]["type"] == "element": - wfmdata = seq[pos]["content"][1]["data"][chan]["wfm"] - chanminmax = update_minmax(chanminmax, wfmdata, chanind) - elif seq[pos]["type"] == "subsequence": - for pos2 in seq[pos]["content"].keys(): - elem = seq[pos]["content"][pos2]["data"] - wfmdata = elem[chan]["wfm"] - chanminmax = update_minmax(chanminmax, wfmdata, chanind) + Args: + obj_to_plot: The original object being plotted + seq: The forged sequence + chans: List of channel names + seqlen: Number of sequence positions + chanminmax: List of (min, max) tuples for each channel + Returns: + The matplotlib Figure object + """ fig, axs = plt.subplots(len(chans), seqlen, squeeze=False) # ...and do the plotting @@ -342,15 +317,427 @@ def update_minmax(chanminmax, wfmdata, chanind): if seq_info["twait"] == 1: # trigger wait titlestring += "T " if seq_info["nrep"] > 1: # nreps - titlestring += "\u21bb{} ".format(seq_info["nrep"]) + titlestring += "↻{} ".format(seq_info["nrep"]) if seq_info["nrep"] == 0: - titlestring += "\u221e " + titlestring += "∞ " if seq_info["jump_input"] != 0: if seq_info["jump_input"] == -1: - titlestring += "E\u2192 " + titlestring += "⚡ " else: - titlestring += "E{} ".format(seq_info["jump_input"]) + titlestring += "⚡{} ".format(seq_info["jump_input"]) if seq_info["goto"] > 0: - titlestring += "\u21b1{}".format(seq_info["goto"]) + titlestring += "→{}".format(seq_info["goto"]) ax.set_title(titlestring) + + return fig + + +def _plot_plotly( + obj_to_plot: BBObject, + seq: dict[int, dict], + chans: list, + seqlen: int, + chanminmax: list[tuple[float, float]], +): + """ + Create a plotly plot of the forged sequence. + + Args: + obj_to_plot: The original object being plotted + seq: The forged sequence + chans: List of channel names + seqlen: Number of sequence positions + chanminmax: List of (min, max) tuples for each channel + + Returns: + The plotly Figure object + """ + try: + from plotly import graph_objects as go + from plotly.subplots import make_subplots + except ImportError: + raise ImportError( + "plotly is required for the 'plotly' backend. " + "Install it with: pip install broadbean[plotly]" + ) + + # Create subplots + fig = make_subplots( + rows=len(chans), + cols=seqlen, + shared_yaxes="rows", + horizontal_spacing=0.0, + vertical_spacing=0.0, + ) + + # Convert RGB tuples to rgba strings + def rgba(rgb: tuple[float, float, float], alpha: float) -> str: + r, g, b = [int(c * 255) for c in rgb] + return f"rgba({r},{g},{b},{alpha})" + + # ...and do the plotting + for chanind, chan in enumerate(chans): + # figure out the channel voltage scaling + # The entire channel shares a y-axis + + minmax: tuple[float, float] = chanminmax[chanind] + + (voltagescaling, voltageprefix) = getSIScalingAndPrefix(minmax) + voltageunit = voltageprefix + "V" + + for pos in range(seqlen): + row = chanind + 1 + col = pos + 1 + + if seq[pos + 1]["type"] == "element": + content = seq[pos + 1]["content"][1]["data"][chan] + wfm = content["wfm"] + m1 = content.get("m1", np.zeros_like(wfm)) + m2 = content.get("m2", np.zeros_like(wfm)) + time = content["time"] + newdurs = content.get("newdurations", []) + + else: + arr_dict = _plot_summariser(seq[pos + 1]["content"]) + wfm = arr_dict[chan]["wfm"] + newdurs = [] + time = np.linspace(0, 1, 2) # needed for timeexponent + + # Figure out the axes' scaling + timeexponent = np.log10(time.max()) + timeunit = "s" + timescaling: float = 1.0 + if timeexponent < 0: + timeunit = "ms" + timescaling = 1e3 + if timeexponent < -3: + timeunit = "micro s" + timescaling = 1e6 + if timeexponent < -6: + timeunit = "ns" + timescaling = 1e9 + + # Calculate y-axis range + ymax = voltagescaling * chanminmax[chanind][1] + ymin = voltagescaling * chanminmax[chanind][0] + yrange = ymax - ymin + ylim_min = ymin - 0.05 * yrange + ylim_max = ymax + 0.2 * yrange + + # Plot waveform for elements + if seq[pos + 1]["type"] == "element": + fig.add_trace( + go.Scatter( + x=timescaling * time, + y=voltagescaling * wfm, + mode="lines", + line=dict(color=rgba((0.6, 0.4, 0.3), 0.4), width=3), + showlegend=False, + hoverinfo="skip", + ), + row=row, + col=col, + ) + + # marker1 (red, on top) + y_m1 = ymax + 0.15 * yrange + marker_on_mask = m1 != 0 + # Off state (background) + fig.add_trace( + go.Scatter( + x=timescaling * time, + y=np.ones_like(m1) * y_m1, + mode="lines", + line=dict(color=rgba((0.6, 0.1, 0.1), 0.2), width=2), + showlegend=False, + hoverinfo="skip", + ), + row=row, + col=col, + ) + # On state + if marker_on_mask.any(): + fig.add_trace( + go.Scatter( + x=(timescaling * time)[marker_on_mask], + y=(np.ones_like(m1) * y_m1)[marker_on_mask], + mode="lines", + line=dict(color=rgba((0.6, 0.1, 0.1), 0.6), width=2), + showlegend=False, + hoverinfo="skip", + ), + row=row, + col=col, + ) + + # marker 2 (blue, below the red) + y_m2 = ymax + 0.10 * yrange + marker_on_mask = m2 != 0 + # Off state (background) + fig.add_trace( + go.Scatter( + x=timescaling * time, + y=np.ones_like(m2) * y_m2, + mode="lines", + line=dict(color=rgba((0.1, 0.1, 0.6), 0.2), width=2), + showlegend=False, + hoverinfo="skip", + ), + row=row, + col=col, + ) + # On state + if marker_on_mask.any(): + fig.add_trace( + go.Scatter( + x=(timescaling * time)[marker_on_mask], + y=(np.ones_like(m2) * y_m2)[marker_on_mask], + mode="lines", + line=dict(color=rgba((0.1, 0.1, 0.6), 0.6), width=2), + showlegend=False, + hoverinfo="skip", + ), + row=row, + col=col, + ) + + # time step lines + for dur in np.cumsum(newdurs): + fig.add_trace( + go.Scatter( + x=[timescaling * dur, timescaling * dur], + y=[ylim_min, ylim_max], + mode="lines", + line=dict(color=rgba((0.312, 0.2, 0.33), 0.3), width=1), + showlegend=False, + hoverinfo="skip", + ), + row=row, + col=col, + ) + + # If subsequence, plot lines indicating min and max value + if seq[pos + 1]["type"] == "subsequence": + # min: + fig.add_trace( + go.Scatter( + x=time, + y=np.ones_like(time) * wfm[0], + mode="lines", + line=dict(color=rgba((0.12, 0.12, 0.12), 0.2), width=2), + showlegend=False, + hoverinfo="skip", + ), + row=row, + col=col, + ) + # max: + fig.add_trace( + go.Scatter( + x=time, + y=np.ones_like(time) * wfm[1], + mode="lines", + line=dict(color=rgba((0.12, 0.12, 0.12), 0.2), width=2), + showlegend=False, + hoverinfo="skip", + ), + row=row, + col=col, + ) + + # Add "SUBSEQ" annotation using paper coordinates + fig.add_annotation( + text="SUBSEQ", + xref="paper", + yref="paper", + x=(col - 0.5) / seqlen, # Normalized x position + y=1 - (row - 0.5) / len(chans), # Normalized y position + xanchor="center", + yanchor="middle", + showarrow=False, + row=row, + col=col, + ) + + # Update axes + xaxis_name = f"xaxis{(row - 1) * seqlen + col}" + yaxis_name = f"yaxis{(row - 1) * seqlen + col}" + + # Y-axis configuration + fig.layout[yaxis_name].update( + range=[ylim_min, ylim_max], + showgrid=False, + zeroline=False, + showline=True, + linewidth=1, + linecolor="black", + mirror=True, + ) + + # X-axis configuration + fig.layout[xaxis_name].update( + showgrid=False, + zeroline=False, + showline=True, + linewidth=1, + linecolor="black", + mirror=True, + ) + + # Axis labels + if pos == 0: + fig.layout[yaxis_name].update(title=f"({voltageunit})") + else: + fig.layout[yaxis_name].update(showticklabels=False) + + if chanind == len(chans) - 1: + if seq[pos + 1]["type"] == "subsequence": + fig.layout[xaxis_name].update( + title="Time N/A", showticklabels=False + ) + else: + fig.layout[xaxis_name].update(title=f"({timeunit})") + else: + fig.layout[xaxis_name].update(showticklabels=False) + + # Add channel label on the right for the last column + if pos == seqlen - 1 and not (isinstance(obj_to_plot, BluePrint)): + if isinstance(chan, int): + chan_label = f"Ch. {chan}" + elif isinstance(chan, str): + chan_label = chan + else: + chan_label = str(chan) + + fig.add_annotation( + text=chan_label, + xref="paper", + yref="paper", + x=1.02, + y=1 + - (row - 0.5) / len(chans), # Normalized y position for this row + xanchor="left", + yanchor="middle", + showarrow=False, + textangle=-90, + ) + + # display sequencer information as subplot title + if chanind == 0 and isinstance(obj_to_plot, Sequence): + seq_info = seq[pos + 1]["sequencing"] + titlestring = "" + if seq_info["twait"] == 1: # trigger wait + titlestring += "T " + if seq_info["nrep"] > 1: # nreps + titlestring += "↻{} ".format(seq_info["nrep"]) + if seq_info["nrep"] == 0: + titlestring += "∞ " + if seq_info["jump_input"] != 0: + if seq_info["jump_input"] == -1: + titlestring += "⚡ " + else: + titlestring += "⚡{} ".format(seq_info["jump_input"]) + if seq_info["goto"] > 0: + titlestring += "→{}".format(seq_info["goto"]) + + if titlestring.strip(): + # Add title annotation above this subplot + fig.add_annotation( + text=titlestring, + xref="paper", + yref="paper", + x=(col - 0.5) / seqlen, # Center of this column + y=1.02, # Just above the top + xanchor="center", + yanchor="bottom", + showarrow=False, + ) + + # Update overall layout + fig.update_layout( + plot_bgcolor="white", + paper_bgcolor="white", + margin=dict(l=50, r=50, t=30, b=50), + showlegend=False, + ) + + return fig + + +# the Grand Unified Plotter +def plotter( + obj_to_plot: BBObject, + backend: str = "matplotlib", + max_subsequences: int | None = None, + **forger_kwargs, +): + """ + The one plot function to be called. Turns whatever it gets + into a sequence, forges it, and plots that. + + Args: + obj_to_plot: The object to plot (Sequence, Element, or BluePrint) + backend: The plotting backend to use. Either "matplotlib" or "plotly". + Default is "matplotlib". + max_subsequences: If set, limits the number of subsequences plotted + to this number. + **forger_kwargs: Additional keyword arguments passed to the forge method + + Returns: + matplotlib.figure.Figure if backend is "matplotlib", + plotly.graph_objects.Figure if backend is "plotly" + """ + + # Validate backend parameter + if backend not in ("matplotlib", "plotly"): + raise ValueError( + f"Invalid backend '{backend}'. Must be either 'matplotlib' or 'plotly'." + ) + + # TODO: Take axes as input + + # strategy: + # * Validate + # * Forge + # * Plot + + _plot_object_validator(obj_to_plot) + + seq = _plot_object_forger(obj_to_plot, **forger_kwargs) + + # Get the dimensions. + chans = list(seq[1]["content"][1]["data"].keys()) + seqlen = len(seq.keys()) + + if max_subsequences is not None: + seqlen = min(seqlen, max_subsequences) + + def update_minmax(chanminmax, wfmdata, chanind): + (thismin, thismax) = (wfmdata.min(), wfmdata.max()) + if thismin < chanminmax[chanind][0]: + chanminmax[chanind] = [thismin, chanminmax[chanind][1]] + if thismax > chanminmax[chanind][1]: + chanminmax[chanind] = [chanminmax[chanind][0], thismax] + return chanminmax + + # Then figure out the figure scalings + minf: float = -np.inf + inf: float = np.inf + chanminmax: list[tuple[float, float]] = [(inf, minf)] * len(chans) + for chanind, chan in enumerate(chans): + for pos in range(1, seqlen + 1): + if seq[pos]["type"] == "element": + wfmdata = seq[pos]["content"][1]["data"][chan]["wfm"] + chanminmax = update_minmax(chanminmax, wfmdata, chanind) + elif seq[pos]["type"] == "subsequence": + for pos2 in seq[pos]["content"].keys(): + elem = seq[pos]["content"][pos2]["data"] + wfmdata = elem[chan]["wfm"] + chanminmax = update_minmax(chanminmax, wfmdata, chanind) + + # Route to appropriate backend + if backend == "matplotlib": + return _plot_matplotlib(obj_to_plot, seq, chans, seqlen, chanminmax) + else: # backend == "plotly" + return _plot_plotly(obj_to_plot, seq, chans, seqlen, chanminmax) diff --git a/tests/test_plotting.py b/tests/test_plotting.py new file mode 100644 index 000000000..df27fecfe --- /dev/null +++ b/tests/test_plotting.py @@ -0,0 +1,356 @@ +# Test suite for the plotting module of the broadbean package + +import matplotlib +import pytest + +matplotlib.use("Agg") # Use non-interactive backend for tests +import matplotlib.figure + +import broadbean as bb +from broadbean.plotting import plotter + + +@pytest.fixture +def simple_blueprint(): + """ + Create a simple blueprint for testing + """ + bp = bb.BluePrint() + bp.insertSegment(0, bb.PulseAtoms.ramp, args=(0, 1), name="rise", dur=1e-6) + bp.insertSegment(1, bb.PulseAtoms.ramp, args=(1, 0), name="fall", dur=1e-6) + bp.setSR(1e9) + return bp + + +@pytest.fixture +def simple_element(): + """ + Create a simple element for testing + """ + bp = bb.BluePrint() + bp.insertSegment(0, bb.PulseAtoms.ramp, args=(0, 1), name="ramp", dur=1e-6) + bp.setSR(1e9) + + elem = bb.Element() + elem.addBluePrint(1, bp) + return elem + + +@pytest.fixture +def simple_sequence(): + """ + Create a simple sequence for testing + """ + bp1 = bb.BluePrint() + bp1.insertSegment(0, bb.PulseAtoms.ramp, args=(0, 1), name="rise", dur=1e-6) + bp1.setSR(1e9) + + bp2 = bb.BluePrint() + bp2.insertSegment(0, bb.PulseAtoms.ramp, args=(1, 0), name="fall", dur=1e-6) + bp2.setSR(1e9) + + elem1 = bb.Element() + elem1.addBluePrint(1, bp1) + + elem2 = bb.Element() + elem2.addBluePrint(1, bp2) + + seq = bb.Sequence() + seq.addElement(1, elem1) + seq.addElement(2, elem2) + seq.setSR(1e9) + + return seq + + +@pytest.fixture +def blueprint_with_markers(): + """ + Create a blueprint with markers for testing + """ + bp = bb.BluePrint() + bp.insertSegment(0, bb.PulseAtoms.ramp, args=(0, 1), name="rise", dur=1e-6) + bp.insertSegment(1, bb.PulseAtoms.ramp, args=(1, 0), name="fall", dur=1e-6) + bp.setSR(1e9) + + # Add markers + bp.marker1 = [(0, 0.5e-6)] + bp.marker2 = [(1e-6, 0.5e-6)] + + return bp + + +################################################## +# TEST BACKEND PARAMETER VALIDATION + + +def test_plotter_invalid_backend(simple_blueprint): + """Test that invalid backend raises ValueError""" + with pytest.raises(ValueError, match="Invalid backend"): + plotter(simple_blueprint, backend="invalid") + + +def test_plotter_matplotlib_backend_default(simple_blueprint): + """Test that matplotlib is the default backend""" + fig = plotter(simple_blueprint) + assert isinstance(fig, matplotlib.figure.Figure) + + +def test_plotter_matplotlib_backend_explicit(simple_blueprint): + """Test explicit matplotlib backend selection""" + fig = plotter(simple_blueprint, backend="matplotlib") + assert isinstance(fig, matplotlib.figure.Figure) + + +################################################## +# TEST MATPLOTLIB BACKEND + + +def test_matplotlib_blueprint(simple_blueprint): + """Test matplotlib plotting of a blueprint""" + fig = plotter(simple_blueprint, backend="matplotlib") + assert isinstance(fig, matplotlib.figure.Figure) + assert len(fig.axes) == 1 # 1 channel, 1 position + + +def test_matplotlib_element(simple_element): + """Test matplotlib plotting of an element""" + fig = plotter(simple_element, backend="matplotlib") + assert isinstance(fig, matplotlib.figure.Figure) + # 1 channel subplot + 1 twin axis for channel label + assert len(fig.axes) >= 1 + + +def test_matplotlib_sequence(simple_sequence): + """Test matplotlib plotting of a sequence""" + fig = plotter(simple_sequence, backend="matplotlib") + assert isinstance(fig, matplotlib.figure.Figure) + # 1 channel, 2 positions + 1 twin axis for channel label + assert len(fig.axes) >= 2 + + +def test_matplotlib_with_markers(blueprint_with_markers): + """Test matplotlib plotting with markers""" + elem = bb.Element() + elem.addBluePrint(1, blueprint_with_markers) + + fig = plotter(elem, backend="matplotlib") + assert isinstance(fig, matplotlib.figure.Figure) + # Should have plotted the waveform and markers + ax = fig.axes[0] + assert len(ax.lines) > 1 # Waveform + marker lines + + +################################################## +# TEST PLOTLY BACKEND + + +def test_plotly_import_error(simple_blueprint, monkeypatch): + """Test that missing plotly raises helpful ImportError""" + + def mock_import(name, *args, **kwargs): + if "plotly" in name: + raise ImportError("No module named 'plotly'") + return __import__(name, *args, **kwargs) + + monkeypatch.setattr("builtins.__import__", mock_import) + + with pytest.raises(ImportError, match="plotly is required"): + plotter(simple_blueprint, backend="plotly") + + +@pytest.mark.skipif( + not pytest.importorskip("plotly", reason="plotly not installed"), + reason="plotly not available", +) +def test_plotly_blueprint(): + """Test plotly plotting of a blueprint""" + import plotly.graph_objects as go + + bp = bb.BluePrint() + bp.insertSegment(0, bb.PulseAtoms.ramp, args=(0, 1), name="rise", dur=1e-6) + bp.setSR(1e9) + + fig = plotter(bp, backend="plotly") + assert isinstance(fig, go.Figure) + # Check that traces were added + assert len(fig.data) > 0 + + +@pytest.mark.skipif( + not pytest.importorskip("plotly", reason="plotly not installed"), + reason="plotly not available", +) +def test_plotly_element(): + """Test plotly plotting of an element""" + import plotly.graph_objects as go + + bp = bb.BluePrint() + bp.insertSegment(0, bb.PulseAtoms.ramp, args=(0, 1), name="rise", dur=1e-6) + bp.setSR(1e9) + + elem = bb.Element() + elem.addBluePrint(1, bp) + + fig = plotter(elem, backend="plotly") + assert isinstance(fig, go.Figure) + assert len(fig.data) > 0 + + +@pytest.mark.skipif( + not pytest.importorskip("plotly", reason="plotly not installed"), + reason="plotly not available", +) +def test_plotly_sequence(): + """Test plotly plotting of a sequence""" + import plotly.graph_objects as go + + bp1 = bb.BluePrint() + bp1.insertSegment(0, bb.PulseAtoms.ramp, args=(0, 1), name="rise", dur=1e-6) + bp1.setSR(1e9) + + bp2 = bb.BluePrint() + bp2.insertSegment(0, bb.PulseAtoms.ramp, args=(1, 0), name="fall", dur=1e-6) + bp2.setSR(1e9) + + elem1 = bb.Element() + elem1.addBluePrint(1, bp1) + + elem2 = bb.Element() + elem2.addBluePrint(1, bp2) + + seq = bb.Sequence() + seq.addElement(1, elem1) + seq.addElement(2, elem2) + seq.setSR(1e9) + + fig = plotter(seq, backend="plotly") + assert isinstance(fig, go.Figure) + assert len(fig.data) > 0 + + +@pytest.mark.skipif( + not pytest.importorskip("plotly", reason="plotly not installed"), + reason="plotly not available", +) +def test_plotly_with_markers(): + """Test plotly plotting with markers""" + import plotly.graph_objects as go + + bp = bb.BluePrint() + bp.insertSegment(0, bb.PulseAtoms.ramp, args=(0, 1), name="rise", dur=1e-6) + bp.insertSegment(1, bb.PulseAtoms.ramp, args=(1, 0), name="fall", dur=1e-6) + bp.setSR(1e9) + + # Add markers + bp.marker1 = [(0, 0.5e-6)] + bp.marker2 = [(1e-6, 0.5e-6)] + + elem = bb.Element() + elem.addBluePrint(1, bp) + + fig = plotter(elem, backend="plotly") + assert isinstance(fig, go.Figure) + # Should have multiple traces (waveform + markers) + assert len(fig.data) > 1 + + +@pytest.mark.skipif( + not pytest.importorskip("plotly", reason="plotly not installed"), + reason="plotly not available", +) +def test_plotly_multichannel(): + """Test plotly plotting with multiple channels""" + import plotly.graph_objects as go + + bp1 = bb.BluePrint() + bp1.insertSegment(0, bb.PulseAtoms.ramp, args=(0, 1), name="rise", dur=1e-6) + bp1.setSR(1e9) + + bp2 = bb.BluePrint() + bp2.insertSegment(0, bb.PulseAtoms.ramp, args=(1, 0), name="fall", dur=1e-6) + bp2.setSR(1e9) + + elem = bb.Element() + elem.addBluePrint(1, bp1) + elem.addBluePrint(2, bp2) + + fig = plotter(elem, backend="plotly") + assert isinstance(fig, go.Figure) + # Should have traces for both channels + assert len(fig.data) > 0 + + +@pytest.mark.skipif( + not pytest.importorskip("plotly", reason="plotly not installed"), + reason="plotly not available", +) +def test_plotly_layout_properties(): + """Test that plotly figure has correct layout properties""" + import plotly.graph_objects as go + + bp = bb.BluePrint() + bp.insertSegment(0, bb.PulseAtoms.ramp, args=(0, 1), name="rise", dur=1e-6) + bp.setSR(1e9) + + fig = plotter(bp, backend="plotly") + assert isinstance(fig, go.Figure) + + # Check layout properties + assert fig.layout.plot_bgcolor == "white" + assert fig.layout.paper_bgcolor == "white" + assert fig.layout.showlegend is False + + +################################################## +# TEST BACKEND CONSISTENCY + + +@pytest.mark.skipif( + not pytest.importorskip("plotly", reason="plotly not installed"), + reason="plotly not available", +) +def test_backends_produce_output(): + """Test that both backends produce output without errors""" + bp = bb.BluePrint() + bp.insertSegment(0, bb.PulseAtoms.ramp, args=(0, 1), name="rise", dur=1e-6) + bp.setSR(1e9) + + elem = bb.Element() + elem.addBluePrint(1, bp) + + # Both should complete without errors + fig_mpl = plotter(elem, backend="matplotlib") + fig_plotly = plotter(elem, backend="plotly") + + assert fig_mpl is not None + assert fig_plotly is not None + + +################################################## +# TEST FORGER KWARGS + + +def test_matplotlib_with_forger_kwargs(simple_element): + """Test that forger kwargs are passed through for matplotlib""" + fig = plotter(simple_element, backend="matplotlib", apply_delays=False) + assert isinstance(fig, matplotlib.figure.Figure) + + +@pytest.mark.skipif( + not pytest.importorskip("plotly", reason="plotly not installed"), + reason="plotly not available", +) +def test_plotly_with_forger_kwargs(): + """Test that forger kwargs are passed through for plotly""" + import plotly.graph_objects as go + + bp = bb.BluePrint() + bp.insertSegment(0, bb.PulseAtoms.ramp, args=(0, 1), name="rise", dur=1e-6) + bp.setSR(1e9) + + elem = bb.Element() + elem.addBluePrint(1, bp) + + fig = plotter(elem, backend="plotly", apply_delays=False) + assert isinstance(fig, go.Figure)