diff --git a/docs/source/funsors.rst b/docs/source/funsors.rst index 2334ed50..cd876512 100644 --- a/docs/source/funsors.rst +++ b/docs/source/funsors.rst @@ -64,3 +64,11 @@ Constant :undoc-members: :show-inheritance: :member-order: bysource + +Provenance +---------- +.. automodule:: funsor.provenance + :members: + :undoc-members: + :show-inheritance: + :member-order: bysource diff --git a/funsor/__init__.py b/funsor/__init__.py index adb0ae63..34a2e22d 100644 --- a/funsor/__init__.py +++ b/funsor/__init__.py @@ -7,6 +7,7 @@ from funsor.integrate import Integrate from funsor.interpreter import interpretation, reinterpret from funsor.op_factory import make_op +from funsor.provenance import Provenance from funsor.sum_product import MarkovProduct from funsor.tensor import Tensor, function from funsor.terms import ( @@ -47,6 +48,7 @@ montecarlo, ops, precondition, + provenance, recipes, sum_product, terms, @@ -71,6 +73,7 @@ "Number", "Real", "Reals", + "Provenance", "Slice", "Stack", "Tensor", @@ -105,6 +108,7 @@ "ops", "precondition", "pretty", + "provenance", "quote", "reals", "recipes", diff --git a/funsor/distribution.py b/funsor/distribution.py index 0e16cf85..cbb90804 100644 --- a/funsor/distribution.py +++ b/funsor/distribution.py @@ -19,6 +19,7 @@ from funsor.domains import Array, Real, Reals from funsor.gaussian import Gaussian from funsor.interpreter import gensym +from funsor.provenance import Provenance from funsor.tensor import ( Tensor, align_tensors, @@ -458,6 +459,14 @@ def backenddist_to_funsor( for param_name in funsor_dist_class._ast_fields if param_name != "value" ] + provenance = frozenset().union( + *[param.provenance for param in params if isinstance(param, Provenance)] + ) + if provenance: + params = [ + param.term if isinstance(param, Provenance) else param for param in params + ] + return Provenance(funsor_dist_class(*params), provenance) return funsor_dist_class(*params) diff --git a/funsor/integrate.py b/funsor/integrate.py index ab6f98ad..285592a9 100644 --- a/funsor/integrate.py +++ b/funsor/integrate.py @@ -11,6 +11,7 @@ from funsor.delta import Delta from funsor.gaussian import Gaussian, _norm2, _vm, align_gaussian from funsor.interpretations import eager, normalize +from funsor.provenance import Provenance from funsor.tensor import Tensor from funsor.terms import ( Funsor, @@ -139,6 +140,7 @@ def normalize_integrate_contraction(log_measure, integrand, reduced_vars): Tensor, GaussianMixture, EagerConstant, + Provenance, ), ) def eager_contraction_binary_to_integrate(red_op, bin_op, reduced_vars, lhs, rhs): diff --git a/funsor/provenance.py b/funsor/provenance.py new file mode 100644 index 00000000..3cbd7847 --- /dev/null +++ b/funsor/provenance.py @@ -0,0 +1,105 @@ +# Copyright Contributors to the Pyro project. +# SPDX-License-Identifier: Apache-2.0 + +from collections import OrderedDict + +import funsor.ops as ops +from funsor.tensor import Tensor +from funsor.terms import Binary, Funsor, FunsorMeta, Number, Unary, Variable, eager + + +class ProvenanceMeta(FunsorMeta): + """ + Wrapper to combine provenance information from the term. + """ + + def __call__(cls, term, provenance): + while isinstance(term, Provenance): + provenance |= term.provenance + term = term.term + + return super(ProvenanceMeta, cls).__call__(term, provenance) + + +class Provenance(Funsor, metaclass=ProvenanceMeta): + """ + Provenance funsor for tracking the dependence of terms on ``(name, point)`` + of sampled random variables. + + **References** + + [1] David Wingate, Noah Goodman, Andreas Stuhlmüller, Jeffrey Siskind (2011) + Nonstandard Interpretations of Probabilistic Programs for Efficient Inference + http://papers.neurips.cc/paper/4309-nonstandard-interpretations-of-probabilistic-programs-for-efficient-inference.pdf + + :param funsor term: A term that depends on tracked variables. + :param frozenset provenance: A set of tuples of the form ``(name, point)``. + """ + + def __init__(self, term, provenance): + assert isinstance(term, Funsor) + assert isinstance(provenance, frozenset) + + provenance_names = frozenset([name for name, point in provenance]) + assert provenance_names.isdisjoint(term.inputs) + inputs = OrderedDict() + for name, point in provenance: + assert isinstance(name, str) + assert isinstance(point, Funsor) + assert name not in point.inputs + inputs.update({name: point.output}) + inputs.update(point.inputs) + + inputs.update(term.inputs) + output = term.output + fresh = provenance_names + bound = {} + super(Provenance, self).__init__(inputs, output, fresh, bound) + self.term = term + self.provenance = provenance + + def eager_subs(self, subs): + assert isinstance(subs, tuple) + subs = OrderedDict(subs) + assert set(subs).issubset(self.fresh) + new_provenance = frozenset() + new_term = self.term + for name, point in self.provenance: + if name in subs: + value = subs[name] + if isinstance(value, Variable): + new_provenance |= frozenset([(value.name, point)]) + continue + + # leave out the substituted provenance variable + # make sure that the value matches the point + assert value is point + else: + new_provenance |= frozenset([(name, point)]) + return Provenance(new_term, new_provenance) if new_provenance else new_term + + def _sample(self, sampled_vars, sample_inputs, rng_key): + result = self.term._sample(sampled_vars, sample_inputs, rng_key) + return Provenance(result, self.provenance) + + +@eager.register(Binary, ops.BinaryOp, Provenance, Provenance) +def eager_binary_provenance_provenance(op, lhs, rhs): + return Provenance(op(lhs.term, rhs.term), lhs.provenance | rhs.provenance) + + +@eager.register(Binary, ops.BinaryOp, Provenance, (Number, Tensor)) +def eager_binary_provenance_tensor(op, lhs, rhs): + assert lhs.fresh.isdisjoint(rhs.inputs) + return Provenance(op(lhs.term, rhs), lhs.provenance) + + +@eager.register(Binary, ops.BinaryOp, (Number, Tensor), Provenance) +def eager_binary_tensor_provenance(op, lhs, rhs): + assert rhs.fresh.isdisjoint(lhs.inputs) + return Provenance(op(lhs, rhs.term), rhs.provenance) + + +@eager.register(Unary, ops.UnaryOp, Provenance) +def eager_unary(op, arg): + return Provenance(op(arg.term), arg.provenance) diff --git a/funsor/torch/__init__.py b/funsor/torch/__init__.py index 71f2c698..b4c36e36 100644 --- a/funsor/torch/__init__.py +++ b/funsor/torch/__init__.py @@ -1,12 +1,10 @@ # Copyright Contributors to the Pyro project. # SPDX-License-Identifier: Apache-2.0 -from collections import OrderedDict - import torch from multipledispatch import dispatch -from funsor.constant import Constant +from funsor.provenance import Provenance from funsor.tensor import tensor_to_funsor from funsor.terms import to_data, to_funsor from funsor.torch.provenance import ProvenanceTensor @@ -31,15 +29,15 @@ def _quote(x, indent, out): @to_funsor.register(ProvenanceTensor) -def provenance_to_funsor(x, output=None, dim_to_name=None): - ret = to_funsor(x._t, output=output, dim_to_name=dim_to_name) - return Constant(OrderedDict(x._provenance), ret) +def provenancetensor_to_funsor(x, output=None, dim_to_name=None): + term = to_funsor(x._t, output=output, dim_to_name=dim_to_name) + return Provenance(term, x._provenance) -@to_data.register(Constant) -def constant_to_data(x, name_to_dim=None): - data = to_data(x.arg, name_to_dim=name_to_dim) - return ProvenanceTensor(data, provenance=frozenset(x.const_inputs.items())) +@to_data.register(Provenance) +def provenance_to_data(x, name_to_dim=None): + data = to_data(x.term, name_to_dim=name_to_dim) + return ProvenanceTensor(data, provenance=x.provenance) to_funsor.register(torch.Tensor)(tensor_to_funsor) diff --git a/test/test_constant.py b/test/test_constant.py index 06679c3d..7e287376 100644 --- a/test/test_constant.py +++ b/test/test_constant.py @@ -8,8 +8,8 @@ from funsor.delta import Delta from funsor.domains import Bint, Real from funsor.tensor import Tensor -from funsor.terms import Number, Variable, to_data, to_funsor -from funsor.testing import assert_close, randn, requires_backend +from funsor.terms import Number, Variable +from funsor.testing import assert_close, randn def test_eager_subs_variable(): @@ -81,29 +81,3 @@ def test_align(): for i in range(2): for j in range(3): assert x(a=0, b=b, i=i, j=j) == y(a=0, b=b, i=i, j=j) - - -@requires_backend("torch", reason="requires ProvenanceTensor") -def test_to_funsor(): - import torch - - from funsor.torch.provenance import ProvenanceTensor - - data = torch.zeros(3, 3) - pt = ProvenanceTensor(data, frozenset({("x", Real)})) - c = to_funsor(pt) - assert c is Constant(OrderedDict(x=Real), Tensor(data)) - - -@requires_backend("torch", reason="requires ProvenanceTensor") -def test_to_data(): - import torch - - from funsor.torch.provenance import ProvenanceTensor - - data = torch.zeros(3, 3) - c = Constant(OrderedDict(x=Real), Tensor(data)) - pt = to_data(c) - assert isinstance(pt, ProvenanceTensor) - assert pt._t is data - assert pt._provenance == frozenset({("x", Real)})