Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

First draft for modular Hindsight Experience Replay Transform #2667

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
163 changes: 163 additions & 0 deletions torchrl/envs/transforms/transforms.py
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe let's create a dedicated file for these?

Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
TensorDictBase,
unravel_key,
unravel_key_list,
pad_sequence,
)
from tensordict.nn import dispatch, TensorDictModuleBase
from tensordict.utils import (
Expand Down Expand Up @@ -9264,3 +9265,165 @@ def transform_observation_spec(self, observation_spec: Composite) -> Composite:
high=torch.iinfo(torch.int64).max,
)
return super().transform_observation_spec(observation_spec)


class HERSubGoalSampler(Transform):
"""Returns a TensorDict with a key `subgoal_idx` of shape [batch_size, num_samples] represebting the subgoal index.
Available strategies are: `last` and `future`. The `last` strategy assigns the last state as the subgoal. The `future` strategy samples up to `num_samples` subgoal from the future states.
Args:
num_samples (int): Number of subgoals to sample from each trajectory. Defaults to 4.
out_keys (str): The key to store the subgoal index. Defaults to "subgoal_idx".
Comment on lines +9271 to +9275
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"""Returns a TensorDict with a key `subgoal_idx` of shape [batch_size, num_samples] represebting the subgoal index.
Available strategies are: `last` and `future`. The `last` strategy assigns the last state as the subgoal. The `future` strategy samples up to `num_samples` subgoal from the future states.
Args:
num_samples (int): Number of subgoals to sample from each trajectory. Defaults to 4.
out_keys (str): The key to store the subgoal index. Defaults to "subgoal_idx".
"""Returns a TensorDict with a key `subgoal_idx` of shape [batch_size, num_samples] represebting the subgoal index.
Available strategies are: `last` and `future`. The `last` strategy assigns the last state as the subgoal. The `future` strategy samples up to `num_samples` subgoal from the future states.
Args:
num_samples (int): Number of subgoals to sample from each trajectory. Defaults to 4.
==> NOT PRESENT: out_keys (str): The key to store the subgoal index. Defaults to "subgoal_idx".
subgoal_idx_key: TODO
strategy: TODO

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would add a .. seealso:: with other related classes.

"""
def __init__(
self,
num_samples: int = 4,
subgoal_idx_key: str = "subgoal_idx",
strategy: str = "future"
Comment on lines +9280 to +9281
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
subgoal_idx_key: str = "subgoal_idx",
strategy: str = "future"
subgoal_idx_key: NestedKey = "subgoal_idx",
strategy: NestedKey = "future"

):
super().__init__(
in_keys=None,
in_keys_inv=None,
out_keys_inv=None,
)
self.num_samples = num_samples
self.subgoal_idx_key = subgoal_idx_key
self.strategy = strategy

def forward(self, trajectories: TensorDictBase) -> TensorDictBase:
if len(trajectories.shape) == 1:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if 0 or greater than 2, raise an error?

trajectories = trajectories.unsqueeze(0)

batch_size, trajectory_len = trajectories.shape
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe

Suggested change
batch_size, trajectory_len = trajectories.shape
*batch_size, trajectory_len = trajectories.shape

to account for batch size > 2


if self.strategy == "last":
return TensorDict({"subgoal_idx": torch.full((batch_size, 1), -1)}, batch_size=batch_size)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
return TensorDict({"subgoal_idx": torch.full((batch_size, 1), -1)}, batch_size=batch_size)
return TensorDict({self.subgoal_idx_key: torch.full((batch_size, 1), -1)}, batch_size=batch_size)

also missing device and dtype


else:
subgoal_idxs = []
for i in range(batch_size):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess

Suggested change
for i in range(batch_size):
for i in range(batch_size.numel()):

for batch_size with more than one dim

subgoal_idxs.append(
TensorDict(
{"subgoal_idx": (torch.randperm(trajectory_len-2)+1)[:self.num_samples]},
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
{"subgoal_idx": (torch.randperm(trajectory_len-2)+1)[:self.num_samples]},
{self.subgoal_idx_key: (torch.randperm(trajectory_len-2)+1)[:self.num_samples]},

also missing dtype and device

batch_size=torch.Size(),
)
)
return pad_sequence(subgoal_idxs, pad_dim=0, return_mask=True)


