Skip to content

Commit

Permalink
[Feature] from and to_pytree (#832)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Jun 25, 2024
1 parent 8522bb6 commit 9f942ba
Show file tree
Hide file tree
Showing 3 changed files with 203 additions and 8 deletions.
147 changes: 140 additions & 7 deletions tensordict/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import importlib
import json
import numbers
import uuid
import weakref
from collections.abc import MutableMapping

Expand Down Expand Up @@ -65,6 +66,8 @@
IndexType,
infer_size_impl,
int_generator,
is_namedtuple,
is_namedtuple_class,
is_non_tensor,
lazy_legacy,
lock_blocked,
Expand Down Expand Up @@ -795,6 +798,134 @@ def from_dict_instance(
"""
...

@classmethod
def from_pytree(
cls,
pytree,
*,
batch_size: torch.Size | None = None,
auto_batch_size: bool = False,
batch_dims: int | None = None,
):
"""Converts a pytree to a TensorDict instance.
This method is designed to keep the pytree nested structure as much as possible.
Additional non-tensor keys are added to keep track of each level's identity, providing
a built-in pytree-to-tensordict bijective transform API.
Accepted classes currently include lists, tuples, named tuples and dict.
.. note:: for dictionaries, non-NestedKey keys are registered separately as :class:`~tensordict.NonTensorData`
instances.
.. note:: Tensor-castable types (such as int, float or np.ndarray) will be converted to torch.Tensor instances.
NOte that this transformation is surjective: transforming back the tensordict to a pytree will not
recover the original types.
Examples:
>>> # Create a pytree with tensor leaves, and one "weird"-looking dict key
>>> class WeirdLookingClass:
... pass
...
>>> weird_key = WeirdLookingClass()
>>> # Make a pytree with tuple, lists, dict and namedtuple
>>> pytree = (
... [torch.randint(10, (3,)), torch.zeros(2)],
... {
... "tensor": torch.randn(
... 2,
... ),
... "td": TensorDict({"one": 1}),
... weird_key: torch.randint(10, (2,)),
... "list": [1, 2, 3],
... },
... {"named_tuple": TensorDict({"two": torch.ones(1) * 2}).to_namedtuple()},
... )
>>> # Build a TensorDict from that pytree
>>> td = TensorDict.from_pytree(pytree)
>>> # Recover the pytree
>>> pytree_recon = td.to_pytree()
>>> # Check that the leaves match
>>> def check(v1, v2):
>>> assert (v1 == v2).all()
>>>
>>> torch.utils._pytree.tree_map(check, pytree, pytree_recon)
>>> assert weird_key in pytree_recon[1]
"""
if is_tensor_collection(pytree):
return pytree
if isinstance(pytree, (torch.Tensor,)):
return pytree

from tensordict._td import TensorDict

result = None
if is_namedtuple(pytree):
result = TensorDict.from_namedtuple(named_tuple=pytree)
if batch_dims is not None:
result.batch_size = batch_size
result["_pytree_type"] = type(pytree)
elif isinstance(pytree, (list, tuple)):
source = {str(i): cls.from_pytree(elt) for i, elt in enumerate(pytree)}
source["_pytree_type"] = type(pytree)
result = TensorDict(source, batch_size=batch_size)
elif isinstance(pytree, dict):
source = {}
for key, item in pytree.items():
if isinstance(key, NestedKey):
source[key] = cls.from_pytree(item)
else:
subs_key = "<NON_NESTED>" + str(uuid.uuid1())
source[subs_key] = TensorDict(
{"value": cls.from_pytree(item), "key": key}
)
source["_pytree_type"] = type(pytree)
result = TensorDict(source, batch_size=batch_size)
if result is not None:
if auto_batch_size:
result.auto_batch_size_(batch_dims)
return result
if isinstance(pytree, (int, float, np.ndarray)):
return torch.as_tensor(pytree)
raise NotImplementedError(f"Unknown type {type(pytree)}.")

def to_pytree(self):
"""Converts a tensordict to a PyTree.
If the tensordict was not created from a pytree, this method just returns ``self`` without modification.
See :meth:`~.from_pytree` for more information and examples.
"""
_pytree_type = self._get_str("_pytree_type", default=None)
if _pytree_type is None:
return self
_pytree_type = _pytree_type.data
items = {key: val for (key, val) in self.items() if key != "_pytree_type"}
items = {
key: val if not is_tensor_collection(val) else val.to_pytree()
for key, val in items.items()
}
if _pytree_type in (list, tuple):
return _pytree_type((items[str(i)] for i in range(len(items))))
if _pytree_type is dict:
items = dict(
(
(val["key"], val["value"])
if key.startswith("<NON_NESTED>")
else (key, val)
for (key, val) in items.items()
)
)
return items
if is_namedtuple_class(_pytree_type):
from tensordict._td import TensorDict

return TensorDict(items).to_namedtuple(dest_cls=_pytree_type)
raise NotImplementedError(f"unknown type {_pytree_type}")

@classmethod
def from_h5(cls, filename, mode="r"):
"""Creates a PersistentTensorDict from a h5 file.
Expand Down Expand Up @@ -6788,9 +6919,12 @@ def to_numpy(x):

return torch.utils._pytree.tree_map(to_numpy, as_dict)

def to_namedtuple(self):
def to_namedtuple(self, dest_cls: type | None = None):
"""Converts a tensordict to a namedtuple.
Args:
dest_cls (Type, optional): an optional namedtuple class to use.
Examples:
>>> from tensordict import TensorDict
>>> import torch
Expand All @@ -6806,9 +6940,12 @@ def dict_to_namedtuple(dictionary):
for key, value in dictionary.items():
if isinstance(value, dict):
dictionary[key] = dict_to_namedtuple(value)
return collections.namedtuple("GenericDict", dictionary.keys())(
**dictionary
cls = (
collections.namedtuple("GenericDict", dictionary.keys())
if dest_cls is None
else dest_cls
)
return cls(**dictionary)

return dict_to_namedtuple(self.to_dict())

Expand Down Expand Up @@ -6847,10 +6984,6 @@ def from_namedtuple(cls, named_tuple, *, auto_batch_size: bool = False):
"""
from tensordict import TensorDict

def is_namedtuple(obj):
"""Check if obj is a namedtuple."""
return isinstance(obj, tuple) and hasattr(obj, "_fields")

def namedtuple_to_dict(namedtuple_obj):
if is_namedtuple(namedtuple_obj):
namedtuple_obj = namedtuple_obj._asdict()
Expand Down
37 changes: 36 additions & 1 deletion tensordict/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# LICENSE file in the root directory of this source tree.
from __future__ import annotations

import abc
import collections
import concurrent.futures
import inspect
Expand Down Expand Up @@ -129,7 +130,30 @@ def dims(self, *args, **kwargs):

IndexType = Union[None, int, slice, str, Tensor, List[Any], Tuple[Any, ...]]
DeviceType = Union[torch.device, str, int]
NestedKey = Union[str, Tuple[str, ...]]


class _NestedKeyMeta(abc.ABCMeta):
def __instancecheck__(self, instance):
return isinstance(instance, str) or (
isinstance(instance, tuple)
and len(instance)
and all(isinstance(subkey, NestedKey) for subkey in instance)
)


class NestedKey(metaclass=_NestedKeyMeta):
"""An abstract class for nested keys.
Nested keys are the generic key type accepted by TensorDict.
A nested key is either a string or a non-empty tuple of NestedKeys instances.
The NestedKey class supports instance checks.
"""

pass


_KEY_ERROR = 'key "{}" not found in {} with ' "keys {}"
_LOCK_ERROR = (
Expand Down Expand Up @@ -2274,3 +2298,14 @@ def __missing__(self, key):
value = self.fun(key)
self[key] = value
return value


def is_namedtuple(obj):
"""Check if obj is a namedtuple."""
return isinstance(obj, tuple) and hasattr(obj, "_fields")


def is_namedtuple_class(cls):
"""Check if a class is a namedtuple class."""
base_attrs = {"_fields", "_replace", "_asdict"}
return all(hasattr(cls, attr) for attr in base_attrs)
27 changes: 27 additions & 0 deletions test/test_tensordict.py
Original file line number Diff line number Diff line change
Expand Up @@ -891,6 +891,33 @@ def get_leaf(leaf):
assert p.grad is None
assert all(param.grad is not None for param in params.values(True, True))

def test_from_pytree(self):
class WeirdLookingClass:
pass

weird_key = WeirdLookingClass()

pytree = (
[torch.randint(10, (3,)), torch.zeros(2)],
{
"tensor": torch.randn(
2,
),
"td": TensorDict({"one": 1}),
weird_key: torch.randint(10, (2,)),
"list": [1, 2, 3],
},
{"named_tuple": TensorDict({"two": torch.ones(1) * 2}).to_namedtuple()},
)
td = TensorDict.from_pytree(pytree)
pytree_recon = td.to_pytree()

def check(v1, v2):
assert (v1 == v2).all()

torch.utils._pytree.tree_map(check, pytree, pytree_recon)
assert weird_key in pytree_recon[1]

@pytest.mark.parametrize(
"idx",
[
Expand Down

0 comments on commit 9f942ba

Please sign in to comment.