diff --git a/pyro/contrib/funsor/__init__.py b/pyro/contrib/funsor/__init__.py index d8a4d3eea9..ba9b330bbf 100644 --- a/pyro/contrib/funsor/__init__.py +++ b/pyro/contrib/funsor/__init__.py @@ -1,6 +1,8 @@ # Copyright Contributors to the Pyro project. # SPDX-License-Identifier: Apache-2.0 +import os + import pyroapi from pyro.contrib.funsor.handlers import condition, do, markov @@ -38,6 +40,8 @@ def plate(*args, **kwargs): }, ) +os.environ["PYRO_FUNSOR_ACTIVE"] = "1" # TODO better toggle + __all__ = [ "clear_param_store", "condition", diff --git a/pyro/contrib/funsor/handlers/runtime.py b/pyro/contrib/funsor/handlers/runtime.py index d1d3a8081a..979bdb4c13 100644 --- a/pyro/contrib/funsor/handlers/runtime.py +++ b/pyro/contrib/funsor/handlers/runtime.py @@ -99,14 +99,12 @@ def push_global(self, frame): self._global_stack.append(frame) def pop_global(self): - assert self._global_stack, "cannot pop the global frame" return self._global_stack.pop() def push_iter(self, frame): self._iter_stack.append(frame) def pop_iter(self): - assert self._iter_stack, "cannot pop the global frame" return self._iter_stack.pop() def push_local(self, frame): diff --git a/pyro/poutine/enum_messenger.py b/pyro/poutine/enum_messenger.py index 234c0764b8..4c7666b5d3 100644 --- a/pyro/poutine/enum_messenger.py +++ b/pyro/poutine/enum_messenger.py @@ -9,7 +9,7 @@ from pyro.util import ignore_jit_warnings from .messenger import Messenger -from .runtime import _ENUM_ALLOCATOR +from .runtime import _ENUM_ALLOCATOR, _is_funsor_active def _tmc_mixture_sample(msg): @@ -138,6 +138,10 @@ def __init__(self, first_available_dim=None): first_available_dim is None or first_available_dim < 0 ), first_available_dim self.first_available_dim = first_available_dim + if _is_funsor_active(): + from pyro.contrib.funsor.handlers.named_messenger import NamedMessenger + + self._funsor_named = NamedMessenger() super().__init__() def __enter__(self): @@ -146,7 +150,15 @@ def __enter__(self): self._markov_depths = {} # site name -> depth (nonnegative integer) self._param_dims = {} # site name -> (enum dim -> unique id) self._value_dims = {} # site name -> (enum dim -> unique id) - return super().__enter__() + result = super().__enter__() + if hasattr(self, "_funsor_named"): + self._funsor_named.__enter__() + return result + + def __exit__(self, *args): + if hasattr(self, "_funsor_named"): + self._funsor_named.__exit__(*args) + return super().__exit__(*args) @ignore_jit_warnings() def _pyro_sample(self, msg): diff --git a/pyro/poutine/markov_messenger.py b/pyro/poutine/markov_messenger.py index 1d68c9e06a..8283cdb9a6 100644 --- a/pyro/poutine/markov_messenger.py +++ b/pyro/poutine/markov_messenger.py @@ -5,6 +5,7 @@ from contextlib import ExitStack # python 3 from .reentrant_messenger import ReentrantMessenger +from .runtime import _is_funsor_active class MarkovMessenger(ReentrantMessenger): @@ -44,6 +45,14 @@ def __init__(self, history=1, keep=False, dim=None, name=None): self._iterable = None self._pos = -1 self._stack = [] + + if _is_funsor_active(): + from pyro.contrib.funsor.handlers.named_messenger import ( + MarkovMessenger as FunsorMarkovMessenger, + ) + + self._funsor_markov = FunsorMarkovMessenger(history=history, keep=keep) + super().__init__() def generator(self, iterable): @@ -60,12 +69,22 @@ def __enter__(self): self._pos += 1 if len(self._stack) <= self._pos: self._stack.append(set()) - return super().__enter__() + + result = super().__enter__() + + if hasattr(self, "_funsor_markov"): + self._funsor_markov.__enter__() + + return result def __exit__(self, *args, **kwargs): if not self.keep: self._stack.pop() self._pos -= 1 + + if hasattr(self, "_funsor_markov"): + self._funsor_markov.__exit__(*args, **kwargs) + return super().__exit__(*args, **kwargs) def _pyro_sample(self, msg): diff --git a/pyro/poutine/runtime.py b/pyro/poutine/runtime.py index 59b27c8911..f877333cb4 100644 --- a/pyro/poutine/runtime.py +++ b/pyro/poutine/runtime.py @@ -2,6 +2,8 @@ # SPDX-License-Identifier: Apache-2.0 import functools +import os +from collections import OrderedDict from typing import Dict from pyro.params.param_store import ( # noqa: F401 @@ -16,6 +18,10 @@ _PYRO_PARAM_STORE = ParamStoreDict() +def _is_funsor_active() -> bool: + return "PYRO_FUNSOR_ACTIVE" in os.environ + + class _DimAllocator: """ Dimension allocator for internal use by :class:`plate`. @@ -62,6 +68,16 @@ def allocate(self, name, dim): ) ) self._stack[-1 - dim] = name + + if _is_funsor_active(): + from pyro.contrib.funsor.handlers.runtime import ( + _DIM_STACK, + DimRequest, + DimType, + ) + + _DIM_STACK.allocate({name: DimRequest(dim, DimType.VISIBLE)}) + return dim def free(self, name, dim): @@ -74,6 +90,11 @@ def free(self, name, dim): while self._stack and self._stack[-1] is None: self._stack.pop() + if _is_funsor_active(): + from pyro.contrib.funsor.handlers.runtime import _DIM_STACK + + del _DIM_STACK.global_frame[name] + # Handles placement of plate dimensions _DIM_ALLOCATOR = _DimAllocator() @@ -100,6 +121,12 @@ def set_first_available_dim(self, first_available_dim): self.next_available_id = 0 self.dim_to_id = {} # only the global ids + if _is_funsor_active(): + from pyro.contrib.funsor.handlers.runtime import _DIM_STACK, StackFrame + + _DIM_STACK.set_first_available_dim(first_available_dim) + self.global_frame = StackFrame(OrderedDict(), OrderedDict()) + def allocate(self, scope_dims=None): """ Allocate a new recyclable dim and a unique id. @@ -132,8 +159,38 @@ def allocate(self, scope_dims=None): while dim in scope_dims: dim -= 1 + if _is_funsor_active(): + from pyro.contrib.funsor.handlers.runtime import ( + _DIM_STACK, + DimRequest, + DimType, + ) + + dim_ = dim + name = f"_enum_dim_{id_}" + if scope_dims is None: + dim = _DIM_STACK.allocate({name: DimRequest(None, DimType.GLOBAL)})[ + name + ] + self.dim_to_id[dim] = self.dim_to_id.pop(dim_) + else: + dim = _DIM_STACK.allocate({name: DimRequest(None, DimType.LOCAL)})[name] + assert dim not in scope_dims + return dim, id_ + def restore_globals(self): + if _is_funsor_active(): + from pyro.contrib.funsor.handlers.runtime import _DIM_STACK + + _DIM_STACK.push_global(self.global_frame) + + def remove_globals(self): + if _is_funsor_active(): + from pyro.contrib.funsor.handlers.runtime import _DIM_STACK + + self.global_frame = _DIM_STACK.pop_global() + # Handles placement of enumeration dimensions _ENUM_ALLOCATOR = _EnumAllocator()