diff --git a/brax/envs/wrappers/training.py b/brax/envs/wrappers/training.py index e6383c6b3..78a3a54f2 100644 --- a/brax/envs/wrappers/training.py +++ b/brax/envs/wrappers/training.py @@ -98,10 +98,20 @@ def reset(self, rng: jax.Array) -> State: def step(self, state: State, action: jax.Array) -> State: def f(state, _): nstate = self.env.step(state, action) - return nstate, nstate.reward + return nstate, (nstate.reward, nstate.metrics) - state, rewards = jax.lax.scan(f, state, (), self.action_repeat) + state, (rewards, all_metrics) = jax.lax.scan( + f, state, (), self.action_repeat + ) state = state.replace(reward=jp.sum(rewards, axis=0)) + # Sum per-step metrics across action-repeat sub-steps so that + # sparse or per-step metrics (e.g. action-change penalties) are + # correctly accumulated instead of only reflecting the last + # sub-step. See #610. + summed_metrics = jax.tree_util.tree_map( + lambda m: jp.sum(m, axis=0), all_metrics + ) + state = state.replace(metrics=summed_metrics) steps = state.info['steps'] + self.action_repeat one = jp.ones_like(state.done) zero = jp.zeros_like(state.done) @@ -257,4 +267,4 @@ def step(sys, s, a): res = jax.vmap(step, in_axes=[self._in_axes, 0, 0])( self._sys_v, state, action ) - return res + return res \ No newline at end of file