diff --git a/chainerrl/distribution.py b/chainerrl/distribution.py index 0cf8e4ce8..cda0dc9b2 100644 --- a/chainerrl/distribution.py +++ b/chainerrl/distribution.py @@ -19,13 +19,6 @@ from chainerrl.functions import mellowmax -def _wrap_by_variable(x): - if isinstance(x, chainer.Variable): - return x - else: - return chainer.Variable(x) - - def _unwrap_variable(x): if isinstance(x, chainer.Variable): return x.array @@ -263,8 +256,8 @@ class GaussianDistribution(Distribution): """Gaussian distribution.""" def __init__(self, mean, var): - self.mean = _wrap_by_variable(mean) - self.var = _wrap_by_variable(var) + self.mean = chainer.as_variable(mean) + self.var = chainer.as_variable(var) self.ln_var = F.log(var) @property @@ -324,7 +317,7 @@ class ContinuousDeterministicDistribution(Distribution): """ def __init__(self, x): - self.x = _wrap_by_variable(x) + self.x = chainer.as_variable(x) @cached_property def entropy(self):