From 448fc43d81f1b268b260b0a4b8826006c73cbeb5 Mon Sep 17 00:00:00 2001 From: Eli Date: Wed, 27 Oct 2021 12:17:14 -0400 Subject: [PATCH 1/3] stash --- pyro/poutine/enum_messenger.py | 6 +++++ pyro/poutine/markov_messenger.py | 13 +++++++++++ pyro/poutine/runtime.py | 39 +++++++++++++++++++++++++++++++- 3 files changed, 57 insertions(+), 1 deletion(-) diff --git a/pyro/poutine/enum_messenger.py b/pyro/poutine/enum_messenger.py index 234c0764b8..17398afd6c 100644 --- a/pyro/poutine/enum_messenger.py +++ b/pyro/poutine/enum_messenger.py @@ -143,11 +143,17 @@ def __init__(self, first_available_dim=None): def __enter__(self): if self.first_available_dim is not None: _ENUM_ALLOCATOR.set_first_available_dim(self.first_available_dim) + else: + _ENUM_ALLOCATOR.restore_globals() self._markov_depths = {} # site name -> depth (nonnegative integer) self._param_dims = {} # site name -> (enum dim -> unique id) self._value_dims = {} # site name -> (enum dim -> unique id) return super().__enter__() + def __exit__(self, *args): + _ENUM_ALLOCATOR.remove_globals() + return super().__exit__(*args) + @ignore_jit_warnings() def _pyro_sample(self, msg): """ diff --git a/pyro/poutine/markov_messenger.py b/pyro/poutine/markov_messenger.py index 1d68c9e06a..c78c46df5b 100644 --- a/pyro/poutine/markov_messenger.py +++ b/pyro/poutine/markov_messenger.py @@ -44,6 +44,11 @@ def __init__(self, history=1, keep=False, dim=None, name=None): self._iterable = None self._pos = -1 self._stack = [] + + if _FUNSOR_ACTIVE: + from pyro.contrib.funsor.handlers.named_messenger import MarkovMessenger as FunsorMarkovMessenger + self._funsor_markov = FunsorMarkovMessenger(history=history, keep=keep) + super().__init__() def generator(self, iterable): @@ -60,12 +65,20 @@ def __enter__(self): self._pos += 1 if len(self._stack) <= self._pos: self._stack.append(set()) + + if hasattr(self, "_funsor_markov"): + self._funsor_markov.__enter__() + return super().__enter__() def __exit__(self, *args, **kwargs): if not self.keep: self._stack.pop() self._pos -= 1 + + if hasattr(self, "_funsor_markov"): + self._funsor_markov.__exit__(*args, **kwargs) + return super().__exit__(*args, **kwargs) def _pyro_sample(self, msg): diff --git a/pyro/poutine/runtime.py b/pyro/poutine/runtime.py index 59b27c8911..34eda3fd64 100644 --- a/pyro/poutine/runtime.py +++ b/pyro/poutine/runtime.py @@ -2,6 +2,8 @@ # SPDX-License-Identifier: Apache-2.0 import functools + +from collections import OrderedDict from typing import Dict from pyro.params.param_store import ( # noqa: F401 @@ -62,6 +64,11 @@ def allocate(self, name, dim): ) ) self._stack[-1 - dim] = name + + if _FUNSOR_ACTIVE: + from pyro.contrib.funsor.handlers.runtime import _DIM_STACK, DimRequest, DimType + _DIM_STACK.allocate({name: DimRequest(dim, DimType.VISIBLE)}) + return dim def free(self, name, dim): @@ -74,6 +81,10 @@ def free(self, name, dim): while self._stack and self._stack[-1] is None: self._stack.pop() + if _FUNSOR_ACTIVE: + from pyro.contrib.funsor.handlers.runtime import _DIM_STACK + del _DIM_STACK.global_frame[name] + # Handles placement of plate dimensions _DIM_ALLOCATOR = _DimAllocator() @@ -98,7 +109,12 @@ def set_first_available_dim(self, first_available_dim): assert first_available_dim < 0, first_available_dim self.next_available_dim = first_available_dim self.next_available_id = 0 - self.dim_to_id = {} # only the global ids + self.dim_to_id, prev_dim_to_id = {}, getattr(self, "dim_to_id", {}) # only the global ids + + if _FUNSOR_ACTIVE: + from pyro.contrib.funsor.handlers.runtime import _DIM_STACK, StackFrame + _DIM_STACK.set_first_available_dim(first_available_dim) + self.global_frame = StackFrame(OrderedDict(), OrderedDict()) def allocate(self, scope_dims=None): """ @@ -132,8 +148,29 @@ def allocate(self, scope_dims=None): while dim in scope_dims: dim -= 1 + if _FUNSOR_ACTIVE: + from pyro.contrib.funsor.handlers.runtime import _DIM_STACK, DimRequest, DimType + dim_ = dim + name = f"_enum_dim_{id_}" + if scope_dims is None: + dim = _DIM_STACK.allocate({name: DimRequest(None, DimType.GLOBAL)})[name] + self.dim_to_id[dim] = self.dim_to_id.pop(dim_) + else: + dim = _DIM_STACK.allocate({name: DimRequest(None, DimType.LOCAL)})[name] + assert dim not in scope_dims + return dim, id_ + def restore_globals(self): + if _FUNSOR_ACTIVE: + from pyro.contrib.funsor.handlers.runtime import _DIM_STACK, DimRequest, DimType + _DIM_STACK.push_global(self.global_frame) + + def remove_globals(self): + if _FUNSOR_ACTIVE: + from pyro.contrib.funsor.handlers.runtime import _DIM_STACK + self.global_frame = _DIM_STACK.pop_global() + # Handles placement of enumeration dimensions _ENUM_ALLOCATOR = _EnumAllocator() From 55952f53554b454917555c901fd3e638184b66a0 Mon Sep 17 00:00:00 2001 From: Eli Date: Wed, 15 Dec 2021 16:58:59 -0500 Subject: [PATCH 2/3] fix _funsor_active --- pyro/contrib/funsor/handlers/runtime.py | 5 +-- pyro/poutine/markov_messenger.py | 12 +++++-- pyro/poutine/runtime.py | 46 ++++++++++++++++++------- 3 files changed, 46 insertions(+), 17 deletions(-) diff --git a/pyro/contrib/funsor/handlers/runtime.py b/pyro/contrib/funsor/handlers/runtime.py index d1d3a8081a..87830fe9d6 100644 --- a/pyro/contrib/funsor/handlers/runtime.py +++ b/pyro/contrib/funsor/handlers/runtime.py @@ -1,9 +1,12 @@ # Copyright Contributors to the Pyro project. # SPDX-License-Identifier: Apache-2.0 +import os from collections import Counter, OrderedDict, namedtuple from enum import Enum +os.environ["PYRO_FUNSOR_ACTIVE"] = "1" # TODO better toggle + class StackFrame: """ @@ -99,14 +102,12 @@ def push_global(self, frame): self._global_stack.append(frame) def pop_global(self): - assert self._global_stack, "cannot pop the global frame" return self._global_stack.pop() def push_iter(self, frame): self._iter_stack.append(frame) def pop_iter(self): - assert self._iter_stack, "cannot pop the global frame" return self._iter_stack.pop() def push_local(self, frame): diff --git a/pyro/poutine/markov_messenger.py b/pyro/poutine/markov_messenger.py index c78c46df5b..8283cdb9a6 100644 --- a/pyro/poutine/markov_messenger.py +++ b/pyro/poutine/markov_messenger.py @@ -5,6 +5,7 @@ from contextlib import ExitStack # python 3 from .reentrant_messenger import ReentrantMessenger +from .runtime import _is_funsor_active class MarkovMessenger(ReentrantMessenger): @@ -45,8 +46,11 @@ def __init__(self, history=1, keep=False, dim=None, name=None): self._pos = -1 self._stack = [] - if _FUNSOR_ACTIVE: - from pyro.contrib.funsor.handlers.named_messenger import MarkovMessenger as FunsorMarkovMessenger + if _is_funsor_active(): + from pyro.contrib.funsor.handlers.named_messenger import ( + MarkovMessenger as FunsorMarkovMessenger, + ) + self._funsor_markov = FunsorMarkovMessenger(history=history, keep=keep) super().__init__() @@ -66,10 +70,12 @@ def __enter__(self): if len(self._stack) <= self._pos: self._stack.append(set()) + result = super().__enter__() + if hasattr(self, "_funsor_markov"): self._funsor_markov.__enter__() - return super().__enter__() + return result def __exit__(self, *args, **kwargs): if not self.keep: diff --git a/pyro/poutine/runtime.py b/pyro/poutine/runtime.py index 34eda3fd64..da15e758f0 100644 --- a/pyro/poutine/runtime.py +++ b/pyro/poutine/runtime.py @@ -2,7 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 import functools - +import os from collections import OrderedDict from typing import Dict @@ -18,6 +18,10 @@ _PYRO_PARAM_STORE = ParamStoreDict() +def _is_funsor_active() -> bool: + return "PYRO_FUNSOR_ACTIVE" in os.environ + + class _DimAllocator: """ Dimension allocator for internal use by :class:`plate`. @@ -65,8 +69,13 @@ def allocate(self, name, dim): ) self._stack[-1 - dim] = name - if _FUNSOR_ACTIVE: - from pyro.contrib.funsor.handlers.runtime import _DIM_STACK, DimRequest, DimType + if _is_funsor_active(): + from pyro.contrib.funsor.handlers.runtime import ( + _DIM_STACK, + DimRequest, + DimType, + ) + _DIM_STACK.allocate({name: DimRequest(dim, DimType.VISIBLE)}) return dim @@ -81,8 +90,9 @@ def free(self, name, dim): while self._stack and self._stack[-1] is None: self._stack.pop() - if _FUNSOR_ACTIVE: + if _is_funsor_active(): from pyro.contrib.funsor.handlers.runtime import _DIM_STACK + del _DIM_STACK.global_frame[name] @@ -109,10 +119,13 @@ def set_first_available_dim(self, first_available_dim): assert first_available_dim < 0, first_available_dim self.next_available_dim = first_available_dim self.next_available_id = 0 - self.dim_to_id, prev_dim_to_id = {}, getattr(self, "dim_to_id", {}) # only the global ids + self.dim_to_id, prev_dim_to_id = {}, getattr( + self, "dim_to_id", {} + ) # only the global ids - if _FUNSOR_ACTIVE: + if _is_funsor_active(): from pyro.contrib.funsor.handlers.runtime import _DIM_STACK, StackFrame + _DIM_STACK.set_first_available_dim(first_available_dim) self.global_frame = StackFrame(OrderedDict(), OrderedDict()) @@ -148,12 +161,19 @@ def allocate(self, scope_dims=None): while dim in scope_dims: dim -= 1 - if _FUNSOR_ACTIVE: - from pyro.contrib.funsor.handlers.runtime import _DIM_STACK, DimRequest, DimType + if _is_funsor_active(): + from pyro.contrib.funsor.handlers.runtime import ( + _DIM_STACK, + DimRequest, + DimType, + ) + dim_ = dim name = f"_enum_dim_{id_}" if scope_dims is None: - dim = _DIM_STACK.allocate({name: DimRequest(None, DimType.GLOBAL)})[name] + dim = _DIM_STACK.allocate({name: DimRequest(None, DimType.GLOBAL)})[ + name + ] self.dim_to_id[dim] = self.dim_to_id.pop(dim_) else: dim = _DIM_STACK.allocate({name: DimRequest(None, DimType.LOCAL)})[name] @@ -162,13 +182,15 @@ def allocate(self, scope_dims=None): return dim, id_ def restore_globals(self): - if _FUNSOR_ACTIVE: - from pyro.contrib.funsor.handlers.runtime import _DIM_STACK, DimRequest, DimType + if _is_funsor_active(): + from pyro.contrib.funsor.handlers.runtime import _DIM_STACK + _DIM_STACK.push_global(self.global_frame) def remove_globals(self): - if _FUNSOR_ACTIVE: + if _is_funsor_active(): from pyro.contrib.funsor.handlers.runtime import _DIM_STACK + self.global_frame = _DIM_STACK.pop_global() From 34e7974e1c68018f139ce9d9d7b8b145749846ae Mon Sep 17 00:00:00 2001 From: Eli Date: Wed, 15 Dec 2021 17:16:14 -0500 Subject: [PATCH 3/3] fix --- pyro/contrib/funsor/__init__.py | 4 ++++ pyro/contrib/funsor/handlers/runtime.py | 3 --- pyro/poutine/enum_messenger.py | 16 +++++++++++----- pyro/poutine/runtime.py | 4 +--- 4 files changed, 16 insertions(+), 11 deletions(-) diff --git a/pyro/contrib/funsor/__init__.py b/pyro/contrib/funsor/__init__.py index d8a4d3eea9..ba9b330bbf 100644 --- a/pyro/contrib/funsor/__init__.py +++ b/pyro/contrib/funsor/__init__.py @@ -1,6 +1,8 @@ # Copyright Contributors to the Pyro project. # SPDX-License-Identifier: Apache-2.0 +import os + import pyroapi from pyro.contrib.funsor.handlers import condition, do, markov @@ -38,6 +40,8 @@ def plate(*args, **kwargs): }, ) +os.environ["PYRO_FUNSOR_ACTIVE"] = "1" # TODO better toggle + __all__ = [ "clear_param_store", "condition", diff --git a/pyro/contrib/funsor/handlers/runtime.py b/pyro/contrib/funsor/handlers/runtime.py index 87830fe9d6..979bdb4c13 100644 --- a/pyro/contrib/funsor/handlers/runtime.py +++ b/pyro/contrib/funsor/handlers/runtime.py @@ -1,12 +1,9 @@ # Copyright Contributors to the Pyro project. # SPDX-License-Identifier: Apache-2.0 -import os from collections import Counter, OrderedDict, namedtuple from enum import Enum -os.environ["PYRO_FUNSOR_ACTIVE"] = "1" # TODO better toggle - class StackFrame: """ diff --git a/pyro/poutine/enum_messenger.py b/pyro/poutine/enum_messenger.py index 17398afd6c..4c7666b5d3 100644 --- a/pyro/poutine/enum_messenger.py +++ b/pyro/poutine/enum_messenger.py @@ -9,7 +9,7 @@ from pyro.util import ignore_jit_warnings from .messenger import Messenger -from .runtime import _ENUM_ALLOCATOR +from .runtime import _ENUM_ALLOCATOR, _is_funsor_active def _tmc_mixture_sample(msg): @@ -138,20 +138,26 @@ def __init__(self, first_available_dim=None): first_available_dim is None or first_available_dim < 0 ), first_available_dim self.first_available_dim = first_available_dim + if _is_funsor_active(): + from pyro.contrib.funsor.handlers.named_messenger import NamedMessenger + + self._funsor_named = NamedMessenger() super().__init__() def __enter__(self): if self.first_available_dim is not None: _ENUM_ALLOCATOR.set_first_available_dim(self.first_available_dim) - else: - _ENUM_ALLOCATOR.restore_globals() self._markov_depths = {} # site name -> depth (nonnegative integer) self._param_dims = {} # site name -> (enum dim -> unique id) self._value_dims = {} # site name -> (enum dim -> unique id) - return super().__enter__() + result = super().__enter__() + if hasattr(self, "_funsor_named"): + self._funsor_named.__enter__() + return result def __exit__(self, *args): - _ENUM_ALLOCATOR.remove_globals() + if hasattr(self, "_funsor_named"): + self._funsor_named.__exit__(*args) return super().__exit__(*args) @ignore_jit_warnings() diff --git a/pyro/poutine/runtime.py b/pyro/poutine/runtime.py index da15e758f0..f877333cb4 100644 --- a/pyro/poutine/runtime.py +++ b/pyro/poutine/runtime.py @@ -119,9 +119,7 @@ def set_first_available_dim(self, first_available_dim): assert first_available_dim < 0, first_available_dim self.next_available_dim = first_available_dim self.next_available_id = 0 - self.dim_to_id, prev_dim_to_id = {}, getattr( - self, "dim_to_id", {} - ) # only the global ids + self.dim_to_id = {} # only the global ids if _is_funsor_active(): from pyro.contrib.funsor.handlers.runtime import _DIM_STACK, StackFrame