-
Notifications
You must be signed in to change notification settings - Fork 327
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
First draft for modular Hindsight Experience Replay Transform #2667
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -40,6 +40,7 @@ | |||||||||||||||||||||||||||||||||||||||||||||||||||||
TensorDictBase, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
unravel_key, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
unravel_key_list, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
pad_sequence, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
from tensordict.nn import dispatch, TensorDictModuleBase | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
from tensordict.utils import ( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
@@ -9264,3 +9265,165 @@ def transform_observation_spec(self, observation_spec: Composite) -> Composite: | |||||||||||||||||||||||||||||||||||||||||||||||||||||
high=torch.iinfo(torch.int64).max, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
return super().transform_observation_spec(observation_spec) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
class HERSubGoalSampler(Transform): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
"""Returns a TensorDict with a key `subgoal_idx` of shape [batch_size, num_samples] represebting the subgoal index. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
Available strategies are: `last` and `future`. The `last` strategy assigns the last state as the subgoal. The `future` strategy samples up to `num_samples` subgoal from the future states. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
Args: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
num_samples (int): Number of subgoals to sample from each trajectory. Defaults to 4. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
out_keys (str): The key to store the subgoal index. Defaults to "subgoal_idx". | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
Comment on lines
+9271
to
+9275
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would add a |
||||||||||||||||||||||||||||||||||||||||||||||||||||||
""" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
def __init__( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
self, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
num_samples: int = 4, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
subgoal_idx_key: str = "subgoal_idx", | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
strategy: str = "future" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
Comment on lines
+9280
to
+9281
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
super().__init__( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
in_keys=None, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
in_keys_inv=None, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
out_keys_inv=None, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
self.num_samples = num_samples | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
self.subgoal_idx_key = subgoal_idx_key | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
self.strategy = strategy | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
def forward(self, trajectories: TensorDictBase) -> TensorDictBase: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
if len(trajectories.shape) == 1: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. if 0 or greater than 2, raise an error? |
||||||||||||||||||||||||||||||||||||||||||||||||||||||
trajectories = trajectories.unsqueeze(0) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
batch_size, trajectory_len = trajectories.shape | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. maybe
Suggested change
to account for batch size > 2 |
||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
if self.strategy == "last": | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
return TensorDict({"subgoal_idx": torch.full((batch_size, 1), -1)}, batch_size=batch_size) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
also missing device and dtype |
||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
else: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
subgoal_idxs = [] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
for i in range(batch_size): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I guess
Suggested change
for batch_size with more than one dim |
||||||||||||||||||||||||||||||||||||||||||||||||||||||
subgoal_idxs.append( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
TensorDict( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
{"subgoal_idx": (torch.randperm(trajectory_len-2)+1)[:self.num_samples]}, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
also missing dtype and device |
||||||||||||||||||||||||||||||||||||||||||||||||||||||
batch_size=torch.Size(), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
return pad_sequence(subgoal_idxs, pad_dim=0, return_mask=True) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
class HERSubGoalAssigner(Transform): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
"""This module assigns the subgoal to the trajectory according to a given subgoal index. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
Args: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
subgoal_idx_name (str): The key to the subgoal index. Defaults to "subgoal_idx". | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
subgoal_name (str): The key to assign the observation of the subgoal to the goal. Defaults to "goal". | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
""" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
Comment on lines
+9314
to
+9318
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would add a |
||||||||||||||||||||||||||||||||||||||||||||||||||||||
def __init__( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
self, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
achieved_goal_key: str = "achieved_goal", | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
desired_goal_key: str = "desired_goal", | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
Comment on lines
+9319
to
+9323
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
self.achieved_goal_key = achieved_goal_key | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
self.desired_goal_key = desired_goal_key | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
def forward(self, trajectories: TensorDictBase, subgoals_idxs: torch.Tensor) -> TensorDictBase: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
batch_size, trajectory_len = trajectories.shape | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
for i in range(batch_size): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I wonder if there's a vectorized version of this? The ops seem simple enough to be executed in a vectorized way |
||||||||||||||||||||||||||||||||||||||||||||||||||||||
subgoal = trajectories[i][subgoals_idxs[i]][self.achieved_goal_key] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
desired_goal_shape = trajectories[i][self.desired_goal_key].shape | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
trajectories[i][self.desired_goal_key] = subgoal.expand(desired_goal_shape) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
trajectories[i][subgoals_idxs[i]]["next", "done"] = True | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
Comment on lines
+9332
to
+9333
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. if we keep the loop, I'd rather have |
||||||||||||||||||||||||||||||||||||||||||||||||||||||
# trajectories[i][subgoals_idxs[i]+1:]["truncated"] = True | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
return trajectories | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
class HERRewardTransform(Transform): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
"""This module assigns the reward to the trajectory according to the new subgoal. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
Args: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
reward_name (str): The key to the reward. Defaults to "reward". | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
Comment on lines
+9341
to
+9342
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. there's no arg There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would add a |
||||||||||||||||||||||||||||||||||||||||||||||||||||||
""" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
def __init__( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
self | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
pass | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
def forward(self, trajectories: TensorDictBase) -> TensorDictBase: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
return trajectories | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
class HindsightExperienceReplayTransform(Transform): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Don't we need to modify the specs? If you look at Perhaps we could inherit from Compose and rewrite |
||||||||||||||||||||||||||||||||||||||||||||||||||||||
"""Hindsight Experience Replay (HER) is a technique that allows to learn from failure by creating new experiences from the failed ones. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
This module is a wrapper that includes the following modules: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
- SubGoalSampler: Creates new trajectories by sampling future subgoals from the same trajectory. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
- SubGoalAssigner: Assigns the subgoal to the trajectory according to a given subgoal index. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
- RewardTransform: Assigns the reward to the trajectory according to the new subgoal. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
Args: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
SubGoalSampler (Transform): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
SubGoalAssigner (Transform): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
RewardTransform (Transform): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
""" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
Comment on lines
+9354
to
+9363
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would add a |
||||||||||||||||||||||||||||||||||||||||||||||||||||||
def __init__( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
self, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
SubGoalSampler: Transform = HERSubGoalSampler(), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
SubGoalAssigner: Transform = HERSubGoalAssigner(), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
RewardTransform: Transform = HERRewardTransform(), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
assign_subgoal_idxs: bool = False, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
super().__init__( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
in_keys=None, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
in_keys_inv=None, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
out_keys_inv=None, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
Comment on lines
+9366
to
+9375
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No |
||||||||||||||||||||||||||||||||||||||||||||||||||||||
self.SubGoalSampler = SubGoalSampler | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
self.SubGoalAssigner = SubGoalAssigner | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
self.RewardTransform = RewardTransform | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
Comment on lines
+9376
to
+9378
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ditto |
||||||||||||||||||||||||||||||||||||||||||||||||||||||
self.assign_subgoal_idxs = assign_subgoal_idxs | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
def _inv_call(self, tensordict: TensorDictBase) -> TensorDictBase: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
augmentation_td = self.her_augmentation(tensordict) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
return torch.cat([tensordict, augmentation_td], dim=0) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
def _inv_apply_transform(self, tensordict: TensorDictBase) -> torch.Tensor: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
return self.her_augmentation(tensordict) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
def forward(self, tensordict: TensorDictBase) -> TensorDictBase: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
return tensordict | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
def _call(self, tensordict: TensorDictBase) -> TensorDictBase: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
raise ValueError(self.ENV_ERR) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
def her_augmentation(self, trajectories: TensorDictBase): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
if len(trajectories.shape) == 1: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
trajectories = trajectories.unsqueeze(0) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
batch_size, trajectory_length = trajectories.shape | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. maybe
Suggested change
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
new_trajectories = trajectories.clone(True) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
# Sample subgoal indices | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
subgoal_idxs = self.SubGoalSampler(new_trajectories) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
# Create new trajectories | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
augmented_trajectories = [] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
list_idxs = [] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
for i in range(batch_size): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
which also works with |
||||||||||||||||||||||||||||||||||||||||||||||||||||||
idxs = subgoal_idxs[i][self.SubGoalSampler.subgoal_idx_key] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
if "masks" in subgoal_idxs.keys(): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
idxs = idxs[subgoal_idxs[i]["masks", self.SubGoalSampler.subgoal_idx_key]] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
list_idxs.append(idxs.unsqueeze(-1)) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
new_traj = new_trajectories[i].expand((idxs.numel(),trajectory_length)).clone(True) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
if self.assign_subgoal_idxs: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
new_traj[self.SubGoalSampler.subgoal_idx_key] = idxs.unsqueeze(-1).repeat(1, trajectory_length) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
augmented_trajectories.append(new_traj) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
augmented_trajectories = torch.cat(augmented_trajectories, dim=0) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
associated_idxs = torch.cat(list_idxs, dim=0) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
# Assign subgoals to the new trajectories | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
augmented_trajectories = self.SubGoalAssigner.forward(augmented_trajectories, associated_idxs) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
# Adjust the rewards based on the new subgoals | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
augmented_trajectories = self.RewardTransform.forward(augmented_trajectories) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
return augmented_trajectories |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe let's create a dedicated file for these?