Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BugFix] Compatibility of tensordict primers with batched envs (specifically for LSTM and GRU) #2668

Merged
merged 3 commits into from
Dec 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion sota-implementations/decision_transformer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def make_transformed_env(base_env, env_cfg, obs_loc, obs_std, train=False):
)

# copy action from the input tensordict to the output
transformed_env.append_transform(TensorDictPrimer(action=base_env.action_spec))
transformed_env.append_transform(TensorDictPrimer(base_env.full_action_spec))

transformed_env.append_transform(DoubleToFloat())
obsnorm = ObservationNorm(
Expand Down
105 changes: 80 additions & 25 deletions test/test_tensordictmodules.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# LICENSE file in the root directory of this source tree.

import argparse
import functools
import os

import pytest
Expand All @@ -12,6 +13,7 @@
import torchrl.modules
from tensordict import LazyStackedTensorDict, pad, TensorDict, unravel_key_list
from tensordict.nn import InteractionType, TensorDictModule, TensorDictSequential
from tensordict.utils import assert_close
from torch import nn
from torchrl.data.tensor_specs import Bounded, Composite, Unbounded
from torchrl.envs import (
Expand Down Expand Up @@ -938,10 +940,12 @@ def test_multi_consecutive(self, shape, python_based):
@pytest.mark.parametrize("python_based", [True, False])
@pytest.mark.parametrize("parallel", [True, False])
@pytest.mark.parametrize("heterogeneous", [True, False])
def test_lstm_parallel_env(self, python_based, parallel, heterogeneous):
@pytest.mark.parametrize("within", [False, True])
def test_lstm_parallel_env(self, python_based, parallel, heterogeneous, within):
from torchrl.envs import InitTracker, ParallelEnv, TransformedEnv

torch.manual_seed(0)
num_envs = 3
device = "cuda" if torch.cuda.device_count() else "cpu"
# tests that hidden states are carried over with parallel envs
lstm_module = LSTMModule(
Expand All @@ -958,25 +962,36 @@ def test_lstm_parallel_env(self, python_based, parallel, heterogeneous):
else:
cls = SerialEnv

def create_transformed_env():
primer = lstm_module.make_tensordict_primer()
env = DiscreteActionVecMockEnv(
categorical_action_encoding=True, device=device
if within:

def create_transformed_env():
primer = lstm_module.make_tensordict_primer()
env = DiscreteActionVecMockEnv(
categorical_action_encoding=True, device=device
)
env = TransformedEnv(env)
env.append_transform(InitTracker())
env.append_transform(primer)
return env

else:
create_transformed_env = functools.partial(
DiscreteActionVecMockEnv,
categorical_action_encoding=True,
device=device,
)
env = TransformedEnv(env)
env.append_transform(InitTracker())
env.append_transform(primer)
return env

if heterogeneous:
create_transformed_env = [
EnvCreator(create_transformed_env),
EnvCreator(create_transformed_env),
EnvCreator(create_transformed_env) for _ in range(num_envs)
]
env = cls(
create_env_fn=create_transformed_env,
num_workers=2,
num_workers=num_envs,
)
if not within:
env = env.append_transform(InitTracker())
env.append_transform(lstm_module.make_tensordict_primer())

mlp = TensorDictModule(
MLP(
Expand All @@ -1002,6 +1017,19 @@ def create_transformed_env():
data = env.rollout(10, actor, break_when_any_done=break_when_any_done)
assert (data.get(("next", "recurrent_state_c")) != 0.0).all()
assert (data.get("recurrent_state_c") != 0.0).any()
return data

@pytest.mark.parametrize("python_based", [True, False])
@pytest.mark.parametrize("parallel", [True, False])
@pytest.mark.parametrize("heterogeneous", [True, False])
def test_lstm_parallel_within(self, python_based, parallel, heterogeneous):
out_within = self.test_lstm_parallel_env(
python_based, parallel, heterogeneous, within=True
)
out_not_within = self.test_lstm_parallel_env(
python_based, parallel, heterogeneous, within=False
)
assert_close(out_within, out_not_within)

@pytest.mark.skipif(
not _has_functorch, reason="vmap can only be used with functorch"
Expand Down Expand Up @@ -1330,10 +1358,12 @@ def test_multi_consecutive(self, shape, python_based):
@pytest.mark.parametrize("python_based", [True, False])
@pytest.mark.parametrize("parallel", [True, False])
@pytest.mark.parametrize("heterogeneous", [True, False])
def test_gru_parallel_env(self, python_based, parallel, heterogeneous):
@pytest.mark.parametrize("within", [False, True])
def test_gru_parallel_env(self, python_based, parallel, heterogeneous, within):
from torchrl.envs import InitTracker, ParallelEnv, TransformedEnv

torch.manual_seed(0)
num_workers = 3

device = "cuda" if torch.cuda.device_count() else "cpu"
# tests that hidden states are carried over with parallel envs
Expand All @@ -1347,30 +1377,42 @@ def test_gru_parallel_env(self, python_based, parallel, heterogeneous):
python_based=python_based,
)

def create_transformed_env():
primer = gru_module.make_tensordict_primer()
env = DiscreteActionVecMockEnv(
categorical_action_encoding=True, device=device
if within:

def create_transformed_env():
primer = gru_module.make_tensordict_primer()
env = DiscreteActionVecMockEnv(
categorical_action_encoding=True, device=device
)
env = TransformedEnv(env)
env.append_transform(InitTracker())
env.append_transform(primer)
return env

else:
create_transformed_env = functools.partial(
DiscreteActionVecMockEnv,
categorical_action_encoding=True,
device=device,
)
env = TransformedEnv(env)
env.append_transform(InitTracker())
env.append_transform(primer)
return env

if parallel:
cls = ParallelEnv
else:
cls = SerialEnv
if heterogeneous:
create_transformed_env = [
EnvCreator(create_transformed_env),
EnvCreator(create_transformed_env),
EnvCreator(create_transformed_env) for _ in range(num_workers)
]

env = cls(
env: ParallelEnv | SerialEnv = cls(
create_env_fn=create_transformed_env,
num_workers=2,
num_workers=num_workers,
)
if not within:
primer = gru_module.make_tensordict_primer()
env = env.append_transform(InitTracker())
env.append_transform(primer)

mlp = TensorDictModule(
MLP(
Expand All @@ -1396,6 +1438,19 @@ def create_transformed_env():
data = env.rollout(10, actor, break_when_any_done=break_when_any_done)
assert (data.get("recurrent_state") != 0.0).any()
assert (data.get(("next", "recurrent_state")) != 0.0).all()
return data

@pytest.mark.parametrize("python_based", [True, False])
@pytest.mark.parametrize("parallel", [True, False])
@pytest.mark.parametrize("heterogeneous", [True, False])
def test_gru_parallel_within(self, python_based, parallel, heterogeneous):
out_within = self.test_gru_parallel_env(
python_based, parallel, heterogeneous, within=True
)
out_not_within = self.test_gru_parallel_env(
python_based, parallel, heterogeneous, within=False
)
assert_close(out_within, out_not_within)

@pytest.mark.skipif(
not _has_functorch, reason="vmap can only be used with functorch"
Expand Down
45 changes: 36 additions & 9 deletions test/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -7408,7 +7408,7 @@ def make_env():
def test_trans_parallel_env_check(self, maybe_fork_ParallelEnv):
env = TransformedEnv(
maybe_fork_ParallelEnv(2, ContinuousActionVecMockEnv),
TensorDictPrimer(mykey=Unbounded([2, 4])),
TensorDictPrimer(mykey=Unbounded([4])),
)
try:
check_env_specs(env)
Expand All @@ -7423,11 +7423,39 @@ def test_trans_parallel_env_check(self, maybe_fork_ParallelEnv):
pass

@pytest.mark.parametrize("spec_shape", [[4], [2, 4]])
def test_trans_serial_env_check(self, spec_shape):
env = TransformedEnv(
SerialEnv(2, ContinuousActionVecMockEnv),
TensorDictPrimer(mykey=Unbounded(spec_shape)),
)
@pytest.mark.parametrize("expand_specs", [True, False, None])
def test_trans_serial_env_check(self, spec_shape, expand_specs):
if expand_specs is None:
with pytest.warns(FutureWarning, match=""):
env = TransformedEnv(
SerialEnv(2, ContinuousActionVecMockEnv),
TensorDictPrimer(
mykey=Unbounded(spec_shape), expand_specs=expand_specs
),
)
env.observation_spec
elif expand_specs is True:
shape = spec_shape[:-1]
env = TransformedEnv(
SerialEnv(2, ContinuousActionVecMockEnv),
TensorDictPrimer(
Composite(mykey=Unbounded(spec_shape), shape=shape),
expand_specs=expand_specs,
),
)
else:
# If we don't expand, we can't use [4]
env = TransformedEnv(
SerialEnv(2, ContinuousActionVecMockEnv),
TensorDictPrimer(
mykey=Unbounded(spec_shape), expand_specs=expand_specs
),
)
if spec_shape == [4]:
with pytest.raises(ValueError):
env.observation_spec
return

check_env_specs(env)
assert "mykey" in env.reset().keys()
r = env.rollout(3)
Expand Down Expand Up @@ -10310,9 +10338,8 @@ def _make_transform_env(self, out_key, base_env):
transform = KLRewardTransform(actor, out_keys=out_key)
return Compose(
TensorDictPrimer(
primers={
"sample_log_prob": Unbounded(shape=base_env.action_spec.shape[:-1])
}
sample_log_prob=Unbounded(shape=base_env.action_spec.shape[:-1]),
shape=base_env.shape,
),
transform,
)
Expand Down
36 changes: 32 additions & 4 deletions torchrl/envs/batched_envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1744,14 +1744,39 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase:
# We keep track of which keys are present to let the worker know what
# should be passed to the env (we don't want to pass done states for instance)
next_td_keys = list(next_td_passthrough.keys(True, True))
next_shared_tensordict_parent = shared_tensordict_parent.get("next")

# We separate keys that are and are not present in the buffer here and not in step_and_maybe_reset.
# The reason we do that is that the policy may write stuff in 'next' that is not part of the specs of
# the batched env but part of the specs of a transformed batched env.
# If that is the case, `update_` will fail to find the entries to update.
# What we do instead is keeping the tensors on the side and putting them back after completing _step.
keys_to_update, keys_to_copy = zip(
*[
(key, None)
if key in next_shared_tensordict_parent.keys(True, True)
else (None, key)
for key in next_td_keys
]
)
keys_to_update = [key for key in keys_to_update if key is not None]
keys_to_copy = [key for key in keys_to_copy if key is not None]
data = [
{"next_td_passthrough_keys": next_td_keys}
{"next_td_passthrough_keys": keys_to_update}
for _ in range(self.num_workers)
]
shared_tensordict_parent.get("next").update_(
next_td_passthrough, non_blocking=self.non_blocking
)
if keys_to_update:
next_shared_tensordict_parent.update_(
next_td_passthrough,
non_blocking=self.non_blocking,
keys_to_update=keys_to_update,
)
if keys_to_copy:
next_td_passthrough = next_td_passthrough.select(*keys_to_copy)
else:
next_td_passthrough = None
else:
next_td_passthrough = None
data = [{} for _ in range(self.num_workers)]

if self._non_tensor_keys:
Expand Down Expand Up @@ -1807,6 +1832,9 @@ def select_and_clone(name, tensor):
LazyStackedTensorDict(*non_tensor_tds),
keys_to_update=self._non_tensor_keys,
)
if next_td_passthrough is not None:
out.update(next_td_passthrough)

self._sync_w2m()
if partial_steps is not None:
result = out.new_zeros(tensordict_save.shape)
Expand Down
37 changes: 30 additions & 7 deletions torchrl/envs/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -4984,6 +4984,7 @@ def __init__(
| Dict[NestedKey, float]
| Dict[NestedKey, Callable] = None,
reset_key: NestedKey | None = None,
expand_specs: bool = None,
**kwargs,
):
self.device = kwargs.pop("device", None)
Expand All @@ -4995,8 +4996,16 @@ def __init__(
)
kwargs = primers
if not isinstance(kwargs, Composite):
kwargs = Composite(kwargs)
self.primers = kwargs
shape = kwargs.pop("shape", None)
device = kwargs.pop("device", None)
if "batch_size" in kwargs.keys():
extra_kwargs = {"batch_size": kwargs.pop("batch_size")}
else:
extra_kwargs = {}
primers = Composite(kwargs, device=device, shape=shape, **extra_kwargs)
self.primers = primers
self.expand_specs = expand_specs

if random and default_value:
raise ValueError(
"Setting random to True and providing a default_value are incompatible."
Expand Down Expand Up @@ -5089,12 +5098,26 @@ def transform_observation_spec(self, observation_spec: Composite) -> Composite:
)

if self.primers.shape != observation_spec.shape:
try:
# We try to set the primer shape to the observation spec shape
self.primers.shape = observation_spec.shape
except ValueError:
# If we fail, we expand them to that shape
if self.expand_specs:
self.primers = self._expand_shape(self.primers)
elif self.expand_specs is None:
warnings.warn(
f"expand_specs wasn't specified in the {type(self).__name__} constructor. "
f"The current behaviour is that the transform will attempt to set the shape of the composite "
f"spec, and if this can't be done it will be expanded. "
f"From v0.8, a mismatched shape between the spec of the transform and the env's batch_size "
f"will raise an exception.",
category=FutureWarning,
)
try:
# We try to set the primer shape to the observation spec shape
self.primers.shape = observation_spec.shape
except ValueError:
# If we fail, we expand them to that shape
self.primers = self._expand_shape(self.primers)
else:
self.primers.shape = observation_spec.shape

device = observation_spec.device
observation_spec.update(self.primers.clone().to(device))
return observation_spec
Expand Down
Loading
Loading