Skip to content
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,9 @@ ipython_config.py
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version
.conda
bootstrap_requirements.txt
environment.yml

# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
Expand Down
314 changes: 126 additions & 188 deletions docs/examples/Pulse_Building_Tutorial.ipynb
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why are there two duplicated(?) plots instead of one in this example? or is it github mis-rendering?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems to be working fine in VScode.

Large diffs are not rendered by default.

38 changes: 33 additions & 5 deletions src/broadbean/blueprint.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import numpy as np

from .broadbean import PulseAtoms
from .func_serialization import deserialize_function, serialize_function


class SegmentDurationError(Exception):
Expand Down Expand Up @@ -83,7 +84,8 @@ def __init__(
# Infer names from signature if not given, i.e. allow for '' names
for ii, name in enumerate(namelist):
if isinstance(funlist[ii], str):
namelist[ii] = funlist[ii]
if name == "":
namelist[ii] = funlist[ii]
elif name == "":
namelist[ii] = funlist[ii].__name__

Expand All @@ -93,8 +95,8 @@ def __init__(
argslist[ii] = (args,)
self._argslist = argslist

self._namelist = namelist
namelist = self._make_names_unique(namelist)
self._namelist = namelist
Comment thread
bennthomsen marked this conversation as resolved.

# initialise markers
if marker1 is None:
Expand Down Expand Up @@ -265,6 +267,12 @@ def description(self):
desc[segkey]["durations"] = self._durslist[sn]
if desc[segkey]["function"] == "waituntil":
desc[segkey]["arguments"] = {"waittime": self._argslist[sn]}
elif desc[segkey]["function"] == "function PulseAtoms.arb_func":
# Special handling for arb_func serialization
func_obj, kwargs_dict = self._argslist[sn]
serialized = serialize_function(func_obj)
serialized["kwargs"] = kwargs_dict
desc[segkey]["arguments"] = serialized
else:
sig = signature(self._funlist[sn])
desc[segkey]["arguments"] = dict(
Expand All @@ -275,6 +283,7 @@ def description(self):
desc["marker2_abs"] = self.marker2
desc["marker1_rel"] = self._segmark1
desc["marker2_rel"] = self._segmark2
desc["SR"] = self._SR

return desc

Expand Down Expand Up @@ -312,7 +321,26 @@ def blueprint_from_description(cls, blue_dict):
if seg_dict["function"] == "waituntil":
arguments = blue_dict[seg]["arguments"].values()
arguments = (list(arguments)[0][0],)
bp_seg.insertSegment(i, "waituntil", arguments)
bp_seg.insertSegment(i, "waituntil", arguments, name=seg_dict["name"])
elif seg_dict["function"] == "function PulseAtoms.arb_func":
# Special handling for arb_func reconstruction
args_dict = blue_dict[seg]["arguments"]
kwargs_dict = args_dict["kwargs"]

if args_dict.get("func_type") in ("lambda", "named_function"):
func_obj = deserialize_function(args_dict)
arguments = (func_obj, kwargs_dict)
else:
# Legacy format or fallback
arguments = tuple(blue_dict[seg]["arguments"].values())

bp_seg.insertSegment(
i,
knowfunctions[seg_dict["function"]],
arguments,
name=re.sub(r"\d", "", seg_dict["name"]),
dur=seg_dict["durations"],
)
else:
arguments = tuple(blue_dict[seg]["arguments"].values())
bp_seg.insertSegment(
Expand All @@ -329,6 +357,8 @@ def blueprint_from_description(cls, blue_dict):
listmarker2 = blue_dict["marker2_rel"]
bp_sum._segmark1 = [tuple(mark) for mark in listmarker1]
bp_sum._segmark2 = [tuple(mark) for mark in listmarker2]
if "SR" in blue_dict:
bp_sum._SR = blue_dict["SR"]
return bp_sum

@classmethod
Expand Down Expand Up @@ -664,7 +694,6 @@ def insertSegment(self, pos, func, args=(), dur=None, name=None, durs=None):

if pos < -1:
raise ValueError("Position must be strictly larger than -1")

if name is None or name == "":
if func == "waituntil":
name = "waituntil"
Expand All @@ -674,7 +703,6 @@ def insertSegment(self, pos, func, args=(), dur=None, name=None, durs=None):
if len(name) > 0:
if name[-1].isdigit():
raise ValueError("Segment name must not end in a number")

if pos == -1:
self._namelist.append(name)
self._namelist = self._make_names_unique(self._namelist)
Expand Down
195 changes: 195 additions & 0 deletions src/broadbean/func_serialization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,195 @@
# Utilities for serializing and deserializing Python functions to/from strings.
#
# This module provides a safe, controlled way to store function expressions
# as strings (for JSON serialization) and reconstruct them later. It is
# designed for use with PulseAtoms.arb_func, where users define custom
# waveforms via lambda expressions or simple named functions.
#
# The serialization format stores a source expression string such as:
# "lambda t, ampl, freq: ampl * np.sin(2 * np.pi * freq * t)"
#
# On deserialization, the expression is evaluated in a restricted namespace
# containing only numpy (as ``np``) and Python builtins. Functions that rely
# on other imports are not supported — they must be self-sufficient.

import inspect
import logging
import re
import textwrap

import numpy as np

log = logging.getLogger(__name__)

# The namespace available when deserializing function expressions.
# Only numpy is exposed; no access to the caller's globals.
_SAFE_EVAL_NAMESPACE: dict = {"np": np, "numpy": np, "__builtins__": {}}


def serialize_function(func_obj) -> dict:
"""
Serialize a callable into a JSON-compatible dict.

The dict always contains:
- ``func_type``: either ``"lambda"`` or ``"named_function"``
- ``func_source``: the source expression as a string, or ``None``
if it could not be determined

For named functions, ``func_name`` is also included.

The source string must be *self-sufficient*: it may reference ``np``
(numpy) but must not depend on any other imports or global state.

If the callable carries a ``__func_source__`` attribute (e.g. set by
an external UI), that value is used directly and no introspection is
attempted.

Args:
func_obj: A callable (lambda or regular function).

Returns:
dict with serialization metadata.
"""
# Prefer an explicitly attached source string (from an external UI, etc.)
if hasattr(func_obj, "__func_source__"):
return {
"func_type": "lambda",
"func_source": func_obj.__func_source__,
}

is_lambda = not (hasattr(func_obj, "__name__") and func_obj.__name__ != "<lambda>")

if is_lambda:
func_source = _extract_lambda_source(func_obj)
return {
"func_type": "lambda",
"func_source": func_source,
}
else:
func_source = _extract_named_function_source(func_obj)
return {
"func_type": "named_function",
"func_name": func_obj.__name__,
"func_source": func_source,
}


def deserialize_function(serialized: dict):
"""
Reconstruct a callable from a serialization dict produced by
:func:`serialize_function`.

Evaluation is performed in a restricted namespace that only exposes
``numpy`` (as both ``np`` and ``numpy``). No caller globals are
available; the function expression must be self-sufficient.

Args:
serialized: A dict with at least ``func_type`` and ``func_source``.

Returns:
A callable reconstructed from the source expression.

Raises:
ValueError: If ``func_source`` is ``None`` (cannot reconstruct).
"""
func_source = serialized.get("func_source")
func_type = serialized.get("func_type")

if func_source is None:
raise ValueError(
"Cannot deserialize function: no source expression was stored. "
"Ensure the function was serialized with a valid source string."
)

if func_type == "lambda":
return _eval_expression(func_source)
elif func_type == "named_function":
func_name = serialized.get("func_name")
return _exec_named_function(func_source, func_name)
else:
raise ValueError(f"Unknown func_type: {func_type!r}")


# ---------------------------------------------------------------------------
# Internal helpers
# ---------------------------------------------------------------------------


def _extract_lambda_source(func_obj) -> str | None:
"""
Attempt to extract the lambda expression source from a callable.

Uses :func:`inspect.getsource` and a regex to isolate the lambda
expression. Returns ``None`` on failure.
"""
try:
raw_source = inspect.getsource(func_obj)
except (OSError, TypeError):
return None

# Extract the lambda expression from the (possibly multi-line) source.
# The regex matches "lambda <params>: <body>" up to the first unbalanced
# comma, semicolon, or newline.
match = re.search(r"lambda\s+[^:]*:\s*[^\n,;]+", raw_source)
if match:
return match.group(0).strip()
return None


def _extract_named_function_source(func_obj) -> str | None:
"""
Attempt to extract the full source of a named function.

Returns ``None`` on failure.
"""
try:
return inspect.getsource(func_obj)
except (OSError, TypeError):
return None


def _eval_expression(source: str):
"""
Evaluate a lambda expression string in the safe namespace.

Args:
source: e.g. ``"lambda t, ampl: ampl * t"``

Returns:
The resulting callable.
"""
try:
return eval(source, _SAFE_EVAL_NAMESPACE) # noqa: S307
except Exception as e:
log.warning("Could not reconstruct function from source %r: %s", source, e)
raise ValueError(f"Failed to evaluate function expression: {source!r}") from e


def _exec_named_function(source: str, func_name: str):
"""
Execute a named function definition in the safe namespace and return
the resulting callable.

Args:
source: The full function source code.
func_name: The expected function name to retrieve after exec.

Returns:
The reconstructed callable.
"""
local_ns: dict = {}
try:
exec(textwrap.dedent(source), _SAFE_EVAL_NAMESPACE, local_ns) # noqa: S102
except Exception as e:
log.warning(
"Could not reconstruct named function %r from source: %s",
func_name,
e,
)
raise ValueError(f"Failed to execute function source for {func_name!r}") from e

if func_name not in local_ns:
raise ValueError(
f"Function {func_name!r} not found after executing source code."
)
return local_ns[func_name]
Loading