diff --git a/pyro/infer/autoguide/guides.py b/pyro/infer/autoguide/guides.py index cdde09f65d..760511fd25 100644 --- a/pyro/infer/autoguide/guides.py +++ b/pyro/infer/autoguide/guides.py @@ -193,6 +193,85 @@ def median(self, *args, **kwargs): """ raise NotImplementedError + @torch.no_grad() + def _predict(self, _guide_samples, *args, **kwargs): + model = poutine.condition(self.model, _guide_samples) + model = poutine.mask(model, False) # disables unnecessary computation + trace = poutine.trace(model).get_trace(*args, **kwargs) + samples = { + name: site["value"] + for name, site in trace.nodes.items() + if site["type"] == "sample" + if not site_is_subsample(site) + } + return samples + + def predict(self, *args, **kwargs): + """ + Draws a single posterior latent sample and replays the model against + latent samples. + + :returns: A dict mapping sample site name to sampled value. This includes + latent variables, ``pyro.deterministic`` sites, and observations. + :rtype dict" + """ + data = self(*args, **kwargs) + return self._predict(data, *args, **kwargs) + + def predict_sample(self, sample_shape, *args, **kwargs): + """ + Draws a batch of posterior latent samples and replays the model against + latent samples. + + This may conflict with enumeration for mode-side enumerated models. + + :returns: A dict mapping sample site name to sampled value. This includes + latent variables, ``pyro.deterministic`` sites, and observations. + :rtype dict" + """ + if not isinstance(sample_shape, tuple): + sample_shape = (sample_shape,) + if self.prototype_trace is None: + self._setup_prototype(*args, **kwargs) + dim = -1 + if self._prototype_frames: + dim += min(f.dim for f in self._prototype_frames.values()) + with pyro.plate_stack("particles", sample_shape, dim): + return self.predict(*args, **kwargs) + + def predict_median(self, *args, **kwargs): + """ + Computes posterior median of latent values and replays model against + that median. + + .. warning:: downstream deterministic sites computed from latent + medians may not themselves be medians, e.g. ``abs(median(z)) != + median(abs(z))`` in general. + + :returns: A dict mapping sample site name to value. This includes + latent variables, ``pyro.deterministic`` sites, and observations. + :rtype dict" + """ + data = self.median(*args, **kwargs) + return self._predict(data, *args, **kwargs) + + def predict_quantiles(self, quantiles, *args, **kwargs): + """ + Computes posterior median of latent values and replays model against + that median. + + .. warning:: downstream deterministic sites computed from latent + quantiles may not themselves be quantiles, e.g. ``abs(quantiles(z, + q)) != quantiles(abs(z), q)`` in general. + + :returns: A dict mapping sample site name to value. This includes + latent variables, ``pyro.deterministic`` sites, and observations. + :rtype dict" + """ + # FIXME this does not correctly align the batch dimension + data = self.quantiles(quantiles, *args, **kwargs) + return self._predict(data, *args, **kwargs) + class AutoGuideList(AutoGuide, nn.ModuleList): """