class HERSubGoalAssigner(Transform):
"""This module assigns the subgoal to the trajectory according to a given subgoal index.
Args:
subgoal_idx_name (str): The key to the subgoal index. Defaults to "subgoal_idx".
subgoal_name (str): The key to assign the observation of the subgoal to the goal. Defaults to "goal".
"""
Comment on lines +9314 to +9318
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"""This module assigns the subgoal to the trajectory according to a given subgoal index.
Args:
subgoal_idx_name (str): The key to the subgoal index. Defaults to "subgoal_idx".
subgoal_name (str): The key to assign the observation of the subgoal to the goal. Defaults to "goal".
"""
"""This module assigns the subgoal to the trajectory according to a given subgoal index.
Args:
SHOULD BE achieved_goal_key??? ===> subgoal_idx_name (str): The key to the subgoal index. Defaults to "subgoal_idx".
SHOULD BE desired_goal_key?? ===> subgoal_name (str): The key to assign the observation of the subgoal to the goal. Defaults to "goal".
"""

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would add a .. seealso:: with other related classes.

def __init__(
self,
achieved_goal_key: str = "achieved_goal",
desired_goal_key: str = "desired_goal",
):
Comment on lines +9319 to +9323
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def __init__(
self,
achieved_goal_key: str = "achieved_goal",
desired_goal_key: str = "desired_goal",
):
def __init__(
self,
achieved_goal_key: NestedKey = "achieved_goal",
desired_goal_key: NestedKey = "desired_goal",
):

self.achieved_goal_key = achieved_goal_key
self.desired_goal_key = desired_goal_key

def forward(self, trajectories: TensorDictBase, subgoals_idxs: torch.Tensor) -> TensorDictBase:
batch_size, trajectory_len = trajectories.shape
for i in range(batch_size):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if there's a vectorized version of this? The ops seem simple enough to be executed in a vectorized way

subgoal = trajectories[i][subgoals_idxs[i]][self.achieved_goal_key]
desired_goal_shape = trajectories[i][self.desired_goal_key].shape
trajectories[i][self.desired_goal_key] = subgoal.expand(desired_goal_shape)
trajectories[i][subgoals_idxs[i]]["next", "done"] = True
Comment on lines +9332 to +9333
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if we keep the loop, I'd rather have trajectories.unbind(0) than indexing every element along dim 0, it will be faster

# trajectories[i][subgoals_idxs[i]+1:]["truncated"] = True

return trajectories


class HERRewardTransform(Transform):
"""This module assigns the reward to the trajectory according to the new subgoal.
Args:
reward_name (str): The key to the reward. Defaults to "reward".
Comment on lines +9341 to +9342
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there's no arg

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would add a .. seealso:: with other related classes.

"""
def __init__(
self
):
pass

def forward(self, trajectories: TensorDictBase) -> TensorDictBase:
return trajectories


class HindsightExperienceReplayTransform(Transform):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't we need to modify the specs?
Does this work with replay buffer (static data) or only envs? If the latter, we should not be using forward.

If you look at Compose, there are a bunch of things that need to be implemented when nesting transforms, like clone, cache eraser etc.

Perhaps we could inherit from Compose and rewrite forward, _apply_transform, _call, _reset etc such that the logic hold but the extra features are included automatically?

"""Hindsight Experience Replay (HER) is a technique that allows to learn from failure by creating new experiences from the failed ones.
This module is a wrapper that includes the following modules:
- SubGoalSampler: Creates new trajectories by sampling future subgoals from the same trajectory.
- SubGoalAssigner: Assigns the subgoal to the trajectory according to a given subgoal index.
- RewardTransform: Assigns the reward to the trajectory according to the new subgoal.
Args:
SubGoalSampler (Transform):
SubGoalAssigner (Transform):
RewardTransform (Transform):
"""
Comment on lines +9354 to +9363
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"""Hindsight Experience Replay (HER) is a technique that allows to learn from failure by creating new experiences from the failed ones.
This module is a wrapper that includes the following modules:
- SubGoalSampler: Creates new trajectories by sampling future subgoals from the same trajectory.
- SubGoalAssigner: Assigns the subgoal to the trajectory according to a given subgoal index.
- RewardTransform: Assigns the reward to the trajectory according to the new subgoal.
Args:
SubGoalSampler (Transform):
SubGoalAssigner (Transform):
RewardTransform (Transform):
"""
"""Hindsight Experience Replay (HER) is a technique that allows to learn from failure by creating new experiences from the failed ones.
This module is a wrapper that includes the following modules:
- SubGoalSampler: Creates new trajectories by sampling future subgoals from the same trajectory.
- SubGoalAssigner: Assigns the subgoal to the trajectory according to a given subgoal index.
- RewardTransform: Assigns the reward to the trajectory according to the new subgoal.
Args:
SubGoalSampler (Transform): TODO
SubGoalAssigner (Transform): TODO
RewardTransform (Transform): TODO
"""

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would add a .. seealso:: with other related classes.

