Skip to content
Open
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 41 additions & 19 deletions python/nutpie/compile_pymc.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,7 @@ def _compile_pymc_model_numba(
expand_fn_pt,
initial_point_fn,
shape_info,
unconstrained_info,
) = _make_functions(
model,
mode="NUMBA",
Expand Down Expand Up @@ -326,7 +327,7 @@ def _compile_pymc_model_numba(

expand_numba = numba.cfunc(c_sig_expand, **kwargs)(expand_numba_raw)

dims, coords = _prepare_dims_and_coords(model, shape_info)
dims, coords = _prepare_dims_and_coords(model, shape_info, unconstrained_info)

return CompiledPyMCModel(
_n_dim=n_dim,
Expand All @@ -345,26 +346,29 @@ def _compile_pymc_model_numba(
)


def _prepare_dims_and_coords(model, shape_info):
def _prepare_dims_and_coords(model, shape_info, unconstrained_info):
coords = {}
for name, vals in model.coords.items():
if vals is None:
vals = pd.RangeIndex(int(model.dim_lengths[name].eval()))
coords[name] = pd.Index(vals)
idx = pd.Index(vals)
if idx.dtype == "object" or idx.dtype == "string":
coords[name] = idx.tolist()
else:
coords[name] = idx

if "unconstrained_parameter" in coords:
raise ValueError("Model contains invalid name 'unconstrained_parameter'.")

unconstrained_names, unconstrained_shapes = unconstrained_info
names = []
for base, _, shape in zip(*shape_info):
if base not in [var.name for var in model.value_vars]:
continue
for base, shape in zip(unconstrained_names, unconstrained_shapes):
for idx in itertools.product(*[range(length) for length in shape]):
if len(idx) == 0:
names.append(base)
else:
names.append(f"{base}_{'.'.join(str(i) for i in idx)}")
coords["unconstrained_parameter"] = pd.Index(names)
coords["unconstrained_parameter"] = names

dims = model.named_vars_to_dims
return dims, coords
Expand Down Expand Up @@ -399,6 +403,7 @@ def _compile_pymc_model_jax(
expand_fn_pt,
initial_point_fn,
shape_info,
unconstrained_info,
) = _make_functions(
model,
mode="JAX",
Expand Down Expand Up @@ -464,7 +469,7 @@ def expand(_x, **shared):

return expand

dims, coords = _prepare_dims_and_coords(model, shape_info)
dims, coords = _prepare_dims_and_coords(model, shape_info, unconstrained_info)

return from_pyfunc(
ndim=n_dim,
Expand Down Expand Up @@ -641,6 +646,7 @@ def _make_functions(
Callable,
Callable,
tuple[list[str], list[slice], list[tuple[int, ...]]],
tuple[list[str], list[tuple[int, ...]]],
]:
"""
Compile functions required by nuts-rs from a given PyMC model.
Expand Down Expand Up @@ -747,7 +753,7 @@ def _make_functions(
if use_split:
variables = pt.split(joined, splits, len(splits))
else:
variables = [joined[slice_val] for slice_val in zip(joined_slices)]
variables = [joined[slice_val] for slice_val in joined_slices]

replacements = {
model.rvs_to_values[var]: value.reshape(shape).astype(var.dtype)
Expand All @@ -767,6 +773,12 @@ def _make_functions(
with model:
logp_fn_pt = compile_pymc((joined,), (logp,), mode=mode)

transformed_value_names = set()
Comment thread
fonnesbeck marked this conversation as resolved.
Outdated
for var in model.free_RVs:
value_var = model.rvs_to_values[var]
if model.rvs_to_transforms.get(var) is not None:
transformed_value_names.add(value_var.name)

# Make function that computes remaining variables for the trace
remaining_rvs = [
var for var in model.unobserved_value_vars if var.name not in joined_names
Expand All @@ -776,27 +788,36 @@ def _make_functions(
names = set(var_names)
remaining_rvs = [var for var in remaining_rvs if var.name in names]

all_names = joined_names + remaining_rvs

all_names = joined_names.copy()
all_slices = joined_slices.copy()
all_shapes = joined_shapes.copy()
all_names = []
all_slices = []
all_shapes = []
count_expanded = 0

untransformed_free = []
Comment thread
fonnesbeck marked this conversation as resolved.
Outdated
for var_expr, name, shape in zip(variables, joined_names, joined_shapes):
if name not in transformed_value_names:
length = prod(shape)
all_names.append(name)
all_shapes.append(shape)
all_slices.append(slice(count_expanded, count_expanded + length))
count_expanded += length
untransformed_free.append(var_expr)

for var in remaining_rvs:
all_names.append(var.name)
shape = cast(tuple[int, ...], shapes[var.name])
all_shapes.append(shape)
length = prod(shape)
all_slices.append(slice(count, count + length))
count += length
all_slices.append(slice(count_expanded, count_expanded + length))
count_expanded += length

num_expanded = count
num_expanded = count_expanded

if join_expanded:
allvars = [
pt.concatenate(
[
joined,
*[v.ravel() for v in untransformed_free],
*[
pt.as_tensor(var, allow_xtensor_conversion=True).ravel()
for var in remaining_rvs
Expand All @@ -805,7 +826,7 @@ def _make_functions(
)
]
else:
allvars = [*variables, *remaining_rvs]
allvars = [*untransformed_free, *remaining_rvs]
with model:
expand_fn_pt = compile_pymc(
(joined,),
Expand All @@ -821,6 +842,7 @@ def _make_functions(
expand_fn_pt,
initial_point_fn,
(all_names, all_slices, all_shapes),
(joined_names, joined_shapes),
)


Expand Down
Loading