diff --git a/jumanji/wrappers.py b/jumanji/wrappers.py index 98d1ec91d..406612a38 100644 --- a/jumanji/wrappers.py +++ b/jumanji/wrappers.py @@ -360,18 +360,40 @@ def render(self, state: State) -> Any: return super().render(state_0) +NEXT_OBS_KEY_IN_EXTRAS = "next_obs" + + +def _obs_in_extras( + state: State, timestep: TimeStep[Observation] +) -> Tuple[State, TimeStep[Observation]]: + """Place the observation in timestep.extras[NEXT_OBS_KEY_IN_EXTRAS]. + Used when auto-resetting to store the observation from the terminal TimeStep. + + Args: + state: State object containing the dynamics of the environment. + timestep: TimeStep object containing the timestep returned by the environment. + + Returns: + (state, timestep): where the observation is placed in timestep.extras["next_obs"]. + """ + extras = timestep.extras + extras[NEXT_OBS_KEY_IN_EXTRAS] = timestep.observation + return state, timestep.replace(extras=extras) # type: ignore + + class AutoResetWrapper(Wrapper): """Automatically resets environments that are done. Once the terminal state is reached, the state, observation, and step_type are reset. The observation and step_type of the terminal TimeStep is reset to the reset observation and StepType.LAST, respectively. The reward, discount, and extras retrieved from the transition to the terminal state. + NOTE: The observation from the terminal TimeStep is stored in timestep.extras["next_obs"]. WARNING: do not `jax.vmap` the wrapped environment (e.g. do not use with the `VmapWrapper`), which would lead to inefficient computation due to both the `step` and `reset` functions being processed each time `step` is called. Please use the `VmapAutoResetWrapper` instead. """ def _auto_reset( - self, state: State, timestep: TimeStep + self, state: State, timestep: TimeStep[Observation] ) -> Tuple[State, TimeStep[Observation]]: """Reset the state and overwrite `timestep.observation` with the reset observation if the episode has terminated. @@ -387,13 +409,17 @@ def _auto_reset( key, _ = jax.random.split(state.key) # type: ignore state, reset_timestep = self._env.reset(key) + # Place original observation in extras. + state, timestep = _obs_in_extras(state, timestep) + # Replace observation with reset observation. - timestep = timestep.replace( # type: ignore - observation=reset_timestep.observation - ) + timestep = timestep.replace(observation=reset_timestep.observation) # type: ignore return state, timestep + def reset(self, key: chex.PRNGKey) -> Tuple[State, TimeStep[Observation]]: + return _obs_in_extras(*super().reset(key)) + def step( self, state: State, action: chex.Array ) -> Tuple[State, TimeStep[Observation]]: @@ -404,7 +430,7 @@ def step( state, timestep = jax.lax.cond( timestep.last(), self._auto_reset, - lambda *x: x, + _obs_in_extras, state, timestep, ) @@ -421,6 +447,7 @@ class VmapAutoResetWrapper(Wrapper): - Homogeneous computation: call step function on all environments in the batch. - Heterogeneous computation: conditional auto-reset (call reset function for some environments within the batch because they have terminated). + NOTE: The observation from the terminal TimeStep is stored in timestep.extras["next_obs"]. """ def reset(self, key: chex.PRNGKey) -> Tuple[State, TimeStep[Observation]]: @@ -441,6 +468,7 @@ def reset(self, key: chex.PRNGKey) -> Tuple[State, TimeStep[Observation]]: environments, """ state, timestep = jax.vmap(self._env.reset)(key) + state, timestep = _obs_in_extras(state, timestep) return state, timestep def step( @@ -487,6 +515,9 @@ def _auto_reset( key, _ = jax.random.split(state.key) state, reset_timestep = self._env.reset(key) + # Place original observation in extras. + state, timestep = _obs_in_extras(state, timestep) + # Replace observation with reset observation. timestep = timestep.replace( # type: ignore observation=reset_timestep.observation @@ -501,7 +532,7 @@ def _maybe_reset( state, timestep = jax.lax.cond( timestep.last(), self._auto_reset, - lambda *x: x, + _obs_in_extras, state, timestep, ) diff --git a/jumanji/wrappers_test.py b/jumanji/wrappers_test.py index 4dc339694..7e8b8912f 100644 --- a/jumanji/wrappers_test.py +++ b/jumanji/wrappers_test.py @@ -32,6 +32,7 @@ from jumanji.testing.pytrees import assert_trees_are_different from jumanji.types import StepType, TimeStep from jumanji.wrappers import ( + NEXT_OBS_KEY_IN_EXTRAS, AutoResetWrapper, JumanjiToDMEnvWrapper, JumanjiToGymWrapper, @@ -535,6 +536,8 @@ def test_auto_reset_wrapper__auto_reset( state, timestep ) chex.assert_trees_all_equal(timestep.observation, reset_timestep.observation) + # Expect that non-reset timestep obs and extras are the same. + assert jnp.all(timestep.observation == timestep.extras[NEXT_OBS_KEY_IN_EXTRAS]) def test_auto_reset_wrapper__step_no_reset( self, fake_auto_reset_environment: AutoResetWrapper, key: chex.PRNGKey @@ -556,6 +559,8 @@ def test_auto_reset_wrapper__step_no_reset( assert timestep.step_type == StepType.MID assert_trees_are_different(timestep, first_timestep) chex.assert_trees_all_equal(timestep.reward, 0) + # no reset so expect extras and obs to be the same. + assert jnp.all(timestep.observation == timestep.extras[NEXT_OBS_KEY_IN_EXTRAS]) def test_auto_reset_wrapper__step_reset( self, @@ -566,18 +571,30 @@ def test_auto_reset_wrapper__step_reset( """Validates that the auto-reset is done correctly by the step function of the AutoResetWrapper when the terminal timestep is reached. """ - state, first_timestep = fake_auto_reset_environment.reset(key) + state, first_timestep = fake_auto_reset_environment.reset(key) # type: ignore fake_environment.time_limit = 5 # Loop across time_limit so auto-reset occurs - timestep = first_timestep - for _ in range(fake_environment.time_limit): + for _ in range(fake_environment.time_limit - 1): action = fake_auto_reset_environment.action_spec().generate_value() state, timestep = jax.jit(fake_auto_reset_environment.step)(state, action) + assert jnp.all( + timestep.observation == timestep.extras[NEXT_OBS_KEY_IN_EXTRAS] + ) + + state, final_timestep = jax.jit(fake_auto_reset_environment.step)(state, action) - assert timestep.step_type == StepType.LAST - chex.assert_trees_all_equal(timestep.observation, first_timestep.observation) + assert final_timestep.step_type == StepType.LAST + chex.assert_trees_all_equal( + final_timestep.observation, first_timestep.observation + ) + assert not jnp.all( + final_timestep.observation == final_timestep.extras[NEXT_OBS_KEY_IN_EXTRAS] + ) + assert jnp.all( + (timestep.observation + 1) == final_timestep.extras[NEXT_OBS_KEY_IN_EXTRAS] + ) class TestVmapAutoResetWrapper: @@ -614,6 +631,8 @@ def test_vmap_auto_reset_wrapper__reset( assert timestep.observation.shape[0] == keys.shape[0] assert timestep.reward.shape == (keys.shape[0],) assert timestep.discount.shape == (keys.shape[0],) + # only reset so expect extras and obs to be the same. + assert jnp.all(timestep.observation == timestep.extras[NEXT_OBS_KEY_IN_EXTRAS]) def test_vmap_auto_reset_wrapper__auto_reset( self, @@ -627,6 +646,10 @@ def test_vmap_auto_reset_wrapper__auto_reset( (state, timestep), ) chex.assert_trees_all_equal(timestep.observation, reset_timestep.observation) + # expect rest timestep.extras to have the same obs as the original timestep + assert jnp.all( + timestep.observation == reset_timestep.extras[NEXT_OBS_KEY_IN_EXTRAS] + ) def test_vmap_auto_reset_wrapper__maybe_reset( self, @@ -640,6 +663,10 @@ def test_vmap_auto_reset_wrapper__maybe_reset( (state, timestep), ) chex.assert_trees_all_equal(timestep.observation, reset_timestep.observation) + # expect rest timestep.extras to have the same obs as the original timestep + assert jnp.all( + timestep.observation == reset_timestep.extras[NEXT_OBS_KEY_IN_EXTRAS] + ) def test_vmap_auto_reset_wrapper__step_no_reset( self, @@ -657,6 +684,13 @@ def test_vmap_auto_reset_wrapper__step_no_reset( assert_trees_are_different(timestep, first_timestep) chex.assert_trees_all_equal(timestep.reward, 0) + # no reset so expect extras and obs to be the same. + # and the first timestep should have different obs in extras. + assert not jnp.all( + first_timestep.observation == timestep.extras[NEXT_OBS_KEY_IN_EXTRAS] + ) + assert jnp.all(timestep.observation == timestep.extras[NEXT_OBS_KEY_IN_EXTRAS]) + def test_vmap_auto_reset_wrapper__step_reset( self, fake_environment: Environment, @@ -671,13 +705,27 @@ def test_vmap_auto_reset_wrapper__step_reset( fake_vmap_auto_reset_environment.unwrapped.time_limit = 5 # type: ignore # Loop across time_limit so auto-reset occurs - for _ in range(fake_vmap_auto_reset_environment.time_limit): + for _ in range(fake_vmap_auto_reset_environment.time_limit - 1): state, timestep = jax.jit(fake_vmap_auto_reset_environment.step)( state, action ) + assert jnp.all( + timestep.observation == timestep.extras[NEXT_OBS_KEY_IN_EXTRAS] + ) - assert jnp.all(timestep.step_type == StepType.LAST) - chex.assert_trees_all_equal(timestep.observation, first_timestep.observation) + state, final_timestep = jax.jit(fake_vmap_auto_reset_environment.step)( + state, action + ) + assert jnp.all(final_timestep.step_type == StepType.LAST) + chex.assert_trees_all_equal( + final_timestep.observation, first_timestep.observation + ) + assert not jnp.all( + final_timestep.observation == final_timestep.extras[NEXT_OBS_KEY_IN_EXTRAS] + ) + assert jnp.all( + (timestep.observation + 1) == final_timestep.extras[NEXT_OBS_KEY_IN_EXTRAS] + ) def test_vmap_auto_reset_wrapper__step( self, @@ -696,6 +744,10 @@ def test_vmap_auto_reset_wrapper__step( assert next_timestep.reward.shape == (keys.shape[0],) assert next_timestep.discount.shape == (keys.shape[0],) assert next_timestep.observation.shape[0] == keys.shape[0] + # expect observation and extras to be the same, since no reset + assert jnp.all( + next_timestep.observation == next_timestep.extras[NEXT_OBS_KEY_IN_EXTRAS] + ) def test_vmap_auto_reset_wrapper__render( self, fake_vmap_auto_reset_environment: VmapAutoResetWrapper, keys: chex.PRNGKey