-
Notifications
You must be signed in to change notification settings - Fork 226
Share persistent values among processes #486
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 11 commits
9a26e2a
5c120ba
fbafd1a
ae1efd1
70a2df1
03f46cc
2b9066d
ae87f41
ae4b6be
6ad630d
bafbfad
5ff9d8d
5755ccb
5256c93
4aaf766
b82b710
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,37 @@ | ||
| from __future__ import unicode_literals | ||
| from __future__ import print_function | ||
| from __future__ import division | ||
| from __future__ import absolute_import | ||
| from builtins import * # NOQA | ||
| from future import standard_library | ||
| standard_library.install_aliases() # NOQA | ||
|
|
||
| import chainer | ||
|
|
||
|
|
||
| def namedpersistent(link): | ||
| """Return a generator of all (path, persistent) pairs for a given link. | ||
|
|
||
| This function is adopted from https://github.com/chainer/chainer/pull/6788. | ||
| Once it is merged into Chainer, we should use the property instead. | ||
|
|
||
| Args: | ||
| link (chainer.Link): Link. | ||
|
|
||
| Returns: | ||
| A generator object that generates all (path, persistent) pairs. | ||
| The paths are relative from this link. | ||
| """ | ||
| d = link.__dict__ | ||
| for name in sorted(link._persistent): | ||
| yield '/' + name, d[name] | ||
| if isinstance(link, chainer.Chain): | ||
| for name in sorted(link._children): | ||
| prefix = '/' + name | ||
| for path, persistent in namedpersistent(d[name]): | ||
| yield prefix + path, persistent | ||
| elif isinstance(link, chainer.ChainList): | ||
| for idx, link in enumerate(link._children): | ||
| prefix = '/{}'.format(idx) | ||
| for path, persistent in namedpersistent(link): | ||
| yield prefix + path, persistent | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. def _namedchildren(link):
...
elif isinstance(link, chainer.ChainList):
for idx, child in enumerate(link._children):
yield str(idx), childcould avoid repeating the recursion logic.
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm not sure how. Can you give me some more detail? chainer/chainer#6788 is also implemented with recursion, and I don't think we need to deviate from it.
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Probably I meant that a mutual recursion between
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Understood, thanks! I modified accordingly, so can you check again? |
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,52 @@ | ||
| from __future__ import unicode_literals | ||
| from __future__ import print_function | ||
| from __future__ import division | ||
| from __future__ import absolute_import | ||
| from builtins import * # NOQA | ||
| from future import standard_library | ||
| standard_library.install_aliases() # NOQA | ||
|
|
||
| import chainer | ||
| import numpy | ||
|
|
||
| import chainerrl | ||
|
|
||
|
|
||
| def test_namedpersistent(): | ||
| # This test case is adopted from | ||
| # https://github.com/chainer/chainer/pull/6788 | ||
|
|
||
| l1 = chainer.Link() | ||
| with l1.init_scope(): | ||
| l1.x = chainer.Parameter(shape=(2, 3)) | ||
|
|
||
| l2 = chainer.Link() | ||
| with l2.init_scope(): | ||
| l2.x = chainer.Parameter(shape=2) | ||
| l2.add_persistent( | ||
| 'l2_a', numpy.array([1, 2, 3], dtype=numpy.float32)) | ||
|
|
||
| l3 = chainer.Link() | ||
| with l3.init_scope(): | ||
| l3.x = chainer.Parameter() | ||
| l3.add_persistent( | ||
| 'l3_a', numpy.array([1, 2, 3], dtype=numpy.float32)) | ||
|
|
||
| c1 = chainer.Chain() | ||
| with c1.init_scope(): | ||
| c1.l1 = l1 | ||
| c1.add_link('l2', l2) | ||
| c1.add_persistent( | ||
| 'c1_a', numpy.array([1, 2, 3], dtype=numpy.float32)) | ||
|
|
||
| c2 = chainer.Chain() | ||
| with c2.init_scope(): | ||
| c2.c1 = c1 | ||
| c2.l3 = l3 | ||
| c2.add_persistent( | ||
| 'c2_a', numpy.array([1, 2, 3], dtype=numpy.float32)) | ||
| namedpersistent = list(chainerrl.misc.namedpersistent(c2)) | ||
| assert ( | ||
| [(name, id(p)) for name, p in namedpersistent] == | ||
| [('/c2_a', id(c2.c2_a)), ('/c1/c1_a', id(c2.c1.c1_a)), | ||
| ('/c1/l2/l2_a', id(c2.c1.l2.l2_a)), ('/l3/l3_a', id(c2.l3.l3_a))]) |
Uh oh!
There was an error while loading. Please reload this page.