From 9f942bac60e3f5ae69b11f12638523c9c239ba71 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Tue, 25 Jun 2024 09:36:22 +0100 Subject: [PATCH] [Feature] from and to_pytree (#832) --- tensordict/base.py | 147 ++++++++++++++++++++++++++++++++++++++-- tensordict/utils.py | 37 +++++++++- test/test_tensordict.py | 27 ++++++++ 3 files changed, 203 insertions(+), 8 deletions(-) diff --git a/tensordict/base.py b/tensordict/base.py index 8978f2b7d..430a092b7 100644 --- a/tensordict/base.py +++ b/tensordict/base.py @@ -13,6 +13,7 @@ import importlib import json import numbers +import uuid import weakref from collections.abc import MutableMapping @@ -65,6 +66,8 @@ IndexType, infer_size_impl, int_generator, + is_namedtuple, + is_namedtuple_class, is_non_tensor, lazy_legacy, lock_blocked, @@ -795,6 +798,134 @@ def from_dict_instance( """ ... + @classmethod + def from_pytree( + cls, + pytree, + *, + batch_size: torch.Size | None = None, + auto_batch_size: bool = False, + batch_dims: int | None = None, + ): + """Converts a pytree to a TensorDict instance. + + This method is designed to keep the pytree nested structure as much as possible. + + Additional non-tensor keys are added to keep track of each level's identity, providing + a built-in pytree-to-tensordict bijective transform API. + + Accepted classes currently include lists, tuples, named tuples and dict. + + .. note:: for dictionaries, non-NestedKey keys are registered separately as :class:`~tensordict.NonTensorData` + instances. + + .. note:: Tensor-castable types (such as int, float or np.ndarray) will be converted to torch.Tensor instances. + NOte that this transformation is surjective: transforming back the tensordict to a pytree will not + recover the original types. + + Examples: + >>> # Create a pytree with tensor leaves, and one "weird"-looking dict key + >>> class WeirdLookingClass: + ... pass + ... + >>> weird_key = WeirdLookingClass() + >>> # Make a pytree with tuple, lists, dict and namedtuple + >>> pytree = ( + ... [torch.randint(10, (3,)), torch.zeros(2)], + ... { + ... "tensor": torch.randn( + ... 2, + ... ), + ... "td": TensorDict({"one": 1}), + ... weird_key: torch.randint(10, (2,)), + ... "list": [1, 2, 3], + ... }, + ... {"named_tuple": TensorDict({"two": torch.ones(1) * 2}).to_namedtuple()}, + ... ) + >>> # Build a TensorDict from that pytree + >>> td = TensorDict.from_pytree(pytree) + >>> # Recover the pytree + >>> pytree_recon = td.to_pytree() + >>> # Check that the leaves match + >>> def check(v1, v2): + >>> assert (v1 == v2).all() + >>> + >>> torch.utils._pytree.tree_map(check, pytree, pytree_recon) + >>> assert weird_key in pytree_recon[1] + + """ + if is_tensor_collection(pytree): + return pytree + if isinstance(pytree, (torch.Tensor,)): + return pytree + + from tensordict._td import TensorDict + + result = None + if is_namedtuple(pytree): + result = TensorDict.from_namedtuple(named_tuple=pytree) + if batch_dims is not None: + result.batch_size = batch_size + result["_pytree_type"] = type(pytree) + elif isinstance(pytree, (list, tuple)): + source = {str(i): cls.from_pytree(elt) for i, elt in enumerate(pytree)} + source["_pytree_type"] = type(pytree) + result = TensorDict(source, batch_size=batch_size) + elif isinstance(pytree, dict): + source = {} + for key, item in pytree.items(): + if isinstance(key, NestedKey): + source[key] = cls.from_pytree(item) + else: + subs_key = "" + str(uuid.uuid1()) + source[subs_key] = TensorDict( + {"value": cls.from_pytree(item), "key": key} + ) + source["_pytree_type"] = type(pytree) + result = TensorDict(source, batch_size=batch_size) + if result is not None: + if auto_batch_size: + result.auto_batch_size_(batch_dims) + return result + if isinstance(pytree, (int, float, np.ndarray)): + return torch.as_tensor(pytree) + raise NotImplementedError(f"Unknown type {type(pytree)}.") + + def to_pytree(self): + """Converts a tensordict to a PyTree. + + If the tensordict was not created from a pytree, this method just returns ``self`` without modification. + + See :meth:`~.from_pytree` for more information and examples. + + """ + _pytree_type = self._get_str("_pytree_type", default=None) + if _pytree_type is None: + return self + _pytree_type = _pytree_type.data + items = {key: val for (key, val) in self.items() if key != "_pytree_type"} + items = { + key: val if not is_tensor_collection(val) else val.to_pytree() + for key, val in items.items() + } + if _pytree_type in (list, tuple): + return _pytree_type((items[str(i)] for i in range(len(items)))) + if _pytree_type is dict: + items = dict( + ( + (val["key"], val["value"]) + if key.startswith("") + else (key, val) + for (key, val) in items.items() + ) + ) + return items + if is_namedtuple_class(_pytree_type): + from tensordict._td import TensorDict + + return TensorDict(items).to_namedtuple(dest_cls=_pytree_type) + raise NotImplementedError(f"unknown type {_pytree_type}") + @classmethod def from_h5(cls, filename, mode="r"): """Creates a PersistentTensorDict from a h5 file. @@ -6788,9 +6919,12 @@ def to_numpy(x): return torch.utils._pytree.tree_map(to_numpy, as_dict) - def to_namedtuple(self): + def to_namedtuple(self, dest_cls: type | None = None): """Converts a tensordict to a namedtuple. + Args: + dest_cls (Type, optional): an optional namedtuple class to use. + Examples: >>> from tensordict import TensorDict >>> import torch @@ -6806,9 +6940,12 @@ def dict_to_namedtuple(dictionary): for key, value in dictionary.items(): if isinstance(value, dict): dictionary[key] = dict_to_namedtuple(value) - return collections.namedtuple("GenericDict", dictionary.keys())( - **dictionary + cls = ( + collections.namedtuple("GenericDict", dictionary.keys()) + if dest_cls is None + else dest_cls ) + return cls(**dictionary) return dict_to_namedtuple(self.to_dict()) @@ -6847,10 +6984,6 @@ def from_namedtuple(cls, named_tuple, *, auto_batch_size: bool = False): """ from tensordict import TensorDict - def is_namedtuple(obj): - """Check if obj is a namedtuple.""" - return isinstance(obj, tuple) and hasattr(obj, "_fields") - def namedtuple_to_dict(namedtuple_obj): if is_namedtuple(namedtuple_obj): namedtuple_obj = namedtuple_obj._asdict() diff --git a/tensordict/utils.py b/tensordict/utils.py index 67cb80474..db9ff922f 100644 --- a/tensordict/utils.py +++ b/tensordict/utils.py @@ -4,6 +4,7 @@ # LICENSE file in the root directory of this source tree. from __future__ import annotations +import abc import collections import concurrent.futures import inspect @@ -129,7 +130,30 @@ def dims(self, *args, **kwargs): IndexType = Union[None, int, slice, str, Tensor, List[Any], Tuple[Any, ...]] DeviceType = Union[torch.device, str, int] -NestedKey = Union[str, Tuple[str, ...]] + + +class _NestedKeyMeta(abc.ABCMeta): + def __instancecheck__(self, instance): + return isinstance(instance, str) or ( + isinstance(instance, tuple) + and len(instance) + and all(isinstance(subkey, NestedKey) for subkey in instance) + ) + + +class NestedKey(metaclass=_NestedKeyMeta): + """An abstract class for nested keys. + + Nested keys are the generic key type accepted by TensorDict. + + A nested key is either a string or a non-empty tuple of NestedKeys instances. + + The NestedKey class supports instance checks. + + """ + + pass + _KEY_ERROR = 'key "{}" not found in {} with ' "keys {}" _LOCK_ERROR = ( @@ -2274,3 +2298,14 @@ def __missing__(self, key): value = self.fun(key) self[key] = value return value + + +def is_namedtuple(obj): + """Check if obj is a namedtuple.""" + return isinstance(obj, tuple) and hasattr(obj, "_fields") + + +def is_namedtuple_class(cls): + """Check if a class is a namedtuple class.""" + base_attrs = {"_fields", "_replace", "_asdict"} + return all(hasattr(cls, attr) for attr in base_attrs) diff --git a/test/test_tensordict.py b/test/test_tensordict.py index af5e3035d..e7f729b3c 100644 --- a/test/test_tensordict.py +++ b/test/test_tensordict.py @@ -891,6 +891,33 @@ def get_leaf(leaf): assert p.grad is None assert all(param.grad is not None for param in params.values(True, True)) + def test_from_pytree(self): + class WeirdLookingClass: + pass + + weird_key = WeirdLookingClass() + + pytree = ( + [torch.randint(10, (3,)), torch.zeros(2)], + { + "tensor": torch.randn( + 2, + ), + "td": TensorDict({"one": 1}), + weird_key: torch.randint(10, (2,)), + "list": [1, 2, 3], + }, + {"named_tuple": TensorDict({"two": torch.ones(1) * 2}).to_namedtuple()}, + ) + td = TensorDict.from_pytree(pytree) + pytree_recon = td.to_pytree() + + def check(v1, v2): + assert (v1 == v2).all() + + torch.utils._pytree.tree_map(check, pytree, pytree_recon) + assert weird_key in pytree_recon[1] + @pytest.mark.parametrize( "idx", [