diff --git a/Makefile b/Makefile index 04c6112d47..f09c7eb694 100644 --- a/Makefile +++ b/Makefile @@ -40,7 +40,8 @@ scrub: FORCE doctest: FORCE # We skip testing pyro.distributions.torch wrapper classes because # they include torch docstrings which are tested upstream. - python -m pytest -p tests.doctest_fixtures --doctest-modules -o filterwarnings=ignore pyro --ignore=pyro/distributions/torch.py + python -m pytest -p tests.doctest_fixtures --doctest-modules -o filterwarnings=ignore pyro --ignore=pyro/distributions/torch.py \ + --ignore=pyro/contrib/named perf-test: FORCE bash scripts/perf_test.sh ${ref} diff --git a/pyro/contrib/named/__init__.py b/pyro/contrib/named/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/pyro/contrib/named/infer/__init__.py b/pyro/contrib/named/infer/__init__.py new file mode 100644 index 0000000000..77832766f4 --- /dev/null +++ b/pyro/contrib/named/infer/__init__.py @@ -0,0 +1,6 @@ +# Copyright Contributors to the Pyro project. +# SPDX-License-Identifier: Apache-2.0 + +from pyro.contrib.named.infer.elbo import Trace_ELBO + +__all__ = ["Trace_ELBO"] diff --git a/pyro/contrib/named/infer/elbo.py b/pyro/contrib/named/infer/elbo.py new file mode 100644 index 0000000000..505946b896 --- /dev/null +++ b/pyro/contrib/named/infer/elbo.py @@ -0,0 +1,133 @@ +# Copyright Contributors to the Pyro project. +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any, Callable, Tuple + +import torch +from functorch.dim import Dim +from typing_extensions import ParamSpec + +import pyro +from pyro import poutine +from pyro.distributions.torch_distribution import TorchDistributionMixin +from pyro.infer import ELBO as _OrigELBO +from pyro.poutine.messenger import Messenger +from pyro.poutine.runtime import Message + +_P = ParamSpec("_P") + + +class ELBO(_OrigELBO): + def _get_trace(self, *args, **kwargs): + raise RuntimeError("shouldn't be here!") + + def differentiable_loss(self, model, guide, *args, **kwargs): + raise NotImplementedError("Must implement differentiable_loss") + + def loss(self, model, guide, *args, **kwargs): + return self.differentiable_loss(model, guide, *args, **kwargs).detach().item() + + def loss_and_grads(self, model, guide, *args, **kwargs): + loss = self.differentiable_loss(model, guide, *args, **kwargs) + loss.backward() + return loss.item() + + +def track_provenance(x: torch.Tensor, provenance: Dim) -> torch.Tensor: + return x.unsqueeze(0)[provenance] + + +class track_nonreparam(Messenger): + def _pyro_post_sample(self, msg: Message) -> None: + if ( + msg["type"] == "sample" + and isinstance(msg["fn"], TorchDistributionMixin) + and not msg["is_observed"] + and not msg["fn"].has_rsample + ): + provenance = Dim(msg["name"]) + msg["value"] = track_provenance(msg["value"], provenance) + + +def get_importance_trace( + model: Callable[_P, Any], + guide: Callable[_P, Any], + *args: _P.args, + **kwargs: _P.kwargs +) -> Tuple[poutine.Trace, poutine.Trace]: + """ + Returns traces from the guide and the model that is run against it. + The returned traces also store the log probability at each site. + """ + with track_nonreparam(): + guide_trace = poutine.trace(guide).get_trace(*args, **kwargs) + replay_model = poutine.replay(model, trace=guide_trace) + model_trace = poutine.trace(replay_model).get_trace(*args, **kwargs) + + for is_guide, trace in zip((True, False), (guide_trace, model_trace)): + for site in list(trace.nodes.values()): + if site["type"] == "sample" and isinstance( + site["fn"], TorchDistributionMixin + ): + log_prob = site["fn"].log_prob(site["value"]) + site["log_prob"] = log_prob + + if is_guide and not site["fn"].has_rsample: + # importance sampling weights + site["log_measure"] = log_prob - log_prob.detach() + else: + trace.remove_node(site["name"]) + return model_trace, guide_trace + + +class Trace_ELBO(ELBO): + def differentiable_loss( + self, + model: Callable[_P, Any], + guide: Callable[_P, Any], + *args: _P.args, + **kwargs: _P.kwargs + ) -> torch.Tensor: + if self.num_particles > 1: + vectorize = pyro.plate( + "num_particles", self.num_particles, dim=Dim("num_particles") + ) + model = vectorize(model) + guide = vectorize(guide) + + model_trace, guide_trace = get_importance_trace(model, guide, *args, **kwargs) + + cost_terms = [] + # logp terms + for site in model_trace.nodes.values(): + cost = site["log_prob"] + scale = site["scale"] + batch_dims = tuple(f.dim for f in site["cond_indep_stack"]) + deps = tuple(set(getattr(cost, "dims", ())) - set(batch_dims)) + cost_terms.append((cost, scale, batch_dims, deps)) + # -logq terms + for site in guide_trace.nodes.values(): + cost = -site["log_prob"] + scale = site["scale"] + batch_dims = tuple(f.dim for f in site["cond_indep_stack"]) + deps = tuple(set(getattr(cost, "dims", ())) - set(batch_dims)) + cost_terms.append((cost, scale, batch_dims, deps)) + + elbo = 0.0 + for cost, scale, batch_dims, deps in cost_terms: + if deps: + dice_factor = 0.0 + for key in deps: + dice_factor += guide_trace.nodes[str(key)]["log_measure"] + dice_factor_dims = getattr(dice_factor, "dims", ()) + cost_dims = getattr(cost, "dims", ()) + sum_dims = tuple(set(dice_factor_dims) - set(cost_dims)) + if sum_dims: + dice_factor = dice_factor.sum(sum_dims) + cost = torch.exp(dice_factor) * cost + cost = cost.mean(deps) + if scale is not None: + cost = cost * scale + elbo += cost.sum(batch_dims) / self.num_particles + + return -elbo diff --git a/pyro/distributions/torch_distribution.py b/pyro/distributions/torch_distribution.py index ace02da72a..9097c10668 100644 --- a/pyro/distributions/torch_distribution.py +++ b/pyro/distributions/torch_distribution.py @@ -3,10 +3,11 @@ import warnings from collections import OrderedDict -from typing import Callable +from typing import TYPE_CHECKING, Callable, Tuple import torch from torch.distributions.kl import kl_divergence, register_kl +from typing_extensions import Self import pyro.distributions.torch @@ -15,6 +16,9 @@ from .score_parts import ScoreParts from .util import broadcast_shape, scale_and_mask +if TYPE_CHECKING: + from functorch.dim import Dim + class TorchDistributionMixin(Distribution, Callable): """ @@ -45,11 +49,52 @@ def __call__(self, sample_shape: torch.Size = torch.Size()) -> torch.Tensor: batched). The shape of the result should be `self.shape()`. :rtype: torch.Tensor """ - return ( + sample_shape = self.named_sample_shape + sample_shape + result = ( self.rsample(sample_shape) if self.has_rsample else self.sample(sample_shape) ) + bind_named_dims = self.named_shape[ + len(self.named_shape) - len(self.named_sample_shape) : + ] + if bind_named_dims: + result = result[bind_named_dims] + return result + + @property + def named_shape(self) -> Tuple["Dim"]: + if getattr(self, "_named_shape", None) is None: + result = [] + for param in self.arg_constraints: + value = getattr(self, param) + for dim in getattr(value, "dims", ()): + # Can't use `dim in result` when `result` is a list or a tuple + # RuntimeError: vmap: It looks like you're attempting to use + # a Tensor in some data-dependent control flow. We don't support + # that yet, please shout over at + # https://github.com/pytorch/functorch/issues/257 + if dim not in set(result): + result.append(dim) + self._named_shape = tuple(result) + return self._named_shape + + def expand_named_shape(self, named_shape: Tuple["Dim"]) -> Self: + for dim in named_shape: + if dim not in set(self.named_shape): + self._named_shape += (dim,) + self.named_sample_shape = self.named_sample_shape + (dim.size,) + return self + + @property + def named_sample_shape(self) -> torch.Size: + if getattr(self, "_named_sample_shape", None) is None: + self._named_sample_shape = torch.Size() + return self._named_sample_shape + + @named_sample_shape.setter + def named_sample_shape(self, value: torch.Size) -> None: + self._named_sample_shape = value @property def batch_shape(self) -> torch.Size: diff --git a/pyro/ops/indexing.py b/pyro/ops/indexing.py index 2fc57aa9f8..d8f8155214 100644 --- a/pyro/ops/indexing.py +++ b/pyro/ops/indexing.py @@ -215,3 +215,7 @@ def __init__(self, tensor): def __getitem__(self, args): return vindex(self._tensor, args) + + +def index_select(input, dim, index): + return input.order(dim)[index] diff --git a/pyro/poutine/broadcast_messenger.py b/pyro/poutine/broadcast_messenger.py index 87e9dd2f7b..eedd1b1364 100644 --- a/pyro/poutine/broadcast_messenger.py +++ b/pyro/poutine/broadcast_messenger.py @@ -8,6 +8,8 @@ from pyro.util import ignore_jit_warnings if TYPE_CHECKING: + from functorch.dim import Dim + from pyro.poutine.runtime import Message @@ -59,7 +61,11 @@ def _pyro_sample(msg: "Message") -> None: target_batch_shape = [ None if size == 1 else size for size in actual_batch_shape ] + named_shape: List["Dim"] = [] for f in msg["cond_indep_stack"]: + if hasattr(f.dim, "is_bound"): + named_shape.append(f.dim) + continue if f.dim is None or f.size == -1: continue assert f.dim < 0 @@ -88,6 +94,10 @@ def _pyro_sample(msg: "Message") -> None: target_batch_shape[i] = ( actual_batch_shape[i] if len(actual_batch_shape) >= -i else 1 ) - msg["fn"] = dist.expand(target_batch_shape) + if named_shape: + assert len(target_batch_shape) == 0 + msg["fn"] = dist.expand_named_shape(tuple(named_shape)) + else: + msg["fn"] = dist.expand(target_batch_shape) if msg["fn"].has_rsample != dist.has_rsample: msg["fn"].has_rsample = dist.has_rsample # copy custom attribute diff --git a/pyro/poutine/indep_messenger.py b/pyro/poutine/indep_messenger.py index 69d41756f6..2ae2863501 100644 --- a/pyro/poutine/indep_messenger.py +++ b/pyro/poutine/indep_messenger.py @@ -2,7 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 import numbers -from typing import Iterator, NamedTuple, Optional, Tuple +from typing import TYPE_CHECKING, Iterator, NamedTuple, Optional, Tuple, Union import torch from typing_extensions import Self @@ -11,10 +11,13 @@ from pyro.poutine.runtime import _DIM_ALLOCATOR, Message from pyro.util import ignore_jit_warnings +if TYPE_CHECKING: + from functorch.dim import Dim + class CondIndepStackFrame(NamedTuple): name: str - dim: Optional[int] + dim: Optional[Union[int, "Dim"]] size: int counter: int full_size: Optional[int] = None @@ -23,7 +26,7 @@ class CondIndepStackFrame(NamedTuple): def vectorized(self) -> bool: return self.dim is not None - def _key(self) -> Tuple[str, Optional[int], int, int]: + def _key(self) -> Tuple[str, Optional[Union[int, "Dim"]], int, int]: size = self.size with ignore_jit_warnings(["Converting a tensor to a Python number"]): if isinstance(size, torch.Tensor): # type: ignore[unreachable] @@ -69,7 +72,7 @@ def __init__( self, name: str, size: int, - dim: Optional[int] = None, + dim: Optional[Union[int, "Dim"]] = None, device: Optional[str] = None, ) -> None: if not torch._C._get_tracing_state() and size == 0: @@ -97,13 +100,13 @@ def __enter__(self) -> Self: if self._vectorized is not False: self._vectorized = True - if self._vectorized is True: + if self._vectorized is True and not hasattr(self.dim, "is_bound"): self.dim = _DIM_ALLOCATOR.allocate(self.name, self.dim) return super().__enter__() def __exit__(self, *args) -> None: - if self._vectorized is True: + if self._vectorized is True and not hasattr(self.dim, "is_bound"): assert self.dim is not None _DIM_ALLOCATOR.free(self.name, self.dim) return super().__exit__(*args) @@ -124,7 +127,7 @@ def __iter__(self) -> Iterator[int]: yield i if isinstance(i, numbers.Number) else i.item() def _reset(self) -> None: - if self._vectorized: + if self._vectorized and not hasattr(self.dim, "is_bound"): assert self.dim is not None _DIM_ALLOCATOR.free(self.name, self.dim) self._vectorized = None @@ -134,6 +137,8 @@ def _reset(self) -> None: def indices(self) -> torch.Tensor: if self._indices is None: self._indices = torch.arange(self.size, dtype=torch.long).to(self.device) + if hasattr(self.dim, "is_bound"): + return self._indices[self.dim] return self._indices def _process_message(self, msg: Message) -> None: diff --git a/tests/contrib/named/infer/test_gradient.py b/tests/contrib/named/infer/test_gradient.py new file mode 100644 index 0000000000..69ea59353f --- /dev/null +++ b/tests/contrib/named/infer/test_gradient.py @@ -0,0 +1,63 @@ +# Copyright Contributors to the Pyro project. +# SPDX-License-Identifier: Apache-2.0 + +import pytest +import torch +from functorch.dim import dims + +import pyro +import pyro.distributions as dist +from pyro.contrib.named.infer import Trace_ELBO +from pyro.distributions.testing import fakes +from pyro.infer import SVI +from pyro.optim import Adam +from tests.common import assert_equal + + +@pytest.mark.parametrize( + "reparameterized", [True, False], ids=["reparam", "nonreparam"] +) +def test_plate_elbo_vectorized_particles(reparameterized): + pyro.enable_validation(False) + pyro.clear_param_store() + data = torch.tensor([-0.5, 2.0]) + num_particles = 200000 + Normal = dist.Normal if reparameterized else fakes.NonreparameterizedNormal + i = dims() + + def model(): + data_plate = pyro.plate("data", len(data), dim=i) + + pyro.sample("nuisance_a", Normal(0, 1)) + with data_plate: + z = pyro.sample("z", Normal(0, 1)) + pyro.sample("nuisance_b", Normal(2, 3)) + with data_plate as idx: + pyro.sample("x", Normal(z, torch.ones(len(data))[idx]), obs=data[idx]) + pyro.sample("nuisance_c", Normal(4, 5)) + + def guide(): + loc = pyro.param("loc", torch.zeros(len(data))) + scale = pyro.param("scale", torch.ones(len(data))) + + pyro.sample("nuisance_c", Normal(4, 5)) + with pyro.plate("data", len(data), dim=i) as idx: + pyro.sample("z", Normal(loc[idx], scale[idx])) + pyro.sample("nuisance_b", Normal(2, 3)) + pyro.sample("nuisance_a", Normal(0, 1)) + + optim = Adam({"lr": 0.1}) + loss = Trace_ELBO( + num_particles=num_particles, + vectorize_particles=True, + ) + inference = SVI(model, guide, optim, loss=loss) + inference.loss_and_grads(model, guide) + params = dict(pyro.get_param_store().named_parameters()) + actual_grads = {name: param.grad.detach() for name, param in params.items()} + + expected_grads = { + "loc": torch.tensor([0.5, -2.0]), + "scale": torch.tensor([1.0, 1.0]), + } + assert_equal(actual_grads, expected_grads, prec=0.06)