diff --git a/Makefile b/Makefile index f98e2cae72..713070cb29 100644 --- a/Makefile +++ b/Makefile @@ -19,7 +19,7 @@ tutorial: FORCE lint: FORCE flake8 - black --check *.py pyro examples tests scripts profiler + black --check . isort --check . python scripts/update_headers.py --check mypy --install-types --non-interactive pyro scripts @@ -28,7 +28,7 @@ license: FORCE python scripts/update_headers.py format: license FORCE - black *.py pyro examples tests scripts profiler + black . isort . version: FORCE diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000000..60398beec7 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,10 @@ +[tool.black] +include = ''' +( + pyro/.*\.py + | examples/.*\.py + | tests/.*\.py + | scripts/.*\.py + | profiler/.*\.py +) +''' diff --git a/pyro/infer/renyi_elbo.py b/pyro/infer/renyi_elbo.py index 349f7c43d4..0c44bae929 100644 --- a/pyro/infer/renyi_elbo.py +++ b/pyro/infer/renyi_elbo.py @@ -9,7 +9,7 @@ from pyro.distributions.util import is_identically_zero from pyro.infer.elbo import ELBO from pyro.infer.enum import get_importance_trace -from pyro.infer.util import get_dependent_plate_dims, is_validation_enabled, torch_sum +from pyro.infer.util import get_nonparticle_plate_dims, is_validation_enabled, torch_sum from pyro.util import check_if_enumerated, warn_if_nan @@ -104,7 +104,7 @@ def loss(self, model, guide, *args, **kwargs): # grab a vectorized trace from the generator for model_trace, guide_trace in self._get_traces(model, guide, args, kwargs): elbo_particle = 0.0 - sum_dims = get_dependent_plate_dims(model_trace.nodes.values()) + sum_dims = get_nonparticle_plate_dims(model_trace.nodes.values()) # compute elbo for name, site in model_trace.nodes.items(): @@ -152,7 +152,7 @@ def loss_and_grads(self, model, guide, *args, **kwargs): for model_trace, guide_trace in self._get_traces(model, guide, args, kwargs): elbo_particle = 0 surrogate_elbo_particle = 0 - sum_dims = get_dependent_plate_dims(model_trace.nodes.values()) + sum_dims = get_nonparticle_plate_dims(model_trace.nodes.values()) # compute elbo and surrogate elbo for name, site in model_trace.nodes.items(): diff --git a/pyro/infer/util.py b/pyro/infer/util.py index 3ee94b884d..670d31495b 100644 --- a/pyro/infer/util.py +++ b/pyro/infer/util.py @@ -99,17 +99,21 @@ def get_plate_stacks(trace): } -def get_dependent_plate_dims(sites): +def get_nonparticle_plate_dims(sites): """ - Return a list of unique dims for plates that are not common to all sites. + Return a list of unique dims of all plates except vectorized particles """ plate_sets = [ site["cond_indep_stack"] for site in sites if site["type"] == "sample" ] all_plates = set().union(*plate_sets) - common_plates = all_plates.intersection(*plate_sets) - sum_plates = all_plates - common_plates - sum_dims = sorted({f.dim for f in sum_plates if f.dim is not None}) + sum_dims = sorted( + { + f.dim + for f in all_plates + if f.dim is not None and f.name != "num_particles_vectorized" + } + ) return sum_dims diff --git a/tests/infer/test_inference.py b/tests/infer/test_inference.py index a598ea9c54..7ce2cd72a9 100644 --- a/tests/infer/test_inference.py +++ b/tests/infer/test_inference.py @@ -1005,3 +1005,72 @@ def guide(data, weights): loss = svi.step(data, weights) if step % 20 == 0: logger.info("step {} loss = {:0.4g}".format(step, loss)) + + +@pytest.mark.stage("integration", "integration_batch_2") +class OneWayNormalRandomEffects(TestCase): + def setUp(self) -> None: + self.n_groups = 3 + self.n_experiments = 5 + self.data = torch.tensor( + [ + [4.1, 3.5, 0.2, -3.3, 3.3], + [2.4, -6.5, -0.7, 4.4, -4.8], + [1.1, -0.6, 1.3, -1.3, -1.1], + ] + ) + self.group_locs = torch.tensor([[3.0], [-2.0], [0.0]]) + self.group_prec = torch.tensor([[0.2], [0.1], [0.3]]) + self.obs_prec = torch.tensor(6.0) + obs_prec = self.obs_prec * self.n_experiments + self.post_group_locs = ( + self.data.mean(1, keepdim=True) * obs_prec + + self.group_locs * self.group_prec + ) / (obs_prec + self.group_prec) + + def test_renyi_reparameterized(self): + self.do_elbo_test(True, 10_000, RenyiELBO(num_particles=2)) + + def test_renyi_vectorized(self): + self.do_elbo_test( + True, + 15_000, + RenyiELBO(num_particles=2, vectorize_particles=True, max_plate_nesting=3), + ) + + def test_renyi_nonreparameterized(self): + self.do_elbo_test(False, 15000, RenyiELBO(alpha=0.2, num_particles=2)) + + def do_elbo_test(self, reparameterized, n_steps, loss, debug=False): + def model(): + with pyro.plate("groups", self.n_groups, dim=-2): + group_loc = pyro.sample( + "group_loc", + dist.Normal(self.group_locs, torch.pow(self.group_prec, -0.5)), + ) + with pyro.plate("data", self.n_experiments, dim=-1): + pyro.sample( + "y", + dist.Normal(group_loc, torch.pow(self.obs_prec, -0.5)), + obs=self.data, + ) + + def guide(): + gloc = pyro.param( + "group_loc_param", + self.post_group_locs + torch.tensor([[0.05], [-0.08], [0.14]]), + ) + with pyro.plate("groups", self.n_groups, dim=-2): + Normal = ( + dist.Normal if reparameterized else fakes.NonreparameterizedNormal + ) + pyro.sample("group_loc", Normal(gloc, torch.pow(self.group_prec, -0.5))) + + adam = optim.Adam({"lr": 0.0005, "betas": (0.97, 0.999)}) + svi = SVI(model, guide, adam, loss=loss) + + for k in range(n_steps): + svi.step() + + group_loc_error = param_abs_error("group_loc_param", self.post_group_locs) + assert_equal(0.0, group_loc_error, prec=0.08)