From 5dbbd2665654e66024ecd858d25b9a6bf4c5bc97 Mon Sep 17 00:00:00 2001 From: Dimitrios Tsaras Date: Thu, 19 Dec 2024 20:15:55 +0800 Subject: [PATCH] Added the necessary transforms for Hindsight Experience Replay --- torchrl/envs/transforms/transforms.py | 163 ++++++++++++++++++++++++++ 1 file changed, 163 insertions(+) diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index f3329d085df..d9200208843 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -40,6 +40,7 @@ TensorDictBase, unravel_key, unravel_key_list, + pad_sequence, ) from tensordict.nn import dispatch, TensorDictModuleBase from tensordict.utils import ( @@ -9264,3 +9265,165 @@ def transform_observation_spec(self, observation_spec: Composite) -> Composite: high=torch.iinfo(torch.int64).max, ) return super().transform_observation_spec(observation_spec) + + +class HERSubGoalSampler(Transform): + """Returns a TensorDict with a key `subgoal_idx` of shape [batch_size, num_samples] represebting the subgoal index. + Available strategies are: `last` and `future`. The `last` strategy assigns the last state as the subgoal. The `future` strategy samples up to `num_samples` subgoal from the future states. + Args: + num_samples (int): Number of subgoals to sample from each trajectory. Defaults to 4. + out_keys (str): The key to store the subgoal index. Defaults to "subgoal_idx". + """ + def __init__( + self, + num_samples: int = 4, + subgoal_idx_key: str = "subgoal_idx", + strategy: str = "future" + ): + super().__init__( + in_keys=None, + in_keys_inv=None, + out_keys_inv=None, + ) + self.num_samples = num_samples + self.subgoal_idx_key = subgoal_idx_key + self.strategy = strategy + + def forward(self, trajectories: TensorDictBase) -> TensorDictBase: + if len(trajectories.shape) == 1: + trajectories = trajectories.unsqueeze(0) + + batch_size, trajectory_len = trajectories.shape + + if self.strategy == "last": + return TensorDict({"subgoal_idx": torch.full((batch_size, 1), -1)}, batch_size=batch_size) + + else: + subgoal_idxs = [] + for i in range(batch_size): + subgoal_idxs.append( + TensorDict( + {"subgoal_idx": (torch.randperm(trajectory_len-2)+1)[:self.num_samples]}, + batch_size=torch.Size(), + ) + ) + return pad_sequence(subgoal_idxs, pad_dim=0, return_mask=True) + + +class HERSubGoalAssigner(Transform): + """This module assigns the subgoal to the trajectory according to a given subgoal index. + Args: + subgoal_idx_name (str): The key to the subgoal index. Defaults to "subgoal_idx". + subgoal_name (str): The key to assign the observation of the subgoal to the goal. Defaults to "goal". + """ + def __init__( + self, + achieved_goal_key: str = "achieved_goal", + desired_goal_key: str = "desired_goal", + ): + self.achieved_goal_key = achieved_goal_key + self.desired_goal_key = desired_goal_key + + def forward(self, trajectories: TensorDictBase, subgoals_idxs: torch.Tensor) -> TensorDictBase: + batch_size, trajectory_len = trajectories.shape + for i in range(batch_size): + subgoal = trajectories[i][subgoals_idxs[i]][self.achieved_goal_key] + desired_goal_shape = trajectories[i][self.desired_goal_key].shape + trajectories[i][self.desired_goal_key] = subgoal.expand(desired_goal_shape) + trajectories[i][subgoals_idxs[i]]["next", "done"] = True + # trajectories[i][subgoals_idxs[i]+1:]["truncated"] = True + + return trajectories + + +class HERRewardTransform(Transform): + """This module assigns the reward to the trajectory according to the new subgoal. + Args: + reward_name (str): The key to the reward. Defaults to "reward". + """ + def __init__( + self + ): + pass + + def forward(self, trajectories: TensorDictBase) -> TensorDictBase: + return trajectories + + +class HindsightExperienceReplayTransform(Transform): + """Hindsight Experience Replay (HER) is a technique that allows to learn from failure by creating new experiences from the failed ones. + This module is a wrapper that includes the following modules: + - SubGoalSampler: Creates new trajectories by sampling future subgoals from the same trajectory. + - SubGoalAssigner: Assigns the subgoal to the trajectory according to a given subgoal index. + - RewardTransform: Assigns the reward to the trajectory according to the new subgoal. + Args: + SubGoalSampler (Transform): + SubGoalAssigner (Transform): + RewardTransform (Transform): + """ + def __init__( + self, + SubGoalSampler: Transform = HERSubGoalSampler(), + SubGoalAssigner: Transform = HERSubGoalAssigner(), + RewardTransform: Transform = HERRewardTransform(), + assign_subgoal_idxs: bool = False, + ): + super().__init__( + in_keys=None, + in_keys_inv=None, + out_keys_inv=None, + ) + self.SubGoalSampler = SubGoalSampler + self.SubGoalAssigner = SubGoalAssigner + self.RewardTransform = RewardTransform + self.assign_subgoal_idxs = assign_subgoal_idxs + + def _inv_call(self, tensordict: TensorDictBase) -> TensorDictBase: + augmentation_td = self.her_augmentation(tensordict) + return torch.cat([tensordict, augmentation_td], dim=0) + + def _inv_apply_transform(self, tensordict: TensorDictBase) -> torch.Tensor: + return self.her_augmentation(tensordict) + + def forward(self, tensordict: TensorDictBase) -> TensorDictBase: + return tensordict + + def _call(self, tensordict: TensorDictBase) -> TensorDictBase: + raise ValueError(self.ENV_ERR) + + def her_augmentation(self, trajectories: TensorDictBase): + if len(trajectories.shape) == 1: + trajectories = trajectories.unsqueeze(0) + batch_size, trajectory_length = trajectories.shape + + new_trajectories = trajectories.clone(True) + + # Sample subgoal indices + subgoal_idxs = self.SubGoalSampler(new_trajectories) + + # Create new trajectories + augmented_trajectories = [] + list_idxs = [] + for i in range(batch_size): + idxs = subgoal_idxs[i][self.SubGoalSampler.subgoal_idx_key] + + if "masks" in subgoal_idxs.keys(): + idxs = idxs[subgoal_idxs[i]["masks", self.SubGoalSampler.subgoal_idx_key]] + + list_idxs.append(idxs.unsqueeze(-1)) + new_traj = new_trajectories[i].expand((idxs.numel(),trajectory_length)).clone(True) + + if self.assign_subgoal_idxs: + new_traj[self.SubGoalSampler.subgoal_idx_key] = idxs.unsqueeze(-1).repeat(1, trajectory_length) + + augmented_trajectories.append(new_traj) + augmented_trajectories = torch.cat(augmented_trajectories, dim=0) + associated_idxs = torch.cat(list_idxs, dim=0) + + # Assign subgoals to the new trajectories + augmented_trajectories = self.SubGoalAssigner.forward(augmented_trajectories, associated_idxs) + + # Adjust the rewards based on the new subgoals + augmented_trajectories = self.RewardTransform.forward(augmented_trajectories) + + return augmented_trajectories