diff --git a/pyro/nn/auto_reg_nn.py b/pyro/nn/auto_reg_nn.py index 3ae06bd055..de4ac594e4 100644 --- a/pyro/nn/auto_reg_nn.py +++ b/pyro/nn/auto_reg_nn.py @@ -3,12 +3,14 @@ import warnings +from typing import Union, List, Optional + import torch import torch.nn as nn from torch.nn import functional as F -def sample_mask_indices(input_dim, hidden_dim, simple=True): +def sample_mask_indices(input_dim: int, hidden_dim: int, simple: bool=True) -> Union[int,torch.Tensor]: """ Samples the indices assigned to hidden units during the construction of MADE masks @@ -32,9 +34,7 @@ def sample_mask_indices(input_dim, hidden_dim, simple=True): return ints -def create_mask( - input_dim, context_dim, hidden_dims, permutation, output_dim_multiplier -): +def create_mask(input_dim: int , context_dim: int, hidden_dims: List[int], permutation: torch.LongTensor, output_dim_multiplier: int): """ Creates MADE masks for a conditional distribution @@ -109,7 +109,7 @@ class MaskedLinear(nn.Linear): :type bias: bool """ - def __init__(self, in_features, out_features, mask, bias=True): + def __init__(self, in_features: int, out_features: int, mask: torch.Tensor, bias: bool=True): super().__init__(in_features, out_features, bias) self.register_buffer("mask", mask.data) @@ -165,15 +165,15 @@ class ConditionalAutoRegressiveNN(nn.Module): """ def __init__( - self, - input_dim, - context_dim, - hidden_dims, - param_dims=[1, 1], - permutation=None, - skip_connections=False, - nonlinearity=nn.ReLU(), - ): + self, + input_dim: int, + context_dim: int, + hidden_dims: List[int], + param_dims: List[int]=[1, 1], + permutation: Optional[torch.LongTensor]=None, + skip_connections: bool=False, + nonlinearity=nn.ReLU()): + super().__init__() if input_dim == 1: warnings.warn( @@ -327,14 +327,13 @@ class AutoRegressiveNN(ConditionalAutoRegressiveNN): """ def __init__( - self, - input_dim, - hidden_dims, - param_dims=[1, 1], - permutation=None, - skip_connections=False, - nonlinearity=nn.ReLU(), - ): + self, + input_dim: int, + hidden_dims: List[int], + param_dims: List=[1, 1], + permutation: torch.LongTensor=None, + skip_connections: bool=False, + nonlinearity=nn.ReLU()): super(AutoRegressiveNN, self).__init__( input_dim, 0, diff --git a/pyro/nn/dense_nn.py b/pyro/nn/dense_nn.py index a7a9a7e645..920ee02497 100644 --- a/pyro/nn/dense_nn.py +++ b/pyro/nn/dense_nn.py @@ -1,6 +1,8 @@ # Copyright (c) 2017-2019 Uber Technologies, Inc. # SPDX-License-Identifier: Apache-2.0 +from typing import List, Union, Tuple + import torch @@ -34,13 +36,13 @@ class ConditionalDenseNN(torch.nn.Module): """ def __init__( - self, - input_dim, - context_dim, - hidden_dims, - param_dims=[1, 1], - nonlinearity=torch.nn.ReLU(), - ): + self, + input_dim:int, + context_dim:int, + hidden_dims: List[int], + param_dims: List[int] = [1, 1], + nonlinearity: torch.nn.Module = torch.nn.ReLU()): + super().__init__() self.input_dim = input_dim @@ -65,14 +67,14 @@ def __init__( # Save the nonlinearity self.f = nonlinearity - def forward(self, x, context): + def forward(self, x:torch.Tensor, context:torch.Tensor) -> Union[torch.Tensor,Tuple[torch.Tensor]]: # We must be able to broadcast the size of the context over the input context = context.expand(x.size()[:-1] + (context.size(-1),)) x = torch.cat([context, x], dim=-1) return self._forward(x) - def _forward(self, x): + def _forward(self, x:torch.Tensor) -> Union[torch.Tensor, Tuple[torch.Tensor]]: """ The forward method """ @@ -122,11 +124,15 @@ class DenseNN(ConditionalDenseNN): """ def __init__( - self, input_dim, hidden_dims, param_dims=[1, 1], nonlinearity=torch.nn.ReLU() - ): + self, + input_dim: int, + hidden_dims: List[int], + param_dims: List[int] = [1, 1], + nonlinearity: torch.nn.module = torch.nn.ReLU()) -> None: + super(DenseNN, self).__init__( input_dim, 0, hidden_dims, param_dims=param_dims, nonlinearity=nonlinearity ) - def forward(self, x): + def forward(self, x: torch.Tensor): return self._forward(x) diff --git a/pyro/nn/module.py b/pyro/nn/module.py index ad11567537..a1f9b2f55a 100644 --- a/pyro/nn/module.py +++ b/pyro/nn/module.py @@ -15,12 +15,13 @@ import functools import inspect from collections import OrderedDict, namedtuple +from typing import Callable, Dict, Union, Optional import torch from torch.distributions import constraints, transform_to -import pyro -from pyro.poutine.runtime import _PYRO_PARAM_STORE +import pyro # type: ignore +from pyro.poutine.runtime import _PYRO_PARAM_STORE # type: ignore class PyroParam(namedtuple("PyroParam", ("init_value", "constraint", "event_dim"))): @@ -156,7 +157,7 @@ def __get__(self, obj, obj_type): return obj.__getattr__(self.name) -def _make_name(prefix, name): +def _make_name(prefix:str, name:str): return "{}.{}".format(prefix, name) if prefix else name @@ -210,7 +211,7 @@ def _get_pyro_params(module): class _PyroModuleMeta(type): - _pyro_mixin_cache = {} + _pyro_mixin_cach: Dict = {} # Unpickling helper to create an empty object of type PyroModule[Module]. class _New: @@ -374,15 +375,15 @@ class PyroLinear(nn.Linear, PyroModule): :param str name: Optional name for a root PyroModule. This is ignored in sub-PyroModules of another PyroModule. """ - - def __init__(self, name=""): + + def __init__(self, name:str=""): self._pyro_name = name self._pyro_context = _Context() # shared among sub-PyroModules - self._pyro_params = OrderedDict() - self._pyro_samples = OrderedDict() + self._pyro_params: OrderedDict = OrderedDict() + self._pyro_samples: OrderedDict = OrderedDict() super().__init__() - def add_module(self, name, module): + def add_module(self, name:str, module): """ Adds a child module to the current module. """ @@ -392,7 +393,7 @@ def add_module(self, name, module): ) super().add_module(name, module) - def named_pyro_params(self, prefix="", recurse=True): + def named_pyro_params(self, prefix:str='', recurse:bool=True): """ Returns an iterator over PyroModule parameters, yielding both the name of the parameter as well as the parameter itself. @@ -407,7 +408,7 @@ def named_pyro_params(self, prefix="", recurse=True): for elem in gen: yield elem - def _pyro_set_supermodule(self, name, context): + def _pyro_set_supermodule(self, name:str, context): self._pyro_name = name self._pyro_context = context for key, value in self._modules.items(): @@ -417,6 +418,10 @@ def _pyro_set_supermodule(self, name, context): ), "submodule {} has executed outside of supermodule".format(name) value._pyro_set_supermodule(_make_name(name, key), context) + def _pyro_get_fullname(self, name:str): + assert self.__dict__['_pyro_context'].used, "fullname is not yet defined" + return _make_name(self.__dict__['_pyro_name'], name) + def _pyro_get_fullname(self, name): assert self.__dict__["_pyro_context"].used, "fullname is not yet defined" return _make_name(self.__dict__["_pyro_name"], name) @@ -425,7 +430,7 @@ def __call__(self, *args, **kwargs): with self._pyro_context: return super().__call__(*args, **kwargs) - def __getattr__(self, name): + def __getattr__(self, name:str): # PyroParams trigger pyro.param statements. if "_pyro_params" in self.__dict__: _pyro_params = self.__dict__["_pyro_params"] @@ -507,7 +512,7 @@ def __getattr__(self, name): return result - def __setattr__(self, name, value): + def __setattr__(self, name:str, value:Union[PyroParam,"PyroModule",torch.nn.Parameter,torch.Tensor]): if isinstance(value, PyroModule): # Create a new sub PyroModule, overwriting any old value. try: @@ -585,7 +590,7 @@ def __setattr__(self, name, value): super().__setattr__(name, value) - def __delattr__(self, name): + def __delattr__(self, name:str): if name in self._parameters: del self._parameters[name] if self._pyro_context.used: @@ -621,7 +626,7 @@ def __delattr__(self, name): super().__delattr__(name) -def pyro_method(fn): +def pyro_method(fn: Callable): """ Decorator for top-level methods of a :class:`PyroModule` to enable pyro effects and cache ``pyro.sample`` statements. @@ -638,7 +643,7 @@ def cached_fn(self, *args, **kwargs): return cached_fn -def clear(mod): +def clear(mod:PyroModule): """ Removes data from both a :class:`PyroModule` and the param store. @@ -653,7 +658,7 @@ def clear(mod): delattr(mod, name) -def to_pyro_module_(m, recurse=True): +def to_pyro_module_(m:torch.nn.Module, recurse:bool=True): """ Converts an ordinary :class:`torch.nn.Module` instance to a :class:`PyroModule` **in-place**. @@ -714,7 +719,7 @@ def to_pyro_module_(m, recurse=True): # attribute. This is required if any attribute is set to a PyroParam or # PyroSample. For motivation, see https://github.com/pyro-ppl/pyro/issues/2390 class _FlatWeightsDescriptor: - def __get__(self, obj, obj_type=None): + def __get__(self, obj, obj_type:Optional=None): if obj is None: return self return [getattr(obj, name) for name in obj._flat_weights_names] diff --git a/pyro/optim/dct_adam.py b/pyro/optim/dct_adam.py index 1a0047c239..8c026f677b 100644 --- a/pyro/optim/dct_adam.py +++ b/pyro/optim/dct_adam.py @@ -119,7 +119,7 @@ def step(self, closure: Optional[Callable] = None) -> Optional[float]: return loss - def _step_param(self, group: Dict, p) -> None: + def _step_param(self, group: Dict, p: torch.Tensor) -> None: grad = p.grad.data grad.clamp_(-group["clip_norm"], group["clip_norm"]) @@ -160,7 +160,7 @@ def _step_param(self, group: Dict, p) -> None: step = _transform_inverse(exp_avg / denom, time_dim, duration) p.data.add_(step.mul_(-step_size)) - def _step_param_subsample(self, group: Dict, p, subsample) -> None: + def _step_param_subsample(self, group: Dict, p: torch.Tensor, subsample) -> None: mask = _get_mask(p, subsample) grad = p.grad.data.masked_select(mask) diff --git a/pyro/params/param_store.py b/pyro/params/param_store.py index 7142e34a85..5a9aaba61f 100644 --- a/pyro/params/param_store.py +++ b/pyro/params/param_store.py @@ -4,10 +4,13 @@ import re import warnings import weakref +from typing import Iterable, Dict, Union, Callable, Optional +from collections import KeysView import torch from torch.distributions import constraints, transform_to +import pyro class ParamStoreDict: """ @@ -46,7 +49,7 @@ def __init__(self): self._param_to_name = {} # dictionary from unconstrained param to param name self._constraints = {} # dictionary from param name to constraint object - def clear(self): + def clear(self) -> None: """ Clear the ParamStore """ @@ -62,7 +65,7 @@ def items(self): for name in self._params: yield name, self[name] - def keys(self): + def keys(self) -> KeysView: """ Iterate over param names. """ @@ -75,22 +78,22 @@ def values(self): for name, constrained_param in self.items(): yield constrained_param - def __bool__(self): + def __bool__(self) -> bool: return bool(self._params) - def __len__(self): + def __len__(self) -> int: return len(self._params) - def __contains__(self, name): + def __contains__(self, name:str) -> bool: return name in self._params - def __iter__(self): + def __iter__(self) -> Iterable: """ Iterate over param names. """ return iter(self.keys()) - def __delitem__(self, name): + def __delitem__(self, name: str) -> None: """ Remove a parameter from the param store. """ @@ -98,7 +101,7 @@ def __delitem__(self, name): self._param_to_name.pop(unconstrained_value) self._constraints.pop(name) - def __getitem__(self, name): + def __getitem__(self, name: str): """ Get the *constrained* value of a named parameter. """ @@ -111,7 +114,7 @@ def __getitem__(self, name): return constrained_value - def __setitem__(self, name, new_constrained_value): + def __setitem__(self, name: str, new_constrained_value): """ Set the constrained value of an existing parameter, or the value of a new *unconstrained* parameter. To declare a new parameter with @@ -131,7 +134,7 @@ def __setitem__(self, name, new_constrained_value): self._params[name] = unconstrained_value self._param_to_name[unconstrained_value] = name - def setdefault(self, name, init_constrained_value, constraint=constraints.real): + def setdefault(self, name:str, init_constrained_value: Union[torch.Tensor,Callable[[],torch.Tensor]], constraint:constraints.Constraint=constraints.real) -> torch.Tensor: """ Retrieve a *constrained* parameter value from the if it exists, otherwise set the initial value. Note that this is a little fancier than @@ -169,7 +172,7 @@ def setdefault(self, name, init_constrained_value, constraint=constraints.real): # ------------------------------------------------------------------------------- # Old non-dict interface - def named_parameters(self): + def named_parameters(self) -> Iterable: """ Returns an iterator over ``(name, unconstrained_value)`` tuples for each parameter in the ParamStore. Note that, in the event the parameter is constrained, @@ -177,24 +180,18 @@ def named_parameters(self): """ return self._params.items() - def get_all_param_names(self): - warnings.warn( - "ParamStore.get_all_param_names() is deprecated; use .keys() instead.", - DeprecationWarning, - ) + def get_all_param_names(self) -> KeysView: + warnings.warn("ParamStore.get_all_param_names() is deprecated; use .keys() instead.", + DeprecationWarning) return self.keys() - def replace_param(self, param_name, new_param, old_param): - warnings.warn( - "ParamStore.replace_param() is deprecated; use .__setitem__() instead.", - DeprecationWarning, - ) + def replace_param(self, param_name:str, new_param: pyro.param, old_param: pyro.param): + warnings.warn("ParamStore.replace_param() is deprecated; use .__setitem__() instead.", + DeprecationWarning) assert self._params[param_name] is old_param.unconstrained() self[param_name] = new_param - def get_param( - self, name, init_tensor=None, constraint=constraints.real, event_dim=None - ): + def get_param(self, name: str, init_tensor: Optional[torch.Tensor] = None, constraint:constraints.Constraint=constraints.real, event_dim:Optional[int] = None): """ Get parameter from its name. If it does not yet exist in the ParamStore, it will be created and stored. @@ -215,7 +212,7 @@ def get_param( else: return self.setdefault(name, init_tensor, constraint) - def match(self, name): + def match(self, name:str) -> Dict: """ Get all parameters that match regex. The parameter must exist. @@ -226,7 +223,7 @@ def match(self, name): pattern = re.compile(name) return {name: self[name] for name in self if pattern.match(name)} - def param_name(self, p): + def param_name(self, p) -> str: """ Get parameter name from parameter @@ -235,7 +232,7 @@ def param_name(self, p): """ return self._param_to_name.get(p) - def get_state(self): + def get_state(self) -> Dict: """ Get the ParamStore state. """ @@ -245,7 +242,7 @@ def get_state(self): } return state - def set_state(self, state): + def set_state(self, state:Dict): """ Set the ParamStore state using state from a previous get_state() call """ @@ -264,7 +261,7 @@ def set_state(self, state): constraint = constraints.real self._constraints[param_name] = constraint - def save(self, filename): + def save(self, filename:str) -> None: """ Save parameters to disk @@ -274,7 +271,7 @@ def save(self, filename): with open(filename, "wb") as output_file: torch.save(self.get_state(), output_file) - def load(self, filename, map_location=None): + def load(self, filename: str, map_location: Optional[Union[Callable, torch.device, str, Dict]] = None) -> None: """ Loads parameters from disk @@ -300,19 +297,19 @@ def load(self, filename, map_location=None): _MODULE_NAMESPACE_DIVIDER = "$$$" -def param_with_module_name(pyro_name, param_name): +def param_with_module_name(pyro_name: str, param_name: str) -> str: return _MODULE_NAMESPACE_DIVIDER.join([pyro_name, param_name]) -def module_from_param_with_module_name(param_name): +def module_from_param_with_module_name(param_name: str) -> str: return param_name.split(_MODULE_NAMESPACE_DIVIDER)[0] -def user_param_name(param_name): +def user_param_name(param_name: str) -> str: if _MODULE_NAMESPACE_DIVIDER in param_name: return param_name.split(_MODULE_NAMESPACE_DIVIDER)[1] return param_name -def normalize_param_name(name): +def normalize_param_name(name: str) -> str: return name.replace(_MODULE_NAMESPACE_DIVIDER, ".") diff --git a/pyro/poutine/util.py b/pyro/poutine/util.py index e90c2a5917..d1f8250474 100644 --- a/pyro/poutine/util.py +++ b/pyro/poutine/util.py @@ -38,7 +38,7 @@ def prune_subsample_sites(trace): return trace -def enum_extend(trace, msg, num_samples=None): +def enum_extend(trace, msg: str, num_samples: Optional[int]=None) -> List: """ :param trace: a partial trace :param msg: the message at a Pyro primitive site @@ -65,7 +65,7 @@ def enum_extend(trace, msg, num_samples=None): return extended_traces -def mc_extend(trace, msg, num_samples=None): +def mc_extend(trace, msg: str, num_samples:Optional[int] = None) -> List: """ :param trace: a partial trace :param msg: the message at a Pyro primitive site @@ -90,7 +90,7 @@ def mc_extend(trace, msg, num_samples=None): return extended_traces -def discrete_escape(trace, msg): +def discrete_escape(trace, msg: str) -> bool: """ :param trace: a partial trace :param msg: the message at a Pyro primitive site @@ -109,7 +109,7 @@ def discrete_escape(trace, msg): ) -def all_escape(trace, msg): +def all_escape(trace, msg: str) -> bool: """ :param trace: a partial trace :param msg: the message at a Pyro primitive site diff --git a/pyro/util.py b/pyro/util.py index 1acc5ebb1e..64e4f5f230 100644 --- a/pyro/util.py +++ b/pyro/util.py @@ -11,6 +11,7 @@ from collections import defaultdict from contextlib import contextmanager from itertools import zip_longest +from typing import Dict, Optional import numpy as np import torch @@ -18,7 +19,7 @@ from pyro.poutine.util import site_is_subsample -def set_rng_seed(rng_seed): +def set_rng_seed(rng_seed : int) -> None: """ Sets seeds of `torch` and `torch.cuda` (if available). @@ -29,15 +30,12 @@ def set_rng_seed(rng_seed): np.random.seed(rng_seed) -def get_rng_state(): - return { - "torch": torch.get_rng_state(), - "random": random.getstate(), - "numpy": np.random.get_state(), - } +def get_rng_state() -> Dict: + return {'torch': torch.get_rng_state(), 'random': random.getstate(), 'numpy': np.random.get_state()} -def set_rng_state(state): + +def set_rng_state(state:Dict) -> Dict: torch.set_rng_state(state["torch"]) random.setstate(state["random"]) if "numpy" in state: @@ -46,7 +44,7 @@ def set_rng_state(state): np.random.set_state(state["numpy"]) -def torch_isnan(x): +def torch_isnan(x : torch.Tensor) -> torch.Tensor: """ A convenient function to check if a Tensor contains any nan; also works with numbers """ @@ -55,7 +53,7 @@ def torch_isnan(x): return torch.isnan(x).any() -def torch_isinf(x): +def torch_isinf(x : torch.Tensor) -> torch.Tensor: """ A convenient function to check if a Tensor contains any +inf; also works with numbers """ @@ -64,7 +62,7 @@ def torch_isinf(x): return (x == math.inf).any() or (x == -math.inf).any() -def warn_if_nan(value, msg="", *, filename=None, lineno=None): +def warn_if_nan(value, msg : str="", *, filename :Optional[str] = None, lineno: Optional[bool]=None) -> torch.Tensor: """ A convenient function to warn if a Tensor or its grad contains any nan, also works with numbers. @@ -97,9 +95,9 @@ def warn_if_nan(value, msg="", *, filename=None, lineno=None): return value -def warn_if_inf( - value, msg="", allow_posinf=False, allow_neginf=False, *, filename=None, lineno=None -): + +def warn_if_inf(value:torch.Tensor, msg : str="", allow_posinf:bool=False, allow_neginf:bool=False, *, + filename : str =None, lineno=None) -> torch.Tensor: """ A convenient function to warn if a Tensor or its grad contains any inf, also works with numbers. diff --git a/setup.cfg b/setup.cfg index 4e98b0f2fc..9590d19749 100644 --- a/setup.cfg +++ b/setup.cfg @@ -78,9 +78,6 @@ warn_unused_ignores = True [mypy-pyro.optm.*] warn_unused_ignores = True -[mypy-pyro.params.*] -ignore_errors = True -warn_unused_ignores = True [mypy-pyro.poutine.*] ignore_errors = True