def __init__(
self,
SubGoalSampler: Transform = HERSubGoalSampler(),
SubGoalAssigner: Transform = HERSubGoalAssigner(),
RewardTransform: Transform = HERRewardTransform(),
assign_subgoal_idxs: bool = False,
):
super().__init__(
in_keys=None,
in_keys_inv=None,
out_keys_inv=None,
)
Comment on lines +9366 to +9375
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
SubGoalSampler: Transform = HERSubGoalSampler(),
SubGoalAssigner: Transform = HERSubGoalAssigner(),
RewardTransform: Transform = HERRewardTransform(),
assign_subgoal_idxs: bool = False,
):
super().__init__(
in_keys=None,
in_keys_inv=None,
out_keys_inv=None,
)
SubGoalSampler: Transform | None = None,
SubGoalAssigner: Transform | None = None,
RewardTransform: Transform | None= None,
assign_subgoal_idxs: bool = False,
):
if SubGoalSampler is None:
SubGoalSampler = HERSubGoalSampler()
if SubGoalAssigner is None:
SubGoalAssigner = HERSubGoalAssigner()
if HERRewardTransform is None:
HERRewardTransform = HERRewardTransform()
super().__init__(
in_keys=None,
in_keys_inv=None,
out_keys_inv=None,
)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No PascalCase but snake_case for instantiated classes

self.SubGoalSampler = SubGoalSampler
self.SubGoalAssigner = SubGoalAssigner
self.RewardTransform = RewardTransform
Comment on lines +9376 to +9378
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

self.assign_subgoal_idxs = assign_subgoal_idxs

def _inv_call(self, tensordict: TensorDictBase) -> TensorDictBase:
augmentation_td = self.her_augmentation(tensordict)
return torch.cat([tensordict, augmentation_td], dim=0)

def _inv_apply_transform(self, tensordict: TensorDictBase) -> torch.Tensor:
return self.her_augmentation(tensordict)

def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
return tensordict

def _call(self, tensordict: TensorDictBase) -> TensorDictBase:
raise ValueError(self.ENV_ERR)

def her_augmentation(self, trajectories: TensorDictBase):
if len(trajectories.shape) == 1:
trajectories = trajectories.unsqueeze(0)
batch_size, trajectory_length = trajectories.shape
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe

Suggested change
batch_size, trajectory_length = trajectories.shape
*batch_size, trajectory_length = trajectories.shape


new_trajectories = trajectories.clone(True)

# Sample subgoal indices
subgoal_idxs = self.SubGoalSampler(new_trajectories)

# Create new trajectories
augmented_trajectories = []
list_idxs = []
for i in range(batch_size):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
for i in range(batch_size):
for i in range(batch_size.numel()):

which also works with batch_size=torch.Size([])!

idxs = subgoal_idxs[i][self.SubGoalSampler.subgoal_idx_key]

if "masks" in subgoal_idxs.keys():
idxs = idxs[subgoal_idxs[i]["masks", self.SubGoalSampler.subgoal_idx_key]]

list_idxs.append(idxs.unsqueeze(-1))
new_traj = new_trajectories[i].expand((idxs.numel(),trajectory_length)).clone(True)

if self.assign_subgoal_idxs:
new_traj[self.SubGoalSampler.subgoal_idx_key] = idxs.unsqueeze(-1).repeat(1, trajectory_length)

augmented_trajectories.append(new_traj)
augmented_trajectories = torch.cat(augmented_trajectories, dim=0)
associated_idxs = torch.cat(list_idxs, dim=0)

# Assign subgoals to the new trajectories
augmented_trajectories = self.SubGoalAssigner.forward(augmented_trajectories, associated_idxs)

# Adjust the rewards based on the new subgoals
augmented_trajectories = self.RewardTransform.forward(augmented_trajectories)

return augmented_trajectories