diff --git a/chainerrl/agents/dqn.py b/chainerrl/agents/dqn.py index a59e5cad2..fd8e454c9 100644 --- a/chainerrl/agents/dqn.py +++ b/chainerrl/agents/dqn.py @@ -352,19 +352,39 @@ def _compute_loss(self, exp_batch, gamma, errors_out=None): return compute_value_loss(y, t, clip_delta=self.clip_delta, batch_accumulator=self.batch_accumulator) - def act(self, obs): + def act_with_exploration(self, obs): + with chainer.using_config('train', False): with chainer.no_backprop_mode(): action_value = self.model( self.batch_states([obs], self.xp, self.phi)) q = float(action_value.max.data) - action = cuda.to_cpu(action_value.greedy_actions.data)[0] - + greedy_action = cuda.to_cpu(action_value.greedy_actions.data)[0] + + action = self.explorer.select_action(self.t, + lambda: greedy_action, + action_value=action_value) + self.t += 1 # Update stats self.average_q *= self.average_q_decay self.average_q += (1 - self.average_q_decay) * q self.logger.debug('t:%s q:%s action_value:%s', self.t, q, action_value) + return action + + def act(self, obs): + with chainer.using_config('train', False): + with chainer.no_backprop_mode(): + action_value = self.model( + self.batch_states([obs], self.xp, self.phi)) + q = float(action_value.max.data) + action = cuda.to_cpu(action_value.greedy_actions.data)[0] + + # Update stats + self.average_q *= self.average_q_decay + self.average_q += (1 - self.average_q_decay) * q + + self.logger.debug('t:%s q:%s action_value:%s', self.t, q, action_value) return action def act_and_train(self, obs, reward):