diff --git a/docs/requirements.txt b/docs/requirements.txt index da561a25c1..ce14b39d4e 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -6,4 +6,4 @@ observations>=0.1.4 opt_einsum>=2.3.2 pyro-api>=0.1.1 tqdm>=4.36 -funsor[torch] +funsor[torch] @ git+https://github.com/pyro-ppl/funsor.git@sampled-funsor diff --git a/pyro/contrib/funsor/handlers/__init__.py b/pyro/contrib/funsor/handlers/__init__.py index 80d01740f9..d3ac3bf471 100644 --- a/pyro/contrib/funsor/handlers/__init__.py +++ b/pyro/contrib/funsor/handlers/__init__.py @@ -15,7 +15,7 @@ ) from pyro.poutine.handlers import _make_handler -from .enum_messenger import EnumMessenger, queue # noqa: F401 +from .enum_messenger import EnumMessenger, ProvenanceMessenger, queue # noqa: F401 from .named_messenger import MarkovMessenger, NamedMessenger from .plate_messenger import PlateMessenger, VectorizedMarkovMessenger from .replay_messenger import ReplayMessenger @@ -26,6 +26,7 @@ MarkovMessenger, NamedMessenger, PlateMessenger, + ProvenanceMessenger, ReplayMessenger, TraceMessenger, VectorizedMarkovMessenger, diff --git a/pyro/contrib/funsor/handlers/enum_messenger.py b/pyro/contrib/funsor/handlers/enum_messenger.py index 3815f1934e..ca7e9fdeeb 100644 --- a/pyro/contrib/funsor/handlers/enum_messenger.py +++ b/pyro/contrib/funsor/handlers/enum_messenger.py @@ -18,7 +18,9 @@ from pyro.contrib.funsor.handlers.primitives import to_data, to_funsor from pyro.contrib.funsor.handlers.replay_messenger import ReplayMessenger from pyro.contrib.funsor.handlers.trace_messenger import TraceMessenger +from pyro.ops.provenance import detach_provenance, extract_provenance from pyro.poutine.escape_messenger import EscapeMessenger +from pyro.poutine.reentrant_messenger import ReentrantMessenger from pyro.poutine.subsample_messenger import _Subsample funsor.set_backend("torch") @@ -58,6 +60,13 @@ def _get_support_value_tensor(funsor_dist, name, **kwargs): ) +@_get_support_value.register(funsor.Provenance) +def _get_support_value_sampled(funsor_dist, name, **kwargs): + assert name in funsor_dist.inputs + value = _get_support_value(funsor_dist.term, name, **kwargs) + return funsor.Provenance(value, funsor_dist.provenance) + + @_get_support_value.register(funsor.distribution.Distribution) def _get_support_value_distribution(funsor_dist, name, expand=False): assert name == funsor_dist.value.name @@ -179,6 +188,49 @@ def enumerate_site(dist, msg): raise ValueError("{} not valid enum strategy".format(msg)) +@extract_provenance.register(funsor.Provenance) +def _extract_provenance_funsor(x): + return x.term, x.provenance + + +class ProvenanceMessenger(ReentrantMessenger): + """ + Adds provenance information for all sample sites that are not enumerated. + """ + + def _pyro_sample(self, msg): + if ( + msg["done"] + or msg["is_observed"] + or msg["infer"].get("enumerate") == "parallel" + or isinstance(msg["fn"], _Subsample) + ): + return + + if "funsor" not in msg: + msg["funsor"] = {} + + with funsor.terms.lazy: + unsampled_log_measure = to_funsor(msg["fn"], output=funsor.Real)( + value=msg["name"] + ) + # TODO delegate to enumerate_site + log_measure = _enum_strategy_default(unsampled_log_measure, msg) + msg["funsor"]["log_measure"] = detach_provenance(log_measure) + support_value = _get_support_value( + log_measure, + msg["name"], + expand=msg["infer"].get("expand", False), + ) + # TODO delegate to _get_support_value + msg["funsor"]["value"] = funsor.Provenance( + support_value, + frozenset([(msg["name"], detach_provenance(support_value))]), + ) + msg["value"] = to_data(msg["funsor"]["value"]) + msg["done"] = True + + class EnumMessenger(NamedMessenger): """ This version of :class:`~EnumMessenger` uses :func:`~pyro.contrib.funsor.to_data` @@ -200,9 +252,10 @@ def _pyro_sample(self, msg): unsampled_log_measure = to_funsor(msg["fn"], output=funsor.Real)( value=msg["name"] ) - msg["funsor"]["log_measure"] = enumerate_site(unsampled_log_measure, msg) + log_measure = enumerate_site(unsampled_log_measure, msg) + msg["funsor"]["log_measure"] = detach_provenance(log_measure) msg["funsor"]["value"] = _get_support_value( - msg["funsor"]["log_measure"], + log_measure, msg["name"], expand=msg["infer"].get("expand", False), ) diff --git a/pyro/contrib/funsor/handlers/named_messenger.py b/pyro/contrib/funsor/handlers/named_messenger.py index ff65a18223..0394584aa1 100644 --- a/pyro/contrib/funsor/handlers/named_messenger.py +++ b/pyro/contrib/funsor/handlers/named_messenger.py @@ -4,6 +4,8 @@ from collections import OrderedDict from contextlib import ExitStack +import funsor + from pyro.contrib.funsor.handlers.runtime import ( _DIM_STACK, DimRequest, @@ -64,7 +66,10 @@ def _pyro_to_data(msg): name_to_dim = msg["kwargs"].setdefault("name_to_dim", OrderedDict()) dim_type = msg["kwargs"].setdefault("dim_type", DimType.LOCAL) - batch_names = tuple(funsor_value.inputs.keys()) + if isinstance(funsor_value, funsor.Provenance): + batch_names = tuple(funsor_value.term.inputs.keys()) + else: + batch_names = tuple(funsor_value.inputs.keys()) # interpret all names/dims as requests since we only run this function once name_to_dim_request = name_to_dim.copy() diff --git a/pyro/contrib/funsor/infer/discrete.py b/pyro/contrib/funsor/infer/discrete.py index 3518882c16..c138a6613f 100644 --- a/pyro/contrib/funsor/infer/discrete.py +++ b/pyro/contrib/funsor/infer/discrete.py @@ -36,7 +36,7 @@ def _sample_posterior(model, first_available_dim, temperature, *args, **kwargs): log_prob = funsor.sum_product.sum_product( sum_op, prod_op, - terms["log_factors"] + terms["log_measures"], + terms["log_factors"] + list(terms["log_measures"].values()), eliminate=terms["measure_vars"] | terms["plate_vars"], plates=terms["plate_vars"], ) diff --git a/pyro/contrib/funsor/infer/trace_elbo.py b/pyro/contrib/funsor/infer/trace_elbo.py index e91787732f..01b6f71490 100644 --- a/pyro/contrib/funsor/infer/trace_elbo.py +++ b/pyro/contrib/funsor/infer/trace_elbo.py @@ -4,10 +4,10 @@ import contextlib import funsor +from funsor.sum_product import _partition from pyro.contrib.funsor import to_data, to_funsor -from pyro.contrib.funsor.handlers import enum, plate, replay, trace -from pyro.contrib.funsor.infer import config_enumerate +from pyro.contrib.funsor.handlers import enum, plate, provenance, replay, trace from pyro.distributions.util import copy_docs_from from pyro.infer import Trace_ELBO as _OrigTrace_ELBO @@ -18,32 +18,107 @@ @copy_docs_from(_OrigTrace_ELBO) class Trace_ELBO(ELBO): def differentiable_loss(self, model, guide, *args, **kwargs): - with enum(), plate( - size=self.num_particles + with enum( + first_available_dim=(-self.max_plate_nesting - 1) + if self.max_plate_nesting is not None + and self.max_plate_nesting != float("inf") + else None + ), provenance(), plate( + name="num_particles_vectorized", + size=self.num_particles, + dim=-self.max_plate_nesting, ) if self.num_particles > 1 else contextlib.ExitStack(): - guide_tr = trace(config_enumerate(default="flat")(guide)).get_trace( - *args, **kwargs - ) + guide_tr = trace(guide).get_trace(*args, **kwargs) model_tr = trace(replay(model, trace=guide_tr)).get_trace(*args, **kwargs) model_terms = terms_from_trace(model_tr) guide_terms = terms_from_trace(guide_tr) - log_measures = guide_terms["log_measures"] + model_terms["log_measures"] - log_factors = model_terms["log_factors"] + [ - -f for f in guide_terms["log_factors"] - ] - plate_vars = model_terms["plate_vars"] | guide_terms["plate_vars"] - measure_vars = model_terms["measure_vars"] | guide_terms["measure_vars"] - - elbo = funsor.Integrate( - sum(log_measures, to_funsor(0.0)), - sum(log_factors, to_funsor(0.0)), - measure_vars, + particle_var = ( + frozenset({"num_particles_vectorized"}) + if self.num_particles > 1 + else frozenset() ) - elbo = elbo.reduce(funsor.ops.add, plate_vars) + plate_vars = ( + guide_terms["plate_vars"] | model_terms["plate_vars"] + ) - particle_var - return -to_data(elbo) + model_measure_vars = model_terms["measure_vars"] - guide_terms["measure_vars"] + with funsor.terms.lazy: + # identify and contract out auxiliary variables in the model with partial_sum_product + contracted_factors, uncontracted_factors = [], [] + for f in model_terms["log_factors"]: + if model_measure_vars.intersection(f.inputs): + contracted_factors.append(f) + else: + uncontracted_factors.append(f) + contracted_costs = [] + # incorporate the effects of subsampling and handlers.scale through a common scale factor + for group_factors, group_vars in _partition( + list(model_terms["log_measures"].values()) + contracted_factors, + model_terms["measure_vars"], + ): + group_factor_vars = frozenset().union( + *[f.inputs for f in group_factors] + ) + group_plates = model_terms["plate_vars"] & group_factor_vars + outermost_plates = frozenset.intersection( + *(frozenset(f.inputs) & group_plates for f in group_factors) + ) + elim_plates = group_plates - outermost_plates + for f in funsor.sum_product.partial_sum_product( + funsor.ops.logaddexp, + funsor.ops.add, + group_factors, + plates=group_plates, + eliminate=group_vars | elim_plates, + ): + contracted_costs.append(model_terms["scale"] * f) + + # accumulate costs from model (logp) and guide (-logq) + costs = contracted_costs + uncontracted_factors # model costs: logp + costs += [-f for f in guide_terms["log_factors"]] # guide costs: -logq + + # compute log_measures corresponding to each cost term + # the goal is to achieve fine-grained Rao-Blackwellization + log_measures = dict() + for cost in costs: + if cost.input_vars not in log_measures: + log_probs = [ + f + for name, f in guide_terms["log_measures"].items() + if name in cost.inputs + ] + log_prob = funsor.sum_product.sum_product( + funsor.ops.logaddexp, + funsor.ops.add, + log_probs, + plates=plate_vars, + eliminate=(plate_vars | guide_terms["measure_vars"]) + - frozenset(cost.inputs), + ) + log_measures[cost.input_vars] = funsor.optimizer.apply_optimizer( + log_prob + ) + + with funsor.terms.lazy: + # finally, integrate out guide variables in the elbo and all plates + elbo = to_funsor(0, output=funsor.Real) + for cost in costs: + log_measure = log_measures[cost.input_vars] + measure_vars = (frozenset(cost.inputs) - plate_vars) - particle_var + elbo_term = funsor.Integrate( + log_measure, + cost, + measure_vars, + ) + elbo += elbo_term.reduce( + funsor.ops.add, plate_vars & frozenset(cost.inputs) + ) + # average over Monte-Carlo particles + elbo = elbo.reduce(funsor.ops.mean, particle_var) + + return -to_data(funsor.optimizer.apply_optimizer(elbo)) class JitTrace_ELBO(Jit_ELBO, Trace_ELBO): diff --git a/pyro/contrib/funsor/infer/traceenum_elbo.py b/pyro/contrib/funsor/infer/traceenum_elbo.py index 4655bc09ac..691c03c9eb 100644 --- a/pyro/contrib/funsor/infer/traceenum_elbo.py +++ b/pyro/contrib/funsor/infer/traceenum_elbo.py @@ -31,7 +31,7 @@ def terms_from_trace(tr): # of free variables as either product (plate) variables or sum (measure) variables terms = { "log_factors": [], - "log_measures": [], + "log_measures": {}, "scale": to_funsor(1.0), "plate_vars": frozenset(), "measure_vars": frozenset(), @@ -62,7 +62,7 @@ def terms_from_trace(tr): ) # grab the log-measure, found only at sites that are not replayed or observed if node["funsor"].get("log_measure", None) is not None: - terms["log_measures"].append(node["funsor"]["log_measure"]) + terms["log_measures"][name] = node["funsor"]["log_measure"] # sum (measure) variables: the fresh non-plate variables at a site terms["measure_vars"] |= ( frozenset(node["funsor"]["value"].inputs) | {name} @@ -132,7 +132,7 @@ def differentiable_loss(self, model, guide, *args, **kwargs): for f in funsor.sum_product.dynamic_partial_sum_product( funsor.ops.logaddexp, funsor.ops.add, - model_terms["log_measures"] + contracted_factors, + list(model_terms["log_measures"].values()) + contracted_factors, plate_to_step=model_terms["plate_to_step"], eliminate=model_terms["measure_vars"] | markov_dims, ) @@ -149,7 +149,7 @@ def differentiable_loss(self, model, guide, *args, **kwargs): log_prob = funsor.sum_product.sum_product( funsor.ops.logaddexp, funsor.ops.add, - guide_terms["log_measures"], + list(guide_terms["log_measures"].values()), plates=plate_vars, eliminate=(plate_vars | guide_terms["measure_vars"]) - frozenset(cost.inputs), @@ -198,7 +198,7 @@ def differentiable_loss(self, model, guide, *args, **kwargs): contracted_costs = [] # incorporate the effects of subsampling and handlers.scale through a common scale factor for group_factors, group_vars in _partition( - model_terms["log_measures"] + contracted_factors, + list(model_terms["log_measures"].values()) + contracted_factors, model_terms["measure_vars"], ): group_factor_vars = frozenset().union( @@ -244,7 +244,7 @@ def differentiable_loss(self, model, guide, *args, **kwargs): logzq = funsor.sum_product.sum_product( funsor.ops.logaddexp, funsor.ops.add, - guide_terms["log_measures"] + list(targets.values()), + list(guide_terms["log_measures"].values()) + list(targets.values()), plates=plate_vars, eliminate=(plate_vars | guide_terms["measure_vars"]), ) diff --git a/pyro/contrib/funsor/infer/tracetmc_elbo.py b/pyro/contrib/funsor/infer/tracetmc_elbo.py index 7cf4ba805e..e57b4db76a 100644 --- a/pyro/contrib/funsor/infer/tracetmc_elbo.py +++ b/pyro/contrib/funsor/infer/tracetmc_elbo.py @@ -29,7 +29,9 @@ def differentiable_loss(self, model, guide, *args, **kwargs): model_terms = terms_from_trace(model_tr) guide_terms = terms_from_trace(guide_tr) - log_measures = guide_terms["log_measures"] + model_terms["log_measures"] + log_measures = list(guide_terms["log_measures"].values()) + list( + model_terms["log_measures"].values() + ) log_factors = model_terms["log_factors"] + [ -f for f in guide_terms["log_factors"] ] diff --git a/setup.py b/setup.py index 7b44fa7522..c99e4d038b 100644 --- a/setup.py +++ b/setup.py @@ -139,8 +139,8 @@ "horovod": ["horovod[pytorch]>=0.19"], "funsor": [ # This must be a released version when Pyro is released. - # "funsor[torch] @ git+git://github.com/pyro-ppl/funsor.git@7bb52d0eae3046d08a20d1b288544e1a21b4f461", - "funsor[torch]==0.4.3", + "funsor[torch] @ git+https://github.com/pyro-ppl/funsor.git@sampled-funsor", + # "funsor[torch]==0.4.3", ], }, python_requires=">=3.7", diff --git a/tests/contrib/funsor/test_gradient.py b/tests/contrib/funsor/test_gradient.py new file mode 100644 index 0000000000..ac441f1469 --- /dev/null +++ b/tests/contrib/funsor/test_gradient.py @@ -0,0 +1,727 @@ +# Copyright (c) 2017-2019 Uber Technologies, Inc. +# SPDX-License-Identifier: Apache-2.0 + +import logging + +import pyroapi +import pytest +import torch + +from pyro.ops.indexing import Vindex +from tests.common import assert_equal + +# put all funsor-related imports here, so test collection works without funsor +try: + import funsor + + import pyro.contrib.funsor + + funsor.set_backend("torch") + from pyroapi import distributions as dist + from pyroapi import handlers, infer, pyro +except ImportError: + pytestmark = pytest.mark.skip(reason="funsor is not installed") + +logger = logging.getLogger(__name__) + + +def model_0(data): + with pyro.plate("data", len(data)): + z = pyro.sample("z", dist.Categorical(torch.tensor([0.3, 0.7]))) + pyro.sample("x", dist.Normal(z.to(data.dtype), 1), obs=data) + + +def guide_0(data): + with pyro.plate("data", len(data)): + probs = pyro.param("probs", lambda: torch.tensor([[0.4, 0.6], [0.5, 0.5]])) + pyro.sample("z", dist.Categorical(probs)) + + +def model_1(data): + a = pyro.sample("a", dist.Categorical(torch.tensor([0.3, 0.7]))) + with pyro.plate("data", len(data)): + probs_b = torch.tensor([[0.1, 0.9], [0.2, 0.8]]) + b = pyro.sample("b", dist.Categorical(probs_b[a.long()])) + pyro.sample("c", dist.Normal(b.to(data.dtype), 1), obs=data) + + +def guide_1(data): + probs_a = pyro.param( + "probs_a", + lambda: torch.tensor([0.5, 0.5]), + ) + a = pyro.sample("a", dist.Categorical(probs_a)) + with pyro.plate("data", len(data)) as idx: + probs_b = pyro.param( + "probs_b", + lambda: torch.tensor( + [[[0.5, 0.5], [0.6, 0.4]], [[0.4, 0.6], [0.35, 0.65]]] + ), + ) + pyro.sample("b", dist.Categorical(Vindex(probs_b)[a.long(), idx])) + + +def model_2(data): + prob_b = torch.tensor([[0.3, 0.7], [0.4, 0.6]]) + prob_c = torch.tensor([[0.5, 0.5], [0.6, 0.4]]) + prob_d = torch.tensor([[0.2, 0.8], [0.3, 0.7]]) + prob_e = torch.tensor([[0.5, 0.5], [0.1, 0.9]]) + a = pyro.sample("a", dist.Categorical(torch.tensor([0.3, 0.7]))) + with pyro.plate("data", len(data)): + b = pyro.sample("b", dist.Categorical(prob_b[a.long()])) + c = pyro.sample("c", dist.Categorical(prob_c[b.long()])) + pyro.sample("d", dist.Categorical(prob_d[b.long()])) + pyro.sample("e", dist.Categorical(prob_e[c.long()]), obs=data) + + +def guide_2(data): + prob_a = pyro.param("prob_a", lambda: torch.tensor([0.5, 0.5])) + prob_b = pyro.param("prob_b", lambda: torch.tensor([[0.4, 0.6], [0.3, 0.7]])) + prob_c = pyro.param( + "prob_c", + lambda: torch.tensor([[[0.3, 0.7], [0.8, 0.2]], [[0.2, 0.8], [0.5, 0.5]]]), + ) + prob_d = pyro.param( + "prob_d", + lambda: torch.tensor([[[0.2, 0.8], [0.9, 0.1]], [[0.1, 0.9], [0.4, 0.6]]]), + ) + a = pyro.sample("a", dist.Categorical(prob_a)) + with pyro.plate("data", len(data)) as idx: + b = pyro.sample("b", dist.Categorical(prob_b[a.long()])) + pyro.sample("c", dist.Categorical(Vindex(prob_c)[b.long(), idx])) + pyro.sample("d", dist.Categorical(Vindex(prob_d)[b.long(), idx])) + + +@pytest.mark.parametrize( + "model,guide,data", + [ + (model_0, guide_0, torch.tensor([-0.5, 2.0])), + (model_1, guide_1, torch.tensor([-0.5, 2.0])), + (model_2, guide_2, torch.tensor([0.0, 1.0])), + ], +) +def test_gradient(model, guide, data): + + # Expected grads based on exact integration + with pyroapi.pyro_backend("pyro"): + pyro.clear_param_store() + elbo = infer.TraceEnum_ELBO( + max_plate_nesting=1, # set this to ensure rng agrees across runs + strict_enumeration_warning=False, + ) + elbo.loss_and_grads(model, infer.config_enumerate(guide), data) + params = dict(pyro.get_param_store().named_parameters()) + expected_grads = { + name: param.grad.detach().cpu() for name, param in params.items() + } + + # Actual grads averaged over num_particles + with pyroapi.pyro_backend("contrib.funsor"): + pyro.clear_param_store() + elbo = infer.Trace_ELBO( + max_plate_nesting=1, # set this to ensure rng agrees across runs + num_particles=50000, + vectorize_particles=True, + strict_enumeration_warning=False, + ) + elbo.loss_and_grads(model, guide, data) + params = dict(pyro.get_param_store().named_parameters()) + actual_grads = { + name: param.grad.detach().cpu() for name, param in params.items() + } + + for name in sorted(params): + logger.info("expected {} = {}".format(name, expected_grads[name])) + logger.info("actual {} = {}".format(name, actual_grads[name])) + + assert_equal(actual_grads, expected_grads, prec=0.02) + + +@pyroapi.pyro_backend("contrib.funsor") +def test_particle_gradient_0(): + # model + # +---------+ + # | z --> x | + # +---------+ + # + # guide + # +---+ + # | z | + # +---+ + data = torch.tensor([-0.5, 2.0]) + + def model(): + with pyro.plate("data", len(data)): + z = pyro.sample("z", dist.Poisson(3)) + pyro.sample("x", dist.Normal(z, 1), obs=data) + + def guide(): + # set this to ensure rng agrees across runs + # this should be ok since we are comparing a single particle gradients + pyro.set_rng_seed(0) + with pyro.plate("data", len(data)): + rate = pyro.param("rate", lambda: torch.tensor([3.5, 1.5])) + pyro.sample("z", dist.Poisson(rate)) + + elbo = infer.Trace_ELBO( + max_plate_nesting=1, # set this to ensure rng agrees across runs + num_particles=1, + strict_enumeration_warning=False, + ) + + # Trace_ELBO gradients + pyro.clear_param_store() + elbo.loss_and_grads(model, guide) + params = dict(pyro.get_param_store().named_parameters()) + actual_grads = {name: param.grad.detach().cpu() for name, param in params.items()} + + # Hand derived gradients + # elbo = MonteCarlo( + # [q(z_i) * log_pz_i].sum(i) + # + [q(z_i) * log_px_i].sum(i) + # - [q(z_i) * log_qx_i].sum(i) + # ) + pyro.clear_param_store() + guide_tr = handlers.trace(guide).get_trace() + model_tr = handlers.trace(handlers.replay(model, guide_tr)).get_trace() + guide_tr.compute_log_prob() + model_tr.compute_log_prob() + # log factors + logpx = model_tr.nodes["x"]["log_prob"] + logpz = model_tr.nodes["z"]["log_prob"] + logqz = guide_tr.nodes["z"]["log_prob"] + # dice factor + df_z = (logqz - logqz.detach()).exp() + # dice elbo + dice_elbo = (df_z * (logpz + logpx - logqz)).sum() + # backward run + loss = -dice_elbo + loss.backward() + params = dict(pyro.get_param_store().named_parameters()) + expected_grads = {name: param.grad.detach().cpu() for name, param in params.items()} + + for name in sorted(params): + logger.info("expected {} = {}".format(name, expected_grads[name])) + logger.info("actual {} = {}".format(name, actual_grads[name])) + + assert_equal(actual_grads, expected_grads, prec=1e-4) + + +@pyroapi.pyro_backend("contrib.funsor") +def test_particle_gradient_1(): + # model + # +-----------+ + # a -|-> b --> c | + # +-----------+ + # + # guide + # +-----+ + # a -|-> b | + # +-----+ + data = torch.tensor([-0.5, 2.0]) + + def model(): + a = pyro.sample("a", dist.Bernoulli(0.3)) + with pyro.plate("data", len(data)): + rate = torch.tensor([2.0, 3.0]) + b = pyro.sample("b", dist.Poisson(rate[a.long()])) + pyro.sample("c", dist.Normal(b, 1), obs=data) + + def guide(): + # set this to ensure rng agrees across runs + # this should be ok since we are comparing a single particle gradients + pyro.set_rng_seed(0) + prob = pyro.param( + "prob", + lambda: torch.tensor(0.5), + ) + a = pyro.sample("a", dist.Bernoulli(prob)) + with pyro.plate("data", len(data)): + rate = pyro.param("rate", lambda: torch.tensor([[3.5, 1.5], [0.5, 2.5]])) + pyro.sample("b", dist.Poisson(rate[a.long()])) + + elbo = infer.Trace_ELBO( + max_plate_nesting=1, # set this to ensure rng agrees across runs + num_particles=1, + strict_enumeration_warning=False, + ) + + # Trace_ELBO gradients + pyro.clear_param_store() + elbo.loss_and_grads(model, guide) + params = dict(pyro.get_param_store().named_parameters()) + actual_grads = {name: param.grad.detach().cpu() for name, param in params.items()} + + # Hand derived gradients + # elbo = MonteCarlo( + # q(a) * log_pa + # + q(a) * [q(b_i|a) * log_pb_i].sum(i) + # + q(a) * [q(b_i|a) * log_pc_i].sum(i) + # - q(a) * log_qa + # - q(a) * [q(b_i|a) * log_qb_i].sum(i) + # ) + pyro.clear_param_store() + guide_tr = handlers.trace(guide).get_trace() + model_tr = handlers.trace(handlers.replay(model, guide_tr)).get_trace() + guide_tr.compute_log_prob() + model_tr.compute_log_prob() + # log factors + logpa = model_tr.nodes["a"]["log_prob"] + logpb = model_tr.nodes["b"]["log_prob"] + logpc = model_tr.nodes["c"]["log_prob"] + logqa = guide_tr.nodes["a"]["log_prob"] + logqb = guide_tr.nodes["b"]["log_prob"] + # dice factors + df_a = (logqa - logqa.detach()).exp() + df_b = (logqb - logqb.detach()).exp() + # dice elbo + dice_elbo = ( + df_a * logpa + + df_a * (df_b * logpb).sum() + + df_a * (df_b * logpc).sum() + - df_a * logqa + - df_a * (df_b * logqb).sum() + ) + # backward run + loss = -dice_elbo + loss.backward() + params = dict(pyro.get_param_store().named_parameters()) + expected_grads = {name: param.grad.detach().cpu() for name, param in params.items()} + + for name in sorted(params): + logger.info("expected {} = {}".format(name, expected_grads[name])) + logger.info("actual {} = {}".format(name, actual_grads[name])) + + assert_equal(actual_grads, expected_grads, prec=1e-4) + + +@pyroapi.pyro_backend("contrib.funsor") +def test_particle_gradient_2(): + # model + # +-----------------+ + # a -|-> b --> c --> e | + # | \--> d | + # +-----------------+ + # + # guide + # +-----------+ + # a -|-> b --> c | + # | \--> d | + # +-----------+ + data = torch.tensor([0.0, 1.0]) + + def model(): + prob_b = torch.tensor([0.3, 0.4]) + prob_c = torch.tensor([0.5, 0.6]) + prob_d = torch.tensor([0.2, 0.3]) + prob_e = torch.tensor([0.5, 0.1]) + a = pyro.sample("a", dist.Bernoulli(0.3)) + with pyro.plate("data", len(data)): + b = pyro.sample("b", dist.Bernoulli(prob_b[a.long()])) + c = pyro.sample("c", dist.Bernoulli(prob_c[b.long()])) + pyro.sample("d", dist.Bernoulli(prob_d[b.long()])) + pyro.sample("e", dist.Bernoulli(prob_e[c.long()]), obs=data) + + def guide(): + # set this to ensure rng agrees across runs + # this should be ok since we are comparing a single particle gradients + pyro.set_rng_seed(0) + prob_a = pyro.param("prob_a", lambda: torch.tensor(0.5)) + prob_b = pyro.param("prob_b", lambda: torch.tensor([0.4, 0.3])) + prob_c = pyro.param("prob_c", lambda: torch.tensor([[0.3, 0.8], [0.2, 0.5]])) + prob_d = pyro.param("prob_d", lambda: torch.tensor([[0.2, 0.9], [0.1, 0.4]])) + a = pyro.sample("a", dist.Bernoulli(prob_a)) + with pyro.plate("data", len(data)) as idx: + b = pyro.sample("b", dist.Bernoulli(prob_b[a.long()])) + pyro.sample("c", dist.Bernoulli(Vindex(prob_c)[b.long(), idx])) + pyro.sample("d", dist.Bernoulli(Vindex(prob_d)[b.long(), idx])) + + elbo = infer.Trace_ELBO( + max_plate_nesting=1, # set this to ensure rng agrees across runs + num_particles=1, + strict_enumeration_warning=False, + ) + + # Trace_ELBO gradients + pyro.clear_param_store() + elbo.loss_and_grads(model, guide) + params = dict(pyro.get_param_store().named_parameters()) + actual_grads = {name: param.grad.detach().cpu() for name, param in params.items()} + + # Hand derived gradients + # elbo = MonteCarlo( + # q(a) * log_pa + # + q(a) * [q(b_i|a) * log_pb_i].sum(i) + # + q(a) * [q(b_i|a) * q(c_i|b_i) * log_pc_i].sum(i) + # + q(a) * [q(b_i|a) * q(c_i|b_i) * log_pe_i].sum(i) + # + q(a) * [q(b_i|a) * q(d_i|b_i) * log_pd_i].sum(i) + # - q(a) * log_qa + # - q(a) * [q(b_i|a) * log_qb_i].sum(i) + # - q(a) * [q(b_i|a) * q(c_i|b_i) * log_qc_i].sum(i) + # - q(a) * [q(b_i|a) * q(d_i|b_i) * log_qd_i].sum(i) + # ) + pyro.clear_param_store() + guide_tr = handlers.trace(guide).get_trace() + model_tr = handlers.trace(handlers.replay(model, guide_tr)).get_trace() + guide_tr.compute_log_prob() + model_tr.compute_log_prob() + # log factors + logpa = model_tr.nodes["a"]["log_prob"] + logpb = model_tr.nodes["b"]["log_prob"] + logpc = model_tr.nodes["c"]["log_prob"] + logpd = model_tr.nodes["d"]["log_prob"] + logpe = model_tr.nodes["e"]["log_prob"] + + logqa = guide_tr.nodes["a"]["log_prob"] + logqb = guide_tr.nodes["b"]["log_prob"] + logqc = guide_tr.nodes["c"]["log_prob"] + logqd = guide_tr.nodes["d"]["log_prob"] + # dice factors + df_a = (logqa - logqa.detach()).exp() + df_b = (logqb - logqb.detach()).exp() + df_c = (logqc - logqc.detach()).exp() + df_d = (logqd - logqd.detach()).exp() + # dice elbo + dice_elbo = ( + df_a * logpa + + df_a * (df_b * logpb).sum() + + df_a * (df_b * df_c * logpc).sum() + + df_a * (df_b * df_c * logpe).sum() + + df_a * (df_b * df_d * logpd).sum() + - df_a * logqa + - df_a * (df_b * logqb).sum() + - df_a * (df_b * df_c * logqc).sum() + - df_a * (df_b * df_d * logqd).sum() + ) + # backward run + loss = -dice_elbo + loss.backward() + params = dict(pyro.get_param_store().named_parameters()) + expected_grads = {name: param.grad.detach().cpu() for name, param in params.items()} + + for name in sorted(params): + logger.info("expected {} = {}".format(name, expected_grads[name])) + logger.info("actual {} = {}".format(name, actual_grads[name])) + + assert_equal(actual_grads, expected_grads, prec=1e-4) + + +@pyroapi.pyro_backend("contrib.funsor") +def test_particle_gradient_3(): + # model + # +-----------------+ + # a -|-> b --> c --> d | + # +-----------------+ + # + # guide (b is enumerated) + # +-----------+ + # a -|-> b --> c | + # +-----------+ + data = torch.tensor([0.0, 1.0]) + + def model(): + prob_b = torch.tensor([[0.3, 0.7], [0.4, 0.6]]) + prob_c = torch.tensor([0.5, 0.6]) + prob_d = torch.tensor([0.5, 0.1]) + a = pyro.sample("a", dist.Bernoulli(0.3)) + with pyro.plate("data", len(data)): + b = pyro.sample("b", dist.Categorical(prob_b[a.long()])) + c = pyro.sample("c", dist.Bernoulli(prob_c[b.long()])) + pyro.sample("d", dist.Bernoulli(prob_d[c.long()]), obs=data) + + def guide(): + # set this to ensure rng agrees across runs + # this should be ok since we are comparing a single particle gradients + pyro.set_rng_seed(0) + prob_a = pyro.param("prob_a", lambda: torch.tensor(0.5)) + prob_b = pyro.param("prob_b", lambda: torch.tensor([[0.4, 0.6], [0.3, 0.7]])) + prob_c = pyro.param("prob_c", lambda: torch.tensor([[0.3, 0.8], [0.2, 0.5]])) + a = pyro.sample("a", dist.Bernoulli(prob_a)) + with pyro.plate("data", len(data)) as idx: + b = pyro.sample( + "b", dist.Categorical(prob_b[a.long()]), infer={"enumerate": "parallel"} + ) + pyro.sample("c", dist.Bernoulli(Vindex(prob_c)[b.long(), idx])) + + elbo = infer.Trace_ELBO( + max_plate_nesting=1, # set this to ensure rng agrees across runs + num_particles=1, + strict_enumeration_warning=False, + ) + + # Trace_ELBO gradients + pyro.clear_param_store() + elbo.loss_and_grads(model, guide) + params = dict(pyro.get_param_store().named_parameters()) + actual_grads = {name: param.grad.detach().cpu() for name, param in params.items()} + + # Hand derived gradients (b is exactly integrated) + # elbo = MonteCarlo( + # q(a) * log_pa + # + q(a) * [q(b_i|a) * log_pb_i].sum(i, b) + # + q(a) * [q(b_i|a) * q(c_i|b_i) * log_pc_i].sum(i, b) + # + q(a) * [q(b_i|a) * q(c_i|b_i) * log_pd_i].sum(i, b) + # - q(a) * log_qa + # - q(a) * [q(b_i|a) * log_qb_i].sum(i) + # - q(a) * [q(b_i|a) * q(c_i|b_i) * log_qc_i].sum(i, b) + # ) + pyro.clear_param_store() + with handlers.enum(first_available_dim=(-2)), handlers.provenance(): + guide_tr = handlers.trace(guide).get_trace() + model_tr = handlers.trace(handlers.replay(model, guide_tr)).get_trace() + guide_tr.compute_log_prob() + model_tr.compute_log_prob() + # log factors + logpa = model_tr.nodes["a"]["log_prob"] + logpb = model_tr.nodes["b"]["log_prob"] + logpc = model_tr.nodes["c"]["log_prob"] + logpd = model_tr.nodes["d"]["log_prob"] + + logqa = guide_tr.nodes["a"]["log_prob"] + logqb = guide_tr.nodes["b"]["log_prob"] + logqc = guide_tr.nodes["c"]["log_prob"] + # dice factors + df_a = (logqa - logqa.detach()).exp() + qb = logqb.exp() + df_c = (logqc - logqc.detach()).exp() + # dice elbo + dice_elbo = ( + df_a * logpa + + df_a * (qb * logpb).sum() + + df_a * (qb * df_c * logpc).sum() + + df_a * (qb * df_c * logpd).sum() + - df_a * logqa + - df_a * (qb * logqb).sum() + - df_a * (qb * df_c * logqc).sum() + ) + # backward run + loss = -dice_elbo + loss.backward() + params = dict(pyro.get_param_store().named_parameters()) + expected_grads = {name: param.grad.detach().cpu() for name, param in params.items()} + + for name in sorted(params): + logger.info("expected {} = {}".format(name, expected_grads[name])) + logger.info("actual {} = {}".format(name, actual_grads[name])) + + assert_equal(actual_grads, expected_grads, prec=1e-4) + + +@pyroapi.pyro_backend("contrib.funsor") +def test_particle_gradient_4(): + # model + # +-----------------+ + # a -|-> b --> c --> d | + # +-----------------+ + # + # guide (c is enumerated) + # +-----------+ + # a -|-> b --> c | + # +-----------+ + data = torch.tensor([0.0, 1.0]) + + def model(): + prob_b = torch.tensor([[0.3, 0.7], [0.4, 0.6]]) + prob_c = torch.tensor([[0.5, 0.5], [0.6, 0.4]]) + prob_d = torch.tensor([0.5, 0.1]) + a = pyro.sample("a", dist.Bernoulli(0.3)) + with pyro.plate("data", len(data)): + b = pyro.sample("b", dist.Categorical(prob_b[a.long()])) + c = pyro.sample("c", dist.Categorical(prob_c[b.long()])) + pyro.sample("d", dist.Bernoulli(prob_d[c.long()]), obs=data) + + def guide(): + # set this to ensure rng agrees across runs + # this should be ok since we are comparing a single particle gradients + pyro.set_rng_seed(0) + prob_a = pyro.param("prob_a", lambda: torch.tensor(0.5)) + prob_b = pyro.param("prob_b", lambda: torch.tensor([[0.4, 0.6], [0.3, 0.7]])) + prob_c = pyro.param( + "prob_c", + lambda: torch.tensor([[[0.3, 0.7], [0.8, 0.2]], [[0.2, 0.8], [0.5, 0.5]]]), + ) + a = pyro.sample("a", dist.Bernoulli(prob_a)) + with pyro.plate("data", len(data)) as idx: + b = pyro.sample("b", dist.Categorical(prob_b[a.long()])) + pyro.sample( + "c", + dist.Categorical(Vindex(prob_c)[b.long(), idx]), + infer={"enumerate": "parallel"}, + ) + + elbo = infer.Trace_ELBO( + max_plate_nesting=1, # set this to ensure rng agrees across runs + num_particles=1, + strict_enumeration_warning=False, + ) + + # Trace_ELBO gradients + pyro.clear_param_store() + elbo.loss_and_grads(model, guide) + params = dict(pyro.get_param_store().named_parameters()) + actual_grads = {name: param.grad.detach().cpu() for name, param in params.items()} + + # Hand derived gradients (c is exactly integrated) + # elbo = MonteCarlo( + # q(a) * log_pa + # + q(a) * [q(b_i|a) * log_pb_i].sum(i) + # + q(a) * [q(b_i|a) * q(c_i|b_i) * log_pc_i].sum(i, c) + # + q(a) * [q(b_i|a) * q(c_i|b_i) * log_pd_i].sum(i, c) + # - q(a) * log_qa + # - q(a) * [q(b_i|a) * log_qb_i].sum(i) + # - q(a) * [q(b_i|a) * q(c_i|b_i) * log_qc_i].sum(i, c) + # ) + pyro.clear_param_store() + with handlers.enum(first_available_dim=(-2)), handlers.provenance(): + guide_tr = handlers.trace(guide).get_trace() + model_tr = handlers.trace(handlers.replay(model, guide_tr)).get_trace() + guide_tr.compute_log_prob() + model_tr.compute_log_prob() + # log factors + logpa = model_tr.nodes["a"]["log_prob"] + logpb = model_tr.nodes["b"]["log_prob"] + logpc = model_tr.nodes["c"]["log_prob"] + logpd = model_tr.nodes["d"]["log_prob"] + + logqa = guide_tr.nodes["a"]["log_prob"] + logqb = guide_tr.nodes["b"]["log_prob"] + logqc = guide_tr.nodes["c"]["log_prob"] + # dice factors + df_a = (logqa - logqa.detach()).exp() + df_b = (logqb - logqb.detach()).exp() + qc = logqc.exp() + # dice elbo + dice_elbo = ( + df_a * logpa + + df_a * (df_b * logpb).sum() + + df_a * (df_b * qc * logpc).sum() + + df_a * (df_b * qc * logpd).sum() + - df_a * logqa + - df_a * (df_b * logqb).sum() + - df_a * (df_b * qc * logqc).sum() + ) + # backward run + loss = -dice_elbo + loss.backward() + params = dict(pyro.get_param_store().named_parameters()) + expected_grads = {name: param.grad.detach().cpu() for name, param in params.items()} + + for name in sorted(params): + logger.info("expected {} = {}".format(name, expected_grads[name])) + logger.info("actual {} = {}".format(name, actual_grads[name])) + + assert_equal(actual_grads, expected_grads, prec=1e-4) + + +@pyroapi.pyro_backend("contrib.funsor") +def test_particle_gradient_5(): + # model + # +-----------------+ + # a -|-> b --> c --> e | + # | \--> d | + # +-----------------+ + # + # guide (b is enumerated) + # +-----------+ + # a -|-> b --> c | + # | \--> d | + # +-----------+ + data = torch.tensor([0.0, 1.0]) + + def model(): + prob_b = torch.tensor([[0.3, 0.7], [0.4, 0.6]]) + prob_c = torch.tensor([0.5, 0.6]) + prob_d = torch.tensor([0.2, 0.3]) + prob_e = torch.tensor([0.5, 0.1]) + a = pyro.sample("a", dist.Bernoulli(0.3)) + with pyro.plate("data", len(data)): + b = pyro.sample("b", dist.Categorical(prob_b[a.long()])) + c = pyro.sample("c", dist.Bernoulli(prob_c[b.long()])) + pyro.sample("d", dist.Bernoulli(prob_d[b.long()])) + pyro.sample("e", dist.Bernoulli(prob_e[c.long()]), obs=data) + + def guide(): + # set this to ensure rng agrees across runs + # this should be ok since we are comparing a single particle gradients + pyro.set_rng_seed(0) + prob_a = pyro.param("prob_a", lambda: torch.tensor(0.5)) + prob_b = pyro.param("prob_b", lambda: torch.tensor([[0.4, 0.6], [0.3, 0.7]])) + prob_c = pyro.param("prob_c", lambda: torch.tensor([[0.3, 0.8], [0.2, 0.5]])) + prob_d = pyro.param("prob_d", lambda: torch.tensor([[0.2, 0.9], [0.1, 0.4]])) + a = pyro.sample("a", dist.Bernoulli(prob_a)) + with pyro.plate("data", len(data)) as idx: + b = pyro.sample( + "b", dist.Categorical(prob_b[a.long()]), infer={"enumerate": "parallel"} + ) + pyro.sample("c", dist.Bernoulli(Vindex(prob_c)[b.long(), idx])) + pyro.sample("d", dist.Bernoulli(Vindex(prob_d)[b.long(), idx])) + + elbo = infer.Trace_ELBO( + max_plate_nesting=1, # set this to ensure rng agrees across runs + num_particles=1, + strict_enumeration_warning=False, + ) + + # Trace_ELBO gradients + pyro.clear_param_store() + elbo.loss_and_grads(model, guide) + params = dict(pyro.get_param_store().named_parameters()) + actual_grads = {name: param.grad.detach().cpu() for name, param in params.items()} + + # Hand derived gradients (b exactly integrated out) + # elbo = MonteCarlo( + # q(a) * log_pa + # + q(a) * [q(b_i|a) * log_pb_i].sum(i, b) + # + q(a) * [q(b_i|a) * q(c_i|b_i) * log_pc_i].sum(i, b) + # + q(a) * [q(b_i|a) * q(c_i|b_i) * log_pe_i].sum(i, b) + # + q(a) * [q(b_i|a) * q(d_i|b_i) * log_pd_i].sum(i, b) + # - q(a) * log_qa + # - q(a) * [q(b_i|a) * log_qb_i].sum(i, b) + # - q(a) * [q(b_i|a) * q(c_i|b_i) * log_qc_i].sum(i, b) + # - q(a) * [q(b_i|a) * q(d_i|b_i) * log_qd_i].sum(i, b) + # ) + pyro.clear_param_store() + with handlers.enum(first_available_dim=(-2)), handlers.provenance(): + guide_tr = handlers.trace(guide).get_trace() + model_tr = handlers.trace(handlers.replay(model, guide_tr)).get_trace() + guide_tr.compute_log_prob() + model_tr.compute_log_prob() + # log factors + logpa = model_tr.nodes["a"]["log_prob"] + logpb = model_tr.nodes["b"]["log_prob"] + logpc = model_tr.nodes["c"]["log_prob"] + logpd = model_tr.nodes["d"]["log_prob"] + logpe = model_tr.nodes["e"]["log_prob"] + + logqa = guide_tr.nodes["a"]["log_prob"] + logqb = guide_tr.nodes["b"]["log_prob"] + logqc = guide_tr.nodes["c"]["log_prob"] + logqd = guide_tr.nodes["d"]["log_prob"] + # dice factors + df_a = (logqa - logqa.detach()).exp() + qb = logqb.exp() + df_c = (logqc - logqc.detach()).exp() + df_d = (logqd - logqd.detach()).exp() + # dice elbo + dice_elbo = ( + df_a * logpa + + df_a * (qb * logpb).sum() + + df_a * (qb * df_c * logpc).sum() + + df_a * (qb * df_c * logpe).sum() + + df_a * (qb * df_d * logpd).sum() + - df_a * logqa + - df_a * (qb * logqb).sum() + - df_a * (qb * df_c * logqc).sum() + - df_a * (qb * df_d * logqd).sum() + ) + # backward run + loss = -dice_elbo + loss.backward() + params = dict(pyro.get_param_store().named_parameters()) + expected_grads = {name: param.grad.detach().cpu() for name, param in params.items()} + + for name in sorted(params): + logger.info("expected {} = {}".format(name, expected_grads[name])) + logger.info("actual {} = {}".format(name, actual_grads[name])) + + assert_equal(actual_grads, expected_grads, prec=1e-4)