Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 51 additions & 30 deletions python/nutpie/compile_pymc.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import dataclasses
import itertools
import threading
import warnings
from collections.abc import Iterable
Expand Down Expand Up @@ -273,6 +272,7 @@ def _compile_pymc_model_numba(
expand_fn_pt,
initial_point_fn,
shape_info,
reparameterized_names,
) = _make_functions(
model,
mode="NUMBA",
Expand Down Expand Up @@ -326,7 +326,7 @@ def _compile_pymc_model_numba(

expand_numba = numba.cfunc(c_sig_expand, **kwargs)(expand_numba_raw)

dims, coords = _prepare_dims_and_coords(model, shape_info)
dims, coords = _prepare_dims_and_coords(model, shape_info, reparameterized_names)

return CompiledPyMCModel(
_n_dim=n_dim,
Expand All @@ -342,31 +342,34 @@ def _compile_pymc_model_numba(
shape_info=shape_info,
logp_func=logp_fn_pt,
expand_func=expand_fn_pt,
reparameterized_names=reparameterized_names,
)


def _prepare_dims_and_coords(model, shape_info):
def _prepare_dims_and_coords(model, shape_info, reparameterized_names):
coords = {}
for name, vals in model.coords.items():
if vals is None:
vals = pd.RangeIndex(int(model.dim_lengths[name].eval()))
coords[name] = pd.Index(vals)

if "unconstrained_parameter" in coords:
raise ValueError("Model contains invalid name 'unconstrained_parameter'.")

names = []
for base, _, shape in zip(*shape_info):
if base not in [var.name for var in model.value_vars]:
idx = pd.Index(vals)
if idx.dtype == "object" or idx.dtype == "string":
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did this fix a problem? Why not just keep it an Index?

coords[name] = idx.tolist()
else:
coords[name] = idx
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we still want the unconstrained_parameter check, or are you sure this is ok now? The unconstrained_parameter dimension shouldn't really be used in the posterior group, but it is used in the sample_stats quite a lot, and I'm not sure everything (with zarr and arrow) works if we have a variable named like that.
We also definitely still want the coordinate for the unconstrained_parameter.


names, _, shape_list = shape_info
shape_by_name = {n: tuple(s) for n, s in zip(names, shape_list)}
value_to_rv = {model.rvs_to_values[var].name: var.name for var in model.free_RVs}

dims = dict(model.named_vars_to_dims)
for value_name in reparameterized_names:
rv_name = value_to_rv[value_name]
rv_dims = dims.get(rv_name)
if rv_dims is None:
continue
for idx in itertools.product(*[range(length) for length in shape]):
if len(idx) == 0:
names.append(base)
else:
names.append(f"{base}_{'.'.join(str(i) for i in idx)}")
coords["unconstrained_parameter"] = pd.Index(names)

dims = model.named_vars_to_dims
if shape_by_name.get(rv_name) == shape_by_name.get(value_name):
dims[value_name] = rv_dims

return dims, coords


Expand Down Expand Up @@ -399,6 +402,7 @@ def _compile_pymc_model_jax(
expand_fn_pt,
initial_point_fn,
shape_info,
reparameterized_names,
) = _make_functions(
model,
mode="JAX",
Expand Down Expand Up @@ -464,7 +468,7 @@ def expand(_x, **shared):

return expand

dims, coords = _prepare_dims_and_coords(model, shape_info)
dims, coords = _prepare_dims_and_coords(model, shape_info, reparameterized_names)

return from_pyfunc(
ndim=n_dim,
Expand All @@ -478,6 +482,7 @@ def expand(_x, **shared):
dims=dims,
coords=coords,
raw_logp_fn=orig_logp_fn,
reparameterized_names=reparameterized_names,
)


Expand Down Expand Up @@ -641,6 +646,7 @@ def _make_functions(
Callable,
Callable,
tuple[list[str], list[slice], list[tuple[int, ...]]],
list[str],
]:
"""
Compile functions required by nuts-rs from a given PyMC model.
Expand Down Expand Up @@ -748,7 +754,7 @@ def _make_functions(
if use_split:
variables = pt.split(joined, splits, len(splits))
else:
variables = [joined[slice_val] for slice_val in zip(joined_slices)]
variables = [joined[slice_val] for slice_val in joined_slices]

replacements = {
model.rvs_to_values[var]: value.reshape(shape).astype(var.dtype)
Expand All @@ -768,6 +774,12 @@ def _make_functions(
with model:
logp_fn_pt = compile_pymc((joined,), (logp,), mode=mode)

reparameterized_names = [
model.rvs_to_values[var].name
for var in model.free_RVs
if model.rvs_to_transforms.get(var) is not None
]

# Make function that computes remaining variables for the trace
remaining_rvs = [
var for var in model.unobserved_value_vars if var.name not in joined_names
Expand All @@ -777,27 +789,35 @@ def _make_functions(
names = set(var_names)
remaining_rvs = [var for var in remaining_rvs if var.name in names]

all_names = joined_names + remaining_rvs
all_names = []
all_slices = []
all_shapes = []
count_expanded = 0

all_names = joined_names.copy()
all_slices = joined_slices.copy()
all_shapes = joined_shapes.copy()
identity_free = []
for var_expr, name, shape in zip(variables, joined_names, joined_shapes):
length = prod(shape)
all_names.append(name)
all_shapes.append(shape)
all_slices.append(slice(count_expanded, count_expanded + length))
count_expanded += length
identity_free.append(var_expr)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are these changes here still necessary? I don't quite get what this does.


for var in remaining_rvs:
all_names.append(var.name)
shape = cast(tuple[int, ...], shapes[var.name])
all_shapes.append(shape)
length = prod(shape)
all_slices.append(slice(count, count + length))
count += length
all_slices.append(slice(count_expanded, count_expanded + length))
count_expanded += length

num_expanded = count
num_expanded = count_expanded

if join_expanded:
allvars = [
pt.concatenate(
[
joined,
*[v.ravel() for v in identity_free],
*[
pt.as_tensor(var, allow_xtensor_conversion=True).ravel()
for var in remaining_rvs
Expand All @@ -806,7 +826,7 @@ def _make_functions(
)
]
else:
allvars = [*variables, *remaining_rvs]
allvars = [*identity_free, *remaining_rvs]
with model:
expand_fn_pt = compile_pymc(
(joined,),
Expand All @@ -822,6 +842,7 @@ def _make_functions(
expand_fn_pt,
initial_point_fn,
(all_names, all_slices, all_shapes),
reparameterized_names,
)


Expand Down
2 changes: 2 additions & 0 deletions python/nutpie/compiled_pyfunc.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ def from_pyfunc(
make_initial_point_fn: Callable[[SeedType], np.ndarray] | None = None,
make_transform_adapter=None,
raw_logp_fn=None,
reparameterized_names=None,
):
if coords is None:
coords = {}
Expand Down Expand Up @@ -150,4 +151,5 @@ def from_pyfunc(
_variables=variables,
_shared_data=shared_data,
_raw_logp_fn=raw_logp_fn,
reparameterized_names=reparameterized_names,
)
68 changes: 61 additions & 7 deletions python/nutpie/sample.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
import warnings
from dataclasses import dataclass
from dataclasses import dataclass, field
from importlib.metadata import version
from typing import Any, Literal, Optional, cast, get_args, overload

import arviz
Expand All @@ -15,6 +16,7 @@
@dataclass(frozen=True)
class CompiledModel:
dims: Optional[dict[str, tuple[str, ...]]]
reparameterized_names: list[str] | None = field(default=None, kw_only=True)

@property
def n_dim(self) -> int:
Expand Down Expand Up @@ -56,9 +58,18 @@ def benchmark_logp(self, point, num_evals, cores):
return pd.concat(times)


def _arrow_to_arviz(draw_batches, stat_batches, skip_vars=None, **kwargs):
def _arrow_to_arviz(
draw_batches,
stat_batches,
skip_vars=None,
reparameterized_names=None,
keep_unconstrained_draw=False,
**kwargs,
):
if skip_vars is None:
skip_vars = []
if reparameterized_names is None:
reparameterized_names = []

n_chains = len(draw_batches)
assert n_chains == len(stat_batches)
Expand Down Expand Up @@ -98,11 +109,23 @@ def _arrow_to_arviz(draw_batches, stat_batches, skip_vars=None, **kwargs):
stats_posterior, max_posterior, stat_posterior, i, n_chains, dims, skip_vars
)

from importlib.metadata import version
uc_data_posterior = {
name: data_posterior.pop(name)
for name in reparameterized_names
if name in data_posterior
}
uc_data_tune = {
name: data_tune.pop(name) for name in reparameterized_names if name in data_tune
}

if not keep_unconstrained_draw:
stats_posterior.pop("unconstrained_draw", None)
stats_tune.pop("unconstrained_draw", None)
dims.pop("unconstrained_draw", None)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is no such dim. Maybe you meant unconstrained_parameter? But that is also used for other things (mass matrix, transformed, grad etc.).


arviz_version = version("arviz")
if tuple(map(int, arviz_version.split(".")[:2])) >= (1, 0):
return arviz.from_dict(
idata = arviz.from_dict(
{
"posterior": data_posterior,
"sample_stats": stats_posterior,
Expand All @@ -113,7 +136,7 @@ def _arrow_to_arviz(draw_batches, stat_batches, skip_vars=None, **kwargs):
**kwargs,
)
else:
return arviz.from_dict(
idata = arviz.from_dict(
**{
"posterior": data_posterior,
"sample_stats": stats_posterior,
Expand All @@ -124,6 +147,22 @@ def _arrow_to_arviz(draw_batches, stat_batches, skip_vars=None, **kwargs):
**kwargs,
)

if keep_unconstrained_draw and uc_data_posterior:
coords = kwargs.get("coords")
uc_dims = {name: dims.get(name, []) for name in uc_data_posterior}
groups = {
"unconstrained_posterior": arviz.dict_to_dataset(
uc_data_posterior, coords=coords, dims=uc_dims
)
}
if uc_data_tune:
groups["warmup_unconstrained_posterior"] = arviz.dict_to_dataset(
uc_data_tune, coords=coords, dims=uc_dims
)
idata.add_groups(groups)

return idata


def _add_arrow_data(data_dict, max_length, batch, chain, n_chains, dims, skip_vars):
num_draws = batch.num_rows
Expand Down Expand Up @@ -463,11 +502,13 @@ def __init__(
progress_style=None,
progress_rate=100,
store=None,
store_unconstrained=False,
):
self._settings = settings
self._compiled_model = compiled_model
self._save_warmup = save_warmup
self._return_raw_trace = return_raw_trace
self._store_unconstrained = store_unconstrained

self._html = None

Expand Down Expand Up @@ -622,6 +663,8 @@ def _get_nested(settings, name, default):
draw_batches,
stat_batches,
skip_vars=skip_vars,
reparameterized_names=self._compiled_model.reparameterized_names,
keep_unconstrained_draw=self._store_unconstrained,
coords={
name: pd.Index(vals)
for name, vals in self._compiled_model.coords.items()
Expand Down Expand Up @@ -687,6 +730,7 @@ def sample(
progress_style: str | None = None,
progress_rate: int = 100,
zarr_store: _ZarrStoreType | None = None,
store_unconstrained: bool = False,
) -> xr.DataTree: ...


Expand All @@ -710,6 +754,7 @@ def sample(
progress_style: str | None = None,
progress_rate: int = 100,
zarr_store: _ZarrStoreType | None = None,
store_unconstrained: bool = False,
**kwargs,
) -> xr.DataTree: ...

Expand All @@ -734,6 +779,7 @@ def sample(
progress_style: str | None = None,
progress_rate: int = 100,
zarr_store: _ZarrStoreType | None = None,
store_unconstrained: bool = False,
**kwargs,
) -> _BackgroundSampler: ...

Expand Down Expand Up @@ -781,6 +827,7 @@ def sample(
progress_style: str | None = None,
progress_rate: int = 100,
zarr_store: _ZarrStoreType | None = None,
store_unconstrained: bool = False,
**kwargs,
) -> xr.DataTree | _BackgroundSampler:
"""Sample the posterior distribution for a compiled model.
Expand Down Expand Up @@ -817,8 +864,11 @@ def sample(
point on the transformed parameter space. Defaults to
zeros.
store_unconstrained: bool
If True, store each draw in the unconstrained (transformed)
space in the sample stats.
If True, store the unconstrained (transformed) draws in two forms:
a flat ``unconstrained_draw`` vector in ``sample_stats`` and a
per-variable ``unconstrained_posterior`` group (with
``warmup_unconstrained_posterior`` when ``save_warmup=True``) whose
dims are copied from the corresponding RV.
store_gradient: bool
If True, store the logp gradient of each draw in the unconstrained
space in the sample stats.
Expand Down Expand Up @@ -992,6 +1042,9 @@ def sample(

settings.update(updates)

if store_unconstrained:
settings.store_unconstrained = True

if cores is None:
try:
# Only available in python>=3.13
Expand Down Expand Up @@ -1019,6 +1072,7 @@ def sample(
progress_style=progress_style,
progress_rate=progress_rate,
store=zarr_store,
store_unconstrained=store_unconstrained,
)

if not blocking:
Expand Down
Loading
Loading