Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] multiagent data standardization: PPO advantages #2677

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 66 additions & 1 deletion torchrl/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from distutils.util import strtobool
from functools import wraps
from importlib import import_module
from typing import Any, Callable, cast, Dict, TypeVar, Union
from typing import Any, Callable, cast, Dict, Tuple, TypeVar, Union

import numpy as np
import torch
Expand Down Expand Up @@ -872,6 +872,71 @@ def set_mode(self, type: Any | None) -> None:
self._mode = type


def _standardize(
input, exclude_dims: Tuple[int] = (), mean=None, std=None, eps: float = None
):
"""Standardizes the input tensor with the possibility of excluding specific dims from the statistics.

Useful when processing multi-agent data to keep the agent dimensions independent.

Args:
input (Tensor): the input tensor to be standardized.
exclude_dims (Sequence[int]): dimensions to exclude from the statistics, can be negative. Default: ().
mean (Tensor): a mean to be used for standardization. Must be of shape broadcastable to input. Default: None.
std (Tensor): a standard deviation to be used for standardization. Must be of shape broadcastable to input. Default: None.
eps (float): epsilon to be used for numerical stability. Default: float32 resolution.

"""
if eps is None:
eps = torch.finfo(torch.float.dtype).resolution

input_shape = input.shape
exclude_dims = [
d if d >= 0 else d + len(input_shape) for d in exclude_dims
] # Make negative dims positive

if len(set(exclude_dims)) != len(exclude_dims):
raise ValueError("Exclude dims has repeating elements")
if any(dim < 0 or dim >= len(input_shape) for dim in exclude_dims):
raise ValueError(
f"exclude_dims={exclude_dims} provided outside bounds for input of shape={input_shape}"
)
if len(exclude_dims) == len(input_shape):
warnings.warn(
"standardize called but all dims were excluded from the statistics, returning unprocessed input"
)
return input

if len(exclude_dims):
# Put all excluded dims in the beginning
permutation = list(range(len(input_shape)))
for dim in exclude_dims:
permutation.insert(0, permutation.pop(permutation.index(dim)))
permuted_input = input.permute(*permutation)
else:
permuted_input = input
normalized_shape_len = len(input_shape) - len(exclude_dims)

if mean is None:
mean = torch.mean(
permuted_input, keepdim=True, dim=tuple(range(-normalized_shape_len, 0))
)
if std is None:
std = torch.std(
permuted_input, keepdim=True, dim=tuple(range(-normalized_shape_len, 0))
)
output = (permuted_input - mean) / std.clamp_min(eps)

# Reverse permutation
if len(exclude_dims):
inv_permutation = torch.argsort(
torch.tensor(permutation, dtype=torch.long, device=input.device)
).tolist()
output = torch.permute(output, inv_permutation)

return output


