diff --git a/python/nutpie/compile_pymc.py b/python/nutpie/compile_pymc.py index a7b750a..0a71a27 100644 --- a/python/nutpie/compile_pymc.py +++ b/python/nutpie/compile_pymc.py @@ -1,5 +1,4 @@ import dataclasses -import itertools import threading import warnings from collections.abc import Iterable @@ -273,6 +272,7 @@ def _compile_pymc_model_numba( expand_fn_pt, initial_point_fn, shape_info, + reparameterized_names, ) = _make_functions( model, mode="NUMBA", @@ -326,7 +326,7 @@ def _compile_pymc_model_numba( expand_numba = numba.cfunc(c_sig_expand, **kwargs)(expand_numba_raw) - dims, coords = _prepare_dims_and_coords(model, shape_info) + dims, coords = _prepare_dims_and_coords(model, shape_info, reparameterized_names) return CompiledPyMCModel( _n_dim=n_dim, @@ -342,31 +342,34 @@ def _compile_pymc_model_numba( shape_info=shape_info, logp_func=logp_fn_pt, expand_func=expand_fn_pt, + reparameterized_names=reparameterized_names, ) -def _prepare_dims_and_coords(model, shape_info): +def _prepare_dims_and_coords(model, shape_info, reparameterized_names): coords = {} for name, vals in model.coords.items(): if vals is None: vals = pd.RangeIndex(int(model.dim_lengths[name].eval())) - coords[name] = pd.Index(vals) - - if "unconstrained_parameter" in coords: - raise ValueError("Model contains invalid name 'unconstrained_parameter'.") - - names = [] - for base, _, shape in zip(*shape_info): - if base not in [var.name for var in model.value_vars]: + idx = pd.Index(vals) + if idx.dtype == "object" or idx.dtype == "string": + coords[name] = idx.tolist() + else: + coords[name] = idx + + names, _, shape_list = shape_info + shape_by_name = {n: tuple(s) for n, s in zip(names, shape_list)} + value_to_rv = {model.rvs_to_values[var].name: var.name for var in model.free_RVs} + + dims = dict(model.named_vars_to_dims) + for value_name in reparameterized_names: + rv_name = value_to_rv[value_name] + rv_dims = dims.get(rv_name) + if rv_dims is None: continue - for idx in itertools.product(*[range(length) for length in shape]): - if len(idx) == 0: - names.append(base) - else: - names.append(f"{base}_{'.'.join(str(i) for i in idx)}") - coords["unconstrained_parameter"] = pd.Index(names) - - dims = model.named_vars_to_dims + if shape_by_name.get(rv_name) == shape_by_name.get(value_name): + dims[value_name] = rv_dims + return dims, coords @@ -399,6 +402,7 @@ def _compile_pymc_model_jax( expand_fn_pt, initial_point_fn, shape_info, + reparameterized_names, ) = _make_functions( model, mode="JAX", @@ -464,7 +468,7 @@ def expand(_x, **shared): return expand - dims, coords = _prepare_dims_and_coords(model, shape_info) + dims, coords = _prepare_dims_and_coords(model, shape_info, reparameterized_names) return from_pyfunc( ndim=n_dim, @@ -478,6 +482,7 @@ def expand(_x, **shared): dims=dims, coords=coords, raw_logp_fn=orig_logp_fn, + reparameterized_names=reparameterized_names, ) @@ -641,6 +646,7 @@ def _make_functions( Callable, Callable, tuple[list[str], list[slice], list[tuple[int, ...]]], + list[str], ]: """ Compile functions required by nuts-rs from a given PyMC model. @@ -748,7 +754,7 @@ def _make_functions( if use_split: variables = pt.split(joined, splits, len(splits)) else: - variables = [joined[slice_val] for slice_val in zip(joined_slices)] + variables = [joined[slice_val] for slice_val in joined_slices] replacements = { model.rvs_to_values[var]: value.reshape(shape).astype(var.dtype) @@ -768,6 +774,12 @@ def _make_functions( with model: logp_fn_pt = compile_pymc((joined,), (logp,), mode=mode) + reparameterized_names = [ + model.rvs_to_values[var].name + for var in model.free_RVs + if model.rvs_to_transforms.get(var) is not None + ] + # Make function that computes remaining variables for the trace remaining_rvs = [ var for var in model.unobserved_value_vars if var.name not in joined_names @@ -777,27 +789,35 @@ def _make_functions( names = set(var_names) remaining_rvs = [var for var in remaining_rvs if var.name in names] - all_names = joined_names + remaining_rvs + all_names = [] + all_slices = [] + all_shapes = [] + count_expanded = 0 - all_names = joined_names.copy() - all_slices = joined_slices.copy() - all_shapes = joined_shapes.copy() + identity_free = [] + for var_expr, name, shape in zip(variables, joined_names, joined_shapes): + length = prod(shape) + all_names.append(name) + all_shapes.append(shape) + all_slices.append(slice(count_expanded, count_expanded + length)) + count_expanded += length + identity_free.append(var_expr) for var in remaining_rvs: all_names.append(var.name) shape = cast(tuple[int, ...], shapes[var.name]) all_shapes.append(shape) length = prod(shape) - all_slices.append(slice(count, count + length)) - count += length + all_slices.append(slice(count_expanded, count_expanded + length)) + count_expanded += length - num_expanded = count + num_expanded = count_expanded if join_expanded: allvars = [ pt.concatenate( [ - joined, + *[v.ravel() for v in identity_free], *[ pt.as_tensor(var, allow_xtensor_conversion=True).ravel() for var in remaining_rvs @@ -806,7 +826,7 @@ def _make_functions( ) ] else: - allvars = [*variables, *remaining_rvs] + allvars = [*identity_free, *remaining_rvs] with model: expand_fn_pt = compile_pymc( (joined,), @@ -822,6 +842,7 @@ def _make_functions( expand_fn_pt, initial_point_fn, (all_names, all_slices, all_shapes), + reparameterized_names, ) diff --git a/python/nutpie/compiled_pyfunc.py b/python/nutpie/compiled_pyfunc.py index ee6400f..07602ee 100644 --- a/python/nutpie/compiled_pyfunc.py +++ b/python/nutpie/compiled_pyfunc.py @@ -119,6 +119,7 @@ def from_pyfunc( make_initial_point_fn: Callable[[SeedType], np.ndarray] | None = None, make_transform_adapter=None, raw_logp_fn=None, + reparameterized_names=None, ): if coords is None: coords = {} @@ -150,4 +151,5 @@ def from_pyfunc( _variables=variables, _shared_data=shared_data, _raw_logp_fn=raw_logp_fn, + reparameterized_names=reparameterized_names, ) diff --git a/python/nutpie/sample.py b/python/nutpie/sample.py index cb8a425..845a5b9 100644 --- a/python/nutpie/sample.py +++ b/python/nutpie/sample.py @@ -1,6 +1,7 @@ import os import warnings -from dataclasses import dataclass +from dataclasses import dataclass, field +from importlib.metadata import version from typing import Any, Literal, Optional, cast, get_args, overload import arviz @@ -15,6 +16,7 @@ @dataclass(frozen=True) class CompiledModel: dims: Optional[dict[str, tuple[str, ...]]] + reparameterized_names: list[str] | None = field(default=None, kw_only=True) @property def n_dim(self) -> int: @@ -56,9 +58,18 @@ def benchmark_logp(self, point, num_evals, cores): return pd.concat(times) -def _arrow_to_arviz(draw_batches, stat_batches, skip_vars=None, **kwargs): +def _arrow_to_arviz( + draw_batches, + stat_batches, + skip_vars=None, + reparameterized_names=None, + keep_unconstrained_draw=False, + **kwargs, +): if skip_vars is None: skip_vars = [] + if reparameterized_names is None: + reparameterized_names = [] n_chains = len(draw_batches) assert n_chains == len(stat_batches) @@ -98,11 +109,23 @@ def _arrow_to_arviz(draw_batches, stat_batches, skip_vars=None, **kwargs): stats_posterior, max_posterior, stat_posterior, i, n_chains, dims, skip_vars ) - from importlib.metadata import version + uc_data_posterior = { + name: data_posterior.pop(name) + for name in reparameterized_names + if name in data_posterior + } + uc_data_tune = { + name: data_tune.pop(name) for name in reparameterized_names if name in data_tune + } + + if not keep_unconstrained_draw: + stats_posterior.pop("unconstrained_draw", None) + stats_tune.pop("unconstrained_draw", None) + dims.pop("unconstrained_draw", None) arviz_version = version("arviz") if tuple(map(int, arviz_version.split(".")[:2])) >= (1, 0): - return arviz.from_dict( + idata = arviz.from_dict( { "posterior": data_posterior, "sample_stats": stats_posterior, @@ -113,7 +136,7 @@ def _arrow_to_arviz(draw_batches, stat_batches, skip_vars=None, **kwargs): **kwargs, ) else: - return arviz.from_dict( + idata = arviz.from_dict( **{ "posterior": data_posterior, "sample_stats": stats_posterior, @@ -124,6 +147,22 @@ def _arrow_to_arviz(draw_batches, stat_batches, skip_vars=None, **kwargs): **kwargs, ) + if keep_unconstrained_draw and uc_data_posterior: + coords = kwargs.get("coords") + uc_dims = {name: dims.get(name, []) for name in uc_data_posterior} + groups = { + "unconstrained_posterior": arviz.dict_to_dataset( + uc_data_posterior, coords=coords, dims=uc_dims + ) + } + if uc_data_tune: + groups["warmup_unconstrained_posterior"] = arviz.dict_to_dataset( + uc_data_tune, coords=coords, dims=uc_dims + ) + idata.add_groups(groups) + + return idata + def _add_arrow_data(data_dict, max_length, batch, chain, n_chains, dims, skip_vars): num_draws = batch.num_rows @@ -463,11 +502,13 @@ def __init__( progress_style=None, progress_rate=100, store=None, + store_unconstrained=False, ): self._settings = settings self._compiled_model = compiled_model self._save_warmup = save_warmup self._return_raw_trace = return_raw_trace + self._store_unconstrained = store_unconstrained self._html = None @@ -622,6 +663,8 @@ def _get_nested(settings, name, default): draw_batches, stat_batches, skip_vars=skip_vars, + reparameterized_names=self._compiled_model.reparameterized_names, + keep_unconstrained_draw=self._store_unconstrained, coords={ name: pd.Index(vals) for name, vals in self._compiled_model.coords.items() @@ -687,6 +730,7 @@ def sample( progress_style: str | None = None, progress_rate: int = 100, zarr_store: _ZarrStoreType | None = None, + store_unconstrained: bool = False, ) -> xr.DataTree: ... @@ -710,6 +754,7 @@ def sample( progress_style: str | None = None, progress_rate: int = 100, zarr_store: _ZarrStoreType | None = None, + store_unconstrained: bool = False, **kwargs, ) -> xr.DataTree: ... @@ -734,6 +779,7 @@ def sample( progress_style: str | None = None, progress_rate: int = 100, zarr_store: _ZarrStoreType | None = None, + store_unconstrained: bool = False, **kwargs, ) -> _BackgroundSampler: ... @@ -781,6 +827,7 @@ def sample( progress_style: str | None = None, progress_rate: int = 100, zarr_store: _ZarrStoreType | None = None, + store_unconstrained: bool = False, **kwargs, ) -> xr.DataTree | _BackgroundSampler: """Sample the posterior distribution for a compiled model. @@ -817,8 +864,11 @@ def sample( point on the transformed parameter space. Defaults to zeros. store_unconstrained: bool - If True, store each draw in the unconstrained (transformed) - space in the sample stats. + If True, store the unconstrained (transformed) draws in two forms: + a flat ``unconstrained_draw`` vector in ``sample_stats`` and a + per-variable ``unconstrained_posterior`` group (with + ``warmup_unconstrained_posterior`` when ``save_warmup=True``) whose + dims are copied from the corresponding RV. store_gradient: bool If True, store the logp gradient of each draw in the unconstrained space in the sample stats. @@ -992,6 +1042,9 @@ def sample( settings.update(updates) + if store_unconstrained: + settings.store_unconstrained = True + if cores is None: try: # Only available in python>=3.13 @@ -1019,6 +1072,7 @@ def sample( progress_style=progress_style, progress_rate=progress_rate, store=zarr_store, + store_unconstrained=store_unconstrained, ) if not blocking: diff --git a/tests/test_pymc.py b/tests/test_pymc.py index 692eb5f..655e84b 100644 --- a/tests/test_pymc.py +++ b/tests/test_pymc.py @@ -305,6 +305,7 @@ def test_pymc_model_store_extra(backend, gradient_backend): with pm.Model() as model: model.add_coord("foo", length=5) pm.Normal("a", dims="foo") + pm.HalfNormal("b", sigma=1.0, dims="foo") compiled = nutpie.compile_pymc_model( model, backend=backend, gradient_backend=gradient_backend @@ -318,6 +319,8 @@ def test_pymc_model_store_extra(backend, gradient_backend): store_gradient=True, ) trace.posterior.a # noqa: B018 + trace.posterior.b # noqa: B018 + assert trace.unconstrained_posterior.b_log__.dims == ("chain", "draw", "foo") _ = trace.sample_stats.unconstrained_draw _ = trace.sample_stats.gradient _ = trace.sample_stats.divergence_start