Skip to content
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions chainerrl/misc/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,6 @@
from chainerrl.misc.draw_computational_graph import draw_computational_graph # NOQA
from chainerrl.misc.draw_computational_graph import is_graphviz_available # NOQA
from chainerrl.misc import env_modifiers # NOQA
from chainerrl.misc.namedpersistent import namedpersistent # NOQA
from chainerrl.misc.is_return_code_zero import is_return_code_zero # NOQA
from chainerrl.misc.random_seed import set_random_seed # NOQA
55 changes: 53 additions & 2 deletions chainerrl/misc/async_.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import chainer
import numpy as np

import chainerrl
from chainerrl.misc import random_seed


Expand All @@ -32,19 +33,56 @@ def ensure_initialized_update_rule(param):
u.init_state(param)


def _set_persistent_values_recursively(link, persistent_name, shared_array):
if persistent_name.startswith('/'):
persistent_name = persistent_name[1:]
if hasattr(link, persistent_name):
attr_name = persistent_name
attr = getattr(link, attr_name)
if isinstance(attr, np.ndarray):
setattr(link, persistent_name, np.frombuffer(
shared_array, dtype=attr.dtype).reshape(attr.shape))
else:
assert np.isscalar(attr)
# We wrap scalars with np.ndarray because
# multiprocessing.RawValue cannot be used as a scalar, while
# np.ndarray can be.
typecode = np.asarray(attr).dtype.char
setattr(link, attr_name, np.frombuffer(
shared_array, dtype=typecode).reshape(()))
else:
assert isinstance(link, (chainer.Chain, chainer.ChainList))
assert '/' in persistent_name
child_name, remaining = persistent_name.split('/', maxsplit=1)
if isinstance(link, chainer.Chain):
_set_persistent_values_recursively(
getattr(link, child_name), remaining, shared_array)
else:
_set_persistent_values_recursively(
link._children[int(child_name)], remaining, shared_array)
Comment thread
muupan marked this conversation as resolved.
Outdated


def set_shared_params(a, b):
"""Set shared params to a link.
"""Set shared params (and persistent values) to a link.

Args:
a (chainer.Link): link whose params are to be replaced
b (dict): dict that consists of (param_name, multiprocessing.Array)
"""
assert isinstance(a, chainer.Link)
remaining_keys = set(b.keys())
for param_name, param in a.namedparams():
if param_name in b:
shared_param = b[param_name]
param.array = np.frombuffer(
shared_param, dtype=param.dtype).reshape(param.shape)
remaining_keys.remove(param_name)
for persistent_name, _ in chainerrl.misc.namedpersistent(a):
if persistent_name in b:
_set_persistent_values_recursively(
a, persistent_name, b[persistent_name])
remaining_keys.remove(persistent_name)
assert not remaining_keys


def make_params_not_shared(a):
Expand Down Expand Up @@ -85,7 +123,20 @@ def extract_params_as_shared_arrays(link):
assert isinstance(link, chainer.Link)
shared_arrays = {}
for param_name, param in link.namedparams():
shared_arrays[param_name] = mp.RawArray('f', param.array.ravel())
typecode = param.array.dtype.char
shared_arrays[param_name] = mp.RawArray(typecode, param.array.ravel())

for persistent_name, persistent in chainerrl.misc.namedpersistent(link):
if isinstance(persistent, np.ndarray):
typecode = persistent.dtype.char
shared_arrays[persistent_name] = mp.RawArray(
typecode, persistent.ravel())
else:
assert np.isscalar(persistent)
persistent_as_array = np.asarray([persistent])
Comment thread
muupan marked this conversation as resolved.
typecode = persistent_as_array.dtype.char
shared_arrays[persistent_name] = mp.RawArray(
typecode, persistent_as_array)
return shared_arrays


Expand Down
37 changes: 37 additions & 0 deletions chainerrl/misc/namedpersistent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
from __future__ import unicode_literals
from __future__ import print_function
from __future__ import division
from __future__ import absolute_import
from builtins import * # NOQA
from future import standard_library
standard_library.install_aliases() # NOQA

import chainer


def namedpersistent(link):
"""Return a generator of all (path, persistent) pairs for a given link.

This function is adopted from https://github.com/chainer/chainer/pull/6788.
Once it is merged into Chainer, we should use the property instead.

Args:
link (chainer.Link): Link.