@wraps(torch.compile)
def compile_with_warmup(*args, warmup: int = 1, **kwargs):
"""Compile a model with warm-up.
Expand Down
58 changes: 49 additions & 9 deletions torchrl/objectives/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from tensordict.utils import NestedKey
from torch import distributions as d

from torchrl._utils import _standardize
from torchrl.objectives.common import LossModule

from torchrl.objectives.utils import (
Expand All @@ -46,6 +47,7 @@
TDLambdaEstimator,
VTrace,
)
from yaml import warnings


class PPOLoss(LossModule):
Expand Down Expand Up @@ -87,6 +89,9 @@ class PPOLoss(LossModule):
Can be one of "l1", "l2" or "smooth_l1". Defaults to ``"smooth_l1"``.
normalize_advantage (bool, optional): if ``True``, the advantage will be normalized
before being used. Defaults to ``False``.
normalize_advantage_exclude_dims (Tuple[int], optional): dimensions to exclude from the advantage
standardization. Negative dimensions are valid. This is useful in multiagent (or multiobjective) settings
where the agent (or objective) dimension may be excluded from the reductions. Default: ().
separate_losses (bool, optional): if ``True``, shared parameters between
policy and critic will only be trained on the policy loss.
Defaults to ``False``, i.e., gradients are propagated to shared
Expand Down Expand Up @@ -311,6 +316,7 @@ def __init__(
critic_coef: float = 1.0,
loss_critic_type: str = "smooth_l1",
normalize_advantage: bool = False,
normalize_advantage_exclude_dims: Tuple[int] = (),
gamma: float = None,
separate_losses: bool = False,
advantage_key: str = None,
Expand Down Expand Up @@ -381,6 +387,8 @@ def __init__(
self.critic_coef = None
self.loss_critic_type = loss_critic_type
self.normalize_advantage = normalize_advantage
self.normalize_advantage_exclude_dims = normalize_advantage_exclude_dims

if gamma is not None:
raise TypeError(_GAMMA_LMBDA_DEPREC_ERROR)
self._set_deprecated_ctor_keys(
Expand Down Expand Up @@ -606,9 +614,16 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
)
advantage = tensordict.get(self.tensor_keys.advantage)
if self.normalize_advantage and advantage.numel() > 1:
loc = advantage.mean()
scale = advantage.std().clamp_min(1e-6)
advantage = (advantage - loc) / scale
if advantage.numel() > tensordict.batch_size.numel() and not len(
self.normalize_advantage_exclude_dims
):
warnings.warn(
"You requested advantage normalization and the advantage key has more dimensions"
" than the tensordict batch. Make sure to pass `normalize_advantage_exclude_dims` "
"if you want to keep any dimension independent while computing normalization statistics. "
"If you are working in multi-agent/multi-objective settings this is highly suggested."
)
advantage = _standardize(advantage, self.normalize_advantage_exclude_dims)

log_weight, dist, kl_approx = self._log_weight(tensordict)
if is_tensor_collection(log_weight):
Expand Down Expand Up @@ -711,6 +726,9 @@ class ClipPPOLoss(PPOLoss):
Can be one of "l1", "l2" or "smooth_l1". Defaults to ``"smooth_l1"``.
normalize_advantage (bool, optional): if ``True``, the advantage will be normalized
before being used. Defaults to ``False``.
normalize_advantage_exclude_dims (Tuple[int], optional): dimensions to exclude from the advantage
standardization. Negative dimensions are valid. This is useful in multiagent (or multiobjective) settings
where the agent (or objective) dimension may be excluded from the reductions. Default: ().
separate_losses (bool, optional): if ``True``, shared parameters between
policy and critic will only be trained on the policy loss.
Defaults to ``False``, i.e., gradients are propagated to shared
Expand Down Expand Up @@ -802,6 +820,7 @@ def __init__(
critic_coef: float = 1.0,
loss_critic_type: str = "smooth_l1",
normalize_advantage: bool = False,
normalize_advantage_exclude_dims: Tuple[int] = (),
gamma: float = None,
separate_losses: bool = False,
reduction: str = None,
Expand All @@ -821,6 +840,7 @@ def __init__(
critic_coef=critic_coef,
loss_critic_type=loss_critic_type,
normalize_advantage=normalize_advantage,
normalize_advantage_exclude_dims=normalize_advantage_exclude_dims,
gamma=gamma,
separate_losses=separate_losses,
reduction=reduction,
Expand Down Expand Up @@ -871,9 +891,16 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
)
advantage = tensordict.get(self.tensor_keys.advantage)
if self.normalize_advantage and advantage.numel() > 1:
loc = advantage.mean()
scale = advantage.std().clamp_min(1e-6)
advantage = (advantage - loc) / scale
if advantage.numel() > tensordict.batch_size.numel() and not len(
self.normalize_advantage_exclude_dims
):
warnings.warn(
"You requested advantage normalization and the advantage key has more dimensions"
" than the tensordict batch. Make sure to pass `normalize_advantage_exclude_dims` "
"if you want to keep any dimension independent while computing normalization statistics. "
"If you are working in multi-agent/multi-objective settings this is highly suggested."
)
advantage = _standardize(advantage, self.normalize_advantage_exclude_dims)

log_weight, dist, kl_approx = self._log_weight(tensordict)
# ESS for logging
Expand Down Expand Up @@ -955,6 +982,9 @@ class KLPENPPOLoss(PPOLoss):
Can be one of "l1", "l2" or "smooth_l1". Defaults to ``"smooth_l1"``.
normalize_advantage (bool, optional): if ``True``, the advantage will be normalized
before being used. Defaults to ``False``.
normalize_advantage_exclude_dims (Tuple[int], optional): dimensions to exclude from the advantage
standardization. Negative dimensions are valid. This is useful in multiagent (or multiobjective) settings
where the agent (or objective) dimension may be excluded from the reductions. Default: ().
separate_losses (bool, optional): if ``True``, shared parameters between
policy and critic will only be trained on the policy loss.
Defaults to ``False``, i.e., gradients are propagated to shared
Expand Down Expand Up @@ -1048,6 +1078,7 @@ def __init__(
critic_coef: float = 1.0,
loss_critic_type: str = "smooth_l1",
normalize_advantage: bool = False,
normalize_advantage_exclude_dims: Tuple[int] = (),
gamma: float = None,
separate_losses: bool = False,
reduction: str = None,
Expand All @@ -1063,6 +1094,7 @@ def __init__(
critic_coef=critic_coef,
loss_critic_type=loss_critic_type,
normalize_advantage=normalize_advantage,
normalize_advantage_exclude_dims=normalize_advantage_exclude_dims,
gamma=gamma,
separate_losses=separate_losses,
reduction=reduction,
Expand Down Expand Up @@ -1151,9 +1183,17 @@ def forward(self, tensordict: TensorDictBase) -> TensorDict:
)
advantage = tensordict_copy.get(self.tensor_keys.advantage)
if self.normalize_advantage and advantage.numel() > 1:
loc = advantage.mean()
scale = advantage.std().clamp_min(1e-6)
advantage = (advantage - loc) / scale
if advantage.numel() > tensordict.batch_size.numel() and not len(
self.normalize_advantage_exclude_dims
):
warnings.warn(
"You requested advantage normalization and the advantage key has more dimensions"
" than the tensordict batch. Make sure to pass `normalize_advantage_exclude_dims` "
"if you want to keep any dimension independent while computing normalization statistics. "
"If you are working in multi-agent/multi-objective settings this is highly suggested."
)
advantage = _standardize(advantage, self.normalize_advantage_exclude_dims)

log_weight, dist, kl_approx = self._log_weight(tensordict_copy)
neg_loss = log_weight.exp() * advantage

Expand Down