From 8bde3e50ffccb12e551f2a58e7fd5857b6775b87 Mon Sep 17 00:00:00 2001 From: Adrian Seyboldt Date: Mon, 9 Mar 2026 16:20:29 +0100 Subject: [PATCH 01/21] feat: implement compile cache pruning --- pyproject.toml | 1 + python/nutpie/__init__.py | 3 +- python/nutpie/compile_stan.py | 219 ++++++++++++++++++++++++++++++---- 3 files changed, 196 insertions(+), 27 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index ad8a07c..2f4dc5e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,6 +19,7 @@ dependencies = [ "pyarrow >= 12.0.0", "arro3-core >= 0.6.0", "pandas >= 2.0", + "platformdirs >= 3.0.0", "xarray >= 2025.01.2", "arviz >= 0.20.0,<1.0", "obstore >= 0.8.0", diff --git a/python/nutpie/__init__.py b/python/nutpie/__init__.py index d77063f..22f9164 100644 --- a/python/nutpie/__init__.py +++ b/python/nutpie/__init__.py @@ -1,7 +1,7 @@ from nutpie import _lib from nutpie._lib import store as zarr_store from nutpie.compile_pymc import compile_pymc_model -from nutpie.compile_stan import compile_stan_model +from nutpie.compile_stan import compile_stan_model, prune_stan_cache from nutpie.sample import sample ChainProgress = _lib.PyChainProgress @@ -12,6 +12,7 @@ "ChainProgress", "compile_pymc_model", "compile_stan_model", + "prune_stan_cache", "sample", "zarr_store", ] diff --git a/python/nutpie/compile_stan.py b/python/nutpie/compile_stan.py index 14cc13a..7240e74 100644 --- a/python/nutpie/compile_stan.py +++ b/python/nutpie/compile_stan.py @@ -1,3 +1,7 @@ +import datetime +import hashlib +import json +import shutil import tempfile from dataclasses import dataclass, replace from importlib.util import find_spec @@ -144,6 +148,105 @@ def coords(self): return self._coords +def _stan_cache_key( + code: str, + extra_compile_args: Optional[list[str]], + extra_stanc_args: Optional[list[str]], +) -> str: + """Return a SHA-256 hex digest identifying a unique compilation job.""" + import bridgestan + + fingerprint = json.dumps( + { + "code": code, + "extra_compile_args": sorted(extra_compile_args or []), + "extra_stanc_args": sorted(extra_stanc_args or []), + "bridgestan_version": bridgestan.__version__, + }, + sort_keys=True, + ) + return hashlib.sha256(fingerprint.encode()).hexdigest() + + +def _stan_cache_dir() -> Path: + """Return (and create) the directory where compiled Stan models are cached.""" + import platformdirs + + cache_dir = Path(platformdirs.user_cache_dir("nutpie")) / "stan" + cache_dir.mkdir(parents=True, exist_ok=True) + return cache_dir + + +def prune_stan_cache( + max_entries: int = 16, + min_age: datetime.timedelta = datetime.timedelta(weeks=2), +) -> None: + """Remove old entries from the Stan compilation cache. + + Entries are only considered for removal if they are older than *min_age*. + Among those, the oldest ones are removed until at most *max_entries* + entries remain. + + Parameters + ---------- + max_entries: + Maximum number of cache entries to keep. Defaults to 16. + min_age: + Entries younger than this are never removed, regardless of how many + entries exist. Defaults to 2 weeks. + """ + cache_dir = _stan_cache_dir() + now = datetime.datetime.now(tz=datetime.timezone.utc) + + # Collect all valid (marker exists) entries with their mtime. + entries = [] + for entry_dir in cache_dir.iterdir(): + if not entry_dir.is_dir(): + continue + marker = entry_dir / "ok" + if not marker.exists(): + continue + mtime = datetime.datetime.fromtimestamp( + marker.stat().st_mtime, tz=datetime.timezone.utc + ) + entries.append((mtime, entry_dir)) + + if len(entries) <= max_entries: + return + + # Only entries older than min_age are candidates for eviction. + candidates = sorted( + [(mtime, d) for mtime, d in entries if (now - mtime) >= min_age] + ) + + n_to_remove = len(entries) - max_entries + for _, entry_dir in candidates[:n_to_remove]: + shutil.rmtree(entry_dir, ignore_errors=True) + + +def _compile_stan_model( + model_name: str, + code: str, + build_dir: Path, + make_args: list[str], + stanc_args: list[str], +) -> Path: + """Write *code* into *build_dir*, compile it, and return the path to the shared library.""" + import bridgestan + + model_path = ( + build_dir.joinpath("name") + .with_name(model_name) # This verifies that it is a valid filename + .with_suffix(".stan") + ) + model_path.write_text(code) + so_path = bridgestan.compile_model( + model_path, make_args=make_args, stanc_args=stanc_args + ) + bridgestan.compile.windows_dll_path_setup() + return so_path + + def compile_stan_model( *, code: Optional[str] = None, @@ -154,7 +257,46 @@ def compile_stan_model( coords: Optional[dict[str, Any]] = None, model_name: Optional[str] = None, cleanup: bool = True, + cache: bool = False, + prune_cache: bool = True, ) -> CompiledStanModel: + """Compile a Stan model and return a :class:`CompiledStanModel`. + + Parameters + ---------- + code: + Stan model source code as a string. + filename: + Path to a ``.stan`` file. Mutually exclusive with *code*. + extra_compile_args: + Extra arguments forwarded to the C++ compiler via BridgeStan's + ``make_args``. + extra_stanc_args: + Extra arguments forwarded to the Stan compiler (``stanc``). + dims: + Variable dimension names, e.g. ``{"alpha": ["county"]}``. + coords: + Coordinate labels for each dimension, e.g. + ``{"county": ["Hennepin", "Ramsey", ...]}``. + model_name: + Base name used for the ``.stan`` file. Defaults to ``"model"``. + cleanup: + Remove the temporary build directory after compilation. Has no + effect when *cache* is ``True`` (the build directory is the cache + entry and is never deleted). + cache: + When ``True``, compile the model into a persistent directory under + the user cache directory (``~/.cache/nutpie/stan`` on Linux/macOS, + ``%LOCALAPPDATA%\\nutpie\\stan`` on Windows) and reuse it on + subsequent calls with identical arguments and the same BridgeStan + version. A marker file ``ok`` is written only after a successful + build, so interrupted or failed compilations are never reused. + Defaults to ``False``. + prune_cache: + When ``True`` (the default), call :func:`prune_stan_cache` after + each new compilation to evict old cache entries. Has no effect + when *cache* is ``False``. + """ if find_spec("bridgestan") is None: raise ImportError( "BridgeStan is not installed in the current environment. " @@ -180,33 +322,58 @@ def compile_stan_model( if model_name is None: model_name = "model" - basedir = tempfile.TemporaryDirectory(ignore_cleanup_errors=True) - try: - model_path = ( - Path(basedir.name) - .joinpath("name") - .with_name(model_name) # This verifies that it is a valid filename - .with_suffix(".stan") - ) - model_path.write_text(code) - make_args = ["STAN_THREADS=true"] - if extra_compile_args: - make_args.extend(extra_compile_args) - stanc_args = [] - if extra_stanc_args: - stanc_args.extend(extra_stanc_args) - so_path = bridgestan.compile_model( - model_path, make_args=make_args, stanc_args=stanc_args - ) - # Set necessary library loading paths - bridgestan.compile.windows_dll_path_setup() - library = _lib.StanLibrary(so_path) - finally: + make_args = ["STAN_THREADS=true"] + if extra_compile_args: + make_args.extend(extra_compile_args) + stanc_args = [] + if extra_stanc_args: + stanc_args.extend(extra_stanc_args) + + if cache: + digest = _stan_cache_key(code, extra_compile_args, extra_stanc_args) + entry_dir = _stan_cache_dir() / digest + marker = entry_dir / "ok" + + so_path_file = entry_dir / "so_path.txt" + + if marker.exists(): + # Cache hit: touch the marker to record recent use, then load. + marker.touch() + so_path = Path(so_path_file.read_text()) + if not so_path.exists(): + raise FileNotFoundError( + f"Cached Stan library not found: {so_path}. " + "The cache entry may be corrupt; delete it and recompile." + ) + bridgestan.compile.windows_dll_path_setup() + library = _lib.StanLibrary(str(so_path)) + else: + # Cache miss: compile directly into the cache entry directory so + # that all relative loading paths inside the .so remain valid. + entry_dir.mkdir(parents=True, exist_ok=True) + so_path = _compile_stan_model( + model_name, code, entry_dir, make_args, stanc_args + ) + # Write the .so path before the marker so the marker is only + # ever present once so_path.txt is fully written. + so_path_file.write_text(str(so_path)) + marker.write_text("") + library = _lib.StanLibrary(str(so_path)) + if prune_cache: + prune_stan_cache() + else: + basedir = tempfile.TemporaryDirectory(ignore_cleanup_errors=True) try: - if cleanup: - basedir.cleanup() - except Exception: # noqa: BLE001 - pass + so_path = _compile_stan_model( + model_name, code, Path(basedir.name), make_args, stanc_args + ) + library = _lib.StanLibrary(str(so_path)) + finally: + try: + if cleanup: + basedir.cleanup() + except Exception: # noqa: BLE001 + pass return CompiledStanModel( code=code, From 7f13979ec780bcb2469f642ec576733819744281 Mon Sep 17 00:00:00 2001 From: Adrian Seyboldt Date: Mon, 9 Mar 2026 16:20:29 +0100 Subject: [PATCH 02/21] fix: some typing improvements --- python/nutpie/sample.py | 51 +++++++++++++++++++++++++---------------- 1 file changed, 31 insertions(+), 20 deletions(-) diff --git a/python/nutpie/sample.py b/python/nutpie/sample.py index d52d93b..e7c30d9 100644 --- a/python/nutpie/sample.py +++ b/python/nutpie/sample.py @@ -8,7 +8,7 @@ import pandas as pd import pyarrow -from nutpie import _lib # type: ignore +from nutpie import _lib @dataclass(frozen=True) @@ -297,7 +297,7 @@ def _add_arrow_data(data_dict, max_length, batch, chain, n_chains, dims, skip_va def in_marimo_notebook() -> bool: try: - import marimo as mo + import marimo as mo # ty:ignore[unresolved-import] return mo.running_in_notebook() except ImportError: @@ -306,17 +306,25 @@ def in_marimo_notebook() -> bool: def _mo_write_internal(cell_id, stream, value: object) -> None: """Write to marimo cell given cell_id and stream.""" - import marimo + import marimo # ty:ignore[unresolved-import] if marimo.__version__ < "0.19.0": # The old CellOp API is identical to new CellNotificationUtils - from marimo._messaging.ops import CellOp as CellNotificationUtils + from marimo._messaging.ops import ( # ty:ignore[unresolved-import] + CellOp as CellNotificationUtils, + ) else: - from marimo._messaging.notification_utils import CellNotificationUtils + from marimo._messaging.notification_utils import ( # ty:ignore[unresolved-import] + CellNotificationUtils, + ) - from marimo._messaging.cell_output import CellChannel - from marimo._messaging.tracebacks import write_traceback - from marimo._output import formatting + from marimo._messaging.cell_output import ( # ty:ignore[unresolved-import] + CellChannel, + ) + from marimo._messaging.tracebacks import ( # ty:ignore[unresolved-import] + write_traceback, + ) + from marimo._output import formatting # ty:ignore[unresolved-import] output = formatting.try_format(value) if output.traceback is not None: @@ -333,9 +341,11 @@ def _mo_write_internal(cell_id, stream, value: object) -> None: def _mo_create_replace(): """Create mo.output.replace with current context pinned.""" - from marimo._output import formatting - from marimo._runtime.context import get_context - from marimo._runtime.context.types import ContextNotInitializedError + from marimo._output import formatting # ty:ignore[unresolved-import] + from marimo._runtime.context import get_context # ty:ignore[unresolved-import] + from marimo._runtime.context.types import ( # ty:ignore[unresolved-import] + ContextNotInitializedError, + ) try: ctx = get_context() @@ -359,7 +369,7 @@ def in_notebook(): def in_colab(): "Check if the code is running in Google Colaboratory" try: - from google import colab # noqa: F401 + from google import colab # noqa: F401 # ty:ignore[unresolved-import] return True except ImportError: @@ -371,7 +381,7 @@ def in_colab(): shell = get_ipython().__class__.__name__ # type: ignore if shell == "ZMQInteractiveShell": # Jupyter notebook, Spyder or qtconsole try: - from IPython.display import ( + from IPython.display import ( # ty:ignore[unresolved-import] HTML, # noqa: F401 clear_output, # noqa: F401 display, # noqa: F401 @@ -457,7 +467,7 @@ def __init__( if progress_style is None: progress_style = _progress_style - import IPython + import IPython # ty:ignore[unresolved-import] self._html = "" @@ -483,7 +493,7 @@ def callback(formatted): progress_rate, progress_template, cores, callback ) elif in_marimo_notebook(): - import marimo as mo + import marimo as mo # ty:ignore[unresolved-import] if progress_template is None: progress_template = _progress_template @@ -548,7 +558,7 @@ def _extract(self, results): store = cls(*args, **kwargs) obj_store = ObjectStore(store, read_only=True) - ds = xr.open_datatree(obj_store, engine="zarr", consolidated=False) + ds = xr.open_datatree(obj_store, engine="zarr", consolidated=False) # ty:ignore[invalid-argument-type] return arviz.from_datatree(ds) elif results.is_arrow(): @@ -638,11 +648,13 @@ def sample( adaptation: Literal["diag", "draw_diag", "low_rank", "flow"] = "diag", init_mean: np.ndarray | None = None, return_raw_trace: bool = False, + blocking: Literal[True], progress_callback: Any | None = None, progress_template: str | None = None, progress_style: str | None = None, progress_rate: int = 100, zarr_store: _ZarrStoreType | None = None, + **kwargs, ) -> arviz.InferenceData: ... @@ -660,14 +672,14 @@ def sample( adaptation: Literal["diag", "draw_diag", "low_rank", "flow"] = "diag", init_mean: np.ndarray | None = None, return_raw_trace: bool = False, - blocking: Literal[True], + blocking: Literal[False], progress_callback: Any | None = None, progress_template: str | None = None, progress_style: str | None = None, progress_rate: int = 100, zarr_store: _ZarrStoreType | None = None, **kwargs, -) -> arviz.InferenceData: ... +) -> _BackgroundSampler: ... @overload @@ -684,14 +696,13 @@ def sample( adaptation: Literal["diag", "draw_diag", "low_rank", "flow"] = "diag", init_mean: np.ndarray | None = None, return_raw_trace: bool = False, - blocking: Literal[False], progress_callback: Any | None = None, progress_template: str | None = None, progress_style: str | None = None, progress_rate: int = 100, zarr_store: _ZarrStoreType | None = None, **kwargs, -) -> _BackgroundSampler: ... +) -> xr.DataTree: ... def sample( From 108fe15972bb49ae8e5547be289434862431b8d9 Mon Sep 17 00:00:00 2001 From: Adrian Seyboldt Date: Mon, 9 Mar 2026 16:20:29 +0100 Subject: [PATCH 03/21] feat: handle string values in trace sample stats --- python/nutpie/sample.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/python/nutpie/sample.py b/python/nutpie/sample.py index e7c30d9..0f525b6 100644 --- a/python/nutpie/sample.py +++ b/python/nutpie/sample.py @@ -135,6 +135,8 @@ def _add_arrow_data(data_dict, max_length, batch, chain, n_chains, dims, skip_va if name not in data_dict: if dtype in [np.float64, np.float32]: data = np.full(total_shape, np.nan, dtype=dtype) + elif dtype == np.dtype("O"): + data = np.full(total_shape, None, dtype=dtype) else: data = np.zeros(total_shape, dtype=dtype) data_dict[name] = data @@ -148,6 +150,8 @@ def _add_arrow_data(data_dict, max_length, batch, chain, n_chains, dims, skip_va ) else: is_null = is_null.to_numpy(False) + if values.shape[0] == num_draws: + values = values[~is_null] data_dict[name][chain, :num_draws][~is_null] = values.reshape( ((~is_null).sum(),) + tuple(item_shape) ) @@ -579,8 +583,16 @@ def _extract(self, results): ], } + def _get_nested(settings, name, default): + parts = name.split(".") + for part in parts: + if part not in settings: + return default + settings = settings[part] + return settings + for setting, names in skips.items(): - if not getattr(self._settings, setting, False): + if not _get_nested(settings_dict["settings"], setting, False): skip_vars.extend(names) draw_batches, stat_batches = results.get_arrow_trace() From 078358a270d5e49b0a043111cb7e18c9dd254c4c Mon Sep 17 00:00:00 2001 From: Adrian Seyboldt Date: Mon, 9 Mar 2026 16:20:29 +0100 Subject: [PATCH 04/21] feat: expose extra_doubling and exact_normal_trajectory --- python/nutpie/sample.py | 4 +++ src/pyfunc.rs | 18 ++++++++-- src/pymc.rs | 18 ++++++++-- src/stan.rs | 20 ++++++++--- src/wrapper.rs | 77 ++++++++++++++++++++++++++++++++++------- 5 files changed, 115 insertions(+), 22 deletions(-) diff --git a/python/nutpie/sample.py b/python/nutpie/sample.py index 0f525b6..51965b4 100644 --- a/python/nutpie/sample.py +++ b/python/nutpie/sample.py @@ -581,6 +581,10 @@ def _extract(self, results): "divergence_momentum", "divergence_start_gradient", ], + "store_transformed": [ + "transformed_position", + "transformed_gradient", + ], } def _get_nested(settings, name, default): diff --git a/src/pyfunc.rs b/src/pyfunc.rs index 26ba41e..899381d 100644 --- a/src/pyfunc.rs +++ b/src/pyfunc.rs @@ -6,8 +6,8 @@ use nuts_rs::{CpuLogpFunc, CpuMath, HasDims, LogpError, Model, Storable, Value}; use pyo3::{ exceptions::PyRuntimeError, pyclass, pymethods, - types::{PyAnyMethods, PyDict, PyDictMethods, PyList, PyListMethods}, - Bound, Py, PyAny, PyErr, Python, + types::{PyAnyMethods, PyDict, PyDictMethods, PyList, PyListMethods, PyNone}, + Bound, BoundObject, Py, PyAny, PyErr, Python, }; use rand::Rng; use rand_distr::{Distribution, Uniform}; @@ -477,7 +477,7 @@ impl CpuLogpFunc for PyDensity { Ok(()) } - fn new_transformation( + fn init_transformation( &mut self, rng: &mut R, untransformed_position: &[f64], @@ -492,6 +492,18 @@ impl CpuLogpFunc for PyDensity { Ok(trafo) } + fn new_transformation( + &mut self, + _rng: &mut R, + _dim: usize, + _chain: u64, + ) -> std::result::Result { + Python::attach(|py| { + let params = PyNone::get(py); + Ok(params.unbind().into()) + }) + } + fn transformation_id(&self, params: &Py) -> std::result::Result { let id = self .transform_adapter diff --git a/src/pymc.rs b/src/pymc.rs index 44b2e81..b61c5e3 100644 --- a/src/pymc.rs +++ b/src/pymc.rs @@ -6,8 +6,8 @@ use nuts_rs::{CpuLogpFunc, CpuMath, HasDims, LogpError, Model, Storable, Value}; use pyo3::{ exceptions::PyRuntimeError, pyclass, pymethods, - types::{PyAnyMethods, PyDict, PyDictMethods}, - Py, PyAny, PyErr, PyResult, Python, + types::{PyAnyMethods, PyDict, PyDictMethods, PyNone}, + BoundObject, Py, PyAny, PyErr, PyResult, Python, }; use rand::Rng; @@ -371,7 +371,7 @@ impl CpuLogpFunc for PyMcModelRef<'_> { Ok(()) } - fn new_transformation( + fn init_transformation( &mut self, rng: &mut R, untransformed_position: &[f64], @@ -386,6 +386,18 @@ impl CpuLogpFunc for PyMcModelRef<'_> { Ok(trafo) } + fn new_transformation( + &mut self, + _rng: &mut R, + _dim: usize, + _chain: u64, + ) -> std::result::Result { + Python::attach(|py| { + let params = PyNone::get(py); + Ok(params.unbind().into()) + }) + } + fn transformation_id(&self, params: &Py) -> std::result::Result { let id = self .transform_adapter diff --git a/src/stan.rs b/src/stan.rs index f56fbe5..3c21354 100644 --- a/src/stan.rs +++ b/src/stan.rs @@ -7,9 +7,9 @@ use bridgestan::open_library; use itertools::Itertools; use nuts_rs::{CpuLogpFunc, CpuMath, HasDims, LogpError, Model, Storable, Value}; use pyo3::exceptions::PyRuntimeError; -use pyo3::prelude::*; -use pyo3::types::{PyDict, PyTuple}; +use pyo3::types::{PyDict, PyNone, PyTuple}; use pyo3::{exceptions::PyValueError, pyclass, pymethods, PyResult}; +use pyo3::{prelude::*, BoundObject}; use rand::prelude::Distribution; use rand::{rng, Rng}; use rand_distr::StandardNormal; @@ -202,7 +202,7 @@ where let (mut shape, is_complex) = group .iter() - .map(|&(_, is_complex, ref idx)| (idx, is_complex)) + .map(|&(_, is_complex, idx)| (idx, is_complex)) .fold(None, |acc, (elem_index, &elem_is_complex)| { let (mut shape, is_complex) = acc.unwrap_or((elem_index.clone(), elem_is_complex)); assert!( @@ -630,7 +630,7 @@ impl<'model> CpuLogpFunc for StanDensity<'model> { Ok(()) } - fn new_transformation( + fn init_transformation( &mut self, rng: &mut R, untransformed_position: &[f64], @@ -646,6 +646,18 @@ impl<'model> CpuLogpFunc for StanDensity<'model> { Ok(trafo) } + fn new_transformation( + &mut self, + _rng: &mut R, + _dim: usize, + _chain: u64, + ) -> std::result::Result { + Python::attach(|py| { + let params = PyNone::get(py); + Ok(params.unbind().into()) + }) + } + fn transformation_id(&self, params: &Py) -> std::result::Result { let id = self .transform_adapter diff --git a/src/wrapper.rs b/src/wrapper.rs index 8ba7e50..bf35f88 100644 --- a/src/wrapper.rs +++ b/src/wrapper.rs @@ -405,6 +405,24 @@ impl PyNutsSettings { } } + #[getter] + fn store_transformed(&self) -> bool { + match &self.inner { + Settings::Diag(nuts_settings) => nuts_settings.store_transformed, + Settings::LowRank(nuts_settings) => nuts_settings.store_transformed, + Settings::Transforming(nuts_settings) => nuts_settings.store_transformed, + } + } + + #[setter(store_transformed)] + fn set_store_transformed(&mut self, val: bool) { + match &mut self.inner { + Settings::Diag(nuts_settings) => nuts_settings.store_transformed = val, + Settings::LowRank(nuts_settings) => nuts_settings.store_transformed = val, + Settings::Transforming(nuts_settings) => nuts_settings.store_transformed = val, + } + } + #[getter] fn store_divergences(&self) -> bool { match &self.inner { @@ -535,6 +553,42 @@ impl PyNutsSettings { Ok(()) } + #[getter] + fn exact_normal_trajectory(&self) -> bool { + match &self.inner { + Settings::LowRank(settings) => settings.exact_normal_trajectory, + Settings::Diag(settings) => settings.exact_normal_trajectory, + Settings::Transforming(settings) => settings.exact_normal_trajectory, + } + } + + #[setter(exact_normal_trajectory)] + fn set_exact_normal_trajectory(&mut self, val: bool) { + match &mut self.inner { + Settings::LowRank(settings) => settings.exact_normal_trajectory = val, + Settings::Diag(settings) => settings.exact_normal_trajectory = val, + Settings::Transforming(settings) => settings.exact_normal_trajectory = val, + } + } + + #[getter] + fn extra_doublings(&self) -> u64 { + match &self.inner { + Settings::LowRank(settings) => settings.extra_doublings, + Settings::Diag(settings) => settings.extra_doublings, + Settings::Transforming(settings) => settings.extra_doublings, + } + } + + #[setter(extra_doublings)] + fn set_extra_doublings(&mut self, val: u64) { + match &mut self.inner { + Settings::LowRank(settings) => settings.extra_doublings = val, + Settings::Diag(settings) => settings.extra_doublings = val, + Settings::Transforming(settings) => settings.extra_doublings = val, + } + } + #[getter] fn mass_matrix_switch_freq(&self) -> Result { match &self.inner { @@ -692,20 +746,19 @@ impl PyNutsSettings { #[setter(step_size_adapt_method)] fn set_step_size_adapt_method(&mut self, method: Py) -> Result<()> { - let method = Python::attach(|py| { - if let Ok(method) = method.extract::(py) { - match method.as_str() { - "dual_average" => Ok(StepSizeAdaptMethod::DualAverage), - "adam" => Ok(StepSizeAdaptMethod::Adam), - _ => { - if let Ok(step_size) = method.parse::() { - Ok(StepSizeAdaptMethod::Fixed(step_size)) - } else { - bail!("step_size_adapt_method must be a positive float when using fixed step size"); - } + let method = Python::attach(|py| match method.extract::(py) { + Ok(method) => match method.as_str() { + "dual_average" => Ok(StepSizeAdaptMethod::DualAverage), + "adam" => Ok(StepSizeAdaptMethod::Adam), + _ => { + if let Ok(step_size) = method.parse::() { + Ok(StepSizeAdaptMethod::Fixed(step_size)) + } else { + bail!("step_size_adapt_method must be a positive float when using fixed step size"); } } - } else { + }, + _ => { bail!("step_size_adapt_method must be a string"); } })?; From 68bfb825c4eab7ab4b25295041892d9782191ef2 Mon Sep 17 00:00:00 2001 From: Adrian Seyboldt Date: Mon, 9 Mar 2026 16:20:29 +0100 Subject: [PATCH 05/21] feat: support isokinetic solver --- src/wrapper.rs | 120 ++++++++++++++++++++++++++++++++++++++++++++----- 1 file changed, 109 insertions(+), 11 deletions(-) diff --git a/src/wrapper.rs b/src/wrapper.rs index bf35f88..eb65466 100644 --- a/src/wrapper.rs +++ b/src/wrapper.rs @@ -16,9 +16,9 @@ use crate::{ use anyhow::{anyhow, bail, Context, Result}; use numpy::{PyArray1, PyReadonlyArray1}; use nuts_rs::{ - ArrowConfig, ArrowTrace, ChainProgress, DiagGradNutsSettings, LowRankNutsSettings, Model, - ProgressCallback, Sampler, SamplerWaitResult, StepSizeAdaptMethod, TransformedNutsSettings, - ZarrAsyncConfig, + ArrowConfig, ArrowTrace, ChainProgress, DiagGradNutsSettings, KineticEnergyKind, + LowRankNutsSettings, Model, ProgressCallback, Sampler, SamplerWaitResult, StepSizeAdaptMethod, + TransformedNutsSettings, ZarrAsyncConfig, }; use pyo3::{ exceptions::{PyTimeoutError, PyValueError}, @@ -460,7 +460,7 @@ impl PyNutsSettings { } #[getter] - fn set_target_accept(&self) -> f64 { + fn target_accept(&self) -> f64 { match &self.inner { Settings::Diag(nuts_settings) => { nuts_settings.adapt_options.step_size_settings.target_accept @@ -475,7 +475,7 @@ impl PyNutsSettings { } #[setter(target_accept)] - fn target_accept(&mut self, val: f64) { + fn set_target_accept(&mut self, val: f64) { match &mut self.inner { Settings::Diag(nuts_settings) => { nuts_settings.adapt_options.step_size_settings.target_accept = val @@ -489,6 +489,66 @@ impl PyNutsSettings { } } + #[getter] + fn max_step_size(&self) -> f64 { + match &self.inner { + Settings::Diag(nuts_settings) => { + nuts_settings + .adapt_options + .step_size_settings + .adapt_options + .dual_average + .max_step_size + } + Settings::LowRank(nuts_settings) => { + nuts_settings + .adapt_options + .step_size_settings + .adapt_options + .dual_average + .max_step_size + } + Settings::Transforming(nuts_settings) => { + nuts_settings + .adapt_options + .step_size_settings + .adapt_options + .dual_average + .max_step_size + } + } + } + + #[setter(max_step_size)] + fn set_max_step_size(&mut self, val: f64) { + match &mut self.inner { + Settings::Diag(nuts_settings) => { + nuts_settings + .adapt_options + .step_size_settings + .adapt_options + .dual_average + .max_step_size = val + } + Settings::LowRank(nuts_settings) => { + nuts_settings + .adapt_options + .step_size_settings + .adapt_options + .dual_average + .max_step_size = val + } + Settings::Transforming(nuts_settings) => { + nuts_settings + .adapt_options + .step_size_settings + .adapt_options + .dual_average + .max_step_size = val + } + } + } + #[getter] fn store_mass_matrix(&self) -> Result { match &self.inner { @@ -553,21 +613,59 @@ impl PyNutsSettings { Ok(()) } + #[getter] + fn microcanonical_trajectory(&self) -> bool { + match &self.inner { + Settings::LowRank(settings) => { + settings.trajectory_kind == KineticEnergyKind::Microcanonical + } + Settings::Diag(settings) => { + settings.trajectory_kind == KineticEnergyKind::Microcanonical + } + Settings::Transforming(settings) => { + settings.trajectory_kind == KineticEnergyKind::Microcanonical + } + } + } + + #[setter(microcanonical_trajectory)] + fn set_microcanonical_trajectory(&mut self, val: bool) { + let kind = if val { + KineticEnergyKind::Microcanonical + } else { + KineticEnergyKind::Euclidean + }; + match &mut self.inner { + Settings::LowRank(settings) => settings.trajectory_kind = kind, + Settings::Diag(settings) => settings.trajectory_kind = kind, + Settings::Transforming(settings) => settings.trajectory_kind = kind, + } + } + #[getter] fn exact_normal_trajectory(&self) -> bool { match &self.inner { - Settings::LowRank(settings) => settings.exact_normal_trajectory, - Settings::Diag(settings) => settings.exact_normal_trajectory, - Settings::Transforming(settings) => settings.exact_normal_trajectory, + Settings::LowRank(settings) => { + settings.trajectory_kind == KineticEnergyKind::ExactNormal + } + Settings::Diag(settings) => settings.trajectory_kind == KineticEnergyKind::ExactNormal, + Settings::Transforming(settings) => { + settings.trajectory_kind == KineticEnergyKind::ExactNormal + } } } #[setter(exact_normal_trajectory)] fn set_exact_normal_trajectory(&mut self, val: bool) { + let kind = if val { + KineticEnergyKind::ExactNormal + } else { + KineticEnergyKind::Euclidean + }; match &mut self.inner { - Settings::LowRank(settings) => settings.exact_normal_trajectory = val, - Settings::Diag(settings) => settings.exact_normal_trajectory = val, - Settings::Transforming(settings) => settings.exact_normal_trajectory = val, + Settings::LowRank(settings) => settings.trajectory_kind = kind, + Settings::Diag(settings) => settings.trajectory_kind = kind, + Settings::Transforming(settings) => settings.trajectory_kind = kind, } } From b280ac4fbd02201b7eb9df50f31d23997d916fb7 Mon Sep 17 00:00:00 2001 From: Adrian Seyboldt Date: Mon, 9 Mar 2026 16:20:29 +0100 Subject: [PATCH 06/21] refactor: cleaner settings handling for nuts --- python/nutpie/sample.py | 10 +- src/wrapper.rs | 1137 ++++++++++++--------------------------- 2 files changed, 356 insertions(+), 791 deletions(-) diff --git a/python/nutpie/sample.py b/python/nutpie/sample.py index 51965b4..2ba68ee 100644 --- a/python/nutpie/sample.py +++ b/python/nutpie/sample.py @@ -914,15 +914,15 @@ def sample( f"Expected one of: 'diag', 'draw_diag', 'low_rank', 'flow'." ) + updates = dict(kwargs) if tune is not None: - settings.num_tune = tune + updates["num_tune"] = tune if draws is not None: - settings.num_draws = draws + updates["num_draws"] = draws if chains is not None: - settings.num_chains = chains + updates["num_chains"] = chains - for name, val in kwargs.items(): - setattr(settings, name, val) + settings.update(updates) if cores is None: try: diff --git a/src/wrapper.rs b/src/wrapper.rs index eb65466..c5146f8 100644 --- a/src/wrapper.rs +++ b/src/wrapper.rs @@ -21,14 +21,16 @@ use nuts_rs::{ TransformedNutsSettings, ZarrAsyncConfig, }; use pyo3::{ - exceptions::{PyTimeoutError, PyValueError}, + exceptions::{PyAttributeError, PyTimeoutError, PyValueError}, intern, prelude::*, - types::PyList, + types::PyDict, }; use pyo3_arrow::PyRecordBatch; use pyo3_object_store::AnyObjectStore; +use pythonize::{depythonize_bound, pythonize}; use rand::{rng, Rng}; +use serde_json::Value as JsonValue; use tokio::runtime::Runtime; use zarrs_object_store::{object_store::limit::LimitStore, AsyncObjectStore}; @@ -113,6 +115,84 @@ enum Settings { Transforming(TransformedNutsSettings), } +macro_rules! with_all_settings_mut { + ($self:expr, $settings:ident => $body:block) => {{ + match &mut $self.inner { + Settings::Diag($settings) => $body, + Settings::LowRank($settings) => $body, + Settings::Transforming($settings) => $body, + } + }}; +} + +macro_rules! set_all_settings_field { + ($self:expr, $field:ident = $value:expr) => {{ + with_all_settings_mut!($self, settings => { + settings.$field = $value; + }); + }}; + ($self:expr, $field:ident $(. $rest:ident)+ = $value:expr) => {{ + with_all_settings_mut!($self, settings => { + settings.$field$(.$rest)+ = $value; + }); + }}; +} + +macro_rules! unsupported_option_error { + ($option:expr, $adaptation:expr) => { + PyValueError::new_err(format!( + "Option {} not available for {} adaptation", + $option, $adaptation + )) + }; +} + +macro_rules! with_diag_or_low_rank_settings_mut { + ($self:expr, $option:expr, $settings:ident => $body:block) => {{ + match &mut $self.inner { + Settings::Diag($settings) => $body, + Settings::LowRank($settings) => $body, + Settings::Transforming(_) => { + return Err(unsupported_option_error!($option, "transforming")) + } + } + }}; +} + +macro_rules! with_diag_settings_mut { + ($self:expr, $option:expr, $settings:ident => $body:block) => {{ + match &mut $self.inner { + Settings::Diag($settings) => $body, + Settings::LowRank(_) => return Err(unsupported_option_error!($option, "low-rank")), + Settings::Transforming(_) => { + return Err(unsupported_option_error!($option, "transforming")) + } + } + }}; +} + +macro_rules! with_low_rank_settings_mut { + ($self:expr, $option:expr, $settings:ident => $body:block) => {{ + match &mut $self.inner { + Settings::LowRank($settings) => $body, + Settings::Diag(_) => return Err(unsupported_option_error!($option, "diag")), + Settings::Transforming(_) => { + return Err(unsupported_option_error!($option, "transforming")) + } + } + }}; +} + +macro_rules! with_transform_settings_mut { + ($self:expr, $option:expr, $settings:ident => $body:block) => {{ + match &mut $self.inner { + Settings::Transforming($settings) => $body, + Settings::Diag(_) => return Err(unsupported_option_error!($option, "diag")), + Settings::LowRank(_) => return Err(unsupported_option_error!($option, "low-rank")), + } + }}; +} + impl PyNutsSettings { fn new_diag(seed: Option) -> Self { let seed = seed.unwrap_or_else(|| { @@ -158,838 +238,323 @@ impl PyNutsSettings { inner: Settings::Transforming(settings), } } -} - -// TODO switch to serde to expose all the options... -#[pymethods] -impl PyNutsSettings { - #[staticmethod] - #[allow(non_snake_case)] - #[pyo3(signature = (seed=None))] - fn Diag(seed: Option) -> Self { - PyNutsSettings::new_diag(seed) - } - - #[staticmethod] - #[allow(non_snake_case)] - #[pyo3(signature = (seed=None))] - fn LowRank(seed: Option) -> Self { - PyNutsSettings::new_low_rank(seed) - } - - #[staticmethod] - #[allow(non_snake_case)] - #[pyo3(signature = (seed=None))] - fn Transform(seed: Option) -> Self { - PyNutsSettings::new_tranform_adapt(seed) - } - - #[getter] - fn num_tune(&self) -> u64 { - match &self.inner { - Settings::Diag(nuts_settings) => nuts_settings.num_tune, - Settings::LowRank(nuts_settings) => nuts_settings.num_tune, - Settings::Transforming(nuts_settings) => nuts_settings.num_tune, - } - } - - #[setter(num_tune)] - fn set_num_tune(&mut self, val: u64) { - match &mut self.inner { - Settings::Diag(nuts_settings) => nuts_settings.num_tune = val, - Settings::LowRank(nuts_settings) => nuts_settings.num_tune = val, - Settings::Transforming(nuts_settings) => nuts_settings.num_tune = val, - } - } - - #[getter] - fn num_chains(&self) -> usize { - match &self.inner { - Settings::Diag(nuts_settings) => nuts_settings.num_chains, - Settings::LowRank(nuts_settings) => nuts_settings.num_chains, - Settings::Transforming(nuts_settings) => nuts_settings.num_chains, - } - } - - #[setter(num_chains)] - fn set_num_chains(&mut self, val: usize) { - match &mut self.inner { - Settings::Diag(nuts_settings) => nuts_settings.num_chains = val, - Settings::LowRank(nuts_settings) => nuts_settings.num_chains = val, - Settings::Transforming(nuts_settings) => nuts_settings.num_chains = val, - } - } - - #[getter] - fn num_draws(&self) -> u64 { - match &self.inner { - Settings::Diag(nuts_settings) => nuts_settings.num_draws, - Settings::LowRank(nuts_settings) => nuts_settings.num_draws, - Settings::Transforming(nuts_settings) => nuts_settings.num_draws, - } - } - - #[setter(num_draws)] - fn set_num_draws(&mut self, val: u64) { - match &mut self.inner { - Settings::Diag(nuts_settings) => nuts_settings.num_draws = val, - Settings::LowRank(nuts_settings) => nuts_settings.num_draws = val, - Settings::Transforming(nuts_settings) => nuts_settings.num_draws = val, - } - } - - #[getter] - fn window_switch_freq(&self) -> Result { - match &self.inner { - Settings::Diag(nuts_settings) => { - Ok(nuts_settings.adapt_options.mass_matrix_switch_freq) - } - Settings::LowRank(nuts_settings) => { - Ok(nuts_settings.adapt_options.mass_matrix_switch_freq) - } - Settings::Transforming(nuts_settings) => { - Ok(nuts_settings.adapt_options.transform_update_freq) - } - } - } - - #[setter(window_switch_freq)] - fn set_window_switch_freq(&mut self, val: u64) -> Result<()> { - match &mut self.inner { - Settings::Diag(nuts_settings) => { - nuts_settings.adapt_options.mass_matrix_switch_freq = val; - Ok(()) - } - Settings::LowRank(nuts_settings) => { - nuts_settings.adapt_options.mass_matrix_switch_freq = val; - Ok(()) - } - Settings::Transforming(nuts_settings) => { - nuts_settings.adapt_options.transform_update_freq = val; - Ok(()) - } - } - } - - #[getter] - fn early_window_switch_freq(&self) -> Result { - match &self.inner { - Settings::Diag(nuts_settings) => { - Ok(nuts_settings.adapt_options.early_mass_matrix_switch_freq) - } - Settings::LowRank(nuts_settings) => { - Ok(nuts_settings.adapt_options.early_mass_matrix_switch_freq) - } - Settings::Transforming(_) => { - bail!("Option early_window_switch_freq not availbale for transformation adaptation") - } - } - } - #[setter(early_window_switch_freq)] - fn set_early_window_switch_freq(&mut self, val: u64) -> Result<()> { + fn update_from_nested_dict(&mut self, value: &Bound<'_, PyAny>) -> PyResult<()> { match &mut self.inner { - Settings::Diag(nuts_settings) => { - nuts_settings.adapt_options.early_mass_matrix_switch_freq = val; - Ok(()) - } - Settings::LowRank(nuts_settings) => { - nuts_settings.adapt_options.early_mass_matrix_switch_freq = val; - Ok(()) - } - Settings::Transforming(_) => { - bail!("Option early_window_switch_freq not availbale for transformation adaptation") - } - } - } - - #[getter] - fn initial_step(&self) -> f64 { - match &self.inner { - Settings::Diag(nuts_settings) => { - nuts_settings.adapt_options.step_size_settings.initial_step - } - Settings::LowRank(nuts_settings) => { - nuts_settings.adapt_options.step_size_settings.initial_step - } - Settings::Transforming(nuts_settings) => { - nuts_settings.adapt_options.step_size_settings.initial_step - } - } - } - - #[setter(initial_step)] - fn set_initial_step(&mut self, val: f64) { - match &mut self.inner { - Settings::Diag(nuts_settings) => { - nuts_settings.adapt_options.step_size_settings.initial_step = val; - } - Settings::LowRank(nuts_settings) => { - nuts_settings.adapt_options.step_size_settings.initial_step = val; - } - Settings::Transforming(nuts_settings) => { - nuts_settings.adapt_options.step_size_settings.initial_step = val; - } - } - } - - #[getter] - fn maxdepth(&self) -> u64 { - match &self.inner { - Settings::Diag(nuts_settings) => nuts_settings.maxdepth, - Settings::LowRank(nuts_settings) => nuts_settings.maxdepth, - Settings::Transforming(nuts_settings) => nuts_settings.maxdepth, - } - } - - #[setter(maxdepth)] - fn set_maxdepth(&mut self, val: u64) { - match &mut self.inner { - Settings::Diag(nuts_settings) => nuts_settings.maxdepth = val, - Settings::LowRank(nuts_settings) => nuts_settings.maxdepth = val, - Settings::Transforming(nuts_settings) => nuts_settings.maxdepth = val, - } - } - - #[getter] - fn mindepth(&self) -> u64 { - match &self.inner { - Settings::Diag(nuts_settings) => nuts_settings.mindepth, - Settings::LowRank(nuts_settings) => nuts_settings.mindepth, - Settings::Transforming(nuts_settings) => nuts_settings.mindepth, - } - } - - #[setter(mindepth)] - fn set_mindepth(&mut self, val: u64) { - match &mut self.inner { - Settings::Diag(nuts_settings) => nuts_settings.mindepth = val, - Settings::LowRank(nuts_settings) => nuts_settings.mindepth = val, - Settings::Transforming(nuts_settings) => nuts_settings.mindepth = val, - } - } - - #[getter] - fn store_gradient(&self) -> bool { - match &self.inner { - Settings::Diag(nuts_settings) => nuts_settings.store_gradient, - Settings::LowRank(nuts_settings) => nuts_settings.store_gradient, - Settings::Transforming(nuts_settings) => nuts_settings.store_gradient, - } - } - - #[setter(store_gradient)] - fn set_store_gradient(&mut self, val: bool) { - match &mut self.inner { - Settings::Diag(nuts_settings) => nuts_settings.store_gradient = val, - Settings::LowRank(nuts_settings) => nuts_settings.store_gradient = val, - Settings::Transforming(nuts_settings) => nuts_settings.store_gradient = val, - } - } - - #[getter] - fn store_unconstrained(&self) -> bool { - match &self.inner { - Settings::Diag(nuts_settings) => nuts_settings.store_unconstrained, - Settings::LowRank(nuts_settings) => nuts_settings.store_unconstrained, - Settings::Transforming(nuts_settings) => nuts_settings.store_unconstrained, - } - } - - #[setter(store_unconstrained)] - fn set_store_unconstrained(&mut self, val: bool) { - match &mut self.inner { - Settings::Diag(nuts_settings) => nuts_settings.store_unconstrained = val, - Settings::LowRank(nuts_settings) => nuts_settings.store_unconstrained = val, - Settings::Transforming(nuts_settings) => nuts_settings.store_unconstrained = val, - } - } - - #[getter] - fn store_transformed(&self) -> bool { - match &self.inner { - Settings::Diag(nuts_settings) => nuts_settings.store_transformed, - Settings::LowRank(nuts_settings) => nuts_settings.store_transformed, - Settings::Transforming(nuts_settings) => nuts_settings.store_transformed, - } - } - - #[setter(store_transformed)] - fn set_store_transformed(&mut self, val: bool) { - match &mut self.inner { - Settings::Diag(nuts_settings) => nuts_settings.store_transformed = val, - Settings::LowRank(nuts_settings) => nuts_settings.store_transformed = val, - Settings::Transforming(nuts_settings) => nuts_settings.store_transformed = val, - } - } - - #[getter] - fn store_divergences(&self) -> bool { - match &self.inner { - Settings::Diag(nuts_settings) => nuts_settings.store_divergences, - Settings::LowRank(nuts_settings) => nuts_settings.store_divergences, - Settings::Transforming(nuts_settings) => nuts_settings.store_divergences, - } - } - - #[setter(store_divergences)] - fn set_store_divergences(&mut self, val: bool) { - match &mut self.inner { - Settings::Diag(nuts_settings) => nuts_settings.store_divergences = val, - Settings::LowRank(nuts_settings) => nuts_settings.store_divergences = val, - Settings::Transforming(nuts_settings) => nuts_settings.store_divergences = val, - } - } - - #[getter] - fn max_energy_error(&self) -> f64 { - match &self.inner { - Settings::Diag(nuts_settings) => nuts_settings.max_energy_error, - Settings::LowRank(nuts_settings) => nuts_settings.max_energy_error, - Settings::Transforming(nuts_settings) => nuts_settings.max_energy_error, - } - } - - #[setter(max_energy_error)] - fn set_max_energy_error(&mut self, val: f64) { - match &mut self.inner { - Settings::Diag(nuts_settings) => nuts_settings.max_energy_error = val, - Settings::LowRank(nuts_settings) => nuts_settings.max_energy_error = val, - Settings::Transforming(nuts_settings) => nuts_settings.max_energy_error = val, - } - } - - #[getter] - fn target_accept(&self) -> f64 { - match &self.inner { - Settings::Diag(nuts_settings) => { - nuts_settings.adapt_options.step_size_settings.target_accept - } - Settings::LowRank(nuts_settings) => { - nuts_settings.adapt_options.step_size_settings.target_accept - } - Settings::Transforming(nuts_settings) => { - nuts_settings.adapt_options.step_size_settings.target_accept - } - } - } - - #[setter(target_accept)] - fn set_target_accept(&mut self, val: f64) { - match &mut self.inner { - Settings::Diag(nuts_settings) => { - nuts_settings.adapt_options.step_size_settings.target_accept = val - } - Settings::LowRank(nuts_settings) => { - nuts_settings.adapt_options.step_size_settings.target_accept = val - } - Settings::Transforming(nuts_settings) => { - nuts_settings.adapt_options.step_size_settings.target_accept = val - } - } - } - - #[getter] - fn max_step_size(&self) -> f64 { - match &self.inner { - Settings::Diag(nuts_settings) => { - nuts_settings - .adapt_options - .step_size_settings - .adapt_options - .dual_average - .max_step_size + Settings::Diag(settings) => { + *settings = depythonize_bound(value) + .map_err(|err| PyValueError::new_err(err.to_string()))?; } - Settings::LowRank(nuts_settings) => { - nuts_settings - .adapt_options - .step_size_settings - .adapt_options - .dual_average - .max_step_size + Settings::LowRank(settings) => { + *settings = depythonize_bound(value) + .map_err(|err| PyValueError::new_err(err.to_string()))?; } - Settings::Transforming(nuts_settings) => { - nuts_settings - .adapt_options - .step_size_settings - .adapt_options - .dual_average - .max_step_size + Settings::Transforming(settings) => { + *settings = depythonize_bound(value) + .map_err(|err| PyValueError::new_err(err.to_string()))?; } } + Ok(()) } - #[setter(max_step_size)] - fn set_max_step_size(&mut self, val: f64) { - match &mut self.inner { - Settings::Diag(nuts_settings) => { - nuts_settings - .adapt_options - .step_size_settings - .adapt_options - .dual_average - .max_step_size = val + fn apply_update(&mut self, name: &str, value: &Bound<'_, PyAny>) -> PyResult<()> { + match name { + "num_tune" => { + let value: u64 = value.extract()?; + set_all_settings_field!(self, num_tune = value); } - Settings::LowRank(nuts_settings) => { - nuts_settings - .adapt_options - .step_size_settings - .adapt_options - .dual_average - .max_step_size = val + "num_chains" => { + let value: usize = value.extract()?; + set_all_settings_field!(self, num_chains = value); } - Settings::Transforming(nuts_settings) => { - nuts_settings - .adapt_options - .step_size_settings - .adapt_options - .dual_average - .max_step_size = val + "num_draws" => { + let value: u64 = value.extract()?; + set_all_settings_field!(self, num_draws = value); } - } - } - - #[getter] - fn store_mass_matrix(&self) -> Result { - match &self.inner { - Settings::LowRank(settings) => { - Ok(settings.adapt_options.mass_matrix_options.store_mass_matrix) - } - Settings::Diag(settings) => { - Ok(settings.adapt_options.mass_matrix_options.store_mass_matrix) - } - Settings::Transforming(_) => Ok(false), - } - } - - #[setter(store_mass_matrix)] - fn set_store_mass_matrix(&mut self, val: bool) -> Result<()> { - match &mut self.inner { - Settings::LowRank(settings) => { - settings.adapt_options.mass_matrix_options.store_mass_matrix = val; - Ok(()) + "window_switch_freq" => { + let value: u64 = value.extract()?; + match &mut self.inner { + Settings::Diag(settings) => { + settings.adapt_options.mass_matrix_switch_freq = value + } + Settings::LowRank(settings) => { + settings.adapt_options.mass_matrix_switch_freq = value + } + Settings::Transforming(settings) => { + settings.adapt_options.transform_update_freq = value + } + } } - Settings::Diag(settings) => { - settings.adapt_options.mass_matrix_options.store_mass_matrix = val; - Ok(()) + "early_window_switch_freq" => { + let value: u64 = value.extract()?; + with_diag_or_low_rank_settings_mut!( + self, + "early_window_switch_freq", + settings => { + settings.adapt_options.early_mass_matrix_switch_freq = value; + } + ); + } + "initial_step" => { + let value: f64 = value.extract()?; + set_all_settings_field!( + self, + adapt_options.step_size_settings.initial_step = value + ); + } + "maxdepth" => { + let value: u64 = value.extract()?; + set_all_settings_field!(self, maxdepth = value); + } + "mindepth" => { + let value: u64 = value.extract()?; + set_all_settings_field!(self, mindepth = value); + } + "store_gradient" => { + let value: bool = value.extract()?; + set_all_settings_field!(self, store_gradient = value); + } + "store_unconstrained" => { + let value: bool = value.extract()?; + set_all_settings_field!(self, store_unconstrained = value); + } + "store_transformed" => { + let value: bool = value.extract()?; + set_all_settings_field!(self, store_transformed = value); + } + "store_divergences" => { + let value: bool = value.extract()?; + set_all_settings_field!(self, store_divergences = value); + } + "max_energy_error" => { + let value: f64 = value.extract()?; + set_all_settings_field!(self, max_energy_error = value); + } + "target_accept" => { + let value: f64 = value.extract()?; + set_all_settings_field!( + self, + adapt_options.step_size_settings.target_accept = value + ); + } + "max_step_size" => { + let value: f64 = value.extract()?; + set_all_settings_field!( + self, + adapt_options + .step_size_settings + .adapt_options + .dual_average + .max_step_size = value + ); + } + "store_mass_matrix" => { + let value: bool = value.extract()?; + with_diag_or_low_rank_settings_mut!( + self, + "store_mass_matrix", + settings => { + settings.adapt_options.mass_matrix_options.store_mass_matrix = value; + } + ); + } + "use_grad_based_mass_matrix" => { + let value: bool = value.extract()?; + with_diag_settings_mut!( + self, + "use_grad_based_mass_matrix", + settings => { + settings.adapt_options.mass_matrix_options.use_grad_based_estimate = value; + } + ); + } + "microcanonical_trajectory" => { + let value: bool = value.extract()?; + if value { + set_all_settings_field!( + self, + trajectory_kind = KineticEnergyKind::Microcanonical + ); + } } - Settings::Transforming(_) => { - bail!("Option store_mass_matrix not availbale for transformation adaptation") + "exact_normal_trajectory" => { + let value: bool = value.extract()?; + if value { + set_all_settings_field!(self, trajectory_kind = KineticEnergyKind::ExactNormal); + } } - } - } - - #[getter] - fn use_grad_based_mass_matrix(&self) -> Result { - match &self.inner { - Settings::LowRank(_) => { - bail!("non-grad based mass matrix not available for low-rank adaptation") + "extra_doublings" => { + let value: u64 = value.extract()?; + set_all_settings_field!(self, extra_doublings = value); + } + "mass_matrix_switch_freq" => { + let value: u64 = value.extract()?; + with_diag_or_low_rank_settings_mut!( + self, + "mass_matrix_switch_freq", + settings => { + settings.adapt_options.mass_matrix_switch_freq = value; + } + ); + } + "mass_matrix_eigval_cutoff" => { + let value: Option = value.extract()?; + if let Some(value) = value { + with_low_rank_settings_mut!( + self, + "mass_matrix_eigval_cutoff", + settings => { + settings.adapt_options.mass_matrix_options.eigval_cutoff = value; + } + ); + } } - Settings::Transforming(_) => { - bail!("non-grad based mass matrix not available for transforming adaptation") + "mass_matrix_gamma" => { + let value: Option = value.extract()?; + if let Some(value) = value { + with_low_rank_settings_mut!( + self, + "mass_matrix_gamma", + settings => { + settings.adapt_options.mass_matrix_options.gamma = value; + } + ); + } } - Settings::Diag(diag) => Ok(diag - .adapt_options - .mass_matrix_options - .use_grad_based_estimate), - } - } + "train_on_orbit" => { + let value: bool = value.extract()?; + with_transform_settings_mut!( + self, + "train_on_orbit", + settings => { + settings.adapt_options.use_orbit_for_training = value; + } + ); + } + "check_turning" => { + let value: bool = value.extract()?; + set_all_settings_field!(self, check_turning = value); + } + "step_size_adapt_method" => { + let method = match value.extract::() { + Ok(method) => match method.as_str() { + "dual_average" => StepSizeAdaptMethod::DualAverage, + "adam" => StepSizeAdaptMethod::Adam, + _ => { + if let Ok(step_size) = method.parse::() { + StepSizeAdaptMethod::Fixed(step_size) + } else { + return Err(PyValueError::new_err( + "step_size_adapt_method must be a positive float when using fixed step size", + )); + } + } + }, + _ => { + return Err(PyValueError::new_err( + "step_size_adapt_method must be a string", + )); + } + }; - #[setter(use_grad_based_mass_matrix)] - fn set_use_grad_based_mass_matrix(&mut self, val: bool) -> Result<()> { - match &mut self.inner { - Settings::LowRank(_) => { - bail!("non-grad based mass matrix not available for low-rank adaptation"); + set_all_settings_field!( + self, + adapt_options.step_size_settings.adapt_options.method = method + ); + } + "step_size_adam_learning_rate" => { + let value: Option = value.extract()?; + if let Some(value) = value { + set_all_settings_field!( + self, + adapt_options + .step_size_settings + .adapt_options + .adam + .learning_rate = value + ); + } } - Settings::Transforming(_) => { - bail!("non-grad based mass matrix not available for transforming adaptation"); + "step_size_jitter" => { + let mut value: Option = value.extract()?; + if let Some(jitter) = value { + if jitter < 0.0 { + return Err(PyValueError::new_err("step_size_jitter must be positive")); + } + if jitter == 0.0 { + value = None; + } + } + set_all_settings_field!(self, adapt_options.step_size_settings.jitter = value); } - Settings::Diag(diag) => { - diag.adapt_options - .mass_matrix_options - .use_grad_based_estimate = val; + _ => { + return Err(PyAttributeError::new_err(format!( + "Unknown settings attribute: {name}", + ))); } } Ok(()) } - #[getter] - fn microcanonical_trajectory(&self) -> bool { + fn to_nested_json(&self) -> PyResult { match &self.inner { - Settings::LowRank(settings) => { - settings.trajectory_kind == KineticEnergyKind::Microcanonical - } Settings::Diag(settings) => { - settings.trajectory_kind == KineticEnergyKind::Microcanonical + serde_json::to_value(settings).map_err(|err| PyValueError::new_err(err.to_string())) } - Settings::Transforming(settings) => { - settings.trajectory_kind == KineticEnergyKind::Microcanonical - } - } - } - - #[setter(microcanonical_trajectory)] - fn set_microcanonical_trajectory(&mut self, val: bool) { - let kind = if val { - KineticEnergyKind::Microcanonical - } else { - KineticEnergyKind::Euclidean - }; - match &mut self.inner { - Settings::LowRank(settings) => settings.trajectory_kind = kind, - Settings::Diag(settings) => settings.trajectory_kind = kind, - Settings::Transforming(settings) => settings.trajectory_kind = kind, - } - } - - #[getter] - fn exact_normal_trajectory(&self) -> bool { - match &self.inner { Settings::LowRank(settings) => { - settings.trajectory_kind == KineticEnergyKind::ExactNormal + serde_json::to_value(settings).map_err(|err| PyValueError::new_err(err.to_string())) } - Settings::Diag(settings) => settings.trajectory_kind == KineticEnergyKind::ExactNormal, Settings::Transforming(settings) => { - settings.trajectory_kind == KineticEnergyKind::ExactNormal - } - } - } - - #[setter(exact_normal_trajectory)] - fn set_exact_normal_trajectory(&mut self, val: bool) { - let kind = if val { - KineticEnergyKind::ExactNormal - } else { - KineticEnergyKind::Euclidean - }; - match &mut self.inner { - Settings::LowRank(settings) => settings.trajectory_kind = kind, - Settings::Diag(settings) => settings.trajectory_kind = kind, - Settings::Transforming(settings) => settings.trajectory_kind = kind, - } - } - - #[getter] - fn extra_doublings(&self) -> u64 { - match &self.inner { - Settings::LowRank(settings) => settings.extra_doublings, - Settings::Diag(settings) => settings.extra_doublings, - Settings::Transforming(settings) => settings.extra_doublings, - } - } - - #[setter(extra_doublings)] - fn set_extra_doublings(&mut self, val: u64) { - match &mut self.inner { - Settings::LowRank(settings) => settings.extra_doublings = val, - Settings::Diag(settings) => settings.extra_doublings = val, - Settings::Transforming(settings) => settings.extra_doublings = val, - } - } - - #[getter] - fn mass_matrix_switch_freq(&self) -> Result { - match &self.inner { - Settings::Diag(settings) => Ok(settings.adapt_options.mass_matrix_switch_freq), - Settings::LowRank(settings) => Ok(settings.adapt_options.mass_matrix_switch_freq), - Settings::Transforming(_) => { - bail!("mass_matrix_switch_freq not available for transforming adaptation"); - } - } - } - - #[setter(mass_matrix_switch_freq)] - fn set_mass_matrix_switch_freq(&mut self, val: u64) -> Result<()> { - match &mut self.inner { - Settings::Diag(settings) => settings.adapt_options.mass_matrix_switch_freq = val, - Settings::LowRank(settings) => settings.adapt_options.mass_matrix_switch_freq = val, - Settings::Transforming(_) => { - bail!("mass_matrix_switch_freq not available for transforming adaptation"); - } - } - Ok(()) - } - - #[getter] - fn mass_matrix_eigval_cutoff(&self) -> Result { - match &self.inner { - Settings::LowRank(inner) => Ok(inner.adapt_options.mass_matrix_options.eigval_cutoff), - Settings::Diag(_) => { - bail!("eigenvalue cutoff not available for diag mass matrix adaptation"); - } - Settings::Transforming(_) => { - bail!("eigenvalue cutoff not available for transfor adaptation"); - } - } - } - - #[setter(mass_matrix_eigval_cutoff)] - fn set_mass_matrix_eigval_cutoff(&mut self, val: Option) -> Result<()> { - let Some(val) = val else { - return Ok(()); - }; - match &mut self.inner { - Settings::LowRank(inner) => inner.adapt_options.mass_matrix_options.eigval_cutoff = val, - Settings::Diag(_) => { - bail!("eigenvalue cutoff not available for diag mass matrix adaptation"); - } - Settings::Transforming(_) => { - bail!("eigenvalue cutoff not available for transfor adaptation"); - } - } - Ok(()) - } - - #[getter] - fn mass_matrix_gamma(&self) -> Result { - match &self.inner { - Settings::LowRank(inner) => Ok(inner.adapt_options.mass_matrix_options.gamma), - Settings::Diag(_) => { - bail!("gamma not available for diag mass matrix adaptation"); - } - Settings::Transforming(_) => { - bail!("gamma not available for transform adaptation"); - } - } - } - - #[setter(mass_matrix_gamma)] - fn set_mass_matrix_gamma(&mut self, val: Option) -> Result<()> { - let Some(val) = val else { - return Ok(()); - }; - match &mut self.inner { - Settings::LowRank(inner) => { - inner.adapt_options.mass_matrix_options.gamma = val; - } - Settings::Diag(_) => { - bail!("gamma not available for diag mass matrix adaptation"); - } - Settings::Transforming(_) => { - bail!("gamma not available for transform adaptation"); + serde_json::to_value(settings).map_err(|err| PyValueError::new_err(err.to_string())) } } - Ok(()) } +} - #[getter] - fn train_on_orbit(&self) -> Result { - match &self.inner { - Settings::LowRank(_) => { - bail!("gamma not available for low rank mass matrix adaptation"); - } - Settings::Diag(_) => { - bail!("gamma not available for diag mass matrix adaptation"); - } - Settings::Transforming(inner) => Ok(inner.adapt_options.use_orbit_for_training), - } +// TODO switch to serde to expose all the options... +#[pymethods] +impl PyNutsSettings { + #[staticmethod] + #[allow(non_snake_case)] + #[pyo3(signature = (seed=None))] + fn Diag(seed: Option) -> Self { + PyNutsSettings::new_diag(seed) } - #[setter(train_on_orbit)] - fn set_train_on_orbit(&mut self, val: bool) -> Result<()> { - match &mut self.inner { - Settings::LowRank(_) => { - bail!("gamma not available for low rank mass matrix adaptation"); - } - Settings::Diag(_) => { - bail!("gamma not available for diag mass matrix adaptation"); - } - Settings::Transforming(inner) => inner.adapt_options.use_orbit_for_training = val, - } - Ok(()) + #[staticmethod] + #[allow(non_snake_case)] + #[pyo3(signature = (seed=None))] + fn LowRank(seed: Option) -> Self { + PyNutsSettings::new_low_rank(seed) } - #[getter] - fn check_turning(&self) -> Result { - match &self.inner { - Settings::LowRank(inner) => Ok(inner.check_turning), - Settings::Diag(inner) => Ok(inner.check_turning), - Settings::Transforming(inner) => Ok(inner.check_turning), - } + #[staticmethod] + #[allow(non_snake_case)] + #[pyo3(signature = (seed=None))] + fn Transform(seed: Option) -> Self { + PyNutsSettings::new_tranform_adapt(seed) } - #[setter(check_turning)] - fn set_check_turning(&mut self, val: bool) -> Result<()> { - match &mut self.inner { - Settings::LowRank(inner) => { - inner.check_turning = val; - } - Settings::Diag(inner) => { - inner.check_turning = val; - } - Settings::Transforming(inner) => { - inner.check_turning = val; - } + fn update(&mut self, kwargs: &Bound<'_, PyDict>) -> PyResult<()> { + for (key, value) in kwargs.iter() { + let key: String = key.extract()?; + self.apply_update(&key, &value)?; } Ok(()) } - #[getter] - fn step_size_adapt_method(&self) -> String { - let method = match &self.inner { - Settings::LowRank(inner) => inner.adapt_options.step_size_settings.adapt_options.method, - Settings::Diag(inner) => inner.adapt_options.step_size_settings.adapt_options.method, - Settings::Transforming(inner) => { - inner.adapt_options.step_size_settings.adapt_options.method - } - }; - - match method { - nuts_rs::StepSizeAdaptMethod::DualAverage => "dual_average", - nuts_rs::StepSizeAdaptMethod::Adam => "adam", - nuts_rs::StepSizeAdaptMethod::Fixed(_) => "fixed", - } - .to_string() + fn __setattr__(&mut self, name: &str, value: &Bound<'_, PyAny>) -> PyResult<()> { + self.apply_update(name, value) } - #[setter(step_size_adapt_method)] - fn set_step_size_adapt_method(&mut self, method: Py) -> Result<()> { - let method = Python::attach(|py| match method.extract::(py) { - Ok(method) => match method.as_str() { - "dual_average" => Ok(StepSizeAdaptMethod::DualAverage), - "adam" => Ok(StepSizeAdaptMethod::Adam), - _ => { - if let Ok(step_size) = method.parse::() { - Ok(StepSizeAdaptMethod::Fixed(step_size)) - } else { - bail!("step_size_adapt_method must be a positive float when using fixed step size"); - } - } - }, - _ => { - bail!("step_size_adapt_method must be a string"); - } - })?; - - match &mut self.inner { - Settings::LowRank(inner) => { - inner.adapt_options.step_size_settings.adapt_options.method = method - } - Settings::Diag(inner) => { - inner.adapt_options.step_size_settings.adapt_options.method = method - } - Settings::Transforming(inner) => { - inner.adapt_options.step_size_settings.adapt_options.method = method - } - }; - Ok(()) + fn update_settings(&mut self, settings: &Bound<'_, PyDict>) -> PyResult<()> { + self.update_from_nested_dict(settings.as_any()) } - #[getter] - fn step_size_adam_learning_rate(&self) -> Option { - match &self.inner { - Settings::LowRank(inner) => { - if let StepSizeAdaptMethod::Adam = - inner.adapt_options.step_size_settings.adapt_options.method - { - Some( - inner - .adapt_options - .step_size_settings - .adapt_options - .adam - .learning_rate, - ) - } else { - None - } - } - Settings::Diag(inner) => { - if let StepSizeAdaptMethod::Adam = - inner.adapt_options.step_size_settings.adapt_options.method - { - Some( - inner - .adapt_options - .step_size_settings - .adapt_options - .adam - .learning_rate, - ) - } else { - None - } - } - Settings::Transforming(inner) => { - if let StepSizeAdaptMethod::Adam = - inner.adapt_options.step_size_settings.adapt_options.method - { - Some( - inner - .adapt_options - .step_size_settings - .adapt_options - .adam - .learning_rate, - ) - } else { - None - } - } - } - } - - #[setter(step_size_adam_learning_rate)] - fn set_step_size_adam_learning_rate(&mut self, val: Option) -> Result<()> { - let Some(val) = val else { - return Ok(()); - }; - match &mut self.inner { - Settings::LowRank(inner) => { - inner - .adapt_options - .step_size_settings - .adapt_options - .adam - .learning_rate = val - } - Settings::Diag(inner) => { - inner - .adapt_options - .step_size_settings - .adapt_options - .adam - .learning_rate = val - } - Settings::Transforming(inner) => { - inner - .adapt_options - .step_size_settings - .adapt_options - .adam - .learning_rate = val - } - }; - Ok(()) - } - - #[getter(step_size_jitter)] - fn step_size_jitter(&self) -> Option { - match &self.inner { - Settings::LowRank(inner) => inner.adapt_options.step_size_settings.jitter, - Settings::Diag(inner) => inner.adapt_options.step_size_settings.jitter, - Settings::Transforming(inner) => inner.adapt_options.step_size_settings.jitter, - } - } - - #[setter(step_size_jitter)] - fn set_step_size_jitter(&mut self, mut val: Option) -> PyResult<()> { - if let Some(val) = val { - if val < 0.0 { - return Err(PyValueError::new_err("step_size_jitter must be positive")); - } - } - if let Some(jitter) = val { - if jitter == 0.0 { - val = None; - } - } - match &mut self.inner { - Settings::LowRank(inner) => inner.adapt_options.step_size_settings.jitter = val, - Settings::Diag(inner) => inner.adapt_options.step_size_settings.jitter = val, - Settings::Transforming(inner) => inner.adapt_options.step_size_settings.jitter = val, - } - Ok(()) + fn as_dict(&self, py: Python<'_>) -> PyResult> { + let value = self.to_nested_json()?; + let obj = pythonize(py, &value).map_err(|err| PyValueError::new_err(err.to_string()))?; + Ok(obj.unbind()) } } From 8652adea539706bdcf1efa389fd50c859f058cfc Mon Sep 17 00:00:00 2001 From: Adrian Seyboldt Date: Mon, 9 Mar 2026 16:20:29 +0100 Subject: [PATCH 07/21] feat: support mclmc sampler --- Cargo.toml | 7 +- python/nutpie/sample.py | 60 ++- src/wrapper.rs | 860 +++++++++++++++++++++++++++------------- 3 files changed, 632 insertions(+), 295 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 985bf3a..70f2fdd 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -41,8 +41,11 @@ pyo3-object_store = "0.9.0" zarrs = { version = "0.23.2", features = ["async"] } zarrs_object_store = "0.6.0" tokio = { version = "1.47.1", features = ["rt", "rt-multi-thread"] } -pyo3-arrow = "0.16.0" -arrow = { version = "57.0.0", features = ["json"] } +pyo3-arrow = "0.17.0" +arrow = { version = "58.0.0", features = ["json"] } +serde = { version = "1.0", features = ["derive"] } +serde_json = "1.0" +pythonize = "0.28.0" [dependencies.pyo3] version = "0.28.0" diff --git a/python/nutpie/sample.py b/python/nutpie/sample.py index 2ba68ee..b353df9 100644 --- a/python/nutpie/sample.py +++ b/python/nutpie/sample.py @@ -546,6 +546,7 @@ def wait(self, *, timeout=None): return self._extract(results) def _extract(self, results): + settings_dict = self._settings.as_dict() if self._return_raw_trace: return results else: @@ -570,7 +571,7 @@ def _extract(self, results): skips = { "store_gradient": ["gradient"], "store_unconstrained": ["unconstrained_draw"], - "store_mass_matrix": [ + "adapt_options.mass_matrix_options.store_mass_matrix": [ "mass_matrix_inv", "mass_matrix_eigvals", "mass_matrix_stds", @@ -584,6 +585,7 @@ def _extract(self, results): "store_transformed": [ "transformed_position", "transformed_gradient", + "transformed_mu", ], } @@ -731,6 +733,7 @@ def sample( seed: int | None = None, save_warmup: bool = True, progress_bar: bool = True, + sampler: Literal["nuts", "mclmc"] = "nuts", adaptation: Literal["diag", "draw_diag", "low_rank", "flow"] = "diag", init_mean: np.ndarray | None = None, return_raw_trace: bool = False, @@ -796,6 +799,12 @@ def sample( return_raw_trace: bool, default=False Return the raw trace object (an apache arrow structure) instead of converting to arviz. + sampler: str, default="nuts" + The sampler to use. One of: + + - ``"nuts"`` (default): No-U-Turn Sampler. + - ``"mclmc"``: Microcanonical Langevin Monte Carlo. + adaptation: str, default="diag" The mass matrix adaptation strategy to use. One of: @@ -900,18 +909,37 @@ def sample( stacklevel=2, ) - if adaptation == "low_rank": - settings = _lib.PyNutsSettings.LowRank(seed) - elif adaptation == "flow": - settings = _lib.PyNutsSettings.Transform(seed) - elif adaptation in ("diag", "draw_diag"): - settings = _lib.PyNutsSettings.Diag(seed) - if adaptation == "draw_diag" or _use_grad_based is False: - settings.use_grad_based_mass_matrix = False + if sampler == "nuts": + if adaptation == "low_rank": + settings = _lib.PyNutsSettings.LowRank(seed) + elif adaptation == "flow": + settings = _lib.PyNutsSettings.Flow(seed) + elif adaptation in ("diag", "draw_diag"): + settings = _lib.PyNutsSettings.Diag(seed) + if adaptation == "draw_diag" or _use_grad_based is False: + settings.use_grad_based_mass_matrix = False + else: + raise ValueError( + f"Unknown adaptation strategy '{adaptation}'. " + f"Expected one of: 'diag', 'draw_diag', 'low_rank', 'flow'." + ) + elif sampler == "mclmc": + if adaptation == "low_rank": + settings = _lib.PyMclmcSettings.LowRank(seed) + elif adaptation == "flow": + settings = _lib.PyMclmcSettings.Flow(seed) + elif adaptation in ("diag", "draw_diag"): + settings = _lib.PyMclmcSettings.Diag(seed) + if adaptation == "draw_diag" or _use_grad_based is False: + settings.use_grad_based_mass_matrix = False + else: + raise ValueError( + f"Unknown adaptation strategy '{adaptation}'. " + f"Expected one of: 'diag', 'draw_diag', 'low_rank', 'flow'." + ) else: raise ValueError( - f"Unknown adaptation strategy '{adaptation}'. " - f"Expected one of: 'diag', 'draw_diag', 'low_rank', 'flow'." + f"Unknown sampler '{sampler}'. Expected one of: 'nuts', 'mclmc'." ) updates = dict(kwargs) @@ -938,7 +966,7 @@ def sample( if init_mean is None: init_mean = np.zeros(compiled_model.n_dim) - sampler = _BackgroundSampler( + background_sampler = _BackgroundSampler( compiled_model, settings, init_mean, @@ -954,14 +982,14 @@ def sample( ) if not blocking: - return sampler + return background_sampler try: - result = sampler.wait() + result = background_sampler.wait() except KeyboardInterrupt: - result = sampler.abort() + result = background_sampler.abort() except: - sampler.cancel() + background_sampler.cancel() raise return result diff --git a/src/wrapper.rs b/src/wrapper.rs index c5146f8..cabebb6 100644 --- a/src/wrapper.rs +++ b/src/wrapper.rs @@ -16,19 +16,19 @@ use crate::{ use anyhow::{anyhow, bail, Context, Result}; use numpy::{PyArray1, PyReadonlyArray1}; use nuts_rs::{ - ArrowConfig, ArrowTrace, ChainProgress, DiagGradNutsSettings, KineticEnergyKind, - LowRankNutsSettings, Model, ProgressCallback, Sampler, SamplerWaitResult, StepSizeAdaptMethod, - TransformedNutsSettings, ZarrAsyncConfig, + ArrowConfig, ArrowTrace, ChainProgress, DiagMclmcSettings, DiagNutsSettings, FlowMclmcSettings, + FlowNutsSettings, KineticEnergyKind, LowRankMclmcSettings, LowRankNutsSettings, Model, + ProgressCallback, Sampler, SamplerWaitResult, StepSizeAdaptMethod, ZarrAsyncConfig, }; use pyo3::{ exceptions::{PyAttributeError, PyTimeoutError, PyValueError}, intern, prelude::*, - types::PyDict, + types::{PyDict, PyList}, }; use pyo3_arrow::PyRecordBatch; use pyo3_object_store::AnyObjectStore; -use pythonize::{depythonize_bound, pythonize}; +use pythonize::{depythonize, pythonize}; use rand::{rng, Rng}; use serde_json::Value as JsonValue; use tokio::runtime::Runtime; @@ -105,338 +105,246 @@ impl PyChainProgress { #[pyclass(from_py_object)] #[derive(Clone)] pub struct PyNutsSettings { - inner: Settings, + inner: NutsSettingsKind, +} + +#[derive(Clone, FromPyObject)] +enum PySamplerSettings { + Nuts(PyNutsSettings), + Mclmc(PyMclmcSettings), } #[derive(Clone, Debug)] -enum Settings { - Diag(DiagGradNutsSettings), +enum NutsSettingsKind { + Diag(DiagNutsSettings), LowRank(LowRankNutsSettings), - Transforming(TransformedNutsSettings), + Flow(FlowNutsSettings), +} + +#[pyclass(from_py_object)] +#[derive(Clone)] +pub struct PyMclmcSettings { + inner: MclmcSettingsKind, +} + +#[derive(Clone, Debug)] +enum MclmcSettingsKind { + Diag(DiagMclmcSettings), + LowRank(LowRankMclmcSettings), + Flow(FlowMclmcSettings), +} + +macro_rules! unsupported_option_error { + ($option:expr, $adaptation:expr) => { + PyValueError::new_err(format!( + "Option {} not available for {} adaptation", + $option, $adaptation + )) + }; } macro_rules! with_all_settings_mut { - ($self:expr, $settings:ident => $body:block) => {{ + ($self:expr, $enum_name:ident, $settings:ident => $body:block) => {{ match &mut $self.inner { - Settings::Diag($settings) => $body, - Settings::LowRank($settings) => $body, - Settings::Transforming($settings) => $body, + $enum_name::Diag($settings) => $body, + $enum_name::LowRank($settings) => $body, + $enum_name::Flow($settings) => $body, } }}; } macro_rules! set_all_settings_field { - ($self:expr, $field:ident = $value:expr) => {{ - with_all_settings_mut!($self, settings => { + ($self:expr, $enum_name:ident, $field:ident = $value:expr) => {{ + with_all_settings_mut!($self, $enum_name, settings => { settings.$field = $value; }); }}; - ($self:expr, $field:ident $(. $rest:ident)+ = $value:expr) => {{ - with_all_settings_mut!($self, settings => { + ($self:expr, $enum_name:ident, $field:ident $(. $rest:ident)+ = $value:expr) => {{ + with_all_settings_mut!($self, $enum_name, settings => { settings.$field$(.$rest)+ = $value; }); }}; } -macro_rules! unsupported_option_error { - ($option:expr, $adaptation:expr) => { - PyValueError::new_err(format!( - "Option {} not available for {} adaptation", - $option, $adaptation - )) - }; -} - macro_rules! with_diag_or_low_rank_settings_mut { - ($self:expr, $option:expr, $settings:ident => $body:block) => {{ + ($self:expr, $enum_name:ident, $option:expr, $settings:ident => $body:block) => {{ match &mut $self.inner { - Settings::Diag($settings) => $body, - Settings::LowRank($settings) => $body, - Settings::Transforming(_) => { - return Err(unsupported_option_error!($option, "transforming")) - } + $enum_name::Diag($settings) => $body, + $enum_name::LowRank($settings) => $body, + $enum_name::Flow(_) => return Err(unsupported_option_error!($option, "flow")), } }}; } macro_rules! with_diag_settings_mut { - ($self:expr, $option:expr, $settings:ident => $body:block) => {{ + ($self:expr, $enum_name:ident, $option:expr, $settings:ident => $body:block) => {{ match &mut $self.inner { - Settings::Diag($settings) => $body, - Settings::LowRank(_) => return Err(unsupported_option_error!($option, "low-rank")), - Settings::Transforming(_) => { - return Err(unsupported_option_error!($option, "transforming")) - } + $enum_name::Diag($settings) => $body, + $enum_name::LowRank(_) => return Err(unsupported_option_error!($option, "low-rank")), + $enum_name::Flow(_) => return Err(unsupported_option_error!($option, "flow")), } }}; } macro_rules! with_low_rank_settings_mut { - ($self:expr, $option:expr, $settings:ident => $body:block) => {{ + ($self:expr, $enum_name:ident, $option:expr, $settings:ident => $body:block) => {{ match &mut $self.inner { - Settings::LowRank($settings) => $body, - Settings::Diag(_) => return Err(unsupported_option_error!($option, "diag")), - Settings::Transforming(_) => { - return Err(unsupported_option_error!($option, "transforming")) - } + $enum_name::LowRank($settings) => $body, + $enum_name::Diag(_) => return Err(unsupported_option_error!($option, "diag")), + $enum_name::Flow(_) => return Err(unsupported_option_error!($option, "flow")), } }}; } -macro_rules! with_transform_settings_mut { - ($self:expr, $option:expr, $settings:ident => $body:block) => {{ +macro_rules! with_flow_settings_mut { + ($self:expr, $enum_name:ident, $option:expr, $settings:ident => $body:block) => {{ match &mut $self.inner { - Settings::Transforming($settings) => $body, - Settings::Diag(_) => return Err(unsupported_option_error!($option, "diag")), - Settings::LowRank(_) => return Err(unsupported_option_error!($option, "low-rank")), + $enum_name::Flow($settings) => $body, + $enum_name::Diag(_) => return Err(unsupported_option_error!($option, "diag")), + $enum_name::LowRank(_) => return Err(unsupported_option_error!($option, "low-rank")), } }}; } -impl PyNutsSettings { - fn new_diag(seed: Option) -> Self { - let seed = seed.unwrap_or_else(|| { - let mut rng = rng(); - rng.next_u64() - }); - let settings = DiagGradNutsSettings { - seed, - ..Default::default() - }; - - Self { - inner: Settings::Diag(settings), - } - } - - fn new_low_rank(seed: Option) -> Self { - let seed = seed.unwrap_or_else(|| { - let mut rng = rng(); - rng.next_u64() - }); - let settings = LowRankNutsSettings { - seed, - ..Default::default() - }; - - Self { - inner: Settings::LowRank(settings), - } - } - - fn new_tranform_adapt(seed: Option) -> Self { - let seed = seed.unwrap_or_else(|| { - let mut rng = rng(); - rng.next_u64() - }); - let settings = TransformedNutsSettings { - seed, - ..Default::default() - }; - - Self { - inner: Settings::Transforming(settings), - } - } - - fn update_from_nested_dict(&mut self, value: &Bound<'_, PyAny>) -> PyResult<()> { - match &mut self.inner { - Settings::Diag(settings) => { - *settings = depythonize_bound(value) - .map_err(|err| PyValueError::new_err(err.to_string()))?; - } - Settings::LowRank(settings) => { - *settings = depythonize_bound(value) - .map_err(|err| PyValueError::new_err(err.to_string()))?; - } - Settings::Transforming(settings) => { - *settings = depythonize_bound(value) - .map_err(|err| PyValueError::new_err(err.to_string()))?; - } - } - Ok(()) - } - - fn apply_update(&mut self, name: &str, value: &Bound<'_, PyAny>) -> PyResult<()> { - match name { - "num_tune" => { - let value: u64 = value.extract()?; - set_all_settings_field!(self, num_tune = value); - } - "num_chains" => { - let value: usize = value.extract()?; - set_all_settings_field!(self, num_chains = value); - } - "num_draws" => { - let value: u64 = value.extract()?; - set_all_settings_field!(self, num_draws = value); - } +macro_rules! try_shared_euclidean_adapt_update { + ($self:expr, $enum_name:ident, $name:expr, $value:expr) => {{ + match $name { "window_switch_freq" => { - let value: u64 = value.extract()?; - match &mut self.inner { - Settings::Diag(settings) => { + let value: u64 = $value.extract()?; + match &mut $self.inner { + $enum_name::Diag(settings) => { settings.adapt_options.mass_matrix_switch_freq = value } - Settings::LowRank(settings) => { + $enum_name::LowRank(settings) => { settings.adapt_options.mass_matrix_switch_freq = value } - Settings::Transforming(settings) => { + $enum_name::Flow(settings) => { settings.adapt_options.transform_update_freq = value } } + true } "early_window_switch_freq" => { - let value: u64 = value.extract()?; + let value: u64 = $value.extract()?; with_diag_or_low_rank_settings_mut!( - self, + $self, + $enum_name, "early_window_switch_freq", settings => { settings.adapt_options.early_mass_matrix_switch_freq = value; } ); + true } "initial_step" => { - let value: f64 = value.extract()?; + let value: f64 = $value.extract()?; set_all_settings_field!( - self, + $self, + $enum_name, adapt_options.step_size_settings.initial_step = value ); - } - "maxdepth" => { - let value: u64 = value.extract()?; - set_all_settings_field!(self, maxdepth = value); - } - "mindepth" => { - let value: u64 = value.extract()?; - set_all_settings_field!(self, mindepth = value); - } - "store_gradient" => { - let value: bool = value.extract()?; - set_all_settings_field!(self, store_gradient = value); - } - "store_unconstrained" => { - let value: bool = value.extract()?; - set_all_settings_field!(self, store_unconstrained = value); - } - "store_transformed" => { - let value: bool = value.extract()?; - set_all_settings_field!(self, store_transformed = value); - } - "store_divergences" => { - let value: bool = value.extract()?; - set_all_settings_field!(self, store_divergences = value); - } - "max_energy_error" => { - let value: f64 = value.extract()?; - set_all_settings_field!(self, max_energy_error = value); + true } "target_accept" => { - let value: f64 = value.extract()?; + let value: f64 = $value.extract()?; set_all_settings_field!( - self, + $self, + $enum_name, adapt_options.step_size_settings.target_accept = value ); + true } "max_step_size" => { - let value: f64 = value.extract()?; + let value: f64 = $value.extract()?; set_all_settings_field!( - self, + $self, + $enum_name, adapt_options .step_size_settings .adapt_options .dual_average .max_step_size = value ); + true } "store_mass_matrix" => { - let value: bool = value.extract()?; + let value: bool = $value.extract()?; with_diag_or_low_rank_settings_mut!( - self, + $self, + $enum_name, "store_mass_matrix", settings => { settings.adapt_options.mass_matrix_options.store_mass_matrix = value; } ); + true } "use_grad_based_mass_matrix" => { - let value: bool = value.extract()?; + let value: bool = $value.extract()?; with_diag_settings_mut!( - self, + $self, + $enum_name, "use_grad_based_mass_matrix", settings => { settings.adapt_options.mass_matrix_options.use_grad_based_estimate = value; } ); - } - "microcanonical_trajectory" => { - let value: bool = value.extract()?; - if value { - set_all_settings_field!( - self, - trajectory_kind = KineticEnergyKind::Microcanonical - ); - } - } - "exact_normal_trajectory" => { - let value: bool = value.extract()?; - if value { - set_all_settings_field!(self, trajectory_kind = KineticEnergyKind::ExactNormal); - } - } - "extra_doublings" => { - let value: u64 = value.extract()?; - set_all_settings_field!(self, extra_doublings = value); + true } "mass_matrix_switch_freq" => { - let value: u64 = value.extract()?; + let value: u64 = $value.extract()?; with_diag_or_low_rank_settings_mut!( - self, + $self, + $enum_name, "mass_matrix_switch_freq", settings => { settings.adapt_options.mass_matrix_switch_freq = value; } ); + true } "mass_matrix_eigval_cutoff" => { - let value: Option = value.extract()?; + let value: Option = $value.extract()?; if let Some(value) = value { with_low_rank_settings_mut!( - self, + $self, + $enum_name, "mass_matrix_eigval_cutoff", settings => { settings.adapt_options.mass_matrix_options.eigval_cutoff = value; } ); } + true } "mass_matrix_gamma" => { - let value: Option = value.extract()?; + let value: Option = $value.extract()?; if let Some(value) = value { with_low_rank_settings_mut!( - self, + $self, + $enum_name, "mass_matrix_gamma", settings => { settings.adapt_options.mass_matrix_options.gamma = value; } ); } + true } "train_on_orbit" => { - let value: bool = value.extract()?; - with_transform_settings_mut!( - self, + let value: bool = $value.extract()?; + with_flow_settings_mut!( + $self, + $enum_name, "train_on_orbit", settings => { settings.adapt_options.use_orbit_for_training = value; } ); - } - "check_turning" => { - let value: bool = value.extract()?; - set_all_settings_field!(self, check_turning = value); + true } "step_size_adapt_method" => { - let method = match value.extract::() { + let method = match $value.extract::() { Ok(method) => match method.as_str() { "dual_average" => StepSizeAdaptMethod::DualAverage, "adam" => StepSizeAdaptMethod::Adam, @@ -458,15 +366,18 @@ impl PyNutsSettings { }; set_all_settings_field!( - self, + $self, + $enum_name, adapt_options.step_size_settings.adapt_options.method = method ); + true } "step_size_adam_learning_rate" => { - let value: Option = value.extract()?; + let value: Option = $value.extract()?; if let Some(value) = value { set_all_settings_field!( - self, + $self, + $enum_name, adapt_options .step_size_settings .adapt_options @@ -474,9 +385,10 @@ impl PyNutsSettings { .learning_rate = value ); } + true } "step_size_jitter" => { - let mut value: Option = value.extract()?; + let mut value: Option = $value.extract()?; if let Some(jitter) = value { if jitter < 0.0 { return Err(PyValueError::new_err("step_size_jitter must be positive")); @@ -485,29 +397,300 @@ impl PyNutsSettings { value = None; } } - set_all_settings_field!(self, adapt_options.step_size_settings.jitter = value); + set_all_settings_field!( + $self, + $enum_name, + adapt_options.step_size_settings.jitter = value + ); + true + } + "store_unconstrained" => { + let value: bool = $value.extract()?; + set_all_settings_field!($self, $enum_name, store_unconstrained = value); + true + } + "store_gradient" => { + let value: bool = $value.extract()?; + set_all_settings_field!($self, $enum_name, store_gradient = value); + true + } + "num_tune" => { + let value: u64 = $value.extract()?; + set_all_settings_field!($self, $enum_name, num_tune = value); + true + } + "num_chains" => { + let value: usize = $value.extract()?; + set_all_settings_field!($self, $enum_name, num_chains = value); + true + } + "num_draws" => { + let value: u64 = $value.extract()?; + set_all_settings_field!($self, $enum_name, num_draws = value); + true + } + "store_transformed" => { + let value: bool = $value.extract()?; + set_all_settings_field!($self, $enum_name, store_transformed = value); + true + } + "store_divergences" => { + let value: bool = $value.extract()?; + set_all_settings_field!($self, $enum_name, store_divergences = value); + true + } + "max_energy_error" => { + let value: f64 = $value.extract()?; + set_all_settings_field!($self, $enum_name, max_energy_error = value); + true + } + _ => false, + } + }}; +} + +fn random_seed(seed: Option) -> u64 { + seed.unwrap_or_else(|| { + let mut rng = rng(); + rng.next_u64() + }) +} + +fn update_nuts_from_nested_dict( + inner: &mut NutsSettingsKind, + value: &Bound<'_, PyAny>, +) -> PyResult<()> { + match inner { + NutsSettingsKind::Diag(settings) => { + *settings = depythonize(value).map_err(|err| PyValueError::new_err(err.to_string()))?; + } + NutsSettingsKind::LowRank(settings) => { + *settings = depythonize(value).map_err(|err| PyValueError::new_err(err.to_string()))?; + } + NutsSettingsKind::Flow(settings) => { + *settings = depythonize(value).map_err(|err| PyValueError::new_err(err.to_string()))?; + } + } + Ok(()) +} + +fn update_mclmc_from_nested_dict( + inner: &mut MclmcSettingsKind, + value: &Bound<'_, PyAny>, +) -> PyResult<()> { + match inner { + MclmcSettingsKind::Diag(settings) => { + *settings = depythonize(value).map_err(|err| PyValueError::new_err(err.to_string()))?; + } + MclmcSettingsKind::LowRank(settings) => { + *settings = depythonize(value).map_err(|err| PyValueError::new_err(err.to_string()))?; + } + MclmcSettingsKind::Flow(settings) => { + *settings = depythonize(value).map_err(|err| PyValueError::new_err(err.to_string()))?; + } + } + Ok(()) +} + +fn nuts_to_nested_json(inner: &NutsSettingsKind) -> PyResult { + match inner { + NutsSettingsKind::Diag(settings) => { + serde_json::to_value(settings).map_err(|err| PyValueError::new_err(err.to_string())) + } + NutsSettingsKind::LowRank(settings) => { + serde_json::to_value(settings).map_err(|err| PyValueError::new_err(err.to_string())) + } + NutsSettingsKind::Flow(settings) => { + serde_json::to_value(settings).map_err(|err| PyValueError::new_err(err.to_string())) + } + } +} + +fn mclmc_to_nested_json(inner: &MclmcSettingsKind) -> PyResult { + match inner { + MclmcSettingsKind::Diag(settings) => { + serde_json::to_value(settings).map_err(|err| PyValueError::new_err(err.to_string())) + } + MclmcSettingsKind::LowRank(settings) => { + serde_json::to_value(settings).map_err(|err| PyValueError::new_err(err.to_string())) + } + MclmcSettingsKind::Flow(settings) => { + serde_json::to_value(settings).map_err(|err| PyValueError::new_err(err.to_string())) + } + } +} + +impl PyNutsSettings { + fn new_diag(seed: Option) -> Self { + let settings = DiagNutsSettings { + seed: random_seed(seed), + ..Default::default() + }; + Self { + inner: NutsSettingsKind::Diag(settings), + } + } + + fn new_low_rank(seed: Option) -> Self { + let settings = LowRankNutsSettings { + seed: random_seed(seed), + ..Default::default() + }; + Self { + inner: NutsSettingsKind::LowRank(settings), + } + } + + fn new_flow(seed: Option) -> Self { + let settings = FlowNutsSettings { + seed: random_seed(seed), + ..Default::default() + }; + Self { + inner: NutsSettingsKind::Flow(settings), + } + } + + fn update_from_nested_dict(&mut self, value: &Bound<'_, PyAny>) -> PyResult<()> { + update_nuts_from_nested_dict(&mut self.inner, value) + } + + fn to_nested_json(&self) -> PyResult { + nuts_to_nested_json(&self.inner) + } + + fn apply_update(&mut self, name: &str, value: &Bound<'_, PyAny>) -> PyResult<()> { + match name { + "maxdepth" => { + let value: u64 = value.extract()?; + set_all_settings_field!(self, NutsSettingsKind, maxdepth = value); + } + "mindepth" => { + let value: u64 = value.extract()?; + set_all_settings_field!(self, NutsSettingsKind, mindepth = value); + } + "check_turning" => { + let value: bool = value.extract()?; + set_all_settings_field!(self, NutsSettingsKind, check_turning = value); + } + "target_integration_time" => { + let value: Option = value.extract()?; + set_all_settings_field!(self, NutsSettingsKind, target_integration_time = value); + } + "extra_doublings" => { + let value: u64 = value.extract()?; + set_all_settings_field!(self, NutsSettingsKind, extra_doublings = value); } _ => { - return Err(PyAttributeError::new_err(format!( - "Unknown settings attribute: {name}", - ))); + if try_shared_euclidean_adapt_update!(self, NutsSettingsKind, name, value) { + // handled above + } else { + match name { + "microcanonical_trajectory" => { + let value: bool = value.extract()?; + if value { + set_all_settings_field!( + self, + NutsSettingsKind, + trajectory_kind = KineticEnergyKind::Microcanonical + ); + } + } + "exact_normal_trajectory" => { + let value: bool = value.extract()?; + if value { + set_all_settings_field!( + self, + NutsSettingsKind, + trajectory_kind = KineticEnergyKind::ExactNormal + ); + } + } + _ => { + return Err(PyAttributeError::new_err(format!( + "Unknown settings attribute: {name}", + ))); + } + } + } } } Ok(()) } +} + +impl PyMclmcSettings { + fn new_diag(seed: Option) -> Self { + let settings = DiagMclmcSettings { + seed: random_seed(seed), + ..Default::default() + }; + Self { + inner: MclmcSettingsKind::Diag(settings), + } + } + + fn new_low_rank(seed: Option) -> Self { + let settings = LowRankMclmcSettings { + seed: random_seed(seed), + ..Default::default() + }; + Self { + inner: MclmcSettingsKind::LowRank(settings), + } + } + + fn new_flow(seed: Option) -> Self { + let settings = FlowMclmcSettings { + seed: random_seed(seed), + ..Default::default() + }; + Self { + inner: MclmcSettingsKind::Flow(settings), + } + } + + fn update_from_nested_dict(&mut self, value: &Bound<'_, PyAny>) -> PyResult<()> { + update_mclmc_from_nested_dict(&mut self.inner, value) + } fn to_nested_json(&self) -> PyResult { - match &self.inner { - Settings::Diag(settings) => { - serde_json::to_value(settings).map_err(|err| PyValueError::new_err(err.to_string())) + mclmc_to_nested_json(&self.inner) + } + + fn apply_update(&mut self, name: &str, value: &Bound<'_, PyAny>) -> PyResult<()> { + match name { + "step_size" => { + let value: f64 = value.extract()?; + set_all_settings_field!(self, MclmcSettingsKind, step_size = value); + } + "momentum_decoherence_length" => { + let value: f64 = value.extract()?; + set_all_settings_field!( + self, + MclmcSettingsKind, + momentum_decoherence_length = value + ); } - Settings::LowRank(settings) => { - serde_json::to_value(settings).map_err(|err| PyValueError::new_err(err.to_string())) + "subsample_frequency" => { + let value: f64 = value.extract()?; + set_all_settings_field!(self, MclmcSettingsKind, subsample_frequency = value); } - Settings::Transforming(settings) => { - serde_json::to_value(settings).map_err(|err| PyValueError::new_err(err.to_string())) + "dynamic_step_size" => { + let value: bool = value.extract()?; + set_all_settings_field!(self, MclmcSettingsKind, dynamic_step_size = value); + } + _ => { + if try_shared_euclidean_adapt_update!(self, MclmcSettingsKind, name, value) { + // handled above + } else { + return Err(PyAttributeError::new_err(format!( + "Unknown settings attribute: {name}", + ))); + } } } + Ok(()) } } @@ -531,8 +714,8 @@ impl PyNutsSettings { #[staticmethod] #[allow(non_snake_case)] #[pyo3(signature = (seed=None))] - fn Transform(seed: Option) -> Self { - PyNutsSettings::new_tranform_adapt(seed) + fn Flow(seed: Option) -> Self { + PyNutsSettings::new_flow(seed) } fn update(&mut self, kwargs: &Bound<'_, PyDict>) -> PyResult<()> { @@ -552,7 +735,73 @@ impl PyNutsSettings { } fn as_dict(&self, py: Python<'_>) -> PyResult> { - let value = self.to_nested_json()?; + let settings = self.to_nested_json()?; + let adaptation = match self.inner { + NutsSettingsKind::Diag(_) => "diag", + NutsSettingsKind::LowRank(_) => "low_rank", + NutsSettingsKind::Flow(_) => "flow", + }; + let value = serde_json::json!({ + "sampler": "nuts", + "adaptation": adaptation, + "settings": settings, + }); + let obj = pythonize(py, &value).map_err(|err| PyValueError::new_err(err.to_string()))?; + Ok(obj.unbind()) + } +} + +#[pymethods] +impl PyMclmcSettings { + #[staticmethod] + #[allow(non_snake_case)] + #[pyo3(signature = (seed=None))] + fn Diag(seed: Option) -> Self { + PyMclmcSettings::new_diag(seed) + } + + #[staticmethod] + #[allow(non_snake_case)] + #[pyo3(signature = (seed=None))] + fn LowRank(seed: Option) -> Self { + PyMclmcSettings::new_low_rank(seed) + } + + #[staticmethod] + #[allow(non_snake_case)] + #[pyo3(signature = (seed=None))] + fn Flow(seed: Option) -> Self { + PyMclmcSettings::new_flow(seed) + } + + fn update(&mut self, kwargs: &Bound<'_, PyDict>) -> PyResult<()> { + for (key, value) in kwargs.iter() { + let key: String = key.extract()?; + self.apply_update(&key, &value)?; + } + Ok(()) + } + + fn __setattr__(&mut self, name: &str, value: &Bound<'_, PyAny>) -> PyResult<()> { + self.apply_update(name, value) + } + + fn update_settings(&mut self, settings: &Bound<'_, PyDict>) -> PyResult<()> { + self.update_from_nested_dict(settings.as_any()) + } + + fn as_dict(&self, py: Python<'_>) -> PyResult> { + let settings = self.to_nested_json()?; + let adaptation = match self.inner { + MclmcSettingsKind::Diag(_) => "diag", + MclmcSettingsKind::LowRank(_) => "low_rank", + MclmcSettingsKind::Flow(_) => "flow", + }; + let value = serde_json::json!({ + "sampler": "mclmc", + "adaptation": adaptation, + "settings": settings, + }); let obj = pythonize(py, &value).map_err(|err| PyValueError::new_err(err.to_string()))?; Ok(obj.unbind()) } @@ -688,7 +937,7 @@ struct PySampler(Mutex<(SamplerState, Runtime)>); impl PySampler { fn new( - settings: PyNutsSettings, + settings: PySamplerSettings, cores: usize, model: M, progress_type: ProgressType, @@ -703,31 +952,59 @@ impl PySampler { match &mut store.0 { InnerPyStorage::Arrow => { let storage_config = ArrowConfig::new(); - match settings.inner { - Settings::LowRank(settings) => { - let sampler = - Sampler::new(model, settings, storage_config, cores, callback)?; - Ok(PySampler(Mutex::new(( - SamplerState::RunningArrow(sampler).into(), - tokio_rt, - )))) - } - Settings::Diag(settings) => { - let sampler = - Sampler::new(model, settings, storage_config, cores, callback)?; - Ok(PySampler(Mutex::new(( - SamplerState::RunningArrow(sampler).into(), - tokio_rt, - )))) - } - Settings::Transforming(settings) => { - let sampler = - Sampler::new(model, settings, storage_config, cores, callback)?; - Ok(PySampler(Mutex::new(( - SamplerState::RunningArrow(sampler).into(), - tokio_rt, - )))) - } + match settings { + PySamplerSettings::Nuts(settings) => match settings.inner { + NutsSettingsKind::LowRank(settings) => { + let sampler = + Sampler::new(model, settings, storage_config, cores, callback)?; + Ok(PySampler(Mutex::new(( + SamplerState::RunningArrow(sampler).into(), + tokio_rt, + )))) + } + NutsSettingsKind::Diag(settings) => { + let sampler = + Sampler::new(model, settings, storage_config, cores, callback)?; + Ok(PySampler(Mutex::new(( + SamplerState::RunningArrow(sampler).into(), + tokio_rt, + )))) + } + NutsSettingsKind::Flow(settings) => { + let sampler = + Sampler::new(model, settings, storage_config, cores, callback)?; + Ok(PySampler(Mutex::new(( + SamplerState::RunningArrow(sampler).into(), + tokio_rt, + )))) + } + }, + PySamplerSettings::Mclmc(settings) => match settings.inner { + MclmcSettingsKind::LowRank(settings) => { + let sampler = + Sampler::new(model, settings, storage_config, cores, callback)?; + Ok(PySampler(Mutex::new(( + SamplerState::RunningArrow(sampler).into(), + tokio_rt, + )))) + } + MclmcSettingsKind::Diag(settings) => { + let sampler = + Sampler::new(model, settings, storage_config, cores, callback)?; + Ok(PySampler(Mutex::new(( + SamplerState::RunningArrow(sampler).into(), + tokio_rt, + )))) + } + MclmcSettingsKind::Flow(settings) => { + let sampler = + Sampler::new(model, settings, storage_config, cores, callback)?; + Ok(PySampler(Mutex::new(( + SamplerState::RunningArrow(sampler).into(), + tokio_rt, + )))) + } + }, } } InnerPyStorage::Zarr(store) => { @@ -741,31 +1018,59 @@ impl PySampler { let store = Arc::new(store); let storage_config = ZarrAsyncConfig::new(tokio_rt.handle().clone(), store); let storage_config = storage_config.with_chunk_size(16); - match settings.inner { - Settings::LowRank(settings) => { - let sampler = - Sampler::new(model, settings, storage_config, cores, callback)?; - Ok(PySampler(Mutex::new(( - SamplerState::RunningZarr(sampler).into(), - tokio_rt, - )))) - } - Settings::Diag(settings) => { - let sampler = - Sampler::new(model, settings, storage_config, cores, callback)?; - Ok(PySampler(Mutex::new(( - SamplerState::RunningZarr(sampler).into(), - tokio_rt, - )))) - } - Settings::Transforming(settings) => { - let sampler = - Sampler::new(model, settings, storage_config, cores, callback)?; - Ok(PySampler(Mutex::new(( - SamplerState::RunningZarr(sampler).into(), - tokio_rt, - )))) - } + match settings { + PySamplerSettings::Nuts(settings) => match settings.inner { + NutsSettingsKind::LowRank(settings) => { + let sampler = + Sampler::new(model, settings, storage_config, cores, callback)?; + Ok(PySampler(Mutex::new(( + SamplerState::RunningZarr(sampler).into(), + tokio_rt, + )))) + } + NutsSettingsKind::Diag(settings) => { + let sampler = + Sampler::new(model, settings, storage_config, cores, callback)?; + Ok(PySampler(Mutex::new(( + SamplerState::RunningZarr(sampler).into(), + tokio_rt, + )))) + } + NutsSettingsKind::Flow(settings) => { + let sampler = + Sampler::new(model, settings, storage_config, cores, callback)?; + Ok(PySampler(Mutex::new(( + SamplerState::RunningZarr(sampler).into(), + tokio_rt, + )))) + } + }, + PySamplerSettings::Mclmc(settings) => match settings.inner { + MclmcSettingsKind::LowRank(settings) => { + let sampler = + Sampler::new(model, settings, storage_config, cores, callback)?; + Ok(PySampler(Mutex::new(( + SamplerState::RunningZarr(sampler).into(), + tokio_rt, + )))) + } + MclmcSettingsKind::Diag(settings) => { + let sampler = + Sampler::new(model, settings, storage_config, cores, callback)?; + Ok(PySampler(Mutex::new(( + SamplerState::RunningZarr(sampler).into(), + tokio_rt, + )))) + } + MclmcSettingsKind::Flow(settings) => { + let sampler = + Sampler::new(model, settings, storage_config, cores, callback)?; + Ok(PySampler(Mutex::new(( + SamplerState::RunningZarr(sampler).into(), + tokio_rt, + )))) + } + }, } } } @@ -865,7 +1170,7 @@ impl PySampler { impl PySampler { #[staticmethod] fn from_pymc( - settings: PyNutsSettings, + settings: PySamplerSettings, cores: usize, model: PyMcModel, progress_type: ProgressType, @@ -886,7 +1191,7 @@ impl PySampler { #[staticmethod] fn from_stan( - settings: PyNutsSettings, + settings: PySamplerSettings, cores: usize, model: StanModel, progress_type: ProgressType, @@ -907,7 +1212,7 @@ impl PySampler { #[staticmethod] fn from_pyfunc( - settings: PyNutsSettings, + settings: PySamplerSettings, cores: usize, model: PyModel, progress_type: ProgressType, @@ -1421,6 +1726,7 @@ pub fn _lib(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_class::()?; m.add_class::()?; m.add_class::()?; + m.add_class::()?; m.add_class::()?; m.add_class::()?; m.add_class::()?; From 1adf237043bd5416a9d0ba6578c69eccc71c1271 Mon Sep 17 00:00:00 2001 From: Adrian Seyboldt Date: Mon, 9 Mar 2026 16:20:29 +0100 Subject: [PATCH 08/21] feat: expose trajectory option --- src/wrapper.rs | 22 ++++++++++++++++++++-- 1 file changed, 20 insertions(+), 2 deletions(-) diff --git a/src/wrapper.rs b/src/wrapper.rs index cabebb6..4dd577f 100644 --- a/src/wrapper.rs +++ b/src/wrapper.rs @@ -17,8 +17,9 @@ use anyhow::{anyhow, bail, Context, Result}; use numpy::{PyArray1, PyReadonlyArray1}; use nuts_rs::{ ArrowConfig, ArrowTrace, ChainProgress, DiagMclmcSettings, DiagNutsSettings, FlowMclmcSettings, - FlowNutsSettings, KineticEnergyKind, LowRankMclmcSettings, LowRankNutsSettings, Model, - ProgressCallback, Sampler, SamplerWaitResult, StepSizeAdaptMethod, ZarrAsyncConfig, + FlowNutsSettings, KineticEnergyKind, LowRankMclmcSettings, LowRankNutsSettings, + MclmcTrajectoryKind, Model, ProgressCallback, Sampler, SamplerWaitResult, StepSizeAdaptMethod, + ZarrAsyncConfig, }; use pyo3::{ exceptions::{PyAttributeError, PyTimeoutError, PyValueError}, @@ -680,6 +681,23 @@ impl PyMclmcSettings { let value: bool = value.extract()?; set_all_settings_field!(self, MclmcSettingsKind, dynamic_step_size = value); } + "trajectory" => { + let value: String = value.extract()?; + let value = match value.as_str() { + "microcanonical" => MclmcTrajectoryKind::Microcanonical, + "euclidean" => MclmcTrajectoryKind::Euclidean, + "euclidean_then_microcanonical" => { + MclmcTrajectoryKind::EuclideanEarlyThenMicrocanonical + } + _ => { + return Err(PyValueError::new_err(format!( + "Unknown trajectory: {}", + value + ))) + } + }; + set_all_settings_field!(self, MclmcSettingsKind, trajectory_kind = value); + } _ => { if try_shared_euclidean_adapt_update!(self, MclmcSettingsKind, name, value) { // handled above From 874ea0a927e95caae3caa04f019f46af8e2baadb Mon Sep 17 00:00:00 2001 From: Adrian Seyboldt Date: Mon, 20 Apr 2026 10:20:24 +0200 Subject: [PATCH 09/21] chore: update dependencies --- Cargo.lock | 518 +++++++++++++++++++++++++++++------------------------ Cargo.toml | 8 +- 2 files changed, 285 insertions(+), 241 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 83b9544..3a8711f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -74,9 +74,9 @@ checksum = "4b46cbb362ab8752921c97e041f5e366ee6297bd428a31275b9fcf1e380f7299" [[package]] name = "anstyle" -version = "1.0.13" +version = "1.0.14" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5192cca8006f1fd4f7237516f40fa183bb07f8fbdfedaa0036de5ea9b0b45e78" +checksum = "940b3a0ca603d1eade50a4846a2afffd5ef57a9feac2c0e2ec2e14f9ead76000" [[package]] name = "anyhow" @@ -86,9 +86,9 @@ checksum = "7f202df86484c868dbad7eaa557ef785d5c66295e41b460ef922eca0723b842c" [[package]] name = "arrow" -version = "57.3.0" +version = "58.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e4754a624e5ae42081f464514be454b39711daae0458906dacde5f4c632f33a8" +checksum = "d441fdda254b65f3e9025910eb2c2066b6295d9c8ed409522b8d2ace1ff8574c" dependencies = [ "arrow-arith", "arrow-array", @@ -107,9 +107,9 @@ dependencies = [ [[package]] name = "arrow-arith" -version = "57.3.0" +version = "58.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f7b3141e0ec5145a22d8694ea8b6d6f69305971c4fa1c1a13ef0195aef2d678b" +checksum = "ced5406f8b720cc0bc3aa9cf5758f93e8593cda5490677aa194e4b4b383f9a59" dependencies = [ "arrow-array", "arrow-buffer", @@ -121,9 +121,9 @@ dependencies = [ [[package]] name = "arrow-array" -version = "57.3.0" +version = "58.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4c8955af33b25f3b175ee10af580577280b4bd01f7e823d94c7cdef7cf8c9aef" +checksum = "772bd34cacdda8baec9418d80d23d0fb4d50ef0735685bd45158b83dfeb6e62d" dependencies = [ "ahash", "arrow-buffer", @@ -140,9 +140,9 @@ dependencies = [ [[package]] name = "arrow-buffer" -version = "57.3.0" +version = "58.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c697ddca96183182f35b3a18e50b9110b11e916d7b7799cbfd4d34662f2c56c2" +checksum = "898f4cf1e9598fdb77f356fdf2134feedfd0ee8d5a4e0a5f573e7d0aec16baa4" dependencies = [ "bytes", "half", @@ -152,9 +152,9 @@ dependencies = [ [[package]] name = "arrow-cast" -version = "57.3.0" +version = "58.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "646bbb821e86fd57189c10b4fcdaa941deaf4181924917b0daa92735baa6ada5" +checksum = "b0127816c96533d20fc938729f48c52d3e48f99717e7a0b5ade77d742510736d" dependencies = [ "arrow-array", "arrow-buffer", @@ -174,9 +174,9 @@ dependencies = [ [[package]] name = "arrow-csv" -version = "57.3.0" +version = "58.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8da746f4180004e3ce7b83c977daf6394d768332349d3d913998b10a120b790a" +checksum = "ca025bd0f38eeecb57c2153c0123b960494138e6a957bbda10da2b25415209fe" dependencies = [ "arrow-array", "arrow-cast", @@ -189,9 +189,9 @@ dependencies = [ [[package]] name = "arrow-data" -version = "57.3.0" +version = "58.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1fdd994a9d28e6365aa78e15da3f3950c0fdcea6b963a12fa1c391afb637b304" +checksum = "42d10beeab2b1c3bb0b53a00f7c944a178b622173a5c7bcabc3cb45d90238df4" dependencies = [ "arrow-buffer", "arrow-schema", @@ -202,9 +202,9 @@ dependencies = [ [[package]] name = "arrow-ipc" -version = "57.3.0" +version = "58.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "abf7df950701ab528bf7c0cf7eeadc0445d03ef5d6ffc151eaae6b38a58feff1" +checksum = "609a441080e338147a84e8e6904b6da482cefb957c5cdc0f3398872f69a315d0" dependencies = [ "arrow-array", "arrow-buffer", @@ -216,9 +216,9 @@ dependencies = [ [[package]] name = "arrow-json" -version = "57.3.0" +version = "58.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0ff8357658bedc49792b13e2e862b80df908171275f8e6e075c460da5ee4bf86" +checksum = "6ead0914e4861a531be48fe05858265cf854a4880b9ed12618b1d08cba9bebc8" dependencies = [ "arrow-array", "arrow-buffer", @@ -240,9 +240,9 @@ dependencies = [ [[package]] name = "arrow-ord" -version = "57.3.0" +version = "58.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f7d8f1870e03d4cbed632959498bcc84083b5a24bded52905ae1695bd29da45b" +checksum = "763a7ba279b20b52dad300e68cfc37c17efa65e68623169076855b3a9e941ca5" dependencies = [ "arrow-array", "arrow-buffer", @@ -253,9 +253,9 @@ dependencies = [ [[package]] name = "arrow-row" -version = "57.3.0" +version = "58.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "18228633bad92bff92a95746bbeb16e5fc318e8382b75619dec26db79e4de4c0" +checksum = "e14fe367802f16d7668163ff647830258e6e0aeea9a4d79aaedf273af3bdcd3e" dependencies = [ "arrow-array", "arrow-buffer", @@ -266,9 +266,9 @@ dependencies = [ [[package]] name = "arrow-schema" -version = "57.3.0" +version = "58.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8c872d36b7bf2a6a6a2b40de9156265f0242910791db366a2c17476ba8330d68" +checksum = "c30a1365d7a7dc50cc847e54154e6af49e4c4b0fddc9f607b687f29212082743" dependencies = [ "bitflags", "serde_core", @@ -277,9 +277,9 @@ dependencies = [ [[package]] name = "arrow-select" -version = "57.3.0" +version = "58.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "68bf3e3efbd1278f770d67e5dc410257300b161b93baedb3aae836144edcaf4b" +checksum = "78694888660a9e8ac949853db393af2a8b8fc82c19ce333132dfa2e72cc1a7fe" dependencies = [ "ahash", "arrow-array", @@ -291,9 +291,9 @@ dependencies = [ [[package]] name = "arrow-string" -version = "57.3.0" +version = "58.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "85e968097061b3c0e9fe3079cf2e703e487890700546b5b0647f60fca1b5a8d8" +checksum = "61e04a01f8bb73ce54437514c5fd3ee2aa3e8abe4c777ee5cc55853b1652f79e" dependencies = [ "arrow-array", "arrow-buffer", @@ -405,9 +405,9 @@ dependencies = [ [[package]] name = "bitflags" -version = "2.11.0" +version = "2.11.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "843867be96c8daad0d758b57df9392b6d8d271134fce549de6ce169ff98a92af" +checksum = "c4512299f36f043ab09a583e57bceb5a5aab7a73db1805848e8fef3c9e8c78b3" [[package]] name = "block-buffer" @@ -431,6 +431,19 @@ dependencies = [ "zstd-sys", ] +[[package]] +name = "blusc" +version = "0.0.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b4e0c17eaa785d2673fe58c22fc817946c2330ed47f3d9f79835d65950d32a45" +dependencies = [ + "flate2", + "lz4_flex", + "pkg-config", + "snap", + "zstd 0.13.3", +] + [[package]] name = "bridgestan" version = "2.7.0" @@ -510,9 +523,9 @@ checksum = "37b2a672a2cb129a2e41c10b1224bb368f9f37a2b16b612598138befd7b37eb5" [[package]] name = "cc" -version = "1.2.56" +version = "1.2.60" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "aebf35691d1bfb0ac386a69bac2fde4dd276fb618cf8bf4f5318fe285e821bb2" +checksum = "43c5703da9466b66a946814e1adf53ea2c90f10063b86290cc9eb67ce3478a20" dependencies = [ "find-msvc-tools", "jobserver", @@ -549,7 +562,7 @@ checksum = "6f8d983286843e49675a4b7a2d174efe136dc93a18d69130dd18198a6c167601" dependencies = [ "cfg-if", "cpufeatures 0.3.0", - "rand_core 0.10.0", + "rand_core 0.10.1", ] [[package]] @@ -626,18 +639,18 @@ dependencies = [ [[package]] name = "clap" -version = "4.5.60" +version = "4.6.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2797f34da339ce31042b27d23607e051786132987f595b02ba4f6a6dffb7030a" +checksum = "1ddb117e43bbf7dacf0a4190fef4d345b9bad68dfc649cb349e7d17d28428e51" dependencies = [ "clap_builder", ] [[package]] name = "clap_builder" -version = "4.5.60" +version = "4.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "24a241312cea5059b13574bb9b3861cabf758b879c15190b37b6d6fd63ab6876" +checksum = "714a53001bf66416adb0e2ef5ac857140e7dc3a0c48fb28b2f10762fc4b5069f" dependencies = [ "anstyle", "clap_lex", @@ -645,9 +658,9 @@ dependencies = [ [[package]] name = "clap_lex" -version = "1.0.0" +version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3a822ea5bc7590f9d40f1ba12c0dc3c2760f3482c6984db1573ad11031420831" +checksum = "c8d4a3bb8b1e0c1050499d1815f5ab16d04f0959b233085fb31653fbfc9d98f9" [[package]] name = "comfy-table" @@ -670,13 +683,12 @@ dependencies = [ [[package]] name = "console" -version = "0.16.2" +version = "0.16.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "03e45a4a8926227e4197636ba97a9fc9b00477e9f4bd711395687c5f0734bec4" +checksum = "d64e8af5551369d19cf50138de61f1c42074ab970f74e99be916646777f8fc87" dependencies = [ "encode_unicode", "libc", - "once_cell", "unicode-width", "windows-sys 0.61.2", ] @@ -1345,7 +1357,7 @@ dependencies = [ "cfg-if", "libc", "r-efi 6.0.0", - "rand_core 0.10.0", + "rand_core 0.10.1", "wasip2", "wasip3", ] @@ -1408,6 +1420,12 @@ dependencies = [ "foldhash 0.2.0", ] +[[package]] +name = "hashbrown" +version = "0.17.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4f467dd6dccf739c208452f8014c75c18bb8301b050ad1cfb27153803edb0f51" + [[package]] name = "heck" version = "0.5.0" @@ -1470,9 +1488,9 @@ checksum = "135b12329e5e3ce057a9f972339ea52bc954fe1e9358ef27f95e89716fbc5424" [[package]] name = "hyper" -version = "1.8.1" +version = "1.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2ab2d4f250c3d7b1c9fcdff1cece94ea4e2dfbec68614f7b87cb205f24ca9d11" +checksum = "6299f016b246a94207e63da54dbe807655bf9e00044f73ded42c3ac5305fbcca" dependencies = [ "atomic-waker", "bytes", @@ -1484,7 +1502,6 @@ dependencies = [ "httparse", "itoa", "pin-project-lite", - "pin-utils", "smallvec", "tokio", "want", @@ -1492,16 +1509,15 @@ dependencies = [ [[package]] name = "hyper-rustls" -version = "0.27.7" +version = "0.27.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e3c93eb611681b207e1fe55d5a71ecf91572ec8a6705cdb6857f7d8d5242cf58" +checksum = "33ca68d021ef39cf6463ab54c1d0f5daf03377b70561305bb89a8f83aab66e0f" dependencies = [ "http", "hyper", "hyper-util", "rustls", "rustls-native-certs", - "rustls-pki-types", "tokio", "tokio-rustls", "tower-service", @@ -1556,12 +1572,13 @@ dependencies = [ [[package]] name = "icu_collections" -version = "2.1.1" +version = "2.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4c6b649701667bbe825c3b7e6388cb521c23d88644678e83c0c4d0a621a34b43" +checksum = "2984d1cd16c883d7935b9e07e44071dca8d917fd52ecc02c04d5fa0b5a3f191c" dependencies = [ "displaydoc", "potential_utf", + "utf8_iter", "yoke", "zerofrom", "zerovec", @@ -1569,9 +1586,9 @@ dependencies = [ [[package]] name = "icu_locale_core" -version = "2.1.1" +version = "2.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "edba7861004dd3714265b4db54a3c390e880ab658fec5f7db895fae2046b5bb6" +checksum = "92219b62b3e2b4d88ac5119f8904c10f8f61bf7e95b640d25ba3075e6cac2c29" dependencies = [ "displaydoc", "litemap", @@ -1582,9 +1599,9 @@ dependencies = [ [[package]] name = "icu_normalizer" -version = "2.1.1" +version = "2.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5f6c8828b67bf8908d82127b2054ea1b4427ff0230ee9141c54251934ab1b599" +checksum = "c56e5ee99d6e3d33bd91c5d85458b6005a22140021cc324cea84dd0e72cff3b4" dependencies = [ "icu_collections", "icu_normalizer_data", @@ -1596,15 +1613,15 @@ dependencies = [ [[package]] name = "icu_normalizer_data" -version = "2.1.1" +version = "2.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7aedcccd01fc5fe81e6b489c15b247b8b0690feb23304303a9e560f37efc560a" +checksum = "da3be0ae77ea334f4da67c12f149704f19f81d1adf7c51cf482943e84a2bad38" [[package]] name = "icu_properties" -version = "2.1.2" +version = "2.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "020bfc02fe870ec3a66d93e677ccca0562506e5872c650f893269e08615d74ec" +checksum = "bee3b67d0ea5c2cca5003417989af8996f8604e34fb9ddf96208a033901e70de" dependencies = [ "icu_collections", "icu_locale_core", @@ -1616,15 +1633,15 @@ dependencies = [ [[package]] name = "icu_properties_data" -version = "2.1.2" +version = "2.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "616c294cf8d725c6afcd8f55abc17c56464ef6211f9ed59cccffe534129c77af" +checksum = "8e2bbb201e0c04f7b4b3e14382af113e17ba4f63e2c9d2ee626b720cbce54a14" [[package]] name = "icu_provider" -version = "2.1.1" +version = "2.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "85962cf0ce02e1e0a629cc34e7ca3e373ce20dda4c4d7294bbd0bf1fdb59e614" +checksum = "139c4cf31c8b5f33d7e199446eff9c1e02decfc2f0eec2c8d71f65befa45b421" dependencies = [ "displaydoc", "icu_locale_core", @@ -1664,12 +1681,12 @@ dependencies = [ [[package]] name = "indexmap" -version = "2.13.0" +version = "2.14.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7714e70437a7dc3ac8eb7e6f8df75fd8eb422675fc7678aff7364301092b1017" +checksum = "d466e9454f08e4a911e14806c24e16fba1b4c121d1ea474396f396069cf949d9" dependencies = [ "equivalent", - "hashbrown 0.16.1", + "hashbrown 0.17.0", "serde", "serde_core", ] @@ -1698,9 +1715,9 @@ dependencies = [ [[package]] name = "inventory" -version = "0.3.22" +version = "0.3.24" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "009ae045c87e7082cb72dab0ccd01ae075dd00141ddc108f43a0ea150a9e7227" +checksum = "a4f0c30c76f2f4ccee3fe55a2435f691ca00c0e4bd87abe4f4a851b1d4dac39b" dependencies = [ "rustversion", ] @@ -1713,9 +1730,9 @@ checksum = "d98f6fed1fde3f8c21bc40a1abb88dd75e67924f9cffc3ef95607bad8017f8e2" [[package]] name = "iri-string" -version = "0.7.10" +version = "0.7.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c91338f0783edbd6195decb37bae672fd3b165faffb89bf7b9e6942f8b1a731a" +checksum = "25e659a4bb38e810ebc252e53b5814ff908a8c58c2a9ce2fae1bbec24cbf4e20" dependencies = [ "memchr", "serde", @@ -1741,9 +1758,9 @@ dependencies = [ [[package]] name = "itoa" -version = "1.0.17" +version = "1.0.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "92ecc6618181def0457392ccd0ee51198e065e016d1d527a7ac1b6dc7c1f09d2" +checksum = "8f42a60cbdf9a97f5d2305f08a87dc4e09308d1276d28c869c684d7777685682" [[package]] name = "jobserver" @@ -1757,10 +1774,12 @@ dependencies = [ [[package]] name = "js-sys" -version = "0.3.91" +version = "0.3.95" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b49715b7073f385ba4bc528e5747d02e66cb39c6146efb66b781f131f0fb399c" +checksum = "2964e92d1d9dc3364cae4d718d93f227e3abb088e747d92e0395bfdedf1c12ca" dependencies = [ + "cfg-if", + "futures-util", "once_cell", "wasm-bindgen", ] @@ -1836,9 +1855,9 @@ dependencies = [ [[package]] name = "libc" -version = "0.2.182" +version = "0.2.185" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6800badb6cb2082ffd7b6a67e6125bb39f18782f793520caee8cb8846be06112" +checksum = "52ff2c0fe9bc6cb6b14a0592c2ff4fa9ceb83eea9db979b0487cd054946a2b8f" [[package]] name = "libloading" @@ -1858,9 +1877,9 @@ checksum = "b6d2cec3eae94f9f509c767b45932f1ada8350c4bdb85af2fcab4a3c14807981" [[package]] name = "libz-sys" -version = "1.1.24" +version = "1.1.28" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4735e9cbde5aac84a5ce588f6b23a90b9b0b528f6c5a8db8a4aff300463a0839" +checksum = "fc3a226e576f50782b3305c5ccf458698f92798987f551c6a02efe8276721e22" dependencies = [ "cc", "libc", @@ -1879,9 +1898,9 @@ dependencies = [ [[package]] name = "litemap" -version = "0.8.1" +version = "0.8.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6373607a59f0be73a39b6fe456b8192fcc3585f602af20751600e974dd455e77" +checksum = "92daf443525c4cce67b150400bc2316076100ce0b3686209eb8cf3c31612e6f0" [[package]] name = "lock_api" @@ -1900,9 +1919,9 @@ checksum = "5e5032e24019045c762d3c0f28f5b6b8bbf38563a65908389bf7978758920897" [[package]] name = "lru" -version = "0.16.3" +version = "0.16.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a1dc47f592c06f33f8e3aea9591776ec7c9f9e4124778ff8a3c3b87159f7e593" +checksum = "7f66e8d5d03f609abc3a39e6f08e4164ebf1447a732906d39eb9b99b7919ef39" dependencies = [ "hashbrown 0.16.1", ] @@ -1923,6 +1942,15 @@ dependencies = [ "libc", ] +[[package]] +name = "lz4_flex" +version = "0.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "98c23545df7ecf1b16c303910a69b079e8e251d60f7dd2cc9b4177f2afaf1746" +dependencies = [ + "twox-hash", +] + [[package]] name = "matrixmultiply" version = "0.3.10" @@ -1967,9 +1995,9 @@ dependencies = [ [[package]] name = "mio" -version = "1.1.1" +version = "1.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a69bcab0ad47271a0234d9422b131806bf3968021e5dc9328caf2d4cd58557fc" +checksum = "50b7e5b27aa02a74bac8c3f23f448f8d87ff11f92d3aac1a6ed369ee08cc56c1" dependencies = [ "libc", "wasi", @@ -1978,9 +2006,9 @@ dependencies = [ [[package]] name = "moka" -version = "0.12.14" +version = "0.12.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "85f8024e1c8e71c778968af91d43700ce1d11b219d127d79fb2934153b82b42b" +checksum = "957228ad12042ee839f93c8f257b62b4c0ab5eaae1d4fa60de53b27c9d7c5046" dependencies = [ "async-lock", "crossbeam-channel", @@ -2164,9 +2192,9 @@ dependencies = [ [[package]] name = "num-conv" -version = "0.2.0" +version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cf97ec579c3c42f953ef76dbf8d55ac91fb219dde70e49aa4a6b7d74e9919050" +checksum = "c6673768db2d862beb9b39a78fdcb1a69439615d5794a1be50caa9bc92c81967" [[package]] name = "num-integer" @@ -2241,10 +2269,13 @@ dependencies = [ "pyo3", "pyo3-arrow", "pyo3-object_store", - "rand 0.10.0", + "pythonize", + "rand 0.10.1", "rand_chacha 0.10.0", "rand_distr", "rayon", + "serde", + "serde_json", "smallvec", "tch", "thiserror 2.0.18", @@ -2257,11 +2288,11 @@ dependencies = [ [[package]] name = "nuts-derive" -version = "0.1.0" +version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "64eac5046c75ced9bdaede15ebc30c4ce982a13e75032ae8d5c1312d1e05d82e" +checksum = "a8cce587b5f36bc6bfa54cbf2eaf31fe5a5e0d73e96d31fde5de87a701689363" dependencies = [ - "nuts-storable 0.1.0", + "nuts-storable", "proc-macro2", "quote", "syn 1.0.109", @@ -2269,9 +2300,9 @@ dependencies = [ [[package]] name = "nuts-rs" -version = "0.17.4" +version = "0.18.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4bd441b25a4ff7c222e396d073f36ec03ab3be4264a59a81264bd4501616df31" +checksum = "ed964f329c9de6147920b5bbc55cdeb8ad69c7909f68539b592889dfc07b04de" dependencies = [ "anyhow", "arrow", @@ -2279,9 +2310,9 @@ dependencies = [ "faer", "itertools 0.14.0", "nuts-derive", - "nuts-storable 0.2.0", + "nuts-storable", "pulp", - "rand 0.10.0", + "rand 0.10.1", "rand_distr", "rayon", "serde", @@ -2293,28 +2324,22 @@ dependencies = [ [[package]] name = "nuts-storable" -version = "0.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cb6cf9fc84ca313648ddb112f8728eb2f9531f2e4533959dd01127eb34290b5b" - -[[package]] -name = "nuts-storable" -version = "0.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dcdee8fb53ac39f042885e732348a5bce2f28a08adc2798314a3c66a5f199293" +version = "0.3.0" [[package]] name = "object_store" -version = "0.13.1" +version = "0.13.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c2858065e55c148d294a9f3aae3b0fa9458edadb41a108397094566f4e3c0dfb" +checksum = "622acbc9100d3c10e2ee15804b0caa40e55c933d5aa53814cd520805b7958a49" dependencies = [ "async-trait", "base64", "bytes", "chrono", "form_urlencoded", - "futures", + "futures-channel", + "futures-core", + "futures-util", "http", "http-body-util", "httparse", @@ -2325,7 +2350,7 @@ dependencies = [ "parking_lot", "percent-encoding", "quick-xml", - "rand 0.9.2", + "rand 0.10.1", "reqwest", "ring", "rustls-pki-types", @@ -2343,9 +2368,9 @@ dependencies = [ [[package]] name = "once_cell" -version = "1.21.3" +version = "1.21.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "42f5e15c9953c5e4ccceeb2e7382a716482c34515315f7b03532b8b4e8393d2d" +checksum = "9f7c3e4beb33f85d45ae3e3a1792185706c8e16d043238c593331cc7cd313b50" [[package]] name = "oorandom" @@ -2481,17 +2506,11 @@ version = "0.2.17" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a89322df9ebe1c1578d689c92318e070967d1042b512afbe49518723f4e6d5cd" -[[package]] -name = "pin-utils" -version = "0.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184" - [[package]] name = "pkg-config" -version = "0.3.32" +version = "0.3.33" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7edddbd0b52d732b21ad9a5fab5c704c14cd949e5e9a1ec5929a24fded1b904c" +checksum = "19f132c84eca552bf34cab8ec81f1c1dcc229b811638f9d283dceabe58c5569e" [[package]] name = "plotters" @@ -2529,9 +2548,9 @@ checksum = "c33a9471896f1c69cecef8d20cbe2f7accd12527ce60845ff44c153bb2a21b49" [[package]] name = "portable-atomic-util" -version = "0.2.5" +version = "0.2.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7a9db96d7fa8782dd8c15ce32ffe8680bbd1e978a43bf51a34d39483540495f5" +checksum = "c2a106d1259c23fac8e543272398ae0e3c0b8d33c88ed73d0cc71b0f1d902618" dependencies = [ "portable-atomic", ] @@ -2549,9 +2568,9 @@ dependencies = [ [[package]] name = "potential_utf" -version = "0.1.4" +version = "0.1.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b73949432f5e2a09657003c25bca5e19a0e9c84f8058ca374f49e0ebe605af77" +checksum = "0103b1cef7ec0cf76490e969665504990193874ea05c85ff9bab8b911d0a0564" dependencies = [ "zerovec", ] @@ -2615,9 +2634,9 @@ checksum = "40e24eee682d89fb193496edf918a7f407d30175b2e785fe057e4392dfd182e0" [[package]] name = "pyo3" -version = "0.28.2" +version = "0.28.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cf85e27e86080aafd5a22eae58a162e133a589551542b3e5cee4beb27e54f8e1" +checksum = "91fd8e38a3b50ed1167fb981cd6fd60147e091784c427b8f7183a7ee32c31c12" dependencies = [ "anyhow", "chrono", @@ -2633,9 +2652,9 @@ dependencies = [ [[package]] name = "pyo3-arrow" -version = "0.16.1" +version = "0.17.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "03a08a376a6bbcb28f19122b901ca9ecb9fb7d8d677e886e0e13cd0101d99be4" +checksum = "0360400036dda3db3d69102ef7e9646e4cd946c75a2d1d41fb8fd39879312636" dependencies = [ "arrow-array", "arrow-buffer", @@ -2668,18 +2687,18 @@ dependencies = [ [[package]] name = "pyo3-build-config" -version = "0.28.2" +version = "0.28.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8bf94ee265674bf76c09fa430b0e99c26e319c945d96ca0d5a8215f31bf81cf7" +checksum = "e368e7ddfdeb98c9bca7f8383be1648fd84ab466bf2bc015e94008db6d35611e" dependencies = [ "target-lexicon", ] [[package]] name = "pyo3-ffi" -version = "0.28.2" +version = "0.28.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "491aa5fc66d8059dd44a75f4580a2962c1862a1c2945359db36f6c2818b748dc" +checksum = "7f29e10af80b1f7ccaf7f69eace800a03ecd13e883acfacc1e5d0988605f651e" dependencies = [ "libc", "pyo3-build-config", @@ -2687,9 +2706,9 @@ dependencies = [ [[package]] name = "pyo3-macros" -version = "0.28.2" +version = "0.28.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f5d671734e9d7a43449f8480f8b38115df67bef8d21f76837fa75ee7aaa5e52e" +checksum = "df6e520eff47c45997d2fc7dd8214b25dd1310918bbb2642156ef66a67f29813" dependencies = [ "proc-macro2", "pyo3-macros-backend", @@ -2699,9 +2718,9 @@ dependencies = [ [[package]] name = "pyo3-macros-backend" -version = "0.28.2" +version = "0.28.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "22faaa1ce6c430a1f71658760497291065e6450d7b5dc2bcf254d49f66ee700a" +checksum = "c4cdc218d835738f81c2338f822078af45b4afdf8b2e33cbb5916f108b813acb" dependencies = [ "heck", "proc-macro2", @@ -2733,6 +2752,16 @@ dependencies = [ "url", ] +[[package]] +name = "pythonize" +version = "0.28.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0b79f670c9626c8b651c0581011b57b6ba6970bb69faf01a7c4c0cfc81c43f95" +dependencies = [ + "pyo3", + "serde", +] + [[package]] name = "qd" version = "0.8.0" @@ -2747,9 +2776,9 @@ dependencies = [ [[package]] name = "quick-xml" -version = "0.38.4" +version = "0.39.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b66c2058c55a409d601666cffe35f04333cf1013010882cec174a7467cd4e21c" +checksum = "958f21e8e7ceb5a1aa7fa87fab28e7c75976e0bfe7e23ff069e0a260f894067d" dependencies = [ "memchr", "serde", @@ -2757,9 +2786,9 @@ dependencies = [ [[package]] name = "quick_cache" -version = "0.6.18" +version = "0.6.21" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7ada44a88ef953a3294f6eb55d2007ba44646015e18613d2f213016379203ef3" +checksum = "5a70b1b8b47e31d0498ecbc3c5470bb931399a8bfed1fd79d1717a61ce7f96e3" dependencies = [ "ahash", "equivalent", @@ -2796,7 +2825,7 @@ dependencies = [ "bytes", "getrandom 0.3.4", "lru-slab", - "rand 0.9.2", + "rand 0.9.4", "ring", "rustc-hash", "rustls", @@ -2824,9 +2853,9 @@ dependencies = [ [[package]] name = "quote" -version = "1.0.44" +version = "1.0.45" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "21b2ebcf727b7760c461f091f9f0f539b77b8e87f2fd88131e7f1b433b3cece4" +checksum = "41f2619966050689382d2b44f664f4bc593e129785a36d6ee376ddf37259b924" dependencies = [ "proc-macro2", ] @@ -2845,9 +2874,9 @@ checksum = "f8dcc9c7d52a811697d2151c701e0d08956f92b0e24136cf4cf27b57a6a0d9bf" [[package]] name = "rand" -version = "0.8.5" +version = "0.8.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404" +checksum = "5ca0ecfa931c29007047d1bc58e623ab12e5590e8c7cc53200d5202b69266d8a" dependencies = [ "libc", "rand_chacha 0.3.1", @@ -2856,9 +2885,9 @@ dependencies = [ [[package]] name = "rand" -version = "0.9.2" +version = "0.9.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6db2770f06117d490610c7488547d543617b21bfa07796d7a12f6f1bd53850d1" +checksum = "44c5af06bb1b7d3216d91932aed5265164bf384dc89cd6ba05cf59a35f5f76ea" dependencies = [ "rand_chacha 0.9.0", "rand_core 0.9.5", @@ -2866,13 +2895,13 @@ dependencies = [ [[package]] name = "rand" -version = "0.10.0" +version = "0.10.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bc266eb313df6c5c09c1c7b1fbe2510961e5bcd3add930c1e31f7ed9da0feff8" +checksum = "d2e8e8bcc7961af1fdac401278c6a831614941f6164ee3bf4ce61b7edb162207" dependencies = [ "chacha20", "getrandom 0.4.2", - "rand_core 0.10.0", + "rand_core 0.10.1", ] [[package]] @@ -2902,7 +2931,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3e6af7f3e25ded52c41df4e0b1af2d047e45896c2f3281792ed68a1c243daedb" dependencies = [ "ppv-lite86", - "rand_core 0.10.0", + "rand_core 0.10.1", ] [[package]] @@ -2925,9 +2954,9 @@ dependencies = [ [[package]] name = "rand_core" -version = "0.10.0" +version = "0.10.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0c8d0fd677905edcbeedbf2edb6494d676f0e98d54d5cf9bda0b061cb8fb8aba" +checksum = "63b8176103e19a2643978565ca18b50549f6101881c443590420e4dc998a3c69" [[package]] name = "rand_distr" @@ -2936,7 +2965,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4d431c2703ccf129de4d45253c03f49ebb22b97d6ad79ee3ecfc7e3f4862c1d8" dependencies = [ "num-traits", - "rand 0.10.0", + "rand 0.10.1", ] [[package]] @@ -2956,9 +2985,9 @@ checksum = "60a357793950651c4ed0f3f52338f53b2f809f32d83a07f72909fa13e4c6c1e3" [[package]] name = "rayon" -version = "1.11.0" +version = "1.12.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "368f01d005bf8fd9b1206fb6fa653e6c4a81ceb1466406b81792d87c5677a58f" +checksum = "fb39b166781f92d482534ef4b4b1b2568f42613b53e5b6c160e24cfbfa30926d" dependencies = [ "either", "rayon-core", @@ -3085,9 +3114,9 @@ dependencies = [ [[package]] name = "rustc-hash" -version = "2.1.1" +version = "2.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "357703d41365b4b27c590e3ed91eabb1b663f07c4c084095e60cbed4362dff0d" +checksum = "94300abf3f1ae2e2b8ffb7b58043de3d399c73fa6f4b73826402a5c457614dbe" [[package]] name = "rustc_version" @@ -3100,9 +3129,9 @@ dependencies = [ [[package]] name = "rustls" -version = "0.23.37" +version = "0.23.38" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "758025cb5fccfd3bc2fd74708fd4682be41d99e5dff73c377c0646c6012c73a4" +checksum = "69f9466fb2c14ea04357e91413efb882e2a6d4a406e625449bc0a5d360d53a21" dependencies = [ "once_cell", "ring", @@ -3178,9 +3207,9 @@ dependencies = [ [[package]] name = "schannel" -version = "0.1.28" +version = "0.1.29" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "891d81b926048e76efe18581bf793546b4c0eaf8448d72be8de2bbee5fd166e1" +checksum = "91c1b7e4904c873ef0710c1f407dde2e6287de2bebc1bbbf7d430bb7cbffd939" dependencies = [ "windows-sys 0.61.2", ] @@ -3216,9 +3245,9 @@ dependencies = [ [[package]] name = "semver" -version = "1.0.27" +version = "1.0.28" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d767eb0aabc880b29956c35734170f26ed551a859dbd361d140cdbeca61ab1e2" +checksum = "8a7852d02fc848982e0c167ef163aaff9cd91dc640ba85e263cb1ce46fae51cd" [[package]] name = "seq-macro" @@ -3323,9 +3352,9 @@ checksum = "0fda2ff0d084019ba4d7c6f371c95d8fd75ce3524c3cb8fb653a3023f6323e64" [[package]] name = "simd-adler32" -version = "0.3.8" +version = "0.3.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e320a6c5ad31d271ad523dcf3ad13e2767ad8b1cb8f047f75a8aeaf8da139da2" +checksum = "703d5c7ef118737c72f1af64ad2f6f8c5e1921f818cdcb97b8fe6fc69bf66214" [[package]] name = "simdutf8" @@ -3351,6 +3380,12 @@ version = "1.15.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "67b1b7a3b5fe4f1376887184045fcf45c69e92af734b7aaddc05fb777b6fbd03" +[[package]] +name = "snap" +version = "1.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1b6b67fb9a61334225b5b790716f609cd58395f895b3fe8b328786812a40bc3b" + [[package]] name = "snappy_src" version = "0.2.5+snappy.1.2.2" @@ -3363,12 +3398,12 @@ dependencies = [ [[package]] name = "socket2" -version = "0.6.2" +version = "0.6.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "86f4aa3ad99f2088c990dfa82d367e19cb29268ed67c574d10d0a4bfe71f07e0" +checksum = "3a766e1110788c36f4fa1c2b71b387a7815aa65f88ce0229841826633d93723e" dependencies = [ "libc", - "windows-sys 0.60.2", + "windows-sys 0.61.2", ] [[package]] @@ -3439,15 +3474,15 @@ checksum = "adb6935a6f5c20170eeceb1a3835a49e12e19d792f6dd344ccc76a985ca5a6ca" [[package]] name = "tch" -version = "0.23.0" +version = "0.24.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3d1470a4780300b4a62a097ed50dff2fc7c9cf6576c1947560a1d69fab3b258b" +checksum = "0d3f84a069d8ba16dbf720b61e8bf131d90ffb8e958a664eae8e4993c5c2fa6f" dependencies = [ "half", "lazy_static", "libc", "ndarray 0.16.1", - "rand 0.8.5", + "rand 0.8.6", "safetensors", "thiserror 1.0.69", "torch-sys", @@ -3539,9 +3574,9 @@ dependencies = [ [[package]] name = "tinystr" -version = "0.8.2" +version = "0.8.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "42d3e9c45c09de15d06dd8acf5f4e0e399e85927b7f00711024eb7ae10fa4869" +checksum = "c8323304221c2a851516f22236c5722a72eaa19749016521d6dff0824447d96d" dependencies = [ "displaydoc", "zerovec", @@ -3559,9 +3594,9 @@ dependencies = [ [[package]] name = "tinyvec" -version = "1.10.0" +version = "1.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bfa5fdc3bce6191a1dbc8c02d5c8bffcf557bafa17c124c5264a458f1b0613fa" +checksum = "3e61e67053d25a4e82c844e8424039d9745781b3fc4f32b8d55ed50f5f667ef3" dependencies = [ "tinyvec_macros", ] @@ -3574,9 +3609,9 @@ checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" [[package]] name = "tokio" -version = "1.50.0" +version = "1.52.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "27ad5e34374e03cfffefc301becb44e9dc3c17584f414349ebe29ed26661822d" +checksum = "b67dee974fe86fd92cc45b7a95fdd2f99a36a6d7b0d431a231178d3d670bbcc6" dependencies = [ "bytes", "libc", @@ -3589,9 +3624,9 @@ dependencies = [ [[package]] name = "tokio-macros" -version = "2.6.1" +version = "2.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5c55a2eff8b69ce66c84f85e1da1c233edc36ceb85a2058d11b0d6a3c7e7569c" +checksum = "385a6cb71ab9ab790c5fe8d67f1645e6c450a7ce006a33de03daa956cf70a496" dependencies = [ "proc-macro2", "quote", @@ -3623,9 +3658,9 @@ dependencies = [ [[package]] name = "torch-sys" -version = "0.23.0" +version = "0.24.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a7767541e63ff99c36950f2a54cc277e41a1e5a24ab1646aae19ef7f7d3dedf7" +checksum = "f4ba78777379cf09aaa79708c63e477cf0f95e021d04360c6821f1a9f56173f7" dependencies = [ "anyhow", "cc", @@ -3715,11 +3750,17 @@ version = "0.2.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e421abadd41a4225275504ea4d6566923418b7f05506fbc9c0fe86ba7396114b" +[[package]] +name = "twox-hash" +version = "2.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9ea3136b675547379c4bd395ca6b938e5ad3c3d20fad76e7fe85f9e0d011419c" + [[package]] name = "typenum" -version = "1.19.0" +version = "1.20.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "562d481066bde0658276a35467c4af00bdc6ee726305698a55b86e61d7ad82bb" +checksum = "40ce102ab67701b8526c123c1bab5cbe42d7040ccfd0f64af1a385808d2f43de" [[package]] name = "unicode-ident" @@ -3729,9 +3770,9 @@ checksum = "e6e4313cd5fcd3dad5cafa179702e2b244f760991f45397d14d4ebf38247da75" [[package]] name = "unicode-segmentation" -version = "1.12.0" +version = "1.13.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f6ccf251212114b54433ec949fd6a7841275f9ada20dddd2f29e9ceea4501493" +checksum = "9629274872b2bfaf8d66f5f15725007f635594914870f65218920345aa11aa8c" [[package]] name = "unicode-width" @@ -3789,9 +3830,9 @@ checksum = "b6c140620e7ffbb22c2dee59cafe6084a59b5ffc27a8859a5f0d494b5d52b6be" [[package]] name = "uuid" -version = "1.21.0" +version = "1.23.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b672338555252d43fd2240c714dc444b8c6fb0a5c5335e65a07bba7742735ddb" +checksum = "ddd74a9687298c6858e9b88ec8935ec45d22e8fd5e6394fa1bd4e99a87789c76" dependencies = [ "getrandom 0.4.2", "js-sys", @@ -3837,11 +3878,11 @@ checksum = "ccf3ec651a847eb01de73ccad15eb7d99f80485de043efb2f370cd654f4ea44b" [[package]] name = "wasip2" -version = "1.0.2+wasi-0.2.9" +version = "1.0.3+wasi-0.2.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9517f9239f02c069db75e65f174b3da828fe5f5b945c4dd26bd25d89c03ebcf5" +checksum = "20064672db26d7cdc89c7798c48a0fdfac8213434a1186e5ef29fd560ae223d6" dependencies = [ - "wit-bindgen", + "wit-bindgen 0.57.1", ] [[package]] @@ -3850,14 +3891,14 @@ version = "0.4.0+wasi-0.3.0-rc-2026-01-06" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5428f8bf88ea5ddc08faddef2ac4a67e390b88186c703ce6dbd955e1c145aca5" dependencies = [ - "wit-bindgen", + "wit-bindgen 0.51.0", ] [[package]] name = "wasm-bindgen" -version = "0.2.114" +version = "0.2.118" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6532f9a5c1ece3798cb1c2cfdba640b9b3ba884f5db45973a6f442510a87d38e" +checksum = "0bf938a0bacb0469e83c1e148908bd7d5a6010354cf4fb73279b7447422e3a89" dependencies = [ "cfg-if", "once_cell", @@ -3868,23 +3909,19 @@ dependencies = [ [[package]] name = "wasm-bindgen-futures" -version = "0.4.64" +version = "0.4.68" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e9c5522b3a28661442748e09d40924dfb9ca614b21c00d3fd135720e48b67db8" +checksum = "f371d383f2fb139252e0bfac3b81b265689bf45b6874af544ffa4c975ac1ebf8" dependencies = [ - "cfg-if", - "futures-util", "js-sys", - "once_cell", "wasm-bindgen", - "web-sys", ] [[package]] name = "wasm-bindgen-macro" -version = "0.2.114" +version = "0.2.118" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "18a2d50fcf105fb33bb15f00e7a77b772945a2ee45dcf454961fd843e74c18e6" +checksum = "eeff24f84126c0ec2db7a449f0c2ec963c6a49efe0698c4242929da037ca28ed" dependencies = [ "quote", "wasm-bindgen-macro-support", @@ -3892,9 +3929,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro-support" -version = "0.2.114" +version = "0.2.118" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "03ce4caeaac547cdf713d280eda22a730824dd11e6b8c3ca9e42247b25c631e3" +checksum = "9d08065faf983b2b80a79fd87d8254c409281cf7de75fc4b773019824196c904" dependencies = [ "bumpalo", "proc-macro2", @@ -3905,9 +3942,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-shared" -version = "0.2.114" +version = "0.2.118" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "75a326b8c223ee17883a4251907455a2431acc2791c98c26279376490c378c16" +checksum = "5fd04d9e306f1907bd13c6361b5c6bfc7b3b3c095ed3f8a9246390f8dbdee129" dependencies = [ "unicode-ident", ] @@ -3961,9 +3998,9 @@ dependencies = [ [[package]] name = "web-sys" -version = "0.3.91" +version = "0.3.95" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "854ba17bb104abfb26ba36da9729addc7ce7f06f5c0f90f3c391f8461cca21f9" +checksum = "4f2dfbb17949fa2088e5d39408c48368947b86f7834484e87b73de55bc14d97d" dependencies = [ "js-sys", "wasm-bindgen", @@ -4234,6 +4271,12 @@ dependencies = [ "wit-bindgen-rust-macro", ] +[[package]] +name = "wit-bindgen" +version = "0.57.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1ebf944e87a7c253233ad6766e082e3cd714b5d03812acc24c318f549614536e" + [[package]] name = "wit-bindgen-core" version = "0.51.0" @@ -4315,15 +4358,15 @@ dependencies = [ [[package]] name = "writeable" -version = "0.6.2" +version = "0.6.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9edde0db4769d2dc68579893f2306b26c6ecfbe0ef499b013d731b7b9247e0b9" +checksum = "1ffae5123b2d3fc086436f8834ae3ab053a283cfac8fe0a0b8eaae044768a4c4" [[package]] name = "yoke" -version = "0.8.1" +version = "0.8.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "72d6e5c6afb84d73944e5cedb052c4680d5657337201555f9f2a16b7406d4954" +checksum = "abe8c5fda708d9ca3df187cae8bfb9ceda00dd96231bed36e445a1a48e66f9ca" dependencies = [ "stable_deref_trait", "yoke-derive", @@ -4332,9 +4375,9 @@ dependencies = [ [[package]] name = "yoke-derive" -version = "0.8.1" +version = "0.8.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b659052874eb698efe5b9e8cf382204678a0086ebf46982b79d6ca3182927e5d" +checksum = "de844c262c8848816172cef550288e7dc6c7b7814b4ee56b3e1553f275f1858e" dependencies = [ "proc-macro2", "quote", @@ -4344,15 +4387,16 @@ dependencies = [ [[package]] name = "zarrs" -version = "0.23.5" +version = "0.23.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ede406d3af180ae03d5707b9430376a979048fd5a178f3bd86237ce20562fe0e" +checksum = "22a4e12a1dfa1eea6e3eff4d85cf758134b882fffb6d0b51bff76c0b17ac240b" dependencies = [ "async-generic", "async-lock", "async-trait", "base64", "blosc-src", + "blusc", "bytemuck", "bytes", "crc32c", @@ -4395,9 +4439,9 @@ dependencies = [ [[package]] name = "zarrs_chunk_grid" -version = "0.5.0" +version = "0.5.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a7c4c887fa4dee415c94fe9a27c971ba564cb73c77ddfae3490465c8b44d041b" +checksum = "1cf67386fd96a0336cd3e5ab5ca6cb14e0e05aee80f1acae8c4d3cf562a8bb65" dependencies = [ "derive_more", "inventory", @@ -4424,9 +4468,9 @@ dependencies = [ [[package]] name = "zarrs_codec" -version = "0.2.0" +version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0464bb05bd143af1d316971ae1cb08d402c75195e7e81bb341f9a5b7b49bfaee" +checksum = "383a129a6a0cbb2c80cdba23809e5cab85159756464b7d0f112468a495c128da" dependencies = [ "async-trait", "bytemuck", @@ -4496,9 +4540,9 @@ dependencies = [ [[package]] name = "zarrs_metadata_ext" -version = "0.4.1" +version = "0.4.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8d05697d0c807d7192942b13a018404b98b0bb284541a650ebfa9be4d7cd8303" +checksum = "a96819f29a4fbd489be05184e28201b7d95c3c6de01c663abb7b7d694be48c6e" dependencies = [ "derive_more", "monostate", @@ -4552,18 +4596,18 @@ dependencies = [ [[package]] name = "zerocopy" -version = "0.8.40" +version = "0.8.48" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a789c6e490b576db9f7e6b6d661bcc9799f7c0ac8352f56ea20193b2681532e5" +checksum = "eed437bf9d6692032087e337407a86f04cd8d6a16a37199ed57949d415bd68e9" dependencies = [ "zerocopy-derive", ] [[package]] name = "zerocopy-derive" -version = "0.8.40" +version = "0.8.48" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f65c489a7071a749c849713807783f70672b28094011623e200cb86dcb835953" +checksum = "70e3cd084b1788766f53af483dd21f93881ff30d7320490ec3ef7526d203bad4" dependencies = [ "proc-macro2", "quote", @@ -4572,18 +4616,18 @@ dependencies = [ [[package]] name = "zerofrom" -version = "0.1.6" +version = "0.1.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "50cc42e0333e05660c3587f3bf9d0478688e15d870fab3346451ce7f8c9fbea5" +checksum = "69faa1f2a1ea75661980b013019ed6687ed0e83d069bc1114e2cc74c6c04c4df" dependencies = [ "zerofrom-derive", ] [[package]] name = "zerofrom-derive" -version = "0.1.6" +version = "0.1.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d71e5d6e06ab090c67b5e44993ec16b72dcbaabc526db883a360057678b48502" +checksum = "11532158c46691caf0f2593ea8358fed6bbf68a0315e80aae9bd41fbade684a1" dependencies = [ "proc-macro2", "quote", @@ -4599,9 +4643,9 @@ checksum = "b97154e67e32c85465826e8bcc1c59429aaaf107c1e4a9e53c8d8ccd5eff88d0" [[package]] name = "zerotrie" -version = "0.2.3" +version = "0.2.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2a59c17a5562d507e4b54960e8569ebee33bee890c70aa3fe7b97e85a9fd7851" +checksum = "0f9152d31db0792fa83f70fb2f83148effb5c1f5b8c7686c3459e361d9bc20bf" dependencies = [ "displaydoc", "yoke", @@ -4610,9 +4654,9 @@ dependencies = [ [[package]] name = "zerovec" -version = "0.11.5" +version = "0.11.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6c28719294829477f525be0186d13efa9a3c602f7ec202ca9e353d310fb9a002" +checksum = "90f911cbc359ab6af17377d242225f4d75119aec87ea711a880987b18cd7b239" dependencies = [ "yoke", "zerofrom", @@ -4621,9 +4665,9 @@ dependencies = [ [[package]] name = "zerovec-derive" -version = "0.11.2" +version = "0.11.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "eadce39539ca5cb3985590102671f2567e659fca9666581ad3411d59207951f3" +checksum = "625dc425cab0dca6dc3c3319506e6593dcb08a9f387ea3b284dbd52a92c40555" dependencies = [ "proc-macro2", "quote", diff --git a/Cargo.toml b/Cargo.toml index 70f2fdd..27639b0 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -21,7 +21,7 @@ name = "_lib" crate-type = ["cdylib"] [dependencies] -nuts-rs = { version = "0.17.3", features = ["zarr", "arrow"] } +nuts-rs = { version = "0.18.0", features = ["zarr", "arrow"] } numpy = "0.28.0" rand = "0.10.0" thiserror = "2.0.3" @@ -35,14 +35,14 @@ smallvec = "1.15.0" upon = { version = "0.10.0", default-features = false, features = [] } time-humanize = { version = "0.1.3", default-features = false } indicatif = "0.18.0" -tch = { version = "0.23.0", optional = true } +tch = { version = "0.24.0", optional = true } pyo3-object_store = "0.9.0" # Keep zarrs crates in sync with nuts-rs requirements zarrs = { version = "0.23.2", features = ["async"] } zarrs_object_store = "0.6.0" -tokio = { version = "1.47.1", features = ["rt", "rt-multi-thread"] } pyo3-arrow = "0.17.0" -arrow = { version = "58.0.0", features = ["json"] } +arrow = { version = "58.1.0", features = ["json"] } +tokio = { version = "1.47.1", features = ["rt", "rt-multi-thread"] } serde = { version = "1.0", features = ["derive"] } serde_json = "1.0" pythonize = "0.28.0" From 1d626dc1300c2e402678eb85f1e6860bfa574b6d Mon Sep 17 00:00:00 2001 From: Adrian Seyboldt Date: Mon, 20 Apr 2026 10:54:18 +0200 Subject: [PATCH 10/21] ci: render docs during ci --- .github/workflows/ci.yml | 33 ++++++++++++++++++++++++++++++++- pyproject.toml | 13 +++++++++++++ 2 files changed, 45 insertions(+), 1 deletion(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 718bc42..81841bd 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -245,6 +245,37 @@ jobs: uv pip install 'nutpie[all]' --find-links dist --force-reinstall pytest -m flow --arraydiff fi + docs: + runs-on: ubuntu-22.04 + steps: + - uses: actions/checkout@v6 + - uses: actions/setup-python@v6 + with: + python-version: "3.12" + - name: Install uv + uses: astral-sh/setup-uv@v7 + - name: Install system dependencies + run: sudo apt-get update && sudo apt-get install -y clang libclang-dev + - name: Install Quarto + uses: quarto-dev/quarto-actions/setup@v2 + - name: Build and install nutpie + run: | + python3 -m venv .venv + source .venv/bin/activate + uv pip install maturin + maturin build --release --out dist + - name: Install doc dependencies + run: | + source .venv/bin/activate + uv pip install "nutpie[docs]" + - name: Render docs + env: + TBB_CXX_TYPE: clang + run: | + source .venv/bin/activate + cd docs + quarto render + sdist: runs-on: ubuntu-latest steps: @@ -264,7 +295,7 @@ jobs: name: Release runs-on: ubuntu-latest if: ${{ startsWith(github.ref, 'refs/tags/') || github.event_name == 'workflow_dispatch' }} - needs: [linux, windows, macos, sdist] + needs: [linux, windows, macos, sdist, docs] environment: name: pypi permissions: diff --git a/pyproject.toml b/pyproject.toml index 2f4dc5e..b603026 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -56,6 +56,19 @@ all = [ "flowjax >= 17.1.0", "equinox >= 0.11.12", ] +docs = [ + "bridgestan >= 2.7.0", + "stanio >= 0.5.1", + "pymc >= 5.20.1", + "numba >= 0.60.0", + "jax >= 0.4.27", + "flowjax >= 17.1.0", + "equinox >= 0.11.12", + "cmdstanpy >= 1.2.0", + "matplotlib >= 3.8.0", + "seaborn >= 0.13.0", + "jupyter", +] [tool.ruff] line-length = 88 From 1f349573096edced1ffd7c9ea29904ee2a2a1c96 Mon Sep 17 00:00:00 2001 From: Adrian Seyboldt Date: Mon, 20 Apr 2026 10:54:18 +0200 Subject: [PATCH 11/21] docs: add citation instructions and link to paper --- CITATION.cff | 18 ++++++++++++++++++ docs/about.qmd | 27 +++++++++++++++++++++++++++ docs/index.qmd | 5 +++++ 3 files changed, 50 insertions(+) diff --git a/CITATION.cff b/CITATION.cff index d3ba1d8..53fda64 100644 --- a/CITATION.cff +++ b/CITATION.cff @@ -22,3 +22,21 @@ keywords: - Bayesian inference - MCMC license: MIT +preferred-citation: + type: article + title: >- + Preconditioning Hamiltonian Monte Carlo by minimizing Fisher Divergence + authors: + - given-names: Adrian + family-names: Seyboldt + - given-names: Eliot L. + family-names: Carlson + - given-names: Bob + family-names: Carpenter + year: 2026 + identifiers: + - type: other + value: 'arXiv:2603.18845' + description: arXiv preprint + url: 'https://arxiv.org/abs/2603.18845' + doi: '10.48550/arXiv.2603.18845' diff --git a/docs/about.qmd b/docs/about.qmd index 16fc02f..2bef075 100644 --- a/docs/about.qmd +++ b/docs/about.qmd @@ -15,3 +15,30 @@ For more information about the PyMC organization, visit the following links: - [PyMC Website](https://www.pymc.io) - [PyMC GitHub Organization](https://github.com/pymc-devs) + +## Paper + +The algorithms behind nutpie's mass matrix adaptation are described in the +following paper: + +> Adrian Seyboldt, Eliot L. Carlson, Bob Carpenter (2026). +> **Preconditioning Hamiltonian Monte Carlo by minimizing Fisher Divergence.** +> arXiv:2603.18845 [stat.CO]. +> [https://arxiv.org/abs/2603.18845](https://arxiv.org/abs/2603.18845) + +## Citation + +If you use nutpie in your research, please cite the following paper: + +```bibtex +@article{seyboldt2026preconditioning, + title = {Preconditioning {Hamiltonian Monte Carlo} by minimizing + {Fisher} Divergence}, + author = {Adrian Seyboldt and Eliot L. Carlson and Bob Carpenter}, + year = {2026}, + eprint = {2603.18845}, + archivePrefix = {arXiv}, + primaryClass = {stat.CO}, + url = {https://arxiv.org/abs/2603.18845} +} +``` diff --git a/docs/index.qmd b/docs/index.qmd index 6de4796..dcb6b59 100644 --- a/docs/index.qmd +++ b/docs/index.qmd @@ -14,6 +14,11 @@ likelihoods with gradient are coming soon). - *Experimental* normalizing flow adaptation for more efficient sampling of difficult posteriors. +For more details on the algorithms used in nutpie, see the paper +[Preconditioning Hamiltonian Monte Carlo by minimizing Fisher +Divergence](https://arxiv.org/abs/2603.18845). If you use nutpie in your +research, please see the [citation instructions](about.qmd#citation). + ## Quickstart: PyMC Install `nutpie` with pip, uv, pixi, or conda: From 0f4c13931a90a22ca151f702b1878daa9fe0adf2 Mon Sep 17 00:00:00 2001 From: Adrian Seyboldt Date: Mon, 20 Apr 2026 10:54:18 +0200 Subject: [PATCH 12/21] docs: add deployment workflow for docs --- .github/workflows/ci.yml | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 81841bd..267582a 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -12,6 +12,8 @@ on: permissions: contents: read + pages: write + id-token: write jobs: linux: @@ -275,6 +277,27 @@ jobs: source .venv/bin/activate cd docs quarto render + - name: Upload docs artifact + if: ${{ github.ref == 'refs/heads/main' || startsWith(github.ref, 'refs/tags/') }} + uses: actions/upload-pages-artifact@v3 + with: + path: docs/_site + + deploy-docs: + name: Deploy Docs + runs-on: ubuntu-latest + if: ${{ github.ref == 'refs/heads/main' || startsWith(github.ref, 'refs/tags/') }} + needs: [docs] + permissions: + pages: write + id-token: write + environment: + name: github-pages + url: ${{ steps.deployment.outputs.page_url }} + steps: + - name: Deploy to GitHub Pages + id: deployment + uses: actions/deploy-pages@v4 sdist: runs-on: ubuntu-latest From 1e3021dd8046df96f69a33d480ecd3c4afe9808f Mon Sep 17 00:00:00 2001 From: Adrian Seyboldt Date: Mon, 20 Apr 2026 11:27:35 +0200 Subject: [PATCH 13/21] docs: add note that mclmc is experimental --- python/nutpie/sample.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/python/nutpie/sample.py b/python/nutpie/sample.py index b353df9..1ba0f12 100644 --- a/python/nutpie/sample.py +++ b/python/nutpie/sample.py @@ -804,6 +804,8 @@ def sample( - ``"nuts"`` (default): No-U-Turn Sampler. - ``"mclmc"``: Microcanonical Langevin Monte Carlo. + mclmc is **experimental** and might change or disapear + in a future release. It might also eat your homework. adaptation: str, default="diag" The mass matrix adaptation strategy to use. One of: From 4e6019cd266c3769c9e31608c8429175f7a5d4ad Mon Sep 17 00:00:00 2001 From: Adrian Seyboldt Date: Mon, 20 Apr 2026 12:15:52 +0200 Subject: [PATCH 14/21] ci: update reference draws due to change in window lengths --- .../test_deterministic_sampling_jax.txt | 400 +++++++++--------- .../test_deterministic_sampling_numba.txt | 400 +++++++++--------- .../test_deterministic_sampling_stan.txt | 4 +- 3 files changed, 402 insertions(+), 402 deletions(-) diff --git a/tests/reference/test_deterministic_sampling_jax.txt b/tests/reference/test_deterministic_sampling_jax.txt index 0c6237a..8c4c60c 100644 --- a/tests/reference/test_deterministic_sampling_jax.txt +++ b/tests/reference/test_deterministic_sampling_jax.txt @@ -1,200 +1,200 @@ -0.00293185 -0.00896809 -0.00768812 -0.00190588 -0.00320792 -0.00740496 -0.0958083 -0.109615 -0.0105327 -0.0192266 -0.0214682 -0.0218331 -0.0585783 -0.0251717 -0.0280682 -0.0729859 -0.425133 -0.457443 -2.52983 -1.15484 -1.15484 -1.18416 -0.104273 -1.20247 -1.1064 -1.67005 -1.05586 -2.55089 -1.83339 -0.971751 -0.470398 -0.284519 -0.253759 -2.29193 -1.29672 -1.29672 -0.432495 -0.411462 -1.10822 -1.10822 -0.698466 -1.01384 -0.422528 -0.471828 -0.354965 -0.370006 -0.932942 -0.924415 -0.821473 -2.34528 -1.8362 -0.329965 -0.427145 -0.995745 -1.17653 -0.937676 -0.937676 -0.71568 -0.916428 -1.05491 -0.479239 -0.488732 -1.07755 -1.05904 -0.269731 -0.197423 -0.303258 -0.0738098 -0.0535444 -0.0704248 -0.083286 -0.158385 -0.149845 -0.416708 -0.349628 -0.31117 -0.304837 -0.0724371 -1.5569 -1.20564 -2.12525 -0.303531 -0.712031 -0.844468 -0.434198 -0.277141 -0.593882 -0.648409 -1.02533 -0.692478 -0.367875 -0.316403 -0.351662 -0.117319 -1.85435 -0.413934 -0.409025 -0.661536 -0.650092 -0.766712 -0.594595 -0.501872 -0.515377 -0.236945 -0.689338 -2.99054 -0.172018 -0.0528735 -0.0579658 -0.0581689 -0.0497977 -0.063146 -0.311101 -0.347411 -0.763051 -0.734721 -1.17926 -1.02504 -1.02504 -0.645771 -0.970169 -1.20163 -1.1179 -0.385697 -0.410691 -0.471671 -0.540587 -0.250604 -0.254267 -0.220907 -0.673968 -0.265055 -0.766607 -1.50436 -1.58131 -0.719291 -0.958127 -0.546963 -1.60432 -1.60432 -1.45897 -0.717682 -0.668208 -0.71339 -0.276479 -0.255967 -0.799242 -1.32658 -0.724295 -0.36085 -0.217894 -0.254816 -0.125993 -1.31909 -1.56969 -0.750499 -1.11993 -1.87465 -1.472 -0.950422 -0.754906 -0.270587 -0.231469 -1.19634 -1.19634 -1.19634 -1.51182 -1.34804 -1.42657 -0.544703 -1.66443 -1.66443 -1.14928 -1.10046 -1.16557 -1.5537 -0.629914 -0.880496 -0.525169 -0.312335 -0.797038 -0.733363 -1.6496 -0.0602699 -0.0840557 -0.107319 -0.0324205 -0.0929894 -0.226149 -0.202803 -0.217807 -0.366175 -0.158146 -0.160235 -0.175013 -0.148804 -0.526506 -0.785313 -1.23336 -0.733001 +0.165094 +0.139115 +1.50037 +1.50037 +1.61 +0.821612 +1.03082 +0.821668 +0.775115 +0.543092 +0.0861342 +0.0744322 +0.128651 +0.076151 +0.0260737 +0.0275716 +0.440052 +0.493376 +1.32419 +1.41206 +1.41206 +0.202022 +0.320102 +0.831304 +0.534264 +1.74811 +0.355012 +0.346681 +0.289411 +0.253201 +0.586474 +1.32659 +1.32659 +0.829193 +1.10147 +1.10147 +1.87636 +1.19348 +1.16959 +0.176166 +0.226725 +0.102413 +0.0500521 +0.00269899 +0.00273883 +0.00274408 +0.00123599 +0.00354132 +0.000358106 +0.0003337 +0.00028981 +0.000330851 +0.000826946 +0.00941884 +0.00932276 +0.00841899 +0.00848633 +0.0317742 +0.0303507 +0.0304411 +0.0343137 +0.0390462 +0.0342251 +0.0941789 +0.114046 +0.216451 +0.187088 +0.296988 +0.303359 +0.141562 +0.205275 +0.144005 +0.025403 +0.02575 +0.0276218 +0.147767 +0.311706 +0.36242 +0.32092 +2.46225 +0.276948 +0.302493 +0.138013 +0.13045 +0.25569 +1.62811 +0.610429 +0.974553 +0.974578 +0.97197 +0.278033 +0.469987 +0.534978 +0.534978 +0.505755 +1.38751 +0.316202 +0.565196 +0.394078 +0.442038 +0.529314 +0.753325 +1.23431 +1.03144 +0.717102 +0.405574 +0.447388 +0.673347 +0.330674 +0.322181 +0.322054 +0.268087 +0.257871 +0.251059 +1.0939 +1.0939 +1.09372 +0.640143 +0.422795 +0.358448 +0.300344 +0.325204 +0.215501 +0.725504 +0.763198 +0.991093 +0.420991 +0.957573 +1.31859 +2.13073 +1.50451 +0.464927 +0.169009 +0.0372169 +0.017995 +0.0210875 +0.0114971 +0.0152972 +0.0223257 +0.0263839 +0.048458 +0.665941 +0.657921 +0.490182 +0.252479 +0.472718 +0.945177 +0.658253 +0.131824 +0.369641 +0.715825 +0.331916 +0.370504 +0.347724 +0.0414442 +0.339334 +0.354162 +0.698269 +1.50379 +0.0395252 +0.0383166 +0.415532 +0.705242 +0.983896 +0.14184 +0.491699 +1.40669 +1.72872 +0.744003 +0.744003 +0.520324 +0.505936 +0.570554 +0.937554 +0.716306 +0.872266 +0.730075 +0.673381 +1.85416 +1.44276 +0.379716 +0.226359 +0.228589 +0.268951 +1.66516 +2.42012 +0.87476 +1.1232 +1.15261 +0.06469 +0.114697 +1.30398 +1.25665 +1.27469 +1.15608 +1.12404 +0.736236 +0.321845 +0.136899 +0.141005 diff --git a/tests/reference/test_deterministic_sampling_numba.txt b/tests/reference/test_deterministic_sampling_numba.txt index 5bea297..8c4c60c 100644 --- a/tests/reference/test_deterministic_sampling_numba.txt +++ b/tests/reference/test_deterministic_sampling_numba.txt @@ -1,200 +1,200 @@ -0.00293185 -0.00896808 -0.00768811 -0.00190587 -0.00320792 -0.00740495 -0.0958081 -0.109615 -0.0105327 -0.0192265 -0.0214681 -0.021833 -0.0585782 -0.0251717 -0.0280681 -0.0729857 -0.425132 -0.457442 -2.52983 -1.15483 -1.15483 -1.18416 -0.104274 -1.20247 -1.1064 -1.67005 -1.05586 -2.55089 -1.8334 -0.971751 -0.470398 -0.284519 -0.25376 -2.29193 -1.29672 -1.29672 -0.432495 -0.411463 -1.10822 -1.10822 -0.698467 -1.01384 -0.422528 -0.471828 -0.354965 -0.370006 -0.932942 -0.924415 -0.821473 -2.34528 -1.8362 -0.329965 -0.427145 -0.995744 -1.17653 -0.937677 -0.937677 -0.71568 -0.916428 -1.05491 -0.479239 -0.488732 -1.07755 -1.05904 -0.269731 -0.197423 -0.303257 -0.0738098 -0.0535443 -0.0704248 -0.083286 -0.158385 -0.149844 -0.416707 -0.349628 -0.31117 -0.304836 -0.072437 -1.5569 -1.20564 -2.12525 -0.303531 -0.712031 -0.844469 -0.434198 -0.277141 -0.593882 -0.648409 -1.02533 -0.692478 -0.367875 -0.316403 -0.351662 -0.117319 -1.85435 -0.413932 -0.409023 -0.661534 -0.650092 -0.766712 -0.594595 -0.501872 -0.515377 -0.236945 -0.689338 -2.99054 -0.172018 -0.0528735 -0.0579658 -0.0581689 -0.0497977 -0.063146 -0.311101 -0.347411 -0.763051 -0.734721 -1.17926 -1.02504 -1.02504 -0.645771 -0.970169 -1.20163 -1.1179 -0.385697 -0.410691 -0.471671 -0.540587 -0.250604 -0.254267 -0.220907 -0.673968 -0.265055 -0.766607 -1.50436 -1.58131 -0.719291 -0.958127 -0.546963 -1.60432 -1.60432 -1.45897 -0.717682 -0.668208 -0.71339 -0.276479 -0.255967 -0.799242 -1.32658 -0.724295 -0.36085 -0.217894 -0.254816 -0.125993 -1.31909 -1.56969 -0.750499 -1.11993 -1.87465 -1.472 -0.950422 -0.754906 -0.270587 -0.231469 -1.19634 -1.19634 -1.19634 -1.51182 -1.34804 -1.42657 -0.544703 -1.66443 -1.66443 -1.14928 -1.10046 -1.16557 -1.5537 -0.629914 -0.880496 -0.525169 -0.312335 -0.797038 -0.733363 -1.6496 -0.0602699 -0.0840557 -0.107319 -0.0324205 -0.0929894 -0.226149 -0.202803 -0.217807 -0.366175 -0.158146 -0.160235 -0.175013 -0.148804 -0.526506 -0.785313 -1.23336 -0.733001 +0.165094 +0.139115 +1.50037 +1.50037 +1.61 +0.821612 +1.03082 +0.821668 +0.775115 +0.543092 +0.0861342 +0.0744322 +0.128651 +0.076151 +0.0260737 +0.0275716 +0.440052 +0.493376 +1.32419 +1.41206 +1.41206 +0.202022 +0.320102 +0.831304 +0.534264 +1.74811 +0.355012 +0.346681 +0.289411 +0.253201 +0.586474 +1.32659 +1.32659 +0.829193 +1.10147 +1.10147 +1.87636 +1.19348 +1.16959 +0.176166 +0.226725 +0.102413 +0.0500521 +0.00269899 +0.00273883 +0.00274408 +0.00123599 +0.00354132 +0.000358106 +0.0003337 +0.00028981 +0.000330851 +0.000826946 +0.00941884 +0.00932276 +0.00841899 +0.00848633 +0.0317742 +0.0303507 +0.0304411 +0.0343137 +0.0390462 +0.0342251 +0.0941789 +0.114046 +0.216451 +0.187088 +0.296988 +0.303359 +0.141562 +0.205275 +0.144005 +0.025403 +0.02575 +0.0276218 +0.147767 +0.311706 +0.36242 +0.32092 +2.46225 +0.276948 +0.302493 +0.138013 +0.13045 +0.25569 +1.62811 +0.610429 +0.974553 +0.974578 +0.97197 +0.278033 +0.469987 +0.534978 +0.534978 +0.505755 +1.38751 +0.316202 +0.565196 +0.394078 +0.442038 +0.529314 +0.753325 +1.23431 +1.03144 +0.717102 +0.405574 +0.447388 +0.673347 +0.330674 +0.322181 +0.322054 +0.268087 +0.257871 +0.251059 +1.0939 +1.0939 +1.09372 +0.640143 +0.422795 +0.358448 +0.300344 +0.325204 +0.215501 +0.725504 +0.763198 +0.991093 +0.420991 +0.957573 +1.31859 +2.13073 +1.50451 +0.464927 +0.169009 +0.0372169 +0.017995 +0.0210875 +0.0114971 +0.0152972 +0.0223257 +0.0263839 +0.048458 +0.665941 +0.657921 +0.490182 +0.252479 +0.472718 +0.945177 +0.658253 +0.131824 +0.369641 +0.715825 +0.331916 +0.370504 +0.347724 +0.0414442 +0.339334 +0.354162 +0.698269 +1.50379 +0.0395252 +0.0383166 +0.415532 +0.705242 +0.983896 +0.14184 +0.491699 +1.40669 +1.72872 +0.744003 +0.744003 +0.520324 +0.505936 +0.570554 +0.937554 +0.716306 +0.872266 +0.730075 +0.673381 +1.85416 +1.44276 +0.379716 +0.226359 +0.228589 +0.268951 +1.66516 +2.42012 +0.87476 +1.1232 +1.15261 +0.06469 +0.114697 +1.30398 +1.25665 +1.27469 +1.15608 +1.12404 +0.736236 +0.321845 +0.136899 +0.141005 diff --git a/tests/reference/test_deterministic_sampling_stan.txt b/tests/reference/test_deterministic_sampling_stan.txt index dd85d53..f60ab00 100644 --- a/tests/reference/test_deterministic_sampling_stan.txt +++ b/tests/reference/test_deterministic_sampling_stan.txt @@ -1,2 +1,2 @@ -0.754944 0.746804 0.687211 1.56984 2.15413 2.15413 0.186138 1.19976 1.19976 0.818806 -0.185979 1.20179 0.236474 0.240597 0.416886 0.529295 0.574728 0.59912 1.02193 0.902788 +1.74005 1.56502 0.907101 0.30108 0.202376 0.373732 0.359263 0.364645 0.645229 1.04912 +0.900432 0.900432 0.889217 0.889217 0.859291 1.42722 1.36349 1.36349 0.240703 0.187522 From d057a4dc2ef56935d4bcb7bf6a3a4e0919f20c52 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Sun, 12 Apr 2026 19:59:33 +0200 Subject: [PATCH 15/21] fix: Compat with newer PyTensor and Arviz --- docs/pymc-usage.qmd | 2 +- pyproject.toml | 2 +- python/nutpie/compile_pymc.py | 11 +++++---- python/nutpie/sample.py | 45 ++++++++++++++++++++++++++--------- tests/test_pymc.py | 4 ++-- 5 files changed, 44 insertions(+), 20 deletions(-) diff --git a/docs/pymc-usage.qmd b/docs/pymc-usage.qmd index a045490..9e6d7c0 100644 --- a/docs/pymc-usage.qmd +++ b/docs/pymc-usage.qmd @@ -110,7 +110,7 @@ az.ess(trace) and take a look at a trace plot: ```{python} -az.plot_trace(trace); +az.plot_trace_dist(trace); ``` ### Choosing the backend diff --git a/pyproject.toml b/pyproject.toml index b603026..0122cda 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,7 +21,7 @@ dependencies = [ "pandas >= 2.0", "platformdirs >= 3.0.0", "xarray >= 2025.01.2", - "arviz >= 0.20.0,<1.0", + "arviz >= 0.23.0,<2.0", "obstore >= 0.8.0", "zarr >= 3.1.0", ] diff --git a/python/nutpie/compile_pymc.py b/python/nutpie/compile_pymc.py index 9fe6594..a7b750a 100644 --- a/python/nutpie/compile_pymc.py +++ b/python/nutpie/compile_pymc.py @@ -686,7 +686,8 @@ def _make_functions( correspond to the variables in the flat array, and the third list contains the shapes of the variables. """ - import pytensor + from pytensor.graph import rewrite_graph, clone_replace + import pytensor.tensor as pt from pymc.pytensorf import compile as compile_pymc @@ -703,10 +704,10 @@ def _make_functions( if not model.check_bounds: rewrites.append("local_remove_check_parameter") - logp = pytensor.graph.rewrite_graph(logp, include=rewrites) + logp = rewrite_graph(logp, include=rewrites) if compute_grad: - grads = pytensor.gradient.grad(logp, value_vars) + grads = pt.grad(logp, value_vars) grad = pt.concatenate( [ pt.as_tensor(grad, allow_xtensor_conversion=True).ravel() @@ -759,11 +760,11 @@ def _make_functions( } if compute_grad: - (logp, grad) = pytensor.clone_replace([logp, grad], replacements) + (logp, grad) = clone_replace([logp, grad], replacements) with model: logp_fn_pt = compile_pymc((joined,), (logp, grad), mode=mode) else: - (logp,) = pytensor.clone_replace([logp], replacements) + (logp,) = clone_replace([logp], replacements) with model: logp_fn_pt = compile_pymc((joined,), (logp,), mode=mode) diff --git a/python/nutpie/sample.py b/python/nutpie/sample.py index 1ba0f12..65b5cc8 100644 --- a/python/nutpie/sample.py +++ b/python/nutpie/sample.py @@ -7,6 +7,7 @@ import numpy as np import pandas as pd import pyarrow +import xarray as xr from nutpie import _lib @@ -98,10 +99,12 @@ def _arrow_to_arviz(draw_batches, stat_batches, skip_vars=None, **kwargs): ) return arviz.from_dict( - data_posterior, - sample_stats=stats_posterior, - warmup_posterior=data_tune, - warmup_sample_stats=stats_tune, + { + "posterior": data_posterior, + "sample_stats": stats_posterior, + "warmup_posterior": data_tune, + "warmup_sample_stats": stats_tune, + }, dims=dims, **kwargs, ) @@ -552,7 +555,6 @@ def _extract(self, results): else: if results.is_zarr(): import obstore - import xarray as xr from zarr.storage import ObjectStore assert self._zarr_store is not None @@ -563,8 +565,7 @@ def _extract(self, results): store = cls(*args, **kwargs) obj_store = ObjectStore(store, read_only=True) - ds = xr.open_datatree(obj_store, engine="zarr", consolidated=False) # ty:ignore[invalid-argument-type] - return arviz.from_datatree(ds) + return xr.open_datatree(obj_store, engine="zarr", consolidated=False) # ty:ignore[invalid-argument-type] elif results.is_arrow(): skip_vars = [] @@ -652,6 +653,28 @@ def _repr_html_(self): return self._html +@overload +def sample( + compiled_model: CompiledModel, + *, + draws: int | None = None, + tune: int | None = None, + chains: int | None = None, + cores: int | None = None, + seed: int | None = None, + save_warmup: bool = True, + progress_bar: bool = True, + adaptation: Literal["diag", "draw_diag", "low_rank", "flow"] = "diag", + init_mean: np.ndarray | None = None, + return_raw_trace: bool = False, + progress_callback: Any | None = None, + progress_template: str | None = None, + progress_style: str | None = None, + progress_rate: int = 100, + zarr_store: _ZarrStoreType | None = None, +) -> xr.DataTree: ... + + @overload def sample( compiled_model: CompiledModel, @@ -673,7 +696,7 @@ def sample( progress_rate: int = 100, zarr_store: _ZarrStoreType | None = None, **kwargs, -) -> arviz.InferenceData: ... +) -> xr.DataTree: ... @overload @@ -744,7 +767,7 @@ def sample( progress_rate: int = 100, zarr_store: _ZarrStoreType | None = None, **kwargs, -) -> arviz.InferenceData | _BackgroundSampler: +) -> xr.DataTree | _BackgroundSampler: """Sample the posterior distribution for a compiled model. Parameters @@ -871,8 +894,8 @@ def sample( Returns ------- - trace : arviz.InferenceData - An ArviZ ``InferenceData`` object that contains the samples. + trace : xr.DataTree: + An Xarray ``DataTree`` object that contains the samples. """ # Backward-compatible deprecated keyword arguments. diff --git a/tests/test_pymc.py b/tests/test_pymc.py index 4fa041d..692eb5f 100644 --- a/tests/test_pymc.py +++ b/tests/test_pymc.py @@ -356,7 +356,7 @@ def test_det(backend, gradient_backend): @parameterize_backends def test_non_identifier_names(backend, gradient_backend): with pm.Model() as model: - a = pm.Uniform("a/b", shape=2) + a = pm.Uniform("a::b", shape=2) with pm.Model("foo"): c = pm.Data("c", np.array([2.0, 3.0])) pm.Deterministic("b", c * a) @@ -365,7 +365,7 @@ def test_non_identifier_names(backend, gradient_backend): model, backend=backend, gradient_backend=gradient_backend ) trace = nutpie.sample(compiled, chains=1) - assert trace.posterior["a/b"].shape[-1] == 2 + assert trace.posterior["a::b"].shape[-1] == 2 assert trace.posterior["foo::b"].shape[-1] == 2 From 43d938f72c14a428e1fbbb89e6158171f4d0e35d Mon Sep 17 00:00:00 2001 From: aloctavodia Date: Fri, 17 Apr 2026 08:29:15 +0300 Subject: [PATCH 16/21] ci: test against upstream pytensor and pymc for now --- .github/workflows/ci.yml | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 267582a..dcc3b0f 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -58,13 +58,13 @@ jobs: set -e python3 -m venv .venv source .venv/bin/activate - uv pip install 'nutpie[stan]' --find-links dist --force-reinstall + uv pip install 'nutpie[stan]' --constraint .github/uv-constraints.txt --find-links dist --force-reinstall uv pip install pytest pytest-timeout pytest-arraydiff pytest -m "stan and not flow" --arraydiff - uv pip install 'nutpie[pymc]' --find-links dist --force-reinstall + uv pip install 'nutpie[pymc]' --constraint .github/uv-constraints.txt --find-links dist --force-reinstall uv pip install jax pytest -m "pymc and not flow" --arraydiff - uv pip install 'nutpie[all]' --find-links dist --force-reinstall + uv pip install 'nutpie[all]' --constraint .github/uv-constraints.txt --find-links dist --force-reinstall pytest -m flow --arraydiff # pyarrow doesn't currently seem to work on musllinux @@ -186,13 +186,13 @@ jobs: set -e python3 -m venv .venv source .venv/Scripts/activate - uv pip install "nutpie[stan]" --find-links dist --force-reinstall + uv pip install "nutpie[stan]" --constraint .github/uv-constraints.txt --find-links dist --force-reinstall uv pip install pytest pytest-timeout pytest-arraydiff pytest -m "stan and not flow" --arraydiff - uv pip install "nutpie[pymc]" --find-links dist --force-reinstall + uv pip install "nutpie[pymc]" --constraint .github/uv-constraints.txt --find-links dist --force-reinstall uv pip install jax pytest -m "pymc and not flow" --arraydiff - uv pip install "nutpie[all]" --find-links dist --force-reinstall + uv pip install "nutpie[all]" --constraint .github/uv-constraints.txt --find-links dist --force-reinstall pytest -m flow --arraydiff macos: @@ -236,15 +236,15 @@ jobs: set -e python3 -m venv .venv source .venv/bin/activate - uv pip install 'nutpie[stan]' --find-links dist --force-reinstall + uv pip install 'nutpie[stan]' --constraint .github/uv-constraints.txt --find-links dist --force-reinstall uv pip install pytest pytest-timeout pytest-arraydiff pytest -m "stan and not flow" --arraydiff # Skip on x86_64 due to lack of numba support if [ "${{ matrix.platform.target }}" != "x86_64" ]; then - uv pip install 'nutpie[pymc]' --find-links dist --force-reinstall + uv pip install 'nutpie[pymc]' --constraint .github/uv-constraints.txt --find-links dist --force-reinstall uv pip install jax pytest -m "pymc and not flow" --arraydiff - uv pip install 'nutpie[all]' --find-links dist --force-reinstall + uv pip install 'nutpie[all]' --constraint .github/uv-constraints.txt --find-links dist --force-reinstall pytest -m flow --arraydiff fi docs: @@ -269,7 +269,7 @@ jobs: - name: Install doc dependencies run: | source .venv/bin/activate - uv pip install "nutpie[docs]" + uv pip install --constraint .github/uv-constraints.txt "nutpie[docs]" --find-links dist --force-reinstall - name: Render docs env: TBB_CXX_TYPE: clang From ce36263c5acfdbf17946417a6ac703f33bff4d1a Mon Sep 17 00:00:00 2001 From: Adrian Seyboldt Date: Wed, 22 Apr 2026 09:37:33 +0200 Subject: [PATCH 17/21] docs: add fisher divergence section --- docs/sample-stats.qmd | 75 +++++++++++++++++++++++++++++++++++++++---- 1 file changed, 69 insertions(+), 6 deletions(-) diff --git a/docs/sample-stats.qmd b/docs/sample-stats.qmd index 75798bd..2e9961f 100644 --- a/docs/sample-stats.qmd +++ b/docs/sample-stats.qmd @@ -1,4 +1,4 @@ -# Understanding Sampler Statistics in Nutpie +# Sampler Statistics in Nutpie This guide explains the various statistics that nutpie collects during sampling. We'll use Neal's funnel distribution as an example, as it's a challenging model that demonstrates many important sampling concepts. @@ -29,6 +29,7 @@ trace = nutpie.sample( store_gradient=True, store_unconstrained=True, store_divergences=True, + store_transformed=True, seed=42, ) ``` @@ -149,6 +150,7 @@ trace = nutpie.sample( tune=1000, store_gradient=True, store_unconstrained=True, + store_transformed=True, store_mass_matrix=True, seed=42, ) @@ -158,12 +160,10 @@ Now we can compute eigenvalues of the covariance matrix of the gradient and draws (using the singular value decomposition to avoid quadratic cost): ```{python} -def covariance_eigenvalues(x, mass_matrix): +def covariance_eigenvalues(x): assert x.dims == ("chain", "draw", "unconstrained_parameter") x = x.stack(sample=["draw", "chain"]) - x = (x - x.mean("sample")) / np.sqrt(mass_matrix) u, s, v = np.linalg.svd(x.T / np.sqrt(x.shape[1]), full_matrices=False) - print(u.shape, s.shape, v.shape) s = xr.DataArray( s, dims=["eigenvalue"], @@ -180,8 +180,8 @@ def covariance_eigenvalues(x, mass_matrix): return s ** 2, v mass_matrix = trace.sample_stats.mass_matrix_inv.isel(draw=-1, chain=0) -draws_eigs, draws_eigv = covariance_eigenvalues(trace.sample_stats.unconstrained_draw, mass_matrix) -grads_eigs, grads_eigv = covariance_eigenvalues(trace.sample_stats.gradient, 1 / mass_matrix) +draws_eigs, draws_eigv = covariance_eigenvalues(trace.sample_stats.transformed_position) +grads_eigs, grads_eigv = covariance_eigenvalues(trace.sample_stats.transformed_gradient) draws_eigs.plot.line(x="eigenvalue", yscale="log") grads_eigs.plot.line(x="eigenvalue", yscale="log") @@ -219,3 +219,66 @@ the correlations: .plot.bar(x="unconstrained_parameter") ) ``` + +# Fisher divergence + +We can measure how well the mass matrix adaptation could approximate the posterior +by checking the fisher divergence between the transformed draws and a standard normal +distribution. (The transformed position and scores are only saved with `store_transformed=True`). + +```{python} +fisher_divergence_warmup = ( + trace.warmup_sample_stats.transformed_position + + trace.warmup_sample_stats.transformed_gradient +) ** 2 + +fisher_divergence = ( + trace.sample_stats.transformed_position + + trace.sample_stats.transformed_gradient +) ** 2 + +fisher_divergence_warmup.sum("unconstrained_parameter").plot.line(x="draw") +plt.ylim(1e-3, None) +plt.yscale("log"); +``` + +The relatively large divergence value of more than 1000 tells us that the mass matrix can not adapt well to the posterior. + +We can investigate this on a per-variable basis to get a good indication about +which variables are involved in the problematic region: + +```{python} +fisher_divergence.mean(["draw", "chain"]).to_pandas().sort_values().tail() +``` + +Sampling with low rank mass matrix adaptation improves the fit of the mass matrix, and increases sampler efficiency. + +```{python} +compiled = nutpie.compile_pymc_model(model) +trace = nutpie.sample( + compiled, + tune=1000, + store_gradient=True, + store_unconstrained=True, + store_transformed=True, + store_mass_matrix=True, + seed=42, + adaptation="low_rank", +) +``` + +```{python} +fisher_divergence_warmup = ( + trace.warmup_sample_stats.transformed_position + + trace.warmup_sample_stats.transformed_gradient +) ** 2 + +fisher_divergence = ( + trace.sample_stats.transformed_position + + trace.sample_stats.transformed_gradient +) ** 2 + +fisher_divergence_warmup.sum("unconstrained_parameter").plot.line(x="draw") +plt.ylim(1e-3, None) +plt.yscale("log"); +``` From 4f191bb2626ec2401d668e27492f364214ee026e Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Tue, 21 Apr 2026 17:07:51 +0200 Subject: [PATCH 18/21] ci: split CI into build + per-suite test jobs Each platform now has a build-only job that uploads the wheel, and test jobs (stan/pymc/flow) that download the artifact instead of rebuilding. Job names drop the redundant runner slug (e.g. "linux (x86_64)" instead of "linux (ubuntu-22.04, x86_64)"). pymc_dev reuses the linux x86_64 wheel. Release is gated on builds, stan tests, and pymc_dev. pymc/flow suites run against released pymc/pytensor with continue-on-error so their expected failures don't block release. --- .github/workflows/ci.yml | 212 ++++++++++++++++++++++++++++++++------- 1 file changed, 173 insertions(+), 39 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index dcc3b0f..9f2d7a6 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -16,7 +16,8 @@ permissions: id-token: write jobs: - linux: + build_linux: + name: build linux (${{ matrix.platform.target }}) runs-on: ${{ matrix.platform.runner }} strategy: fail-fast: false @@ -52,20 +53,57 @@ jobs: with: name: wheels-linux-${{ matrix.platform.target }} path: dist + + test_linux: + name: test linux (${{ matrix.target }}, ${{ matrix.python-version }}, ${{ matrix.suite }}) + runs-on: ${{ matrix.runner }} + needs: build_linux + # stan failures block release; pymc/flow are expected red against + # released pymc, so allow them to fail without gating. + continue-on-error: ${{ matrix.suite != 'stan' }} + strategy: + fail-fast: false + matrix: + target: [x86_64, aarch64] + suite: [stan, pymc, flow] + python-version: ["3.10", "3.11", "3.12", "3.13", "3.14"] + include: + - target: x86_64 + runner: ubuntu-22.04 + - target: aarch64 + runner: ubuntu-22.04-arm + steps: + - uses: actions/checkout@v6 + - name: Install uv + uses: astral-sh/setup-uv@v7 + - uses: actions/setup-python@v6 + with: + python-version: ${{ matrix.python-version }} + - uses: actions/download-artifact@v8 + with: + name: wheels-linux-${{ matrix.target }} + path: dist - name: pytest shell: bash run: | set -e python3 -m venv .venv source .venv/bin/activate - uv pip install 'nutpie[stan]' --constraint .github/uv-constraints.txt --find-links dist --force-reinstall uv pip install pytest pytest-timeout pytest-arraydiff - pytest -m "stan and not flow" --arraydiff - uv pip install 'nutpie[pymc]' --constraint .github/uv-constraints.txt --find-links dist --force-reinstall - uv pip install jax - pytest -m "pymc and not flow" --arraydiff - uv pip install 'nutpie[all]' --constraint .github/uv-constraints.txt --find-links dist --force-reinstall - pytest -m flow --arraydiff + case "${{ matrix.suite }}" in + stan) + uv pip install 'nutpie[stan]' --find-links dist --force-reinstall + pytest -m "stan and not flow" --arraydiff + ;; + pymc) + uv pip install 'nutpie[pymc]' jax --find-links dist --force-reinstall + pytest -m "pymc and not flow" --arraydiff + ;; + flow) + uv pip install 'nutpie[all]' --find-links dist --force-reinstall + pytest -m flow --arraydiff + ;; + esac # pyarrow doesn't currently seem to work on musllinux #musllinux: @@ -137,13 +175,9 @@ jobs: # uv pip install 'nutpie[stan]' --find-links dist --force-reinstall # pytest - windows: - runs-on: ${{ matrix.platform.runner }} - strategy: - matrix: - platform: - - runner: windows-latest - target: x64 + build_windows: + name: build windows + runs-on: windows-latest steps: - uses: actions/checkout@v6 - uses: actions/setup-python@v6 @@ -154,7 +188,7 @@ jobs: 3.12 3.13 3.14 - architecture: ${{ matrix.platform.target }} + architecture: x64 - name: Install uv uses: astral-sh/setup-uv@v7 - name: Install LLVM and Clang @@ -163,7 +197,6 @@ jobs: version: "15.0" directory: ${{ runner.temp }}/llvm - name: Set up TBB - if: matrix.os == 'windows-latest' run: | Add-Content $env:GITHUB_PATH "$(pwd)/stan/lib/stan_math/lib/tbb" - name: Build wheels @@ -171,31 +204,60 @@ jobs: env: LIBCLANG_PATH: ${{ runner.temp }}/llvm/lib with: - target: ${{ matrix.platform.target }} + target: x64 args: --release --out dist --find-interpreter sccache: ${{ !startsWith(github.ref, 'refs/tags/') }} - name: Upload wheels uses: actions/upload-artifact@v7 with: - name: wheels-windows-${{ matrix.platform.target }} + name: wheels-windows-x64 + path: dist + + test_windows: + name: test windows (${{ matrix.python-version }}, ${{ matrix.suite }}) + runs-on: windows-latest + needs: build_windows + continue-on-error: ${{ matrix.suite != 'stan' }} + strategy: + fail-fast: false + matrix: + suite: [stan, pymc, flow] + python-version: ["3.10", "3.11", "3.12", "3.13", "3.14"] + steps: + - uses: actions/checkout@v6 + - uses: actions/setup-python@v6 + with: + python-version: ${{ matrix.python-version }} + - name: Install uv + uses: astral-sh/setup-uv@v7 + - uses: actions/download-artifact@v8 + with: + name: wheels-windows-x64 path: dist - name: pytest - if: ${{ !startsWith(matrix.platform.target, 'aarch64') }} shell: bash run: | set -e python3 -m venv .venv source .venv/Scripts/activate - uv pip install "nutpie[stan]" --constraint .github/uv-constraints.txt --find-links dist --force-reinstall uv pip install pytest pytest-timeout pytest-arraydiff - pytest -m "stan and not flow" --arraydiff - uv pip install "nutpie[pymc]" --constraint .github/uv-constraints.txt --find-links dist --force-reinstall - uv pip install jax - pytest -m "pymc and not flow" --arraydiff - uv pip install "nutpie[all]" --constraint .github/uv-constraints.txt --find-links dist --force-reinstall - pytest -m flow --arraydiff + case "${{ matrix.suite }}" in + stan) + uv pip install "nutpie[stan]" --find-links dist --force-reinstall + pytest -m "stan and not flow" --arraydiff + ;; + pymc) + uv pip install "nutpie[pymc]" jax --find-links dist --force-reinstall + pytest -m "pymc and not flow" --arraydiff + ;; + flow) + uv pip install "nutpie[all]" --find-links dist --force-reinstall + pytest -m flow --arraydiff + ;; + esac - macos: + build_macos: + name: build macos (${{ matrix.platform.target }}) runs-on: ${{ matrix.platform.runner }} strategy: fail-fast: false @@ -231,22 +293,83 @@ jobs: with: name: wheels-macos-${{ matrix.platform.target }} path: dist + + test_macos: + name: test macos (${{ matrix.target }}, ${{ matrix.python-version }}, ${{ matrix.suite }}) + runs-on: ${{ matrix.runner }} + needs: build_macos + continue-on-error: ${{ matrix.suite != 'stan' }} + strategy: + fail-fast: false + matrix: + target: [x86_64, aarch64] + suite: [stan, pymc, flow] + python-version: ["3.10", "3.11", "3.12", "3.13", "3.14"] + include: + - target: x86_64 + runner: macos-15-intel + - target: aarch64 + runner: macos-14 + exclude: + # numba isn't available on macOS x86_64 + - target: x86_64 + suite: pymc + - target: x86_64 + suite: flow + steps: + - uses: actions/checkout@v6 + - uses: actions/setup-python@v6 + with: + python-version: ${{ matrix.python-version }} + - name: Install uv + uses: astral-sh/setup-uv@v7 + - uses: actions/download-artifact@v8 + with: + name: wheels-macos-${{ matrix.target }} + path: dist - name: pytest run: | set -e python3 -m venv .venv source .venv/bin/activate - uv pip install 'nutpie[stan]' --constraint .github/uv-constraints.txt --find-links dist --force-reinstall uv pip install pytest pytest-timeout pytest-arraydiff - pytest -m "stan and not flow" --arraydiff - # Skip on x86_64 due to lack of numba support - if [ "${{ matrix.platform.target }}" != "x86_64" ]; then - uv pip install 'nutpie[pymc]' --constraint .github/uv-constraints.txt --find-links dist --force-reinstall - uv pip install jax - pytest -m "pymc and not flow" --arraydiff - uv pip install 'nutpie[all]' --constraint .github/uv-constraints.txt --find-links dist --force-reinstall - pytest -m flow --arraydiff - fi + case "${{ matrix.suite }}" in + stan) + uv pip install 'nutpie[stan]' --find-links dist --force-reinstall + pytest -m "stan and not flow" --arraydiff + ;; + pymc) + uv pip install 'nutpie[pymc]' jax --find-links dist --force-reinstall + pytest -m "pymc and not flow" --arraydiff + ;; + flow) + uv pip install 'nutpie[all]' --find-links dist --force-reinstall + pytest -m flow --arraydiff + ;; + esac + + test_pymc_dev: + name: pymc (dev) + runs-on: ubuntu-22.04 + needs: build_linux + steps: + - uses: actions/checkout@v6 + - name: Install uv + uses: astral-sh/setup-uv@v7 + - uses: actions/setup-python@v6 + with: + python-version: "3.12" + - uses: actions/download-artifact@v8 + with: + name: wheels-linux-x86_64 + path: dist + - name: pytest + run: | + python3 -m venv .venv + source .venv/bin/activate + uv pip install 'nutpie[dev]' --constraint .github/uv-constraints.txt --find-links dist --force-reinstall + pytest -m "pymc and not flow" --arraydiff + docs: runs-on: ubuntu-22.04 steps: @@ -318,7 +441,18 @@ jobs: name: Release runs-on: ubuntu-latest if: ${{ startsWith(github.ref, 'refs/tags/') || github.event_name == 'workflow_dispatch' }} - needs: [linux, windows, macos, sdist, docs] + needs: + [ + build_linux, + build_windows, + build_macos, + sdist, + test_linux, + test_windows, + test_macos, + test_pymc_dev, + docs, + ] environment: name: pypi permissions: From 79b464dfe63740ad15372641e1e65c57ef911336 Mon Sep 17 00:00:00 2001 From: Adrian Seyboldt Date: Wed, 22 Apr 2026 11:59:20 +0200 Subject: [PATCH 19/21] ci: make sure to use the local wheel to test nutpie --- .github/workflows/ci.yml | 29 +++++++++++++++++------------ 1 file changed, 17 insertions(+), 12 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 9f2d7a6..9c35494 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -90,17 +90,18 @@ jobs: python3 -m venv .venv source .venv/bin/activate uv pip install pytest pytest-timeout pytest-arraydiff + uv pip install nutpie --no-deps --find-links dist --no-index case "${{ matrix.suite }}" in stan) - uv pip install 'nutpie[stan]' --find-links dist --force-reinstall + uv pip install 'nutpie[stan]' pytest -m "stan and not flow" --arraydiff ;; pymc) - uv pip install 'nutpie[pymc]' jax --find-links dist --force-reinstall + uv pip install 'nutpie[pymc]' jax --constraint .github/uv-constraints-main.txt pytest -m "pymc and not flow" --arraydiff ;; flow) - uv pip install 'nutpie[all]' --find-links dist --force-reinstall + uv pip install 'nutpie[all]' --constraint .github/uv-constraints-main.txt pytest -m flow --arraydiff ;; esac @@ -241,17 +242,18 @@ jobs: python3 -m venv .venv source .venv/Scripts/activate uv pip install pytest pytest-timeout pytest-arraydiff + uv pip install nutpie --no-deps --find-links dist --no-index case "${{ matrix.suite }}" in stan) - uv pip install "nutpie[stan]" --find-links dist --force-reinstall + uv pip install "nutpie[stan]" pytest -m "stan and not flow" --arraydiff ;; pymc) - uv pip install "nutpie[pymc]" jax --find-links dist --force-reinstall + uv pip install "nutpie[pymc]" jax --constraint .github/uv-constraints-main.txt pytest -m "pymc and not flow" --arraydiff ;; flow) - uv pip install "nutpie[all]" --find-links dist --force-reinstall + uv pip install "nutpie[all]" --constraint .github/uv-constraints-main.txt pytest -m flow --arraydiff ;; esac @@ -333,17 +335,18 @@ jobs: python3 -m venv .venv source .venv/bin/activate uv pip install pytest pytest-timeout pytest-arraydiff + uv pip install nutpie --no-deps --find-links dist --no-index case "${{ matrix.suite }}" in stan) - uv pip install 'nutpie[stan]' --find-links dist --force-reinstall + uv pip install 'nutpie[stan]' pytest -m "stan and not flow" --arraydiff ;; pymc) - uv pip install 'nutpie[pymc]' jax --find-links dist --force-reinstall + uv pip install 'nutpie[pymc]' jax --constraint .github/uv-constraints-main.txt pytest -m "pymc and not flow" --arraydiff ;; flow) - uv pip install 'nutpie[all]' --find-links dist --force-reinstall + uv pip install 'nutpie[all]' --constraint .github/uv-constraints-main.txt pytest -m flow --arraydiff ;; esac @@ -367,8 +370,9 @@ jobs: run: | python3 -m venv .venv source .venv/bin/activate - uv pip install 'nutpie[dev]' --constraint .github/uv-constraints.txt --find-links dist --force-reinstall - pytest -m "pymc and not flow" --arraydiff + uv pip install nutpie --no-deps --find-links dist --no-index + uv pip install 'nutpie[dev]' --constraint .github/uv-constraints-dev.txt + pytest -m "pymc" --arraydiff docs: runs-on: ubuntu-22.04 @@ -392,7 +396,8 @@ jobs: - name: Install doc dependencies run: | source .venv/bin/activate - uv pip install --constraint .github/uv-constraints.txt "nutpie[docs]" --find-links dist --force-reinstall + uv pip install nutpie --no-deps --find-links dist --no-index + uv pip install --constraint .github/uv-constraints-dev.txt "nutpie[docs]" - name: Render docs env: TBB_CXX_TYPE: clang From f53f710bece278adb969cdb5274f98b891328f5a Mon Sep 17 00:00:00 2001 From: Adrian Seyboldt Date: Wed, 22 Apr 2026 12:52:32 +0200 Subject: [PATCH 20/21] chore: require python 3.12 --- .github/workflows/ci.yml | 18 ++++++------------ pyproject.toml | 4 ++-- 2 files changed, 8 insertions(+), 14 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 9c35494..b4c9c8c 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -34,8 +34,6 @@ jobs: - uses: actions/setup-python@v6 with: python-version: | - 3.10 - 3.11 3.12 3.13 3.14 @@ -43,7 +41,7 @@ jobs: uses: PyO3/maturin-action@v1 with: target: ${{ matrix.platform.target }} - args: --release --out dist --interpreter 3.10 3.11 3.12 3.13 3.14 --zig + args: --release --out dist --interpreter 3.12 3.13 3.14 --zig sccache: ${{ !startsWith(github.ref, 'refs/tags/') }} manylinux: auto before-script-linux: | @@ -66,7 +64,7 @@ jobs: matrix: target: [x86_64, aarch64] suite: [stan, pymc, flow] - python-version: ["3.10", "3.11", "3.12", "3.13", "3.14"] + python-version: ["3.12", "3.13", "3.14"] include: - target: x86_64 runner: ubuntu-22.04 @@ -184,8 +182,6 @@ jobs: - uses: actions/setup-python@v6 with: python-version: | - 3.10 - 3.11 3.12 3.13 3.14 @@ -206,7 +202,7 @@ jobs: LIBCLANG_PATH: ${{ runner.temp }}/llvm/lib with: target: x64 - args: --release --out dist --find-interpreter + args: --release --out dist --interpreter 3.12 3.13 3.14 sccache: ${{ !startsWith(github.ref, 'refs/tags/') }} - name: Upload wheels uses: actions/upload-artifact@v7 @@ -223,7 +219,7 @@ jobs: fail-fast: false matrix: suite: [stan, pymc, flow] - python-version: ["3.10", "3.11", "3.12", "3.13", "3.14"] + python-version: ["3.12", "3.13", "3.14"] steps: - uses: actions/checkout@v6 - uses: actions/setup-python@v6 @@ -274,8 +270,6 @@ jobs: - uses: actions/setup-python@v6 with: python-version: | - 3.10 - 3.11 3.12 3.13 3.14 @@ -288,7 +282,7 @@ jobs: uses: PyO3/maturin-action@v1 with: target: ${{ matrix.platform.target }} - args: --release --out dist --find-interpreter + args: --release --out dist --interpreter 3.12 3.13 3.14 --zig sccache: ${{ !startsWith(github.ref, 'refs/tags/') }} - name: Upload wheels uses: actions/upload-artifact@v7 @@ -306,7 +300,7 @@ jobs: matrix: target: [x86_64, aarch64] suite: [stan, pymc, flow] - python-version: ["3.10", "3.11", "3.12", "3.13", "3.14"] + python-version: ["3.12", "3.13", "3.14"] include: - target: x86_64 runner: macos-15-intel diff --git a/pyproject.toml b/pyproject.toml index 0122cda..b4be7e2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,7 +7,7 @@ name = "nutpie" description = "Sample Stan or PyMC models" authors = [{ name = "PyMC Developers", email = "pymc.devs@gmail.com" }] readme = "README.md" -requires-python = ">=3.11" +requires-python = ">=3.12" license = { text = "MIT" } classifiers = [ "Programming Language :: Rust", @@ -72,7 +72,7 @@ docs = [ [tool.ruff] line-length = 88 -target-version = "py310" +target-version = "py312" show-fixes = true output-format = "full" From d88a33f18e18a94137cdaa571a549a9eb3a22e6e Mon Sep 17 00:00:00 2001 From: Adrian Seyboldt Date: Wed, 22 Apr 2026 12:52:32 +0200 Subject: [PATCH 21/21] chore: support both arviz 1.0 and older --- .github/uv-constraints-dev.txt | 2 ++ .github/uv-constraints-main.txt | 1 + Cargo.lock | 4 ---- docs/pymc-usage.qmd | 2 +- python/nutpie/sample.py | 35 +++++++++++++++++++++++---------- 5 files changed, 29 insertions(+), 15 deletions(-) create mode 100644 .github/uv-constraints-dev.txt create mode 100644 .github/uv-constraints-main.txt diff --git a/.github/uv-constraints-dev.txt b/.github/uv-constraints-dev.txt new file mode 100644 index 0000000..d1e6c1a --- /dev/null +++ b/.github/uv-constraints-dev.txt @@ -0,0 +1,2 @@ +pytensor @ git+https://github.com/pymc-devs/pytensor.git@v3 +pymc @ git+https://github.com/pymc-devs/pymc.git@v6 diff --git a/.github/uv-constraints-main.txt b/.github/uv-constraints-main.txt new file mode 100644 index 0000000..7b50d73 --- /dev/null +++ b/.github/uv-constraints-main.txt @@ -0,0 +1 @@ +arviz < 1.0.0 diff --git a/Cargo.lock b/Cargo.lock index 3a8711f..d3df10d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2289,8 +2289,6 @@ dependencies = [ [[package]] name = "nuts-derive" version = "0.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a8cce587b5f36bc6bfa54cbf2eaf31fe5a5e0d73e96d31fde5de87a701689363" dependencies = [ "nuts-storable", "proc-macro2", @@ -2301,8 +2299,6 @@ dependencies = [ [[package]] name = "nuts-rs" version = "0.18.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ed964f329c9de6147920b5bbc55cdeb8ad69c7909f68539b592889dfc07b04de" dependencies = [ "anyhow", "arrow", diff --git a/docs/pymc-usage.qmd b/docs/pymc-usage.qmd index 9e6d7c0..a045490 100644 --- a/docs/pymc-usage.qmd +++ b/docs/pymc-usage.qmd @@ -110,7 +110,7 @@ az.ess(trace) and take a look at a trace plot: ```{python} -az.plot_trace_dist(trace); +az.plot_trace(trace); ``` ### Choosing the backend diff --git a/python/nutpie/sample.py b/python/nutpie/sample.py index 65b5cc8..cb8a425 100644 --- a/python/nutpie/sample.py +++ b/python/nutpie/sample.py @@ -98,16 +98,31 @@ def _arrow_to_arviz(draw_batches, stat_batches, skip_vars=None, **kwargs): stats_posterior, max_posterior, stat_posterior, i, n_chains, dims, skip_vars ) - return arviz.from_dict( - { - "posterior": data_posterior, - "sample_stats": stats_posterior, - "warmup_posterior": data_tune, - "warmup_sample_stats": stats_tune, - }, - dims=dims, - **kwargs, - ) + from importlib.metadata import version + + arviz_version = version("arviz") + if tuple(map(int, arviz_version.split(".")[:2])) >= (1, 0): + return arviz.from_dict( + { + "posterior": data_posterior, + "sample_stats": stats_posterior, + "warmup_posterior": data_tune, + "warmup_sample_stats": stats_tune, + }, + dims=dims, + **kwargs, + ) + else: + return arviz.from_dict( + **{ + "posterior": data_posterior, + "sample_stats": stats_posterior, + "warmup_posterior": data_tune, + "warmup_sample_stats": stats_tune, + }, # ty:ignore[invalid-argument-type] + dims=dims, + **kwargs, + ) def _add_arrow_data(data_dict, max_length, batch, chain, n_chains, dims, skip_vars):