Returns:
A generator object that generates all (path, persistent) pairs.
The paths are relative from this link.
"""
d = link.__dict__
for name in sorted(link._persistent):
yield '/' + name, d[name]
if isinstance(link, chainer.Chain):
for name in sorted(link._children):
prefix = '/' + name
for path, persistent in namedpersistent(d[name]):
yield prefix + path, persistent
elif isinstance(link, chainer.ChainList):
for idx, link in enumerate(link._children):
prefix = '/{}'.format(idx)
for path, persistent in namedpersistent(link):
yield prefix + path, persistent
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

def _namedchildren(link):
    ...
    elif isinstance(link, chainer.ChainList):
        for idx, child in enumerate(link._children):
            yield str(idx), child

could avoid repeating the recursion logic.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure how. Can you give me some more detail? chainer/chainer#6788 is also implemented with recursion, and I don't think we need to deviate from it.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably I meant that a mutual recursion between namedpersistent and _namedchildren looks cleaner. It can be a matter of taste. Sorry that my previous comment is not readable (to me neither at first sight).

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Understood, thanks! I modified accordingly, so can you check again?

151 changes: 129 additions & 22 deletions tests/misc_tests/test_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,51 +19,160 @@
import copy
import numpy as np

import chainerrl
from chainerrl.misc import async_


def _assert_same_pointers_to_persistent_values(a, b):
assert isinstance(a, chainer.Link)
assert isinstance(b, chainer.Link)
a_persistents = dict(chainerrl.misc.namedpersistent(a))
b_persistents = dict(chainerrl.misc.namedpersistent(b))
assert set(a_persistents.keys()) == set(b_persistents.keys())
for key in a_persistents:
a_persistent = a_persistents[key]
b_persistent = b_persistents[key]
assert isinstance(a_persistent, np.ndarray)
assert isinstance(b_persistent, np.ndarray)
assert a_persistent.ctypes.data == b_persistent.ctypes.data


def _assert_same_pointers_to_param_data(a, b):
assert isinstance(a, chainer.Link)
assert isinstance(b, chainer.Link)
a_params = dict(a.namedparams())
b_params = dict(b.namedparams())
assert set(a_params.keys()) == set(b_params.keys())
for key in a_params.keys():
assert isinstance(a_params[key], chainer.Variable)
assert isinstance(b_params[key], chainer.Variable)
assert (a_params[key].array.ctypes.data
== b_params[key].array.ctypes.data)


def _assert_different_pointers_to_param_grad(a, b):
assert isinstance(a, chainer.Link)
assert isinstance(b, chainer.Link)
a_params = dict(a.namedparams())
b_params = dict(b.namedparams())
assert set(a_params.keys()) == set(b_params.keys())
for key in a_params.keys():
assert isinstance(a_params[key], chainer.Variable)
assert isinstance(b_params[key], chainer.Variable)
assert (a_params[key].grad.ctypes.data
!= b_params[key].grad.ctypes.data)


class TestAsync(unittest.TestCase):

def setUp(self):
pass

def test_share_params(self):
def test_share_params_linear(self):

# A's params are shared with B and C so that all the three share the
# same parameter arrays

model_a = L.Linear(2, 2)

arrays = async_.share_params_as_shared_arrays(model_a)
assert isinstance(arrays, dict)
assert set(arrays.keys()) == {'/W', '/b'}

model_b = L.Linear(2, 2)
model_c = L.Linear(2, 2)

async_.set_shared_params(model_b, arrays)
async_.set_shared_params(model_c, arrays)

a_params = dict(model_a.namedparams())
b_params = dict(model_b.namedparams())
c_params = dict(model_c.namedparams())
# Pointers to parameters must be the same
_assert_same_pointers_to_param_data(model_a, model_b)
_assert_same_pointers_to_param_data(model_a, model_c)
# Pointers to gradients must be different
_assert_different_pointers_to_param_grad(model_a, model_b)
_assert_different_pointers_to_param_grad(model_a, model_c)
_assert_different_pointers_to_param_grad(model_b, model_c)
# Pointers to persistent values must be the same
_assert_same_pointers_to_persistent_values(model_a, model_b)
_assert_same_pointers_to_persistent_values(model_a, model_c)

def test_share_params_batch_normalization(self):

# A's params and persistent values are all shared with B and C

model_a = L.BatchNormalization(3)

arrays = async_.share_params_as_shared_arrays(model_a)
assert isinstance(arrays, dict)
assert set(arrays.keys()) == {
'/gamma', '/beta', '/avg_mean', '/avg_var', '/N'}

def assert_same_pointers_to_data(a, b):
self.assertEqual(a['/W'].array.ctypes.data,
b['/W'].array.ctypes.data)
self.assertEqual(a['/b'].array.ctypes.data,
b['/b'].array.ctypes.data)
model_b = L.BatchNormalization(3)
model_c = L.BatchNormalization(3)

def assert_different_pointers_to_grad(a, b):
self.assertNotEqual(a['/W'].grad.ctypes.data,
b['/W'].grad.ctypes.data)
self.assertNotEqual(a['/b'].grad.ctypes.data,
b['/b'].grad.ctypes.data)
async_.set_shared_params(model_b, arrays)
async_.set_shared_params(model_c, arrays)

# Pointers to parameters must be the same
_assert_same_pointers_to_param_data(model_a, model_b)
_assert_same_pointers_to_param_data(model_a, model_c)
# Pointers to gradients must be different
_assert_different_pointers_to_param_grad(model_a, model_b)
_assert_different_pointers_to_param_grad(model_a, model_c)
_assert_different_pointers_to_param_grad(model_b, model_c)
# Pointers to persistent values must be the same
_assert_same_pointers_to_persistent_values(model_a, model_b)
_assert_same_pointers_to_persistent_values(model_a, model_c)

# Check if N is shared correctly among links
assert model_a.N == 0
assert model_b.N == 0
assert model_c.N == 0
test_input = np.random.normal(size=(2, 3)).astype(np.float32)
model_a(test_input, finetune=True)
assert model_a.N == 1
assert model_b.N == 1
assert model_c.N == 1
model_c(test_input, finetune=True)
assert model_a.N == 2
assert model_b.N == 2
assert model_c.N == 2

def test_share_params_chain_list(self):

model_a = chainer.ChainList(
L.BatchNormalization(3),
chainer.ChainList(L.Linear(3, 5)),
)

arrays = async_.share_params_as_shared_arrays(model_a)
assert isinstance(arrays, dict)
assert set(arrays.keys()) == {
'/0/gamma', '/0/beta', '/0/avg_mean', '/0/avg_var', '/0/N',
'/1/0/W', '/1/0/b'}

model_b = chainer.ChainList(
L.BatchNormalization(3),
chainer.ChainList(L.Linear(3, 5)),
)
model_c = chainer.ChainList(
L.BatchNormalization(3),
chainer.ChainList(L.Linear(3, 5)),
)

async_.set_shared_params(model_b, arrays)
async_.set_shared_params(model_c, arrays)

# Pointers to parameters must be the same
assert_same_pointers_to_data(a_params, b_params)
assert_same_pointers_to_data(a_params, c_params)
_assert_same_pointers_to_param_data(model_a, model_b)
_assert_same_pointers_to_param_data(model_a, model_c)
# Pointers to gradients must be different
assert_different_pointers_to_grad(a_params, b_params)
assert_different_pointers_to_grad(a_params, c_params)
_assert_different_pointers_to_param_grad(model_a, model_b)
_assert_different_pointers_to_param_grad(model_a, model_c)
_assert_different_pointers_to_param_grad(model_b, model_c)
# Pointers to persistent values must be the same
_assert_same_pointers_to_persistent_values(model_a, model_b)
_assert_same_pointers_to_persistent_values(model_a, model_c)

def test_share_states(self):

Expand Down Expand Up @@ -114,10 +223,8 @@ def test_shared_link(self):
model_a = chainer.ChainList(head.copy(), L.Linear(2, 3))
model_b = chainer.ChainList(head.copy(), L.Linear(2, 4))

a_arrays = async_.extract_params_as_shared_arrays(
chainer.ChainList(model_a))
b_arrays = async_.extract_params_as_shared_arrays(
chainer.ChainList(model_b))
a_arrays = async_.extract_params_as_shared_arrays(model_a)
b_arrays = async_.extract_params_as_shared_arrays(model_b)

print(('model_a shared_arrays', a_arrays))
print(('model_b shared_arrays', b_arrays))
Expand Down
52 changes: 52 additions & 0 deletions tests/misc_tests/test_namedpersistent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
from __future__ import unicode_literals
from __future__ import print_function
from __future__ import division
from __future__ import absolute_import
from builtins import * # NOQA
from future import standard_library
standard_library.install_aliases() # NOQA

import chainer
import numpy

import chainerrl


def test_namedpersistent():
# This test case is adopted from
# https://github.com/chainer/chainer/pull/6788

l1 = chainer.Link()
with l1.init_scope():
l1.x = chainer.Parameter(shape=(2, 3))

l2 = chainer.Link()
with l2.init_scope():
l2.x = chainer.Parameter(shape=2)
l2.add_persistent(
'l2_a', numpy.array([1, 2, 3], dtype=numpy.float32))

l3 = chainer.Link()
with l3.init_scope():
l3.x = chainer.Parameter()
l3.add_persistent(
'l3_a', numpy.array([1, 2, 3], dtype=numpy.float32))

c1 = chainer.Chain()
with c1.init_scope():
c1.l1 = l1
c1.add_link('l2', l2)
c1.add_persistent(
'c1_a', numpy.array([1, 2, 3], dtype=numpy.float32))

c2 = chainer.Chain()
with c2.init_scope():
c2.c1 = c1
c2.l3 = l3
c2.add_persistent(
'c2_a', numpy.array([1, 2, 3], dtype=numpy.float32))
namedpersistent = list(chainerrl.misc.namedpersistent(c2))
assert (
[(name, id(p)) for name, p in namedpersistent] ==
[('/c2_a', id(c2.c2_a)), ('/c1/c1_a', id(c2.c1.c1_a)),
('/c1/l2/l2_a', id(c2.c1.l2.l2_a)), ('/l3/l3_a', id(c2.l3.l3_a))])