-
Notifications
You must be signed in to change notification settings - Fork 226
[WIP] Implements Hindsight Experience Replay #361
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
prabhatnagarajan
wants to merge
62
commits into
chainer:master
Choose a base branch
from
prabhatnagarajan:her
base: master
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from 59 commits
Commits
Show all changes
62 commits
Select commit
Hold shift + click to select a range
100ba3b
adds an empty HER class
prabhatnagarajan a43dac1
fixes flake
prabhatnagarajan e5d8c74
Merge branch 'master' into her
prabhatnagarajan e15a4e1
Merge branch 'master' into her
prabhatnagarajan 8399f10
adds a train_her_example file
prabhatnagarajan 98565f0
modifies phi function to remove bug
prabhatnagarajan 4633be3
uses hindsightbuffer instead of replay buffer (hindsight is a shell c…
prabhatnagarajan 87cfb8d
Merge branch 'master' into her
prabhatnagarajan 6bb6990
small modifications for her
prabhatnagarajan 0e5b063
adds current episode variable and assert that n != 1 to hindsight rep…
prabhatnagarajan 0329358
adds part of HER transition-storing loop
prabhatnagarajan d65e6b4
minor changes to support future goals:
prabhatnagarajan 0796caa
makes HindsightBuffer and EpisodicBuffer
prabhatnagarajan 3899264
implements future sampling
prabhatnagarajan 4a20443
updates the update frequency
prabhatnagarajan e65e784
changes default gamma to be 0.98, to match paper
prabhatnagarajan 119e2ed
changes buffer size to avoid error
prabhatnagarajan 38582aa
adds some starter code for the HER explorer
prabhatnagarajan 5595404
implements HER exploration
prabhatnagarajan e353632
adds normalization to DDPG and HER
prabhatnagarajan 07575ad
adds a clip threshold argument
prabhatnagarajan 160831c
Merge branch 'master' into her
prabhatnagarajan 7602139
Merge branch 'master' into her
prabhatnagarajan 466cae5
gets rid of batch normalization option in her example
prabhatnagarajan 26ae4e3
sets eval interval to match paper
prabhatnagarajan 8a10d1a
adds wrapper class to get success rate
prabhatnagarajan 17804df
makes some fixes to normalization code
prabhatnagarajan 6fd78ba
implement clipped critic target to her
prabhatnagarajan 4afb548
updates target updat einterval to match paper
prabhatnagarajan 13aae0f
sets polyak averaging parameter
prabhatnagarajan 67b83fc
uses chainer functions clip
prabhatnagarajan e325dbd
Merge branch 'master' into her
prabhatnagarajan fdabdcb
changes actor learning rate to match baselines and paper
prabhatnagarajan c8e546d
cleans up some redundant code and removes reward filter as it's not used
prabhatnagarajan 9377ef5
removes reward scaling from env wrapper
prabhatnagarajan e25a320
Merge branch 'master' into her
prabhatnagarajan caebfa1
adds some tests and comments to HER
prabhatnagarajan e809d6a
fixes merge conflict
prabhatnagarajan 7fc391d
makes a batch HER agent
prabhatnagarajan 0d43949
Merge branch 'master' into her
prabhatnagarajan 73dc692
merges and addresses flakes
prabhatnagarajan 02fa906
Merge branch 'master' into her
prabhatnagarajan d9b110b
reverts epsilon to match original paper
prabhatnagarajan 280571f
Merge branch 'master' into her
prabhatnagarajan b4b2cea
merges with master
prabhatnagarajan 620be35
adds args to Hindsight Buffer
prabhatnagarajan 25e563b
adds action penalty to ddpg
prabhatnagarajan 753dc70
only switches goals if achieved goal is not none
prabhatnagarajan a98e20c
Merge branch 'master' into her
prabhatnagarajan 3c7e4e0
adds observation normalization for batch training
prabhatnagarajan 289e551
fixes incorrect naming
prabhatnagarajan d63beca
fixes more bugs
prabhatnagarajan 56eb3dc
Merge branch 'master' into her
prabhatnagarajan db24b8e
fixes a minor bug
prabhatnagarajan 23e8edb
merges with master, moves HER to replay_buffers directory
prabhatnagarajan 52b538e
refactors hindsight code, addresses flakes, conforms to new master
prabhatnagarajan dd57c94
Merge branch 'master' into her
prabhatnagarajan 1d102fd
adds HER to readme
prabhatnagarajan 159b2af
changes reward structure for HER
prabhatnagarajan e4f7a2e
Merge branch 'master' into her
prabhatnagarajan a87cbce
fixes ation penalty bug and actually uses action penalty
prabhatnagarajan 2a3207a
makes default exploration match paper
prabhatnagarajan File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -75,6 +75,8 @@ class DDPG(AttributeSavingMixin, BatchAgent): | |||||
| logger (Logger): Logger used | ||||||
| batch_states (callable): method which makes a batch of observations. | ||||||
| default is `chainerrl.misc.batch_states.batch_states` | ||||||
| clip_critic_tgt (tuple or None) : tuple containing (min, max) to clip | ||||||
| the target of the critic. If None, target will not be clipped. | ||||||
| burnin_action_func (callable or None): If not None, this callable | ||||||
| object is used to select actions before the model is updated | ||||||
| one or more times during training. | ||||||
|
|
@@ -83,10 +85,11 @@ class DDPG(AttributeSavingMixin, BatchAgent): | |||||
| saved_attributes = ('model', | ||||||
| 'target_model', | ||||||
| 'actor_optimizer', | ||||||
| 'critic_optimizer') | ||||||
| 'critic_optimizer', | ||||||
| 'obs_normalizer') | ||||||
|
|
||||||
| def __init__(self, model, actor_optimizer, critic_optimizer, replay_buffer, | ||||||
| gamma, explorer, | ||||||
| gamma, explorer, obs_normalizer=None, | ||||||
| gpu=None, replay_start_size=50000, | ||||||
| minibatch_size=32, update_interval=1, | ||||||
| target_update_interval=10000, | ||||||
|
|
@@ -99,14 +102,19 @@ def __init__(self, model, actor_optimizer, critic_optimizer, replay_buffer, | |||||
| episodic_update_len=None, | ||||||
| logger=getLogger(__name__), | ||||||
| batch_states=batch_states, | ||||||
| l2_action_penalty=None, | ||||||
| clip_critic_tgt=None, | ||||||
| burnin_action_func=None, | ||||||
| ): | ||||||
|
|
||||||
| self.model = model | ||||||
| self.obs_normalizer = obs_normalizer | ||||||
|
|
||||||
| if gpu is not None and gpu >= 0: | ||||||
| cuda.get_device(gpu).use() | ||||||
| self.model.to_gpu(device=gpu) | ||||||
| if self.obs_normalizer is not None: | ||||||
| self.obs_normalizer.to_gpu(device=gpu) | ||||||
|
|
||||||
| self.xp = self.model.xp | ||||||
| self.replay_buffer = replay_buffer | ||||||
|
|
@@ -137,6 +145,8 @@ def __init__(self, model, actor_optimizer, critic_optimizer, replay_buffer, | |||||
| update_interval=update_interval, | ||||||
| ) | ||||||
| self.batch_states = batch_states | ||||||
| self.clip_critic_tgt = clip_critic_tgt | ||||||
| self.l2_action_penalty = l2_action_penalty | ||||||
| self.burnin_action_func = burnin_action_func | ||||||
|
|
||||||
| self.t = 0 | ||||||
|
|
@@ -204,6 +214,10 @@ def compute_critic_loss(self, batch): | |||||
|
|
||||||
| target_q = batch_rewards + self.gamma * \ | ||||||
| (1.0 - batch_terminal) * F.reshape(next_q, (batchsize,)) | ||||||
| if self.clip_critic_tgt: | ||||||
| target_q = F.clip(target_q, | ||||||
| self.clip_critic_tgt[0], | ||||||
| self.clip_critic_tgt[1]) | ||||||
|
|
||||||
| # Estimated Q-function observes s_t and a_t | ||||||
| predict_q = F.reshape( | ||||||
|
|
@@ -251,6 +265,9 @@ def compute_actor_loss(self, batch): | |||||
|
|
||||||
| # Since we want to maximize Q, loss is negation of Q | ||||||
| loss = - F.sum(q) / batch_size | ||||||
| if self.l2_action_penalty: | ||||||
| loss += self.l2_action_penalty \ | ||||||
| * F.square(onpolicy_actions) / batch_size | ||||||
|
|
||||||
| # Update stats | ||||||
| self.average_actor_loss *= self.average_loss_decay | ||||||
|
|
@@ -260,8 +277,11 @@ def compute_actor_loss(self, batch): | |||||
|
|
||||||
| def update(self, experiences, errors_out=None): | ||||||
| """Update the model from experiences""" | ||||||
|
|
||||||
| batch = batch_experiences(experiences, self.xp, self.phi, self.gamma) | ||||||
| if self.obs_normalizer: | ||||||
| batch['state'] = self.obs_normalizer(batch['state'], update=False) | ||||||
| batch['next_state'] = self.obs_normalizer(batch['next_state'], | ||||||
| update=False) | ||||||
| self.critic_optimizer.update(lambda: self.compute_critic_loss(batch)) | ||||||
| self.actor_optimizer.update(lambda: self.compute_actor_loss(batch)) | ||||||
|
|
||||||
|
|
@@ -280,6 +300,11 @@ def update_from_episodes(self, episodes, errors_out=None): | |||||
| transitions.append([ep[i]]) | ||||||
| batch = batch_experiences( | ||||||
| transitions, xp=self.xp, phi=self.phi, gamma=self.gamma) | ||||||
| if self.obs_normalizer: | ||||||
| batch['state'] = self.obs_normalizer(batch['state'], | ||||||
| update=False) | ||||||
| batch['next_state'] = self.obs_normalizer(batch['state'], | ||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Shouldn't this be
Suggested change
|
||||||
| update=False) | ||||||
| batches.append(batch) | ||||||
|
|
||||||
| with self.model.state_reset(), self.target_model.state_reset(): | ||||||
|
|
@@ -330,6 +355,9 @@ def act_and_train(self, obs, reward): | |||||
| next_state=obs, | ||||||
| next_action=action, | ||||||
| is_state_terminal=False) | ||||||
| # Add to Normalizer | ||||||
| if self.obs_normalizer: | ||||||
| self.obs_normalizer.experience([obs]) | ||||||
|
|
||||||
| self.last_state = obs | ||||||
| self.last_action = action | ||||||
|
|
@@ -339,9 +367,10 @@ def act_and_train(self, obs, reward): | |||||
| return self.last_action | ||||||
|
|
||||||
| def act(self, obs): | ||||||
|
|
||||||
| with chainer.using_config('train', False): | ||||||
| s = self.batch_states([obs], self.xp, self.phi) | ||||||
| if self.obs_normalizer: | ||||||
| s = self.obs_normalizer(s, update=False) | ||||||
| action = self.policy(s).sample() | ||||||
| # Q is not needed here, but log it just for information | ||||||
| q = self.q_function(s, action) | ||||||
|
|
@@ -363,9 +392,10 @@ def batch_act(self, batch_obs): | |||||
| Returns: | ||||||
| Sequence of ~object: Actions. | ||||||
| """ | ||||||
|
|
||||||
| with chainer.using_config('train', False), chainer.no_backprop_mode(): | ||||||
| batch_xs = self.batch_states(batch_obs, self.xp, self.phi) | ||||||
| if self.obs_normalizer: | ||||||
| batch_xs = self.obs_normalizer(batch_xs, update=False) | ||||||
| batch_action = self.policy(batch_xs).sample() | ||||||
| # Q is not needed here, but log it just for information | ||||||
| q = self.q_function(batch_xs, batch_action) | ||||||
|
|
@@ -398,7 +428,12 @@ def batch_act_and_train(self, batch_obs): | |||||
| self.explorer.select_action( | ||||||
| self.t, lambda: batch_greedy_action[i]) | ||||||
| for i in range(len(batch_greedy_action))] | ||||||
|
|
||||||
| # Add to Normalizer | ||||||
| if self.obs_normalizer: | ||||||
| self.obs_normalizer.experience( | ||||||
| self.batch_states(batch_obs, | ||||||
| self.xp, | ||||||
| self.phi)) | ||||||
| self.batch_last_obs = list(batch_obs) | ||||||
| self.batch_last_action = list(batch_action) | ||||||
|
|
||||||
|
|
@@ -459,7 +494,11 @@ def stop_episode_and_train(self, state, reward, done=False): | |||||
| next_state=state, | ||||||
| next_action=self.last_action, | ||||||
| is_state_terminal=done) | ||||||
|
|
||||||
| # Add to Normalizer | ||||||
| if self.obs_normalizer: | ||||||
| self.obs_normalizer(self.batch_states([state], | ||||||
| self.xp, | ||||||
| self.phi)) | ||||||
| self.stop_episode() | ||||||
|
|
||||||
| def stop_episode(self): | ||||||
|
|
||||||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,85 @@ | ||
| from __future__ import print_function | ||
| from __future__ import unicode_literals | ||
| from __future__ import division | ||
| from __future__ import absolute_import | ||
| from builtins import * # NOQA | ||
| from future import standard_library | ||
| standard_library.install_aliases() # NOQA | ||
|
|
||
| import copy | ||
|
|
||
| import numpy as np | ||
|
|
||
| from chainerrl import replay_buffer | ||
| from chainerrl.replay_buffers.episodic import EpisodicReplayBuffer # NOQA | ||
|
|
||
|
|
||
| class HindsightReplayBuffer(EpisodicReplayBuffer): | ||
| """Hindsight Replay Buffer | ||
|
|
||
| https://arxiv.org/abs/1707.01495 | ||
|
|
||
| We currently do not support N-step transitions for the | ||
|
|
||
| Hindsight Buffer. | ||
|
|
||
| Args: | ||
| reward_function: achieved_goal, desired goal to reward | ||
| capacity (int): Capacity of the replay buffer | ||
| future_k (int): number of future goals to sample per true sample | ||
| """ | ||
|
|
||
| def __init__(self, reward_function, | ||
| capacity=None, | ||
| future_k=0): | ||
| super(HindsightReplayBuffer, self).__init__(capacity) | ||
| self.reward_function = reward_function | ||
| # probability of sampling a future goal instead of a | ||
| # true goal | ||
| self.future_prob = 1.0 - 1.0/(float(future_k) + 1) | ||
|
|
||
| def _replace_goal(self, transition, future_transition): | ||
| transition = copy.deepcopy(transition) | ||
| future_state = future_transition['next_state'] | ||
| assert future_state['achieved_goal'] is not None | ||
| new_goal = future_state['achieved_goal'] | ||
| transition['state']['desired_goal'] = new_goal | ||
| transition['next_state']['desired_goal'] = new_goal | ||
| transition['reward'] = self.reward_function( | ||
| transition['next_state']['achieved_goal'], | ||
| new_goal) | ||
| return transition | ||
|
|
||
| def sample(self, n): | ||
| assert len(self.memory) >= n | ||
| # Select n episodes | ||
| episodes = self.sample_episodes(n) | ||
| # Select timesteps from each episode | ||
| episode_lens = np.array([len(episode) for episode in episodes]) | ||
| timesteps = np.array( | ||
| [np.random.randint(episode_lens[i]) for i in range(n)]) | ||
| # Select episodes for which we use a future goal instead of true | ||
|
|
||
| do_replace = np.random.uniform(size=n) < self.future_prob | ||
| # Randomly select offsets of future goals | ||
| future_offset = np.random.uniform(size=n) * (episode_lens - timesteps) | ||
| future_offset = future_offset.astype(int) | ||
| future_times = timesteps + future_offset | ||
| batch = [] | ||
| # Go through episodes | ||
| for episode, timestep, future_timestep, replace in zip( | ||
| episodes, timesteps, future_times, do_replace): | ||
| transition = episode[timestep] | ||
| if replace: | ||
| future_transition = episode[future_timestep] | ||
| transition = self._replace_goal(transition, future_transition) | ||
| batch.append([transition]) | ||
| return batch | ||
|
|
||
| def sample_episodes(self, n_episodes, max_len=None): | ||
| episodes = self.episodic_memory.sample_with_replacement(n_episodes) | ||
| if max_len is not None: | ||
| return [replay_buffer.random_subseq(ep, max_len) | ||
| for ep in episodes] | ||
| else: | ||
| return episodes |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should this also include a F.sum term around the F.square?