From 6d0c8d679db3b4f6403463442db1cc8f487cdca6 Mon Sep 17 00:00:00 2001 From: Yerdos Ordabayev Date: Thu, 14 Oct 2021 23:06:08 -0400 Subject: [PATCH] eager categorical constant pattern --- funsor/torch/distributions.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/funsor/torch/distributions.py b/funsor/torch/distributions.py index 0b9e202a2..208a6ed38 100644 --- a/funsor/torch/distributions.py +++ b/funsor/torch/distributions.py @@ -15,6 +15,7 @@ import funsor.ops as ops from funsor.cnf import Contraction +from funsor.constant import Constant from funsor.distribution import ( # noqa: F401 FUNSOR_DIST_NAMES, Bernoulli, @@ -362,6 +363,9 @@ def deltadist_to_funsor(pyro_dist, output=None, dim_to_name=None): eager.register(Multinomial, Tensor, Tensor, Tensor)(eager_multinomial) # noqa: F821) eager.register(Categorical, Funsor, Tensor)(eager_categorical_funsor) # noqa: F821) eager.register(Categorical, Tensor, Variable)(eager_categorical_tensor) # noqa: F821) +eager.register(Categorical, Constant[Tuple, Tensor], Variable)( + eager_categorical_tensor +) # noqa: F821) eager.register(Delta, Tensor, Tensor, Tensor)(eager_delta_tensor) # noqa: F821 eager.register(Delta, Funsor, Funsor, Variable)( eager_delta_funsor_variable