Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 79 additions & 0 deletions pyro/infer/autoguide/guides.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,85 @@ def median(self, *args, **kwargs):
"""
raise NotImplementedError

@torch.no_grad()
def _predict(self, _guide_samples, *args, **kwargs):
model = poutine.condition(self.model, _guide_samples)
model = poutine.mask(model, False) # disables unnecessary computation
trace = poutine.trace(model).get_trace(*args, **kwargs)
samples = {
name: site["value"]
for name, site in trace.nodes.items()
if site["type"] == "sample"
if not site_is_subsample(site)
}
return samples

def predict(self, *args, **kwargs):
"""
Draws a single posterior latent sample and replays the model against
latent samples.

:returns: A dict mapping sample site name to sampled value. This includes
latent variables, ``pyro.deterministic`` sites, and observations.
:rtype dict"
"""
data = self(*args, **kwargs)
return self._predict(data, *args, **kwargs)

def predict_sample(self, sample_shape, *args, **kwargs):
"""
Draws a batch of posterior latent samples and replays the model against
latent samples.

This may conflict with enumeration for mode-side enumerated models.

:returns: A dict mapping sample site name to sampled value. This includes
latent variables, ``pyro.deterministic`` sites, and observations.
:rtype dict"
"""
if not isinstance(sample_shape, tuple):
sample_shape = (sample_shape,)
if self.prototype_trace is None:
self._setup_prototype(*args, **kwargs)
dim = -1
if self._prototype_frames:
dim += min(f.dim for f in self._prototype_frames.values())
with pyro.plate_stack("particles", sample_shape, dim):
return self.predict(*args, **kwargs)

def predict_median(self, *args, **kwargs):
"""
Computes posterior median of latent values and replays model against
that median.

.. warning:: downstream deterministic sites computed from latent
medians may not themselves be medians, e.g. ``abs(median(z)) !=
median(abs(z))`` in general.

:returns: A dict mapping sample site name to value. This includes
latent variables, ``pyro.deterministic`` sites, and observations.
:rtype dict"
"""
data = self.median(*args, **kwargs)
return self._predict(data, *args, **kwargs)

def predict_quantiles(self, quantiles, *args, **kwargs):
"""
Computes posterior median of latent values and replays model against
that median.

.. warning:: downstream deterministic sites computed from latent
quantiles may not themselves be quantiles, e.g. ``abs(quantiles(z,
q)) != quantiles(abs(z), q)`` in general.

:returns: A dict mapping sample site name to value. This includes
latent variables, ``pyro.deterministic`` sites, and observations.
:rtype dict"
"""
# FIXME this does not correctly align the batch dimension
data = self.quantiles(quantiles, *args, **kwargs)
return self._predict(data, *args, **kwargs)


class AutoGuideList(AutoGuide, nn.ModuleList):
"""
Expand Down