Skip to content

Commit

Permalink
fix: autoreset wrappers (#223)
Browse files Browse the repository at this point in the history
Co-authored-by: Clément Bonnet <[email protected]>
  • Loading branch information
sash-a and clement-bonnet authored Mar 8, 2024
1 parent d053a2d commit ce8b873
Show file tree
Hide file tree
Showing 2 changed files with 97 additions and 14 deletions.
43 changes: 37 additions & 6 deletions jumanji/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,18 +360,40 @@ def render(self, state: State) -> Any:
return super().render(state_0)


NEXT_OBS_KEY_IN_EXTRAS = "next_obs"


def _obs_in_extras(
state: State, timestep: TimeStep[Observation]
) -> Tuple[State, TimeStep[Observation]]:
"""Place the observation in timestep.extras[NEXT_OBS_KEY_IN_EXTRAS].
Used when auto-resetting to store the observation from the terminal TimeStep.
Args:
state: State object containing the dynamics of the environment.
timestep: TimeStep object containing the timestep returned by the environment.
Returns:
(state, timestep): where the observation is placed in timestep.extras["next_obs"].
"""
extras = timestep.extras
extras[NEXT_OBS_KEY_IN_EXTRAS] = timestep.observation
return state, timestep.replace(extras=extras) # type: ignore


class AutoResetWrapper(Wrapper):
"""Automatically resets environments that are done. Once the terminal state is reached,
the state, observation, and step_type are reset. The observation and step_type of the
terminal TimeStep is reset to the reset observation and StepType.LAST, respectively.
The reward, discount, and extras retrieved from the transition to the terminal state.
NOTE: The observation from the terminal TimeStep is stored in timestep.extras["next_obs"].
WARNING: do not `jax.vmap` the wrapped environment (e.g. do not use with the `VmapWrapper`),
which would lead to inefficient computation due to both the `step` and `reset` functions
being processed each time `step` is called. Please use the `VmapAutoResetWrapper` instead.
"""

def _auto_reset(
self, state: State, timestep: TimeStep
self, state: State, timestep: TimeStep[Observation]
) -> Tuple[State, TimeStep[Observation]]:
"""Reset the state and overwrite `timestep.observation` with the reset observation
if the episode has terminated.
Expand All @@ -387,13 +409,17 @@ def _auto_reset(
key, _ = jax.random.split(state.key) # type: ignore
state, reset_timestep = self._env.reset(key)

# Place original observation in extras.
state, timestep = _obs_in_extras(state, timestep)

# Replace observation with reset observation.
timestep = timestep.replace( # type: ignore
observation=reset_timestep.observation
)
timestep = timestep.replace(observation=reset_timestep.observation) # type: ignore

return state, timestep

def reset(self, key: chex.PRNGKey) -> Tuple[State, TimeStep[Observation]]:
return _obs_in_extras(*super().reset(key))

def step(
self, state: State, action: chex.Array
) -> Tuple[State, TimeStep[Observation]]:
Expand All @@ -404,7 +430,7 @@ def step(
state, timestep = jax.lax.cond(
timestep.last(),
self._auto_reset,
lambda *x: x,
_obs_in_extras,
state,
timestep,
)
Expand All @@ -421,6 +447,7 @@ class VmapAutoResetWrapper(Wrapper):
- Homogeneous computation: call step function on all environments in the batch.
- Heterogeneous computation: conditional auto-reset (call reset function for some environments
within the batch because they have terminated).
NOTE: The observation from the terminal TimeStep is stored in timestep.extras["next_obs"].
"""

def reset(self, key: chex.PRNGKey) -> Tuple[State, TimeStep[Observation]]:
Expand All @@ -441,6 +468,7 @@ def reset(self, key: chex.PRNGKey) -> Tuple[State, TimeStep[Observation]]:
environments,
"""
state, timestep = jax.vmap(self._env.reset)(key)
state, timestep = _obs_in_extras(state, timestep)
return state, timestep

def step(
Expand Down Expand Up @@ -487,6 +515,9 @@ def _auto_reset(
key, _ = jax.random.split(state.key)
state, reset_timestep = self._env.reset(key)

# Place original observation in extras.
state, timestep = _obs_in_extras(state, timestep)

# Replace observation with reset observation.
timestep = timestep.replace( # type: ignore
observation=reset_timestep.observation
Expand All @@ -501,7 +532,7 @@ def _maybe_reset(
state, timestep = jax.lax.cond(
timestep.last(),
self._auto_reset,
lambda *x: x,
_obs_in_extras,
state,
timestep,
)
Expand Down
68 changes: 60 additions & 8 deletions jumanji/wrappers_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from jumanji.testing.pytrees import assert_trees_are_different
from jumanji.types import StepType, TimeStep
from jumanji.wrappers import (
NEXT_OBS_KEY_IN_EXTRAS,
AutoResetWrapper,
JumanjiToDMEnvWrapper,
JumanjiToGymWrapper,
Expand Down Expand Up @@ -535,6 +536,8 @@ def test_auto_reset_wrapper__auto_reset(
state, timestep
)
chex.assert_trees_all_equal(timestep.observation, reset_timestep.observation)
# Expect that non-reset timestep obs and extras are the same.
assert jnp.all(timestep.observation == timestep.extras[NEXT_OBS_KEY_IN_EXTRAS])

def test_auto_reset_wrapper__step_no_reset(
self, fake_auto_reset_environment: AutoResetWrapper, key: chex.PRNGKey
Expand All @@ -556,6 +559,8 @@ def test_auto_reset_wrapper__step_no_reset(
assert timestep.step_type == StepType.MID
assert_trees_are_different(timestep, first_timestep)
chex.assert_trees_all_equal(timestep.reward, 0)
# no reset so expect extras and obs to be the same.
assert jnp.all(timestep.observation == timestep.extras[NEXT_OBS_KEY_IN_EXTRAS])

def test_auto_reset_wrapper__step_reset(
self,
Expand All @@ -566,18 +571,30 @@ def test_auto_reset_wrapper__step_reset(
"""Validates that the auto-reset is done correctly by the step function
of the AutoResetWrapper when the terminal timestep is reached.
"""
state, first_timestep = fake_auto_reset_environment.reset(key)
state, first_timestep = fake_auto_reset_environment.reset(key) # type: ignore

fake_environment.time_limit = 5

# Loop across time_limit so auto-reset occurs
timestep = first_timestep
for _ in range(fake_environment.time_limit):
for _ in range(fake_environment.time_limit - 1):
action = fake_auto_reset_environment.action_spec().generate_value()
state, timestep = jax.jit(fake_auto_reset_environment.step)(state, action)
assert jnp.all(
timestep.observation == timestep.extras[NEXT_OBS_KEY_IN_EXTRAS]
)

state, final_timestep = jax.jit(fake_auto_reset_environment.step)(state, action)

assert timestep.step_type == StepType.LAST
chex.assert_trees_all_equal(timestep.observation, first_timestep.observation)
assert final_timestep.step_type == StepType.LAST
chex.assert_trees_all_equal(
final_timestep.observation, first_timestep.observation
)
assert not jnp.all(
final_timestep.observation == final_timestep.extras[NEXT_OBS_KEY_IN_EXTRAS]
)
assert jnp.all(
(timestep.observation + 1) == final_timestep.extras[NEXT_OBS_KEY_IN_EXTRAS]
)


class TestVmapAutoResetWrapper:
Expand Down Expand Up @@ -614,6 +631,8 @@ def test_vmap_auto_reset_wrapper__reset(
assert timestep.observation.shape[0] == keys.shape[0]
assert timestep.reward.shape == (keys.shape[0],)
assert timestep.discount.shape == (keys.shape[0],)
# only reset so expect extras and obs to be the same.
assert jnp.all(timestep.observation == timestep.extras[NEXT_OBS_KEY_IN_EXTRAS])

def test_vmap_auto_reset_wrapper__auto_reset(
self,
Expand All @@ -627,6 +646,10 @@ def test_vmap_auto_reset_wrapper__auto_reset(
(state, timestep),
)
chex.assert_trees_all_equal(timestep.observation, reset_timestep.observation)
# expect rest timestep.extras to have the same obs as the original timestep
assert jnp.all(
timestep.observation == reset_timestep.extras[NEXT_OBS_KEY_IN_EXTRAS]
)

def test_vmap_auto_reset_wrapper__maybe_reset(
self,
Expand All @@ -640,6 +663,10 @@ def test_vmap_auto_reset_wrapper__maybe_reset(
(state, timestep),
)
chex.assert_trees_all_equal(timestep.observation, reset_timestep.observation)
# expect rest timestep.extras to have the same obs as the original timestep
assert jnp.all(
timestep.observation == reset_timestep.extras[NEXT_OBS_KEY_IN_EXTRAS]
)

def test_vmap_auto_reset_wrapper__step_no_reset(
self,
Expand All @@ -657,6 +684,13 @@ def test_vmap_auto_reset_wrapper__step_no_reset(
assert_trees_are_different(timestep, first_timestep)
chex.assert_trees_all_equal(timestep.reward, 0)

# no reset so expect extras and obs to be the same.
# and the first timestep should have different obs in extras.
assert not jnp.all(
first_timestep.observation == timestep.extras[NEXT_OBS_KEY_IN_EXTRAS]
)
assert jnp.all(timestep.observation == timestep.extras[NEXT_OBS_KEY_IN_EXTRAS])

def test_vmap_auto_reset_wrapper__step_reset(
self,
fake_environment: Environment,
Expand All @@ -671,13 +705,27 @@ def test_vmap_auto_reset_wrapper__step_reset(
fake_vmap_auto_reset_environment.unwrapped.time_limit = 5 # type: ignore

# Loop across time_limit so auto-reset occurs
for _ in range(fake_vmap_auto_reset_environment.time_limit):
for _ in range(fake_vmap_auto_reset_environment.time_limit - 1):
state, timestep = jax.jit(fake_vmap_auto_reset_environment.step)(
state, action
)
assert jnp.all(
timestep.observation == timestep.extras[NEXT_OBS_KEY_IN_EXTRAS]
)

assert jnp.all(timestep.step_type == StepType.LAST)
chex.assert_trees_all_equal(timestep.observation, first_timestep.observation)
state, final_timestep = jax.jit(fake_vmap_auto_reset_environment.step)(
state, action
)
assert jnp.all(final_timestep.step_type == StepType.LAST)
chex.assert_trees_all_equal(
final_timestep.observation, first_timestep.observation
)
assert not jnp.all(
final_timestep.observation == final_timestep.extras[NEXT_OBS_KEY_IN_EXTRAS]
)
assert jnp.all(
(timestep.observation + 1) == final_timestep.extras[NEXT_OBS_KEY_IN_EXTRAS]
)

def test_vmap_auto_reset_wrapper__step(
self,
Expand All @@ -696,6 +744,10 @@ def test_vmap_auto_reset_wrapper__step(
assert next_timestep.reward.shape == (keys.shape[0],)
assert next_timestep.discount.shape == (keys.shape[0],)
assert next_timestep.observation.shape[0] == keys.shape[0]
# expect observation and extras to be the same, since no reset
assert jnp.all(
next_timestep.observation == next_timestep.extras[NEXT_OBS_KEY_IN_EXTRAS]
)

def test_vmap_auto_reset_wrapper__render(
self, fake_vmap_auto_reset_environment: VmapAutoResetWrapper, keys: chex.PRNGKey
Expand Down

0 comments on commit ce8b873

Please sign in to comment.