diff --git a/tensordict/_pytree.py b/tensordict/_pytree.py index 1006631c5..c89cf972c 100644 --- a/tensordict/_pytree.py +++ b/tensordict/_pytree.py @@ -4,15 +4,18 @@ # LICENSE file in the root directory of this source tree. from typing import Any, Dict, List, Tuple +import torch from tensordict import ( LazyStackedTensorDict, PersistentTensorDict, SubTensorDict, TensorDict, + TensorDictBase, ) +from tensordict.utils import implement_for try: - from torch.utils._pytree import Context, register_pytree_node + from torch.utils._pytree import Context, MappingKey, register_pytree_node except ImportError: from torch.utils._pytree import ( _register_pytree_node as register_pytree_node, @@ -80,8 +83,16 @@ def _str_to_tensordictdict(str_spec: str) -> Tuple[List[str], str]: def _tensordict_flatten(d: TensorDict) -> Tuple[List[Any], Context]: - return list(d.values()), { - "keys": list(d.keys()), + items = tuple(d.items()) + if items: + keys, values = zip(*d.items()) + keys = list(keys) + values = list(values) + else: + keys = [] + values = [] + return values, { + "keys": keys, "batch_size": d.batch_size, "names": d.names, "device": d.device, @@ -96,17 +107,59 @@ def _tensordictdict_unflatten(values: List[Any], context: Context) -> Dict[Any, if all(val.device == device for val in values if hasattr(val, "device")) else None ) + batch_size = context["batch_size"] + names = context["names"] + keys = context["keys"] + batch_dims = len(batch_size) + if any(tensor.shape[:batch_dims] != batch_size for tensor in values): + batch_size = torch.Size([]) + names = None return TensorDict( - dict(zip(context["keys"], values)), - context["batch_size"], - names=context["names"], + dict(zip(keys, values)), + batch_size=batch_size, + names=names, device=device, + _run_checks=False, ) -for cls in PYTREE_REGISTERED_TDS: +def _td_flatten_with_keys( + d: TensorDictBase, +): + items = tuple(d.items()) + if items: + keys, values = zip(*d.items()) + keys = list(keys) + values = list(values) + else: + keys = [] + values = [] + return [(MappingKey(k), v) for k, v in zip(keys, values)], { + "keys": keys, + "batch_size": d.batch_size, + "names": d.names, + "device": d.device, + } + + +@implement_for("torch", None, "2.3") +def _register_td_node(cls): register_pytree_node( cls, _tensordict_flatten, _tensordictdict_unflatten, ) + + +@implement_for("torch", "2.3") +def _register_td_node(cls): # noqa: F811 + register_pytree_node( + cls, + _tensordict_flatten, + _tensordictdict_unflatten, + flatten_with_keys_fn=_td_flatten_with_keys, + ) + + +for cls in PYTREE_REGISTERED_TDS: + _register_td_node(cls) diff --git a/tensordict/utils.py b/tensordict/utils.py index 1f9ac3d0f..1384b62e4 100644 --- a/tensordict/utils.py +++ b/tensordict/utils.py @@ -822,9 +822,10 @@ def __init__( implement_for._setters.append(self) @staticmethod - def check_version(version, from_version, to_version): - return (from_version is None or parse(version) >= parse(from_version)) and ( - to_version is None or parse(version) < parse(to_version) + def check_version(version: str, from_version: str | None, to_version: str | None): + version = parse(".".join([str(v) for v in parse(version).release])) + return (from_version is None or version >= parse(from_version)) and ( + to_version is None or version < parse(to_version) ) @staticmethod diff --git a/test/test_functorch.py b/test/test_functorch.py index 8f8396dc5..4091f19a8 100644 --- a/test/test_functorch.py +++ b/test/test_functorch.py @@ -8,7 +8,7 @@ import pytest import torch -from _utils_internal import expand_list, TestTensorDictsBase +from _utils_internal import expand_list, get_available_devices, TestTensorDictsBase from tensordict import LazyStackedTensorDict, TensorDict from tensordict.nn import TensorDictModule, TensorDictSequential @@ -17,6 +17,7 @@ make_functional, repopulate_module, ) +from tensordict.utils import implement_for from torch import nn from torch.nn import Linear from torch.utils._pytree import tree_map @@ -633,3 +634,32 @@ def test_pytree_vs_apply(self): for v1, v2 in zip(td_pytree.values(True), td_apply.values(True)): # recursively checks the shape, including for the nested tensordicts assert v1.shape == v2.shape + + @implement_for("torch", "2.3") + def test_map_with_path(self): + def assert_path(path, tensor): + assert path[0].key == "a" + assert path[1].key == "b" + assert path[2].key == "c" + return tensor + + td = TensorDict({"a": {"b": {"c": [1]}}}, [1]) + torch.utils._pytree.tree_map_with_path(assert_path, td) + + @implement_for("torch", None, "2.3") + def test_map_with_path(self): # noqa: F811 + pytest.skip(reason="tree_map_with_path not implemented") + + @pytest.mark.parametrize("dest", get_available_devices()) + def test_device_map(self, dest): + td = TensorDict({"a": {"b": {"c": [1]}, "d": [2]}}, [1], device="cpu") + td_device = tree_map(lambda x: x.to(dest), td) + if dest == torch.device("cpu"): + assert td_device.device == torch.device("cpu") + else: + assert td_device.device is None + + def test_shape_map(self): + td = TensorDict({"a": {"b": {"c": [1]}, "d": [2]}}, [1]) + td_no_shape = tree_map(lambda x: x.squeeze(), td) + assert td_no_shape.shape == torch.Size([])