diff --git a/torchrl/envs/common.py b/torchrl/envs/common.py index 4f6002dedd3..78f89cc8a38 100644 --- a/torchrl/envs/common.py +++ b/torchrl/envs/common.py @@ -3015,6 +3015,52 @@ def add_truncated_keys(self) -> EnvBase: self.__dict__["_done_keys"] = None return self + def step_mdp(self, next_tensordict: TensorDictBase) -> TensorDictBase: + """Advances the environment state by one step using the provided `next_tensordict`. + + This method updates the environment's state by transitioning from the current + state to the next, as defined by the `next_tensordict`. The resulting tensordict + includes updated observations and any other relevant state information, with + keys managed according to the environment's specifications. + + Internally, this method utilizes a precomputed :class:`~torchrl.envs.utils._StepMDP` instance to efficiently + handle the transition of state, observation, action, reward, and done keys. The + :class:`~torchrl.envs.utils._StepMDP` class optimizes the process by precomputing the keys to include and + exclude, reducing runtime overhead during repeated calls. The :class:`~torchrl.envs.utils._StepMDP` instance + is created with `exclude_action=False`, meaning that action keys are retained in + the root tensordict. + + Args: + next_tensordict (TensorDictBase): A tensordict containing the state of the + environment at the next time step. This tensordict should include keys + for observations, actions, rewards, and done flags, as defined by the + environment's specifications. + + Returns: + TensorDictBase: A new tensordict representing the environment state after + advancing by one step. + + .. note:: The method ensures that the environment's key specifications are validated + against the provided `next_tensordict`, issuing warnings if discrepancies + are found. + + .. note:: This method is designed to work efficiently with environments that have + consistent key specifications, leveraging the `_StepMDP` class to minimize + overhead. + + Example: + >>> from torchrl.envs import GymEnv + >>> env = GymEnv("Pendulum-1") + >>> data = env.reset() + >>> for i in range(10): + ... # compute action + ... env.rand_action(data) + ... # Perform action + ... next_data = env.step(reset_data) + ... data = env.step_mdp(next_data) + """ + return self._step_mdp(next_tensordict) + @property def _step_mdp(self): step_func = self.__dict__.get("_step_mdp_value")