-
Notifications
You must be signed in to change notification settings - Fork 24
Exclude transformed variables from trace #298
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
8aada92
166ca16
a8dd5ca
5c18907
9648aed
53cf74d
b1c610b
051ef1b
b514ef3
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,5 +1,4 @@ | ||
| import dataclasses | ||
| import itertools | ||
| import threading | ||
| import warnings | ||
| from collections.abc import Iterable | ||
|
|
@@ -273,6 +272,7 @@ def _compile_pymc_model_numba( | |
| expand_fn_pt, | ||
| initial_point_fn, | ||
| shape_info, | ||
| reparameterized_names, | ||
| ) = _make_functions( | ||
| model, | ||
| mode="NUMBA", | ||
|
|
@@ -326,7 +326,7 @@ def _compile_pymc_model_numba( | |
|
|
||
| expand_numba = numba.cfunc(c_sig_expand, **kwargs)(expand_numba_raw) | ||
|
|
||
| dims, coords = _prepare_dims_and_coords(model, shape_info) | ||
| dims, coords = _prepare_dims_and_coords(model, shape_info, reparameterized_names) | ||
|
|
||
| return CompiledPyMCModel( | ||
| _n_dim=n_dim, | ||
|
|
@@ -342,31 +342,34 @@ def _compile_pymc_model_numba( | |
| shape_info=shape_info, | ||
| logp_func=logp_fn_pt, | ||
| expand_func=expand_fn_pt, | ||
| reparameterized_names=reparameterized_names, | ||
| ) | ||
|
|
||
|
|
||
| def _prepare_dims_and_coords(model, shape_info): | ||
| def _prepare_dims_and_coords(model, shape_info, reparameterized_names): | ||
| coords = {} | ||
| for name, vals in model.coords.items(): | ||
| if vals is None: | ||
| vals = pd.RangeIndex(int(model.dim_lengths[name].eval())) | ||
| coords[name] = pd.Index(vals) | ||
|
|
||
| if "unconstrained_parameter" in coords: | ||
| raise ValueError("Model contains invalid name 'unconstrained_parameter'.") | ||
|
|
||
| names = [] | ||
| for base, _, shape in zip(*shape_info): | ||
| if base not in [var.name for var in model.value_vars]: | ||
| idx = pd.Index(vals) | ||
| if idx.dtype == "object" or idx.dtype == "string": | ||
| coords[name] = idx.tolist() | ||
| else: | ||
| coords[name] = idx | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we still want the |
||
|
|
||
| names, _, shape_list = shape_info | ||
| shape_by_name = {n: tuple(s) for n, s in zip(names, shape_list)} | ||
| value_to_rv = {model.rvs_to_values[var].name: var.name for var in model.free_RVs} | ||
|
|
||
| dims = dict(model.named_vars_to_dims) | ||
| for value_name in reparameterized_names: | ||
| rv_name = value_to_rv[value_name] | ||
| rv_dims = dims.get(rv_name) | ||
| if rv_dims is None: | ||
| continue | ||
| for idx in itertools.product(*[range(length) for length in shape]): | ||
| if len(idx) == 0: | ||
| names.append(base) | ||
| else: | ||
| names.append(f"{base}_{'.'.join(str(i) for i in idx)}") | ||
| coords["unconstrained_parameter"] = pd.Index(names) | ||
|
|
||
| dims = model.named_vars_to_dims | ||
| if shape_by_name.get(rv_name) == shape_by_name.get(value_name): | ||
| dims[value_name] = rv_dims | ||
|
|
||
| return dims, coords | ||
|
|
||
|
|
||
|
|
@@ -399,6 +402,7 @@ def _compile_pymc_model_jax( | |
| expand_fn_pt, | ||
| initial_point_fn, | ||
| shape_info, | ||
| reparameterized_names, | ||
| ) = _make_functions( | ||
| model, | ||
| mode="JAX", | ||
|
|
@@ -464,7 +468,7 @@ def expand(_x, **shared): | |
|
|
||
| return expand | ||
|
|
||
| dims, coords = _prepare_dims_and_coords(model, shape_info) | ||
| dims, coords = _prepare_dims_and_coords(model, shape_info, reparameterized_names) | ||
|
|
||
| return from_pyfunc( | ||
| ndim=n_dim, | ||
|
|
@@ -478,6 +482,7 @@ def expand(_x, **shared): | |
| dims=dims, | ||
| coords=coords, | ||
| raw_logp_fn=orig_logp_fn, | ||
| reparameterized_names=reparameterized_names, | ||
| ) | ||
|
|
||
|
|
||
|
|
@@ -641,6 +646,7 @@ def _make_functions( | |
| Callable, | ||
| Callable, | ||
| tuple[list[str], list[slice], list[tuple[int, ...]]], | ||
| list[str], | ||
| ]: | ||
| """ | ||
| Compile functions required by nuts-rs from a given PyMC model. | ||
|
|
@@ -748,7 +754,7 @@ def _make_functions( | |
| if use_split: | ||
| variables = pt.split(joined, splits, len(splits)) | ||
| else: | ||
| variables = [joined[slice_val] for slice_val in zip(joined_slices)] | ||
| variables = [joined[slice_val] for slice_val in joined_slices] | ||
|
|
||
| replacements = { | ||
| model.rvs_to_values[var]: value.reshape(shape).astype(var.dtype) | ||
|
|
@@ -768,6 +774,12 @@ def _make_functions( | |
| with model: | ||
| logp_fn_pt = compile_pymc((joined,), (logp,), mode=mode) | ||
|
|
||
| reparameterized_names = [ | ||
| model.rvs_to_values[var].name | ||
| for var in model.free_RVs | ||
| if model.rvs_to_transforms.get(var) is not None | ||
| ] | ||
|
|
||
| # Make function that computes remaining variables for the trace | ||
| remaining_rvs = [ | ||
| var for var in model.unobserved_value_vars if var.name not in joined_names | ||
|
|
@@ -777,27 +789,35 @@ def _make_functions( | |
| names = set(var_names) | ||
| remaining_rvs = [var for var in remaining_rvs if var.name in names] | ||
|
|
||
| all_names = joined_names + remaining_rvs | ||
| all_names = [] | ||
| all_slices = [] | ||
| all_shapes = [] | ||
| count_expanded = 0 | ||
|
|
||
| all_names = joined_names.copy() | ||
| all_slices = joined_slices.copy() | ||
| all_shapes = joined_shapes.copy() | ||
| identity_free = [] | ||
| for var_expr, name, shape in zip(variables, joined_names, joined_shapes): | ||
| length = prod(shape) | ||
| all_names.append(name) | ||
| all_shapes.append(shape) | ||
| all_slices.append(slice(count_expanded, count_expanded + length)) | ||
| count_expanded += length | ||
| identity_free.append(var_expr) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Are these changes here still necessary? I don't quite get what this does. |
||
|
|
||
| for var in remaining_rvs: | ||
| all_names.append(var.name) | ||
| shape = cast(tuple[int, ...], shapes[var.name]) | ||
| all_shapes.append(shape) | ||
| length = prod(shape) | ||
| all_slices.append(slice(count, count + length)) | ||
| count += length | ||
| all_slices.append(slice(count_expanded, count_expanded + length)) | ||
| count_expanded += length | ||
|
|
||
| num_expanded = count | ||
| num_expanded = count_expanded | ||
|
|
||
| if join_expanded: | ||
| allvars = [ | ||
| pt.concatenate( | ||
| [ | ||
| joined, | ||
| *[v.ravel() for v in identity_free], | ||
| *[ | ||
| pt.as_tensor(var, allow_xtensor_conversion=True).ravel() | ||
| for var in remaining_rvs | ||
|
|
@@ -806,7 +826,7 @@ def _make_functions( | |
| ) | ||
| ] | ||
| else: | ||
| allvars = [*variables, *remaining_rvs] | ||
| allvars = [*identity_free, *remaining_rvs] | ||
| with model: | ||
| expand_fn_pt = compile_pymc( | ||
| (joined,), | ||
|
|
@@ -822,6 +842,7 @@ def _make_functions( | |
| expand_fn_pt, | ||
| initial_point_fn, | ||
| (all_names, all_slices, all_shapes), | ||
| reparameterized_names, | ||
| ) | ||
|
|
||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,6 +1,7 @@ | ||
| import os | ||
| import warnings | ||
| from dataclasses import dataclass | ||
| from dataclasses import dataclass, field | ||
| from importlib.metadata import version | ||
| from typing import Any, Literal, Optional, cast, get_args, overload | ||
|
|
||
| import arviz | ||
|
|
@@ -15,6 +16,7 @@ | |
| @dataclass(frozen=True) | ||
| class CompiledModel: | ||
| dims: Optional[dict[str, tuple[str, ...]]] | ||
| reparameterized_names: list[str] | None = field(default=None, kw_only=True) | ||
|
|
||
| @property | ||
| def n_dim(self) -> int: | ||
|
|
@@ -56,9 +58,18 @@ def benchmark_logp(self, point, num_evals, cores): | |
| return pd.concat(times) | ||
|
|
||
|
|
||
| def _arrow_to_arviz(draw_batches, stat_batches, skip_vars=None, **kwargs): | ||
| def _arrow_to_arviz( | ||
| draw_batches, | ||
| stat_batches, | ||
| skip_vars=None, | ||
| reparameterized_names=None, | ||
| keep_unconstrained_draw=False, | ||
| **kwargs, | ||
| ): | ||
| if skip_vars is None: | ||
| skip_vars = [] | ||
| if reparameterized_names is None: | ||
| reparameterized_names = [] | ||
|
|
||
| n_chains = len(draw_batches) | ||
| assert n_chains == len(stat_batches) | ||
|
|
@@ -98,11 +109,23 @@ def _arrow_to_arviz(draw_batches, stat_batches, skip_vars=None, **kwargs): | |
| stats_posterior, max_posterior, stat_posterior, i, n_chains, dims, skip_vars | ||
| ) | ||
|
|
||
| from importlib.metadata import version | ||
| uc_data_posterior = { | ||
| name: data_posterior.pop(name) | ||
| for name in reparameterized_names | ||
| if name in data_posterior | ||
| } | ||
| uc_data_tune = { | ||
| name: data_tune.pop(name) for name in reparameterized_names if name in data_tune | ||
| } | ||
|
|
||
| if not keep_unconstrained_draw: | ||
| stats_posterior.pop("unconstrained_draw", None) | ||
| stats_tune.pop("unconstrained_draw", None) | ||
| dims.pop("unconstrained_draw", None) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There is no such dim. Maybe you meant |
||
|
|
||
| arviz_version = version("arviz") | ||
| if tuple(map(int, arviz_version.split(".")[:2])) >= (1, 0): | ||
| return arviz.from_dict( | ||
| idata = arviz.from_dict( | ||
| { | ||
| "posterior": data_posterior, | ||
| "sample_stats": stats_posterior, | ||
|
|
@@ -113,7 +136,7 @@ def _arrow_to_arviz(draw_batches, stat_batches, skip_vars=None, **kwargs): | |
| **kwargs, | ||
| ) | ||
| else: | ||
| return arviz.from_dict( | ||
| idata = arviz.from_dict( | ||
| **{ | ||
| "posterior": data_posterior, | ||
| "sample_stats": stats_posterior, | ||
|
|
@@ -124,6 +147,22 @@ def _arrow_to_arviz(draw_batches, stat_batches, skip_vars=None, **kwargs): | |
| **kwargs, | ||
| ) | ||
|
|
||
| if keep_unconstrained_draw and uc_data_posterior: | ||
| coords = kwargs.get("coords") | ||
| uc_dims = {name: dims.get(name, []) for name in uc_data_posterior} | ||
| groups = { | ||
| "unconstrained_posterior": arviz.dict_to_dataset( | ||
| uc_data_posterior, coords=coords, dims=uc_dims | ||
| ) | ||
| } | ||
| if uc_data_tune: | ||
| groups["warmup_unconstrained_posterior"] = arviz.dict_to_dataset( | ||
| uc_data_tune, coords=coords, dims=uc_dims | ||
| ) | ||
| idata.add_groups(groups) | ||
|
|
||
| return idata | ||
|
|
||
|
|
||
| def _add_arrow_data(data_dict, max_length, batch, chain, n_chains, dims, skip_vars): | ||
| num_draws = batch.num_rows | ||
|
|
@@ -463,11 +502,13 @@ def __init__( | |
| progress_style=None, | ||
| progress_rate=100, | ||
| store=None, | ||
| store_unconstrained=False, | ||
| ): | ||
| self._settings = settings | ||
| self._compiled_model = compiled_model | ||
| self._save_warmup = save_warmup | ||
| self._return_raw_trace = return_raw_trace | ||
| self._store_unconstrained = store_unconstrained | ||
|
|
||
| self._html = None | ||
|
|
||
|
|
@@ -622,6 +663,8 @@ def _get_nested(settings, name, default): | |
| draw_batches, | ||
| stat_batches, | ||
| skip_vars=skip_vars, | ||
| reparameterized_names=self._compiled_model.reparameterized_names, | ||
| keep_unconstrained_draw=self._store_unconstrained, | ||
| coords={ | ||
| name: pd.Index(vals) | ||
| for name, vals in self._compiled_model.coords.items() | ||
|
|
@@ -687,6 +730,7 @@ def sample( | |
| progress_style: str | None = None, | ||
| progress_rate: int = 100, | ||
| zarr_store: _ZarrStoreType | None = None, | ||
| store_unconstrained: bool = False, | ||
| ) -> xr.DataTree: ... | ||
|
|
||
|
|
||
|
|
@@ -710,6 +754,7 @@ def sample( | |
| progress_style: str | None = None, | ||
| progress_rate: int = 100, | ||
| zarr_store: _ZarrStoreType | None = None, | ||
| store_unconstrained: bool = False, | ||
| **kwargs, | ||
| ) -> xr.DataTree: ... | ||
|
|
||
|
|
@@ -734,6 +779,7 @@ def sample( | |
| progress_style: str | None = None, | ||
| progress_rate: int = 100, | ||
| zarr_store: _ZarrStoreType | None = None, | ||
| store_unconstrained: bool = False, | ||
| **kwargs, | ||
| ) -> _BackgroundSampler: ... | ||
|
|
||
|
|
@@ -781,6 +827,7 @@ def sample( | |
| progress_style: str | None = None, | ||
| progress_rate: int = 100, | ||
| zarr_store: _ZarrStoreType | None = None, | ||
| store_unconstrained: bool = False, | ||
| **kwargs, | ||
| ) -> xr.DataTree | _BackgroundSampler: | ||
| """Sample the posterior distribution for a compiled model. | ||
|
|
@@ -817,8 +864,11 @@ def sample( | |
| point on the transformed parameter space. Defaults to | ||
| zeros. | ||
| store_unconstrained: bool | ||
| If True, store each draw in the unconstrained (transformed) | ||
| space in the sample stats. | ||
| If True, store the unconstrained (transformed) draws in two forms: | ||
| a flat ``unconstrained_draw`` vector in ``sample_stats`` and a | ||
| per-variable ``unconstrained_posterior`` group (with | ||
| ``warmup_unconstrained_posterior`` when ``save_warmup=True``) whose | ||
| dims are copied from the corresponding RV. | ||
| store_gradient: bool | ||
| If True, store the logp gradient of each draw in the unconstrained | ||
| space in the sample stats. | ||
|
|
@@ -992,6 +1042,9 @@ def sample( | |
|
|
||
| settings.update(updates) | ||
|
|
||
| if store_unconstrained: | ||
| settings.store_unconstrained = True | ||
|
|
||
| if cores is None: | ||
| try: | ||
| # Only available in python>=3.13 | ||
|
|
@@ -1019,6 +1072,7 @@ def sample( | |
| progress_style=progress_style, | ||
| progress_rate=progress_rate, | ||
| store=zarr_store, | ||
| store_unconstrained=store_unconstrained, | ||
| ) | ||
|
|
||
| if not blocking: | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Did this fix a problem? Why not just keep it an Index?