From ec808a7e445acf6d573b4cd21668fe2108609993 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Tue, 21 May 2024 19:54:44 +0100 Subject: [PATCH 01/28] init --- tensordict/nn/common.py | 2 +- tensordict/nn/sequence.py | 2 +- tensordict/persistent.py | 2 +- tensordict/utils.py | 61 ++++++++++++++++++++++----------------- 4 files changed, 37 insertions(+), 30 deletions(-) diff --git a/tensordict/nn/common.py b/tensordict/nn/common.py index 79f99a224..6ed269393 100644 --- a/tensordict/nn/common.py +++ b/tensordict/nn/common.py @@ -1142,7 +1142,7 @@ def _call_module( return out @dispatch(auto_batch_size=False) - @set_skip_existing(None) + # @set_skip_existing(None) def forward( self, tensordict: TensorDictBase, diff --git a/tensordict/nn/sequence.py b/tensordict/nn/sequence.py index eb7e14cef..33c7cc01c 100644 --- a/tensordict/nn/sequence.py +++ b/tensordict/nn/sequence.py @@ -416,7 +416,7 @@ def _run_module( return tensordict @dispatch(auto_batch_size=False) - @set_skip_existing(None) + # @set_skip_existing(None) def forward( self, tensordict: TensorDictBase, diff --git a/tensordict/persistent.py b/tensordict/persistent.py index b1cf5eb60..ac7b68d46 100644 --- a/tensordict/persistent.py +++ b/tensordict/persistent.py @@ -27,7 +27,7 @@ NO_DEFAULT, TensorDict, ) -from tensordict.base import _default_is_leaf, is_tensor_collection, T, TensorDictBase +from tensordict.base import _default_is_leaf, is_tensor_collection, T, TensorDictBase, _is_leaf_nontensor from tensordict.memmap import MemoryMappedTensor from tensordict.utils import ( _CloudpickleWrapper, diff --git a/tensordict/utils.py b/tensordict/utils.py index 76dbf7cc4..9fe04f4f5 100644 --- a/tensordict/utils.py +++ b/tensordict/utils.py @@ -49,7 +49,7 @@ from packaging.version import parse from tensordict._contextlib import _DecoratorContextManager from tensordict._tensordict import ( # noqa: F401 - _unravel_key_to_tuple, + # _unravel_key_to_tuple, unravel_key, unravel_key_list, unravel_keys, @@ -1129,32 +1129,32 @@ def __contains__(self, item): item = unravel_item[0] return super().__contains__(item) - -class _StringOnlyDict(dict): - """A dict class where contains is restricted to strings.""" - - # kept here for debugging - # def __setitem__(self, key, value): - # if not isinstance(key, str): - # raise RuntimeError - # return super().__setitem__(key, value) - - def __contains__(self, item): - if not isinstance(item, str): - try: - unravel_item = _unravel_key_to_tuple(item) - if not unravel_item: # catch errors during unravel - raise TypeError - except Exception: - raise TypeError(_NON_STR_KEY_ERR) - if len(unravel_item) > 1: - raise TypeError(_NON_STR_KEY_TUPLE_ERR) - else: - item = unravel_item[0] - return super().__contains__(item) - - def keys(self): - return _StringKeys(self) +_StringOnlyDict = dict +# class _StringOnlyDict(dict): +# """A dict class where contains is restricted to strings.""" +# +# # kept here for debugging +# # def __setitem__(self, key, value): +# # if not isinstance(key, str): +# # raise RuntimeError +# # return super().__setitem__(key, value) +# +# def __contains__(self, item): +# if not isinstance(item, str): +# try: +# unravel_item = _unravel_key_to_tuple(item) +# if not unravel_item: # catch errors during unravel +# raise TypeError +# except Exception: +# raise TypeError(_NON_STR_KEY_ERR) +# if len(unravel_item) > 1: +# raise TypeError(_NON_STR_KEY_TUPLE_ERR) +# else: +# item = unravel_item[0] +# return super().__contains__(item) +# +# def keys(self): +# return _StringKeys(self) def lock_blocked(func): @@ -2248,3 +2248,10 @@ def __missing__(self, key): value = self.fun(key) self[key] = value return value + +def _unravel_key_to_tuple(key): + if isinstance(key, str): + return (key,) + if not isinstance(key, tuple): + return () + return tuple(subk for k in key for subk in _unravel_key_to_tuple(k)) From 34c9107f7dc5ed3b70c8dc649a1f8f00e6cdba76 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Tue, 21 May 2024 20:39:34 +0100 Subject: [PATCH 02/28] amend --- tensordict/_td.py | 17 ++++++++++- tensordict/base.py | 2 +- tensordict/nn/common.py | 59 +++++++++++++++++++++++++-------------- tensordict/nn/sequence.py | 27 ++++++++++++++---- tensordict/utils.py | 2 ++ 5 files changed, 78 insertions(+), 29 deletions(-) diff --git a/tensordict/_td.py b/tensordict/_td.py index e66acbc99..b269805af 100644 --- a/tensordict/_td.py +++ b/tensordict/_td.py @@ -2644,7 +2644,7 @@ def _clone(self, recurse: bool = True) -> T: source={key: _clone_value(value, recurse) for key, value in self.items()}, batch_size=self.batch_size, device=self.device, - names=copy(self._td_dim_names), + names=[name for name in self._td_dim_names] if self._has_names() else None, _run_checks=False, ) # If this is uncommented, a shallow copy of a shared/memmap will be shared and locked too @@ -3784,18 +3784,33 @@ def __init__( is_leaf = _default_is_leaf self.is_leaf = is_leaf + # self.items = [] + # self._iter() + # + # def __getitem__(self, i): + # return self.items[i] + def __iter__(self) -> Iterable[str] | Iterable[tuple[str, ...]]: + # yield from self.items + # + # def _iter(self): if not self.include_nested: if self.leaves_only: for key in self._keys(): target_class = self.tensordict.entry_class(key) if _is_tensor_collection(target_class): continue + # self.items.append(key) yield key else: yield from self._keys() + # self.items += self._keys() else: yield from ( + # key if len(key) > 1 else key[0] + # for key in self._iter_helper(self.tensordict) + # ) + # self.items += ( key if len(key) > 1 else key[0] for key in self._iter_helper(self.tensordict) ) diff --git a/tensordict/base.py b/tensordict/base.py index 51c7bc519..4d3093282 100644 --- a/tensordict/base.py +++ b/tensordict/base.py @@ -90,7 +90,7 @@ def __bool__(self): return False -NO_DEFAULT = _NoDefault() +NO_DEFAULT = "-TO_REPLACE-" # _NoDefault() class _NestedTensorsAsLists: diff --git a/tensordict/nn/common.py b/tensordict/nn/common.py index 6ed269393..415fe6f9c 100644 --- a/tensordict/nn/common.py +++ b/tensordict/nn/common.py @@ -501,6 +501,11 @@ def __new__(cls, *args, **kwargs): out = super().__new__(cls) return out + @staticmethod + def is_tdmodule_compatible(module): + """Checks if a module is compatible with TensorDictModule API.""" + return hasattr(module, "in_keys") and hasattr(module, "out_keys") + @property def in_keys(self): return self._in_keys @@ -1242,17 +1247,17 @@ def __repr__(self) -> str: return f"{self.__class__.__name__}(\n{fields})" - def __getattr__(self, name: str) -> Any: - try: - return super().__getattr__(name) - except AttributeError as err1: - if not name.startswith("_"): - # no fallback for private attributes - try: - return getattr(super().__getattr__("module"), name) - except Exception as err2: - raise err2 from err1 - raise + # def __getattr__(self, name: str) -> Any: + # try: + # return super().__getattr__(name) + # except AttributeError as err1: + # if not name.startswith("_"): + # # no fallback for private attributes + # try: + # return getattr(super().__getattr__("module"), name) + # except Exception as err2: + # raise err2 from err1 + # raise def __getstate__(self): state = self.__dict__.copy() @@ -1285,16 +1290,28 @@ def __init__(self, td_module: TensorDictModule) -> None: for pre_hook in self.td_module._forward_hooks: self.register_forward_hook(self.td_module._forward_hooks[pre_hook]) - def __getattr__(self, name: str) -> Any: - try: - return super().__getattr__(name) - except AttributeError: - if name not in self.__dict__ and not name.startswith("__"): - return getattr(self._modules["td_module"], name) - else: - raise AttributeError( - f"attribute {name} not recognised in {type(self).__name__}" - ) + # def __getattr__(self, name: str) -> Any: + # try: + # return super().__getattr__(name) + # except AttributeError: + # if name not in self.__dict__ and not name.startswith("__"): + # return getattr(self._modules["td_module"], name) + # else: + # raise AttributeError( + # f"attribute {name} not recognised in {type(self).__name__}" + # ) def forward(self, *args: Any, **kwargs: Any) -> TensorDictBase: return self.td_module.forward(*args, **kwargs) + + +class WrapModule(TensorDictModuleBase): + in_keys = [] + out_keys = [] + + def __init__(self, func): + self.func = func + super().__init__() + + def forward(self, data): + return self.func(data) diff --git a/tensordict/nn/sequence.py b/tensordict/nn/sequence.py index 33c7cc01c..40b9c266f 100644 --- a/tensordict/nn/sequence.py +++ b/tensordict/nn/sequence.py @@ -10,7 +10,18 @@ from copy import deepcopy from typing import Any, Iterable +from tensordict._tensordict import unravel_key_list +from tensordict.nn.common import ( + dispatch, + TensorDictModule, + TensorDictModuleBase, + WrapModule, +) + from tensordict.nn.utils import set_skip_existing +from tensordict.tensordict import LazyStackedTensorDict, TensorDictBase +from tensordict.utils import NestedKey +from torch import nn _has_functorch = False try: @@ -26,12 +37,6 @@ FUNCTORCH_ERROR = "functorch not installed. Consider installing functorch to use this functionality." -from tensordict._tensordict import unravel_key_list -from tensordict.nn.common import dispatch, TensorDictModule -from tensordict.tensordict import LazyStackedTensorDict, TensorDictBase -from tensordict.utils import NestedKey -from torch import nn - __all__ = ["TensorDictSequential"] @@ -158,6 +163,7 @@ def __init__( *modules: TensorDictModule, partial_tolerant: bool = False, ) -> None: + modules = self._convert_modules(modules) in_keys, out_keys = self._compute_in_and_out_keys(modules) super().__init__( @@ -166,6 +172,15 @@ def __init__( self.partial_tolerant = partial_tolerant + @staticmethod + def _convert_modules(modules): + return [ + WrapModule(module) + if not TensorDictModuleBase.is_tdmodule_compatible(module) + else module + for module in modules + ] + def _compute_in_and_out_keys( self, modules: list[TensorDictModule] ) -> tuple[list[NestedKey], list[NestedKey]]: diff --git a/tensordict/utils.py b/tensordict/utils.py index 9fe04f4f5..b63b606a9 100644 --- a/tensordict/utils.py +++ b/tensordict/utils.py @@ -1129,6 +1129,7 @@ def __contains__(self, item): item = unravel_item[0] return super().__contains__(item) + _StringOnlyDict = dict # class _StringOnlyDict(dict): # """A dict class where contains is restricted to strings.""" @@ -2249,6 +2250,7 @@ def __missing__(self, key): self[key] = value return value + def _unravel_key_to_tuple(key): if isinstance(key, str): return (key,) From f26209d4db2f977a629a1b8b9f56783446e6d883 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Wed, 22 May 2024 14:59:27 -0400 Subject: [PATCH 03/28] amend --- tensordict/_td.py | 2 +- tensordict/nn/common.py | 42 ++++++++++++++++++++--------------------- 2 files changed, 22 insertions(+), 22 deletions(-) diff --git a/tensordict/_td.py b/tensordict/_td.py index b269805af..b6960767f 100644 --- a/tensordict/_td.py +++ b/tensordict/_td.py @@ -2644,7 +2644,7 @@ def _clone(self, recurse: bool = True) -> T: source={key: _clone_value(value, recurse) for key, value in self.items()}, batch_size=self.batch_size, device=self.device, - names=[name for name in self._td_dim_names] if self._has_names() else None, + names=list(self._td_dim_names) if self._has_names() else None, _run_checks=False, ) # If this is uncommented, a shallow copy of a shared/memmap will be shared and locked too diff --git a/tensordict/nn/common.py b/tensordict/nn/common.py index 415fe6f9c..0c2f0192b 100644 --- a/tensordict/nn/common.py +++ b/tensordict/nn/common.py @@ -1247,17 +1247,17 @@ def __repr__(self) -> str: return f"{self.__class__.__name__}(\n{fields})" - # def __getattr__(self, name: str) -> Any: - # try: - # return super().__getattr__(name) - # except AttributeError as err1: - # if not name.startswith("_"): - # # no fallback for private attributes - # try: - # return getattr(super().__getattr__("module"), name) - # except Exception as err2: - # raise err2 from err1 - # raise + def __getattr__(self, name: str) -> Any: + try: + return super().__getattr__(name) + except AttributeError as err1: + if not name.startswith("_"): + # no fallback for private attributes + try: + return getattr(super().__getattr__("module"), name) + except Exception as err2: + raise err2 from err1 + raise def __getstate__(self): state = self.__dict__.copy() @@ -1290,16 +1290,16 @@ def __init__(self, td_module: TensorDictModule) -> None: for pre_hook in self.td_module._forward_hooks: self.register_forward_hook(self.td_module._forward_hooks[pre_hook]) - # def __getattr__(self, name: str) -> Any: - # try: - # return super().__getattr__(name) - # except AttributeError: - # if name not in self.__dict__ and not name.startswith("__"): - # return getattr(self._modules["td_module"], name) - # else: - # raise AttributeError( - # f"attribute {name} not recognised in {type(self).__name__}" - # ) + def __getattr__(self, name: str) -> Any: + try: + return super().__getattr__(name) + except AttributeError: + if name not in self.__dict__ and not name.startswith("__"): + return getattr(self._modules["td_module"], name) + else: + raise AttributeError( + f"attribute {name} not recognised in {type(self).__name__}" + ) def forward(self, *args: Any, **kwargs: Any) -> TensorDictBase: return self.td_module.forward(*args, **kwargs) From 2c644d3e4af7db144e9456063edab48218f9aea4 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Wed, 22 May 2024 18:16:16 -0400 Subject: [PATCH 04/28] amend --- tensordict/_lazy.py | 3 ++- tensordict/nn/common.py | 1 - tensordict/nn/sequence.py | 1 - 3 files changed, 2 insertions(+), 3 deletions(-) diff --git a/tensordict/_lazy.py b/tensordict/_lazy.py index 473dcae25..66e2f0a18 100644 --- a/tensordict/_lazy.py +++ b/tensordict/_lazy.py @@ -33,7 +33,7 @@ _has_funcdim = False from tensordict._td import _SubTensorDict, _TensorDictKeysView, TensorDict -from tensordict._tensordict import _unravel_key_to_tuple, unravel_key_list +from tensordict._tensordict import unravel_key_list from tensordict.base import ( _is_tensor_collection, _NESTED_TENSORS_AS_LISTS, @@ -54,6 +54,7 @@ _renamed_inplace_method, _shape, _td_fields, + _unravel_key_to_tuple, as_decorator, cache, convert_ellipsis_to_idx, diff --git a/tensordict/nn/common.py b/tensordict/nn/common.py index 0c2f0192b..38203aa9b 100644 --- a/tensordict/nn/common.py +++ b/tensordict/nn/common.py @@ -27,7 +27,6 @@ from tensordict.nn.utils import ( _auto_make_functional, _dispatch_td_nn_modules, - set_skip_existing, ) from tensordict.utils import implement_for, NestedKey from torch import nn, Tensor diff --git a/tensordict/nn/sequence.py b/tensordict/nn/sequence.py index 40b9c266f..a87e63d72 100644 --- a/tensordict/nn/sequence.py +++ b/tensordict/nn/sequence.py @@ -18,7 +18,6 @@ WrapModule, ) -from tensordict.nn.utils import set_skip_existing from tensordict.tensordict import LazyStackedTensorDict, TensorDictBase from tensordict.utils import NestedKey from torch import nn From 93ed29a01edb678273abdf17dc2d16209c1a6d6d Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Thu, 23 May 2024 19:23:01 -0400 Subject: [PATCH 05/28] amend --- tensordict/nn/common.py | 77 ++++++++++++++++++++++++++++----------- tensordict/nn/sequence.py | 3 +- tensordict/nn/utils.py | 51 +++++++++++++++++++++++++- 3 files changed, 108 insertions(+), 23 deletions(-) diff --git a/tensordict/nn/common.py b/tensordict/nn/common.py index 38203aa9b..61ef5fed6 100644 --- a/tensordict/nn/common.py +++ b/tensordict/nn/common.py @@ -26,7 +26,8 @@ ) from tensordict.nn.utils import ( _auto_make_functional, - _dispatch_td_nn_modules, + _dispatch_td_nn_modules, set_skip_existing, + _set_skip_existing_None, ) from tensordict.utils import implement_for, NestedKey from torch import nn, Tensor @@ -1146,7 +1147,7 @@ def _call_module( return out @dispatch(auto_batch_size=False) - # @set_skip_existing(None) + @_set_skip_existing_None() def forward( self, tensordict: TensorDictBase, @@ -1247,16 +1248,33 @@ def __repr__(self) -> str: return f"{self.__class__.__name__}(\n{fields})" def __getattr__(self, name: str) -> Any: - try: - return super().__getattr__(name) - except AttributeError as err1: - if not name.startswith("_"): - # no fallback for private attributes - try: - return getattr(super().__getattr__("module"), name) - except Exception as err2: - raise err2 from err1 - raise + __dict__ = self.__dict__ + _parameters = __dict__.get("_parameters") + if _parameters: + # A param can be None so we use False instead to check if the key + # is in the _parameters dict once and only once. + # We use False but any non-None, non-Parameter value would do. + # The alternative `if name in _parameters: return _parameters[name]` + # accesses the value of `name` twice when only one is required + result = _parameters.get(name, False) + if result is not False: + return result + _buffers = __dict__.get("_buffers") + if _buffers: + result = _buffers.get(name, False) + if result is not False: + return result + _modules = __dict__.get("_modules") + if _modules: + result = _modules.get(name, False) + if result is not False: + return result + if not name.startswith("_"): + # no fallback for private attributes + return getattr(super().__getattr__("module"), name) + raise AttributeError( + f"module {self.__class__.__name__} has no attribute named {name}." + ) def __getstate__(self): state = self.__dict__.copy() @@ -1290,15 +1308,32 @@ def __init__(self, td_module: TensorDictModule) -> None: self.register_forward_hook(self.td_module._forward_hooks[pre_hook]) def __getattr__(self, name: str) -> Any: - try: - return super().__getattr__(name) - except AttributeError: - if name not in self.__dict__ and not name.startswith("__"): - return getattr(self._modules["td_module"], name) - else: - raise AttributeError( - f"attribute {name} not recognised in {type(self).__name__}" - ) + __dict__ = self.__dict__ + _parameters = __dict__.get("_parameters") + if _parameters: + # A param can be None so we use False instead to check if the key + # is in the _parameters dict once and only once. + # We use False but any non-None, non-Parameter value would do. + # The alternative `if name in _parameters: return _parameters[name]` + # accesses the value of `name` twice when only one is required + result = _parameters.get(name, False) + if result is not False: + return result + _buffers = __dict__.get("_buffers") + if _buffers: + result = _buffers.get(name, False) + if result is not False: + return result + _modules = __dict__.get("_modules") + if _modules: + result = _modules.get(name, False) + if result is not False: + return result + if name not in self.__dict__ and not name.startswith("__"): + return getattr(self._modules["td_module"], name) + raise AttributeError( + f"attribute {name} not recognised in {type(self).__name__}" + ) def forward(self, *args: Any, **kwargs: Any) -> TensorDictBase: return self.td_module.forward(*args, **kwargs) diff --git a/tensordict/nn/sequence.py b/tensordict/nn/sequence.py index a87e63d72..8c8ec3323 100644 --- a/tensordict/nn/sequence.py +++ b/tensordict/nn/sequence.py @@ -20,6 +20,7 @@ from tensordict.tensordict import LazyStackedTensorDict, TensorDictBase from tensordict.utils import NestedKey +from tensordict.nn.utils import _set_skip_existing_None from torch import nn _has_functorch = False @@ -430,7 +431,7 @@ def _run_module( return tensordict @dispatch(auto_batch_size=False) - # @set_skip_existing(None) + @_set_skip_existing_None() def forward( self, tensordict: TensorDictBase, diff --git a/tensordict/nn/utils.py b/tensordict/nn/utils.py index 105be6538..3525ab1a5 100644 --- a/tensordict/nn/utils.py +++ b/tensordict/nn/utils.py @@ -269,7 +269,7 @@ def __enter__(self) -> None: _SKIP_EXISTING = self.mode elif not self._called: raise RuntimeError( - f"It seems you are using {self.__class__.__name__} as a decorator with ``None`` input. " + f"It seems you are using {self.__class__.__name__} as a context manager with ``None`` input. " f"This behaviour is not allowed." ) @@ -277,6 +277,55 @@ def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: global _SKIP_EXISTING _SKIP_EXISTING = self.prev +class _set_skip_existing_None(set_skip_existing): + """A version of skip_existing that is constant wrt init inputs (for torch.compile compatibility). + + This class should only be used as a decorator, not a context manager. + """ + def __call__(self, func: Callable): + + self._called = True + + # sanity check + for i, key in enumerate(inspect.signature(func).parameters): + if i == 0: + # skip self + continue + if key != "tensordict": + raise RuntimeError( + "the first argument of the wrapped function must be " + "named 'tensordict'." + ) + break + + @functools.wraps(func) + def wrapper(_self, tensordict, *args: Any, **kwargs: Any) -> Any: + in_keys = getattr(_self, self.in_key_attr) + out_keys = getattr(_self, self.out_key_attr) + # we use skip_existing to allow users to override the mode internally + if ( + skip_existing() + and all(key in tensordict.keys(True) for key in out_keys) + and not any(key in out_keys for key in in_keys) + ): + return tensordict + global _SKIP_EXISTING + self.prev = _SKIP_EXISTING + try: + result = func(_self, tensordict, *args, **kwargs) + finally: + _SKIP_EXISTING = self.prev + return result + + return wrapper + + in_key_attr="in_keys" + out_key_attr="out_keys" + __init__ = object.__init__ + def clone(self) -> _set_skip_existing_None: + # override this method if your children class takes __init__ parameters + out = self.__class__() + return out def skip_existing(): """Returns whether or not existing entries in a tensordict should be re-computed by a module.""" From 8f1eb8e2c2c278fbfb4f80b332e6575a713bda73 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Fri, 24 May 2024 10:19:58 -0400 Subject: [PATCH 06/28] amend --- tensordict/__init__.py | 2 + tensordict/_torch_func.py | 19 ++-- tensordict/base.py | 46 +++++++++ tensordict/nn/common.py | 2 +- tensordict/nn/sequence.py | 2 +- tensordict/nn/utils.py | 8 +- tensordict/utils.py | 44 +++++++- test/test_compile.py | 211 ++++++++++++++++++++++++++++++++++++++ 8 files changed, 320 insertions(+), 14 deletions(-) create mode 100644 test/test_compile.py diff --git a/tensordict/__init__.py b/tensordict/__init__.py index 177b595d1..bc91825d8 100644 --- a/tensordict/__init__.py +++ b/tensordict/__init__.py @@ -18,6 +18,7 @@ from tensordict.tensorclass import NonTensorData, NonTensorStack, tensorclass from tensordict.utils import ( assert_allclose_td, + assert_close, is_batchedtensor, is_tensorclass, lazy_legacy, @@ -42,6 +43,7 @@ "TensorDict", "TensorDictBase", "assert_allclose_td", + "assert_close", "dense_stack_tds", "is_batchedtensor", "is_tensor_collection", diff --git a/tensordict/_torch_func.py b/tensordict/_torch_func.py index e205cd102..42bdfa65f 100644 --- a/tensordict/_torch_func.py +++ b/tensordict/_torch_func.py @@ -5,6 +5,7 @@ from __future__ import annotations +import contextlib import functools from typing import Any, Callable, Sequence, TypeVar @@ -264,10 +265,16 @@ def _cat( if out is None: out = {} for key in keys: - with _ErrorInteceptor( - key, "Attempted to concatenate tensors on different devices at key" - ): - out[key] = torch.cat( + if not torch._dynamo.is_compiling(): + with _ErrorInteceptor( + key, "Attempted to concatenate tensors on different devices at key" + ): + out[key] = torch.cat( + [td._get_str(key, NO_DEFAULT) for td in list_of_tensordicts], + dim, + ) + else: + out[key] = TensorDict.cat( [td._get_str(key, NO_DEFAULT) for td in list_of_tensordicts], dim ) if device is None: @@ -295,7 +302,7 @@ def _cat( for key in keys: with _ErrorInteceptor( key, "Attempted to concatenate tensors on different devices at key" - ): + ) if not torch._dynamo.is_compiling() else contextlib.nullcontext(): if isinstance(out, TensorDict): torch.cat( [td.get(key) for td in list_of_tensordicts], @@ -491,7 +498,7 @@ def stack_fn(key, values, is_not_init, is_tensor): return torch.stack(values, dim) with _ErrorInteceptor( key, "Attempted to stack tensors on different devices at key" - ): + ) if not torch._dynamo.is_compiling() else contextlib.nullcontext(): return _stack(values, dim, maybe_dense_stack=maybe_dense_stack) out = { diff --git a/tensordict/base.py b/tensordict/base.py index 035fb22a7..edf571909 100644 --- a/tensordict/base.py +++ b/tensordict/base.py @@ -1402,6 +1402,52 @@ def reshape( """ ... + @classmethod + def stack(cls, input, dim=0, *, out=None): + """Stacks tensordicts into a single tensordict along the given dimension. + + This call is equivalent to calling :func:`torch.stack` but is compatible with torch.compile. + + """ + from tensordict._torch_func import _stack + + if not _is_tensor_collection(type(input[0])): + return torch.stack(input, dim, out=out) + return _stack(input, dim, out=out) + + @classmethod + def cat(cls, input, dim=0, *, out=None): + """Concatenates tensordicts into a single tensordict along the given dimension. + + This call is equivalent to calling :func:`torch.cat` but is compatible with torch.compile. + + """ + from tensordict._torch_func import _cat + + if not _is_tensor_collection(type(input[0])): + return torch.cat(input, dim, out=out) + return _cat(input, dim, out=out) + + @classmethod + def lazy_stack(cls, input, dim=0, *, out=None): + """Creates a lazy stack of tensordicts. + + See :meth:`~tensordict.LazyStackTensorDict.lazy_stack` for details. + """ + from tensordict._lazy import LazyStackedTensorDict + + return LazyStackedTensorDict.lazy_stack(input, dim=dim, out=out) + + @classmethod + def maybe_dense_stack(cls, input, dim=0, *, out=None): + """Attempts to make a dense stack of tensordicts, and falls back on lazy stack when required.. + + See :meth:`~tensordict.LazyStackTensorDict.maybe_dense_stack` for details. + """ + from tensordict._lazy import LazyStackedTensorDict + + return LazyStackedTensorDict.maybe_dense_stack(input, dim=dim, out=out) + @abc.abstractmethod def split(self, split_size: int | list[int], dim: int = 0) -> list[TensorDictBase]: """Splits each tensor in the TensorDict with the specified size in the given dimension, like `torch.split`. diff --git a/tensordict/nn/common.py b/tensordict/nn/common.py index 61ef5fed6..3d09eacd1 100644 --- a/tensordict/nn/common.py +++ b/tensordict/nn/common.py @@ -26,7 +26,7 @@ ) from tensordict.nn.utils import ( _auto_make_functional, - _dispatch_td_nn_modules, set_skip_existing, + _dispatch_td_nn_modules, _set_skip_existing_None, ) from tensordict.utils import implement_for, NestedKey diff --git a/tensordict/nn/sequence.py b/tensordict/nn/sequence.py index 8c8ec3323..049e76744 100644 --- a/tensordict/nn/sequence.py +++ b/tensordict/nn/sequence.py @@ -17,10 +17,10 @@ TensorDictModuleBase, WrapModule, ) +from tensordict.nn.utils import _set_skip_existing_None from tensordict.tensordict import LazyStackedTensorDict, TensorDictBase from tensordict.utils import NestedKey -from tensordict.nn.utils import _set_skip_existing_None from torch import nn _has_functorch = False diff --git a/tensordict/nn/utils.py b/tensordict/nn/utils.py index 3525ab1a5..a2fdb356b 100644 --- a/tensordict/nn/utils.py +++ b/tensordict/nn/utils.py @@ -277,11 +277,13 @@ def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: global _SKIP_EXISTING _SKIP_EXISTING = self.prev + class _set_skip_existing_None(set_skip_existing): """A version of skip_existing that is constant wrt init inputs (for torch.compile compatibility). This class should only be used as a decorator, not a context manager. """ + def __call__(self, func: Callable): self._called = True @@ -319,14 +321,16 @@ def wrapper(_self, tensordict, *args: Any, **kwargs: Any) -> Any: return wrapper - in_key_attr="in_keys" - out_key_attr="out_keys" + in_key_attr = "in_keys" + out_key_attr = "out_keys" __init__ = object.__init__ + def clone(self) -> _set_skip_existing_None: # override this method if your children class takes __init__ parameters out = self.__class__() return out + def skip_existing(): """Returns whether or not existing entries in a tensordict should be re-computed by a module.""" return _SKIP_EXISTING diff --git a/tensordict/utils.py b/tensordict/utils.py index 810917549..90ebb1683 100644 --- a/tensordict/utils.py +++ b/tensordict/utils.py @@ -1382,7 +1382,7 @@ def _get_leaf_tensordict( return tensordict, key[0] -def assert_allclose_td( +def assert_close( actual: T, expected: T, rtol: float | None = None, @@ -1787,8 +1787,14 @@ def _getitem_batch_size(batch_size, index): boolean = False if isinstance(idx, (range, list)): shape = len(idx) - elif isinstance(idx, (torch.Tensor, np.ndarray)): - if idx.dtype == torch.bool or idx.dtype == np.dtype("bool"): + elif isinstance(idx, torch.Tensor): + if idx.dtype == torch.bool: + shape = torch.Size([idx.sum()]) + boolean = True + else: + shape = idx.shape + elif isinstance(idx, np.ndarray): + if idx.dtype == np.dtype("bool"): shape = torch.Size([idx.sum()]) boolean = True else: @@ -1828,7 +1834,8 @@ def _getitem_batch_size(batch_size, index): continue elif isinstance(idx, slice): batch = batch_size[count] - out.append(len(range(*idx.indices(batch)))) + out.append(len(range(*_slice_indices(idx, batch)))) + # out.append(len(range(*idx.indices(batch)))) count += 1 if batch_size[count:]: out.extend(batch_size[count:]) @@ -2269,3 +2276,32 @@ def _unravel_key_to_tuple(key): if not isinstance(key, tuple): return () return tuple(subk for k in key for subk in _unravel_key_to_tuple(k)) + + +def _slice_indices(index: slice, len: int): + """A pure python implementation of slice.indices(len) since torch.compile doesn't recognise it.""" + start = index.start + if start is None: + start = 0 + stop = index.stop + if stop is None: + stop = len + step = index.step + if step is None: + step = 1 + elif step == 0: + raise ValueError("Step cannot be zero.") + + if step > 0: + # Forward direction + lower = start + upper = stop - step + 1 + else: + # Backward direction + lower = start + step - 1 + upper = stop + upper = min(upper, len) + return lower, upper, step + + +assert_allclose_td = assert_close diff --git a/test/test_compile.py b/test/test_compile.py new file mode 100644 index 000000000..12b1e14f6 --- /dev/null +++ b/test/test_compile.py @@ -0,0 +1,211 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +import argparse + +import pytest + +import torch +from tensordict import assert_close, nn, TensorDict +from tensordict.nn import TensorDictModule as Mod, TensorDictSequential as Seq + + +class TestTD: + def test_tensor_output(self): + def add_one(td): + return td["a", "b"] + 1 + + add_one_c = torch.compile(add_one, fullgraph=True) + data = TensorDict({"a": {"b": 0}}) + assert add_one(data) == 1 + assert add_one_c(data) == 1 + assert add_one_c(data + 1) == 2 + + def test_td_output(self): + def add_one(td): + td["a", "c"] = td["a", "b"] + 1 + return td + + add_one_c = torch.compile(add_one, fullgraph=True) + data = TensorDict({"a": {"b": 0}}) + assert add_one(data.clone())["a", "c"] == 1 + assert add_one_c(data.clone())["a", "c"] == 1 + assert add_one_c(data) is data + + @pytest.mark.parametrize("index_type", ["slice", "tensor", "int"]) + def test_td_index(self, index_type): + if index_type == "slice": + + def add_one(td): + return td[:2] + 1 + + elif index_type == "tensor": + + def add_one(td): + return td[torch.tensor([0, 1])] + 1 + + elif index_type == "int": + + def add_one(td): + return td[0] + 1 + + add_one_c = torch.compile(add_one, fullgraph=True) + data = TensorDict({"a": {"b": torch.arange(3)}}, [3]) + if index_type == "int": + assert (add_one(data)["a", "b"] == 1).all() + assert (add_one_c(data)["a", "b"] == 1).all() + assert add_one_c(data).shape == torch.Size([]) + else: + assert (add_one(data)["a", "b"] == torch.arange(1, 3)).all() + assert (add_one_c(data)["a", "b"] == torch.arange(1, 3)).all() + assert add_one_c(data).shape == torch.Size([2]) + + def test_stack(self): + def stack_tds(td0, td1): + return TensorDict.stack([td0, td1]) + # return torch.stack([td0, td1]) + + stack_tds_c = torch.compile(stack_tds, fullgraph=True) + data0 = TensorDict({"a": {"b": torch.arange(3)}}, [3]) + data1 = TensorDict({"a": {"b": torch.arange(3)}}, [3]) + assert (stack_tds(data0, data1) == stack_tds_c(data0, data1)).all() + + def test_cat(self): + def cat_tds(td0, td1): + return TensorDict.cat([td0, td1]) + + cat_tds_c = torch.compile(cat_tds, fullgraph=True) + data0 = TensorDict({"a": {"b": torch.arange(3)}}, [3]) + data1 = TensorDict({"a": {"b": torch.arange(3)}}, [3]) + assert (cat_tds(data0, data1) == cat_tds_c(data0, data1)).all() + + def test_reshape(self): + def reshape(td): + return td.reshape(2, 2) + + reshape_c = torch.compile(reshape, fullgraph=True) + data = TensorDict({"a": {"b": torch.arange(4)}}, [4]) + assert (reshape(data) == reshape_c(data)).all() + + def test_unbind(self): + def unbind(td): + return td.unbind(0) + + unbind_c = torch.compile(unbind, fullgraph=True) + data = TensorDict({"a": {"b": torch.arange(4)}}, [4]) + assert (unbind(data)[-1] == unbind_c(data)[-1]).all() + + @pytest.mark.parametrize("recurse", [True, False]) + def test_clone(self, recurse): + def clone(td: TensorDict): + return td.clone(recurse=recurse) + + clone_c = torch.compile(clone, fullgraph=True) + data = TensorDict({"a": {"b": 0, "c": 1}}) + assert_close(clone_c(data), clone(data)) + assert clone_c(data) is not data + if recurse: + assert clone_c(data)["a", "b"] is not data["a", "b"] + else: + assert clone_c(data)["a", "b"] is data["a", "b"] + + +class TestTC: + pass + + +class TestNN: + def test_func(self): + td = TensorDict({"a": 0}) + module = Mod( + lambda x: x + 1, in_keys=[(((("a",),),),)], out_keys=[(((("a",),),),)] + ) + module_compile = torch.compile(module, fullgraph=True) + + assert_close(module(td), module_compile(td)) + + def test_linear(self): + net = torch.nn.Linear(4, 5) + module = Mod(net, in_keys=[(((("a",),),),)], out_keys=[("c", "d")]) + module_compile = torch.compile(module, fullgraph=True) + td = TensorDict({"a": torch.randn(32, 4)}, [32]) + assert_close(module(td), module_compile(td)) + + def test_seq(self): + net0 = torch.nn.Linear(4, 5) + module0 = Mod(net0, in_keys=["a"], out_keys=["hidden"]) + net1 = torch.nn.Linear(5, 6) + module1 = Mod(net1, in_keys=["hidden"], out_keys=[("c", "d")]) + module = Seq(module0, module1) + module_compile = torch.compile(module, fullgraph=True) + td = TensorDict({"a": torch.randn(32, 4)}, [32]) + assert_close(module(td), module_compile(td)) + + assert module_compile(td) is td + + def test_seq_lmbda(self): + net0 = torch.nn.Linear(4, 5) + module0 = Mod(net0, in_keys=["a"], out_keys=["hidden"]) + net1 = torch.nn.Linear(5, 6) + module1 = Mod(net1, in_keys=["hidden"], out_keys=[("c", "d")]) + + def remove_hidden(td): + del td["hidden"] + return td + + module = Seq(lambda td: td.copy(), module0, module1, remove_hidden) + module_compile = torch.compile(module, fullgraph=True) + td = TensorDict({"a": torch.randn(32, 4)}, [32]) + assert_close(module(td), module_compile(td)) + assert module_compile(td) is not td + + +class TestFunc: + @pytest.mark.parametrize("modif_param", [False, True]) + def test_func(self, modif_param): + class MessUpParams(torch.nn.Module): + def __init__(self): + super().__init__() + self.param = torch.nn.Parameter(torch.zeros(())) + + def forward(self, x): + self.param.data.add_(1) + return x + + module = torch.nn.Sequential( + torch.nn.Linear(3, 4), + torch.nn.ReLU(), + torch.nn.Linear(4, 5), + ) + if modif_param: + module.append(MessUpParams()) + + td = TensorDict.from_module(module) + td_zero = td.clone() + td_zero.zero_() + + def call(x, td): + # with needs registering + params = td.to_module(module, return_swap=True) + result = module(x) + params.to_module(module, return_swap=True, swap_dest=td) + return result + + call_compile = torch.compile(call, fullgraph=True) + x = torch.randn(2, 3) + assert (call(x, td_zero) == 0).all() + if modif_param: + assert td_zero["3", "param"] == 1 + else: + assert (td_zero == 0).all() + assert (call_compile(x, td_zero) == 0).all() + if modif_param: + assert td_zero["3", "param"] == 2 + else: + assert (td_zero == 0).all() + + +if __name__ == "__main__": + args, unknown = argparse.ArgumentParser().parse_known_args() + pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown) From f0c3e8d500b382f28f4d6af5c2e22aca0e3ee931 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Fri, 24 May 2024 11:57:37 -0400 Subject: [PATCH 07/28] amend --- tensordict/_td.py | 28 +++++++++++++++++----------- test/test_compile.py | 12 +++++++++--- 2 files changed, 26 insertions(+), 14 deletions(-) diff --git a/tensordict/_td.py b/tensordict/_td.py index b6960767f..13cfc4ba1 100644 --- a/tensordict/_td.py +++ b/tensordict/_td.py @@ -8,6 +8,7 @@ import json import numbers import os +import weakref from collections import defaultdict from copy import copy from numbers import Number @@ -383,7 +384,8 @@ def _to_module( hooks = memo["hooks"] if return_swap: _swap = {} - memo[id(module)] = _swap + if not torch._dynamo.is_dynamo_compiling(): + memo[weakref.ref(module)] = _swap if use_state_dict: if inplace is not None: @@ -425,7 +427,10 @@ def convert_type(x, y): for key, value in input.items(): if isinstance(value, (Tensor, ftdim.Tensor)): - if module.__class__.__setattr__ is __base__setattr__: + if ( + module.__class__.__setattr__ is __base__setattr__ + and not torch._dynamo.is_dynamo_compiling() + ): # if setattr is the native nn.Module.setattr, we can rely on _set_tensor_dict local_out = _set_tensor_dict( __dict__, hooks, module, key, value, inplace @@ -448,9 +453,10 @@ def convert_type(x, y): # Otherwise, we just go to the next key continue child = __dict__["_modules"][key] - local_out = memo.get(id(child), NO_DEFAULT) + if not torch._dynamo.is_dynamo_compiling(): + local_out = memo.get(weakref.ref(child), NO_DEFAULT) - if local_out is NO_DEFAULT: + if torch._dynamo.is_dynamo_compiling() or local_out is NO_DEFAULT: # if isinstance(child, TensorDictBase): # # then child is a TensorDictParams # from tensordict.nn import TensorDictParams @@ -3929,7 +3935,7 @@ def __repr__(self): def _set_tensor_dict( # noqa: F811 - module_dict, + __dict__, hooks, module: torch.nn.Module, name: str, @@ -3938,12 +3944,12 @@ def _set_tensor_dict( # noqa: F811 ) -> None: """Simplified version of torch.nn.utils._named_member_accessor.""" was_buffer = False - out = module_dict["_parameters"].pop(name, None) # type: ignore[assignment] + out = __dict__["_parameters"].pop(name, None) # type: ignore[assignment] if out is None: - out = module_dict["_buffers"].pop(name, None) + out = __dict__["_buffers"].pop(name, None) was_buffer = out is not None if out is None: - out = module_dict.pop(name) + out = __dict__.pop(name) if inplace: # swap tensor and out after updating out out_tmp = out.clone() @@ -3956,7 +3962,7 @@ def _set_tensor_dict( # noqa: F811 output = hook(module, name, tensor) if output is not None: tensor = output - module_dict["_parameters"][name] = tensor + __dict__["_parameters"][name] = tensor if isinstance( tensor, (_BatchedUninitializedParameter, _BatchedUninitializedBuffer) @@ -3966,9 +3972,9 @@ def _set_tensor_dict( # noqa: F811 ) elif was_buffer and isinstance(tensor, torch.Tensor): - module_dict["_buffers"][name] = tensor + __dict__["_buffers"][name] = tensor else: - module_dict[name] = tensor + __dict__[name] = tensor return out diff --git a/test/test_compile.py b/test/test_compile.py index 12b1e14f6..a0a51be2e 100644 --- a/test/test_compile.py +++ b/test/test_compile.py @@ -7,7 +7,7 @@ import pytest import torch -from tensordict import assert_close, nn, TensorDict +from tensordict import assert_close, TensorDict from tensordict.nn import TensorDictModule as Mod, TensorDictSequential as Seq @@ -186,19 +186,25 @@ def forward(self, x): td_zero.zero_() def call(x, td): - # with needs registering + # TOFIX: `with` needs registering + # with td.to_module(module): + # return module(x) + params = td.to_module(module, return_swap=True) result = module(x) params.to_module(module, return_swap=True, swap_dest=td) return result - call_compile = torch.compile(call, fullgraph=True) + call_compile = torch.compile(call, fullgraph=True) # , backend="eager") x = torch.randn(2, 3) assert (call(x, td_zero) == 0).all() + assert (call(x, td_zero) == 0).all() if modif_param: assert td_zero["3", "param"] == 1 else: assert (td_zero == 0).all() + # torch.testing.assert_close(call_compile(x, td_zero), module(x)) + assert (call_compile(x, td_zero) == 0).all() assert (call_compile(x, td_zero) == 0).all() if modif_param: assert td_zero["3", "param"] == 2 From 5f55d4872e65d16c787fad3c80daef3ee7dac479 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Fri, 24 May 2024 13:41:44 -0400 Subject: [PATCH 08/28] amend --- tensordict/_td.py | 28 +++++++++++++++++----------- test/test_compile.py | 4 ++-- 2 files changed, 19 insertions(+), 13 deletions(-) diff --git a/tensordict/_td.py b/tensordict/_td.py index 13cfc4ba1..99792b9b0 100644 --- a/tensordict/_td.py +++ b/tensordict/_td.py @@ -378,13 +378,10 @@ def _to_module( module.update(self) return - # we use __dict__ directly to avoid the getattr/setattr overhead whenever we can - __dict__ = module.__dict__ - hooks = memo["hooks"] if return_swap: _swap = {} - if not torch._dynamo.is_dynamo_compiling(): + if not torch.compiler.is_dynamo_compiling(): memo[weakref.ref(module)] = _swap if use_state_dict: @@ -425,12 +422,18 @@ def convert_type(x, y): input = self inplace = bool(inplace) + # we use __dict__ directly to avoid the getattr/setattr overhead whenever we can + if ( + module.__class__.__setattr__ is __base__setattr__ + and not torch.compiler.is_dynamo_compiling() + ): + __dict__ = module.__dict__ + else: + __dict__ = None + for key, value in input.items(): if isinstance(value, (Tensor, ftdim.Tensor)): - if ( - module.__class__.__setattr__ is __base__setattr__ - and not torch._dynamo.is_dynamo_compiling() - ): + if __dict__ is not None: # if setattr is the native nn.Module.setattr, we can rely on _set_tensor_dict local_out = _set_tensor_dict( __dict__, hooks, module, key, value, inplace @@ -452,11 +455,14 @@ def convert_type(x, y): # if there is at least one key, we must populate the module. # Otherwise, we just go to the next key continue - child = __dict__["_modules"][key] - if not torch._dynamo.is_dynamo_compiling(): + if __dict__ is not None: + child = __dict__["_modules"][key] + else: + child = getattr(module, key) + if not torch.compiler.is_dynamo_compiling(): local_out = memo.get(weakref.ref(child), NO_DEFAULT) - if torch._dynamo.is_dynamo_compiling() or local_out is NO_DEFAULT: + if torch.compiler.is_dynamo_compiling() or local_out is NO_DEFAULT: # if isinstance(child, TensorDictBase): # # then child is a TensorDictParams # from tensordict.nn import TensorDictParams diff --git a/test/test_compile.py b/test/test_compile.py index a0a51be2e..e0313a876 100644 --- a/test/test_compile.py +++ b/test/test_compile.py @@ -200,14 +200,14 @@ def call(x, td): assert (call(x, td_zero) == 0).all() assert (call(x, td_zero) == 0).all() if modif_param: - assert td_zero["3", "param"] == 1 + assert td_zero["3", "param"] == 2 else: assert (td_zero == 0).all() # torch.testing.assert_close(call_compile(x, td_zero), module(x)) assert (call_compile(x, td_zero) == 0).all() assert (call_compile(x, td_zero) == 0).all() if modif_param: - assert td_zero["3", "param"] == 2 + assert td_zero["3", "param"] == 4 else: assert (td_zero == 0).all() From efdd277c479c789dae8a70a2d03ecb1b74b5b05e Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Fri, 24 May 2024 16:45:41 -0400 Subject: [PATCH 09/28] amend --- tensordict/tensorclass.py | 83 +++++++++++------------------ test/test_compile.py | 108 +++++++++++++++++++++++++++++++++++++- 2 files changed, 137 insertions(+), 54 deletions(-) diff --git a/tensordict/tensorclass.py b/tensordict/tensorclass.py index a17b9d69f..0d1404981 100644 --- a/tensordict/tensorclass.py +++ b/tensordict/tensorclass.py @@ -210,6 +210,9 @@ def __torch_function__( ) cls.fields = classmethod(lambda cls: dataclasses.fields(cls)) + for field in cls.fields(): + if hasattr(cls, field.name): + delattr(cls, field.name) _get_type_hints(cls) cls.__init__ = _init_wrapper(cls.__init__) @@ -219,9 +222,10 @@ def __torch_function__( cls.__torch_function__ = classmethod(__torch_function__) cls.__getstate__ = _getstate cls.__setstate__ = _setstate - cls.__getattribute__ = _getattribute_wrapper(cls.__getattribute__) - cls.__setattr__ = _setattr_wrapper(cls.__setattr__, expected_keys) + # cls.__getattribute__ = object.__getattribute__ cls.__getattr__ = _getattr + cls.__setattr__ = _setattr_wrapper(cls.__setattr__, expected_keys) + # cls.__getattr__ = _getattr cls.__getitem__ = _getitem cls.__getitems__ = _getitem cls.__setitem__ = _setitem @@ -669,41 +673,32 @@ def _setstate(self, state: dict[str, Any]) -> None: # noqa: D417 self._non_tensordict = state.get("non_tensordict", None) -def _getattribute_wrapper(getattribute: Callable) -> Callable: - """Retrieve the value of an object's attribute or raise AttributeError. - - Args: - item (str) : name of the attribute to retrieve - - Returns: - value of the attribute - - """ - - @functools.wraps(getattribute) - def wrapper(self, item: str) -> Any: - if not item.startswith("__"): - if ( - "_tensordict" in self.__dict__ - and item in self.__dict__["_tensordict"].keys() - ): - out = self._tensordict.get(item) - return out - elif ( - "_non_tensordict" in self.__dict__ - and item in self.__dict__["_non_tensordict"] - ): - out = self._non_tensordict[item] - if ( - isinstance(self, NonTensorData) - and item == "data" - and (self._is_shared or self._is_memmap) - ): - return _from_shared_nontensor(out) +def _getattr(self, item: str) -> Any: + # if not item.startswith("__"): + __dict__ = self.__dict__ + _tensordict = __dict__.get("_tensordict") + if _tensordict is not None: + out = _tensordict._get_str(item, default=None) + if out is not None: + return out + out = getattr(_tensordict, item, NO_DEFAULT) + if out is not NO_DEFAULT: + if not callable(out): return out - return getattribute(self, item) - - return wrapper + return _wrap_method(self, item, out) + _non_tensordict = __dict__.get("_non_tensordict") + if _non_tensordict is not None: + out = _non_tensordict.get(item, NO_DEFAULT) + if out is NO_DEFAULT: + raise AttributeError + if ( + isinstance(self, NonTensorData) + and item == "data" + and (self._is_shared or self._is_memmap) + ): + return _from_shared_nontensor(out) + return out + raise AttributeError SET_ATTRIBUTES = ("batch_size", "device", "_locked_tensordicts", "names") @@ -872,22 +867,6 @@ def wrapped_func(*args, **kwargs): return wrapped_func -def _getattr(self, attr: str) -> Any: - """Retrieve the value of an object's attribute, or a method output if attr is callable. - - Args: - attr: name of the attribute to retrieve or function to compute - - Returns: - value of the attribute, or a method output applied on the instance - - """ - res = getattr(self._tensordict, attr) - if not callable(res): - return res - func = res - return _wrap_method(self, attr, func) - def _getitem(self, item: NestedKey) -> Any: """Retrieve the class object at the given index. Indexing will happen for nested tensors as well. diff --git a/test/test_compile.py b/test/test_compile.py index e0313a876..af9277b7a 100644 --- a/test/test_compile.py +++ b/test/test_compile.py @@ -7,8 +7,9 @@ import pytest import torch -from tensordict import assert_close, TensorDict +from tensordict import assert_close, TensorDict, tensorclass from tensordict.nn import TensorDictModule as Mod, TensorDictSequential as Seq +from typing import Any class TestTD: @@ -110,9 +111,112 @@ def clone(td: TensorDict): else: assert clone_c(data)["a", "b"] is data["a", "b"] +@tensorclass +class MyClass: + a: "MyClass" + b: Any=None + c: Any=None class TestTC: - pass + + def test_tensor_output(self): + def add_one(td): + return td.a.b + 1 + + add_one_c = torch.compile(add_one, fullgraph=True) + data = MyClass(MyClass(a=None, b=0)) + assert add_one(data) == 1 + assert add_one_c(data) == 1 + assert add_one_c(data + 1) == 2 + + def test_td_output(self): + def add_one(td): + td.a.c = td.a.b + 1 + return td + + add_one_c = torch.compile(add_one, fullgraph=True) + data = MyClass.from_tensordict(TensorDict({"a": {"b": 0}})) + assert add_one(data.clone()).a.c == 1 + assert add_one_c(data.clone()).a.c == 1 + assert add_one_c(data) is data + + @pytest.mark.parametrize("index_type", ["slice", "tensor", "int"]) + def test_td_index(self, index_type): + if index_type == "slice": + + def add_one(td): + return td[:2] + 1 + + elif index_type == "tensor": + + def add_one(td): + return td[torch.tensor([0, 1])] + 1 + + elif index_type == "int": + + def add_one(td): + return td[0] + 1 + + add_one_c = torch.compile(add_one, fullgraph=True) + data = TensorDict({"a": {"b": torch.arange(3)}}, [3]) + if index_type == "int": + assert (add_one(data)["a", "b"] == 1).all() + assert (add_one_c(data)["a", "b"] == 1).all() + assert add_one_c(data).shape == torch.Size([]) + else: + assert (add_one(data)["a", "b"] == torch.arange(1, 3)).all() + assert (add_one_c(data)["a", "b"] == torch.arange(1, 3)).all() + assert add_one_c(data).shape == torch.Size([2]) + + def test_stack(self): + def stack_tds(td0, td1): + return TensorDict.stack([td0, td1]) + # return torch.stack([td0, td1]) + + stack_tds_c = torch.compile(stack_tds, fullgraph=True) + data0 = TensorDict({"a": {"b": torch.arange(3)}}, [3]) + data1 = TensorDict({"a": {"b": torch.arange(3)}}, [3]) + assert (stack_tds(data0, data1) == stack_tds_c(data0, data1)).all() + + def test_cat(self): + def cat_tds(td0, td1): + return TensorDict.cat([td0, td1]) + + cat_tds_c = torch.compile(cat_tds, fullgraph=True) + data0 = TensorDict({"a": {"b": torch.arange(3)}}, [3]) + data1 = TensorDict({"a": {"b": torch.arange(3)}}, [3]) + assert (cat_tds(data0, data1) == cat_tds_c(data0, data1)).all() + + def test_reshape(self): + def reshape(td): + return td.reshape(2, 2) + + reshape_c = torch.compile(reshape, fullgraph=True) + data = TensorDict({"a": {"b": torch.arange(4)}}, [4]) + assert (reshape(data) == reshape_c(data)).all() + + def test_unbind(self): + def unbind(td): + return td.unbind(0) + + unbind_c = torch.compile(unbind, fullgraph=True) + data = TensorDict({"a": {"b": torch.arange(4)}}, [4]) + assert (unbind(data)[-1] == unbind_c(data)[-1]).all() + + @pytest.mark.parametrize("recurse", [True, False]) + def test_clone(self, recurse): + def clone(td: TensorDict): + return td.clone(recurse=recurse) + + clone_c = torch.compile(clone, fullgraph=True) + data = TensorDict({"a": {"b": 0, "c": 1}}) + assert_close(clone_c(data), clone(data)) + assert clone_c(data) is not data + if recurse: + assert clone_c(data)["a", "b"] is not data["a", "b"] + else: + assert clone_c(data)["a", "b"] is data["a", "b"] + class TestNN: From efb09ed6488c04d9cae0420b2c67253f3d5957fe Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Fri, 24 May 2024 17:43:45 -0400 Subject: [PATCH 10/28] init --- tensordict/tensorclass.py | 100 +++++++++++++++++--------------------- test/test_tensorclass.py | 3 +- 2 files changed, 47 insertions(+), 56 deletions(-) diff --git a/tensordict/tensorclass.py b/tensordict/tensorclass.py index a17b9d69f..d364d1e73 100644 --- a/tensordict/tensorclass.py +++ b/tensordict/tensorclass.py @@ -50,7 +50,7 @@ is_non_tensor, is_tensorclass, KeyDependentDefaultDict, - NestedKey, + NestedKey, _is_non_tensor, ) from torch import multiprocessing as mp, Tensor from torch.multiprocessing import Manager @@ -158,9 +158,17 @@ def __new__(cls, autocast: bool = False): def __init__(self, autocast: bool): self.autocast = autocast - def __call__(self, clz): - clz = _tensorclass(clz) + def __call__(self, cls): + clz = _tensorclass(cls) clz.autocast = self.autocast + if self.autocast and clz._type_hints: + for field, th in list(clz._type_hints.items()): + print('field', field) + print('th', th) + print('cls', cls) + print('clz', clz) + if th == cls: + clz._type_hints[field] = clz return clz @@ -210,6 +218,9 @@ def __torch_function__( ) cls.fields = classmethod(lambda cls: dataclasses.fields(cls)) + for field in cls.fields(): + if hasattr(cls, field.name): + delattr(cls, field.name) _get_type_hints(cls) cls.__init__ = _init_wrapper(cls.__init__) @@ -219,9 +230,10 @@ def __torch_function__( cls.__torch_function__ = classmethod(__torch_function__) cls.__getstate__ = _getstate cls.__setstate__ = _setstate - cls.__getattribute__ = _getattribute_wrapper(cls.__getattribute__) - cls.__setattr__ = _setattr_wrapper(cls.__setattr__, expected_keys) + # cls.__getattribute__ = object.__getattribute__ cls.__getattr__ = _getattr + cls.__setattr__ = _setattr_wrapper(cls.__setattr__, expected_keys) + # cls.__getattr__ = _getattr cls.__getitem__ = _getitem cls.__getitems__ = _getitem cls.__setitem__ = _setitem @@ -669,41 +681,33 @@ def _setstate(self, state: dict[str, Any]) -> None: # noqa: D417 self._non_tensordict = state.get("non_tensordict", None) -def _getattribute_wrapper(getattribute: Callable) -> Callable: - """Retrieve the value of an object's attribute or raise AttributeError. - - Args: - item (str) : name of the attribute to retrieve - - Returns: - value of the attribute - - """ - - @functools.wraps(getattribute) - def wrapper(self, item: str) -> Any: - if not item.startswith("__"): +def _getattr(self, item: str) -> Any: + # if not item.startswith("__"): + __dict__ = self.__dict__ + _non_tensordict = __dict__.get("_non_tensordict") + if _non_tensordict is not None: + out = _non_tensordict.get(item, NO_DEFAULT) + if out is not NO_DEFAULT: if ( - "_tensordict" in self.__dict__ - and item in self.__dict__["_tensordict"].keys() - ): - out = self._tensordict.get(item) - return out - elif ( - "_non_tensordict" in self.__dict__ - and item in self.__dict__["_non_tensordict"] + isinstance(self, NonTensorData) + and item == "data" + and (self._is_shared or self._is_memmap) ): - out = self._non_tensordict[item] - if ( - isinstance(self, NonTensorData) - and item == "data" - and (self._is_shared or self._is_memmap) - ): - return _from_shared_nontensor(out) + return _from_shared_nontensor(out) + return out + _tensordict = __dict__.get("_tensordict") + if _tensordict is not None: + out = _tensordict._get_str(item, default=None) + if out is not None: + if _is_non_tensor(type(out)): + return out.data + return out + out = getattr(_tensordict, item, NO_DEFAULT) + if out is not NO_DEFAULT: + if not callable(out): return out - return getattribute(self, item) - - return wrapper + return _wrap_method(self, item, out) + raise AttributeError SET_ATTRIBUTES = ("batch_size", "device", "_locked_tensordicts", "names") @@ -773,8 +777,8 @@ def _update( input_dict_or_td = self.from_dict(input_dict_or_td) if is_tensorclass(input_dict_or_td): - self._tensordict.update(input_dict_or_td._tensordict) - self._non_tensordict.update(input_dict_or_td._non_tensordict) + self._tensordict.update(input_dict_or_td.__dict__["_tensordict"]) + self._non_tensordict.update(input_dict_or_td.__dict__["_non_tensordict"]) return self non_tensordict = {} @@ -872,22 +876,6 @@ def wrapped_func(*args, **kwargs): return wrapped_func -def _getattr(self, attr: str) -> Any: - """Retrieve the value of an object's attribute, or a method output if attr is callable. - - Args: - attr: name of the attribute to retrieve or function to compute - - Returns: - value of the attribute, or a method output applied on the instance - - """ - res = getattr(self._tensordict, attr) - if not callable(res): - return res - func = res - return _wrap_method(self, attr, func) - def _getitem(self, item: NestedKey) -> Any: """Retrieve the class object at the given index. Indexing will happen for nested tensors as well. @@ -1018,6 +1006,8 @@ def _from_dict(cls, input_dict, batch_size=None, device=None, batch_dims=None): ) non_tensor = {} + # Note: this won't deal with sub-tensordicts which may or may not be tensorclasses. + # We don't want to enforce them to be tensorclasses so we can't do much about it... for key, value in list(td.items()): if is_non_tensor(value): non_tensor[key] = value.data diff --git a/test/test_tensorclass.py b/test/test_tensorclass.py index c0c9818a7..5143271e5 100644 --- a/test/test_tensorclass.py +++ b/test/test_tensorclass.py @@ -1369,7 +1369,7 @@ def get_data(cls, shift): data0 = MyDataNested.get_data(0) data0.update(data1.to_dict()) assert (data0.X == 1).all() - assert data0.z == "test_tensorclass1" + assert data0.z == "test_tensorclass1", data0.z assert (data0.y.X == 1).all() assert data0.y.z == "test_tensorclass1" @@ -2011,6 +2011,7 @@ def test_autocast(self): }, ) + assert isinstance(obj, AutoCast), type(obj) assert isinstance(obj.tensor, torch.Tensor) assert isinstance(obj.non_tensor, str) assert isinstance(obj.td, TensorDict) From f1e9744123e6977151bcfbbe86200d2165356fe1 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Fri, 24 May 2024 17:45:02 -0400 Subject: [PATCH 11/28] amend --- tensordict/tensorclass.py | 12 ++---------- 1 file changed, 2 insertions(+), 10 deletions(-) diff --git a/tensordict/tensorclass.py b/tensordict/tensorclass.py index d364d1e73..9cf0f80db 100644 --- a/tensordict/tensorclass.py +++ b/tensordict/tensorclass.py @@ -44,13 +44,14 @@ from tensordict.utils import ( _get_repr, _is_json_serializable, + _is_non_tensor, _LOCK_ERROR, DeviceType, IndexType, is_non_tensor, is_tensorclass, KeyDependentDefaultDict, - NestedKey, _is_non_tensor, + NestedKey, ) from torch import multiprocessing as mp, Tensor from torch.multiprocessing import Manager @@ -161,14 +162,6 @@ def __init__(self, autocast: bool): def __call__(self, cls): clz = _tensorclass(cls) clz.autocast = self.autocast - if self.autocast and clz._type_hints: - for field, th in list(clz._type_hints.items()): - print('field', field) - print('th', th) - print('cls', cls) - print('clz', clz) - if th == cls: - clz._type_hints[field] = clz return clz @@ -876,7 +869,6 @@ def wrapped_func(*args, **kwargs): return wrapped_func - def _getitem(self, item: NestedKey) -> Any: """Retrieve the class object at the given index. Indexing will happen for nested tensors as well. From 187c2d8d04d515f88fe7f4d58c62c6ca031cd9f2 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Fri, 24 May 2024 18:34:38 -0400 Subject: [PATCH 12/28] amend --- tensordict/_torch_func.py | 8 ++++++-- tensordict/tensorclass.py | 7 +++---- 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/tensordict/_torch_func.py b/tensordict/_torch_func.py index e205cd102..02b79b5f6 100644 --- a/tensordict/_torch_func.py +++ b/tensordict/_torch_func.py @@ -91,7 +91,11 @@ def _gather( f"Cannot gather tensordict with shape {input.shape} along dim {dim_orig}." ) - def _gather_tensor(tensor, dest=None): + def _gather_tensor(tensor, dest_container=None, dest_key=None): + if dest_container is not None: + dest = dest_container.get(dest_key) + else: + dest = None index_expand = index while index_expand.ndim < tensor.ndim: index_expand = index_expand.unsqueeze(-1) @@ -116,7 +120,7 @@ def _gather_tensor(tensor, dest=None): ) TensorDict( { - key: _gather_tensor(value, out.get(key)) + key: _gather_tensor(value, out, key) for key, value in input.items(is_leaf=_is_leaf_nontensor) }, batch_size=index.shape, diff --git a/tensordict/tensorclass.py b/tensordict/tensorclass.py index 9cf0f80db..586508abc 100644 --- a/tensordict/tensorclass.py +++ b/tensordict/tensorclass.py @@ -44,7 +44,6 @@ from tensordict.utils import ( _get_repr, _is_json_serializable, - _is_non_tensor, _LOCK_ERROR, DeviceType, IndexType, @@ -692,8 +691,6 @@ def _getattr(self, item: str) -> Any: if _tensordict is not None: out = _tensordict._get_str(item, default=None) if out is not None: - if _is_non_tensor(type(out)): - return out.data return out out = getattr(_tensordict, item, NO_DEFAULT) if out is not NO_DEFAULT: @@ -1865,7 +1862,9 @@ def __post_init__(self): data = getattr(self.data, "data", None) if data is None: data = self.data.tolist() - self.data = data + del self._tensordict["data"] + self._non_tensordict["data"] = data + assert self._tensordict.is_empty(), self._tensordict def __repr__(self): data_str = str(self.data) From e4cf49fba3565391c5577c2cc04b32449111adb3 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Sat, 25 May 2024 09:02:39 -0400 Subject: [PATCH 13/28] amend --- tensordict/tensorclass.py | 124 +++++++++++++++++++++++++------------- test/test_compile.py | 52 ++++++++++------ 2 files changed, 114 insertions(+), 62 deletions(-) diff --git a/tensordict/tensorclass.py b/tensordict/tensorclass.py index 586508abc..6bd6174c2 100644 --- a/tensordict/tensorclass.py +++ b/tensordict/tensorclass.py @@ -172,6 +172,9 @@ def __torch_function__( args: tuple[Any, ...] = (), kwargs: dict[str, Any] | None = None, ) -> Callable: + # if not torch.compiler.is_dynamo_compiling() and (func not in _TD_PASS_THROUGH or not all( + # issubclass(t, (Tensor, cls)) for t in types + # )): if func not in _TD_PASS_THROUGH or not all( issubclass(t, (Tensor, cls)) for t in types ): @@ -348,7 +351,7 @@ def _from_tensordict_with_copy(tc, tensordict): # creates a new tensorclass with the same type as tc, and a copy of the # non_tensordict data return tc._from_tensordict( - tensordict=tensordict, non_tensordict=copy(tc._non_tensordict) + tensordict=tensordict, non_tensordict=dict(tc._non_tensordict) ) @@ -361,13 +364,13 @@ def _from_tensordict_with_none(tc, tensordict): ) -def _init_wrapper(init: Callable) -> Callable: - init_sig = inspect.signature(init) +def _init_wrapper(__init__: Callable) -> Callable: + init_sig = inspect.signature(__init__) params = list(init_sig.parameters.values()) # drop first entry of params which corresponds to self and isn't passed by the user required_params = [p.name for p in params[1:] if p.default is inspect._empty] - @functools.wraps(init) + @functools.wraps(__init__) def wrapper( self, *args: Any, @@ -377,19 +380,29 @@ def wrapper( **kwargs, ): - for value, key in zip(args, self.__dataclass_fields__): - if key in kwargs: - raise ValueError(f"The key {key} is already set in kwargs") - kwargs[key] = value + if not torch.compiler.is_dynamo_compiling(): + # zip not supported by dynamo + for value, key in zip(args, self.__dataclass_fields__): + if key in kwargs: + raise ValueError(f"The key {key} is already set in kwargs") + kwargs[key] = value + else: + if args: + raise RuntimeError("dynamo doesn't support arguments when building a tensorclass, pass the keyword explicitly.") + if batch_size is None: batch_size = torch.Size([]) - for key, field in self.__dataclass_fields__.items(): - if field.default_factory is not dataclasses.MISSING: - default = field.default_factory() - else: - default = field.default - if default not in (None, dataclasses.MISSING): - kwargs.setdefault(key, default) + if not torch.compiler.is_dynamo_compiling(): + for key, field in self.__dataclass_fields__.items(): + if field.default_factory is not dataclasses.MISSING: + default = field.default_factory() + else: + default = field.default + if default not in (None, dataclasses.MISSING): + kwargs.setdefault(key, default) + else: + # TODO: Decide what to do here + pass missing_params = [p for p in required_params if p not in kwargs] if missing_params: @@ -400,16 +413,17 @@ def wrapper( f"""{", ".join(f"'{name}'" for name in missing_params)}""" ) - self._tensordict = TensorDict( + super(type(self), self).__setattr__('_tensordict', TensorDict( {}, batch_size=torch.Size(batch_size), device=device, names=names, _run_checks=False, - ) - self._non_tensordict = {} + )) + super(type(self), self).__setattr__('_non_tensordict', {}) + self._is_initialized = True - init(self, **kwargs) + __init__(self, **kwargs) new_params = [ inspect.Parameter("batch_size", inspect.Parameter.KEYWORD_ONLY), @@ -553,19 +567,22 @@ def wrapper(cls, tensordict, non_tensordict=None): # noqa: D417 raise KeyError( f"{key} is present in both tensor and non-tensor dicts" ) - # bypass initialisation. this means we don't incur any overhead creating an - # empty tensordict and writing values to it. we can skip this because we already - # have a tensordict to use as the underlying tensordict - tc = cls.__new__(cls) - tc.__dict__["_tensordict"] = tensordict - - tc.__dict__["_non_tensordict"] = ( - non_tensordict if non_tensordict is not None else {} - ) - # since we aren't calling the dataclass init method, we need to manually check - # whether a __post_init__ method has been defined and invoke it if so - if hasattr(tc, "__post_init__"): - tc.__post_init__() + if not torch.compiler.is_dynamo_compiling(): + # bypass initialisation. this means we don't incur any overhead creating an + # empty tensordict and writing values to it. we can skip this because we already + # have a tensordict to use as the underlying tensordict + tc = cls.__new__(cls) + tc.__dict__["_tensordict"] = tensordict + tc.__dict__["_non_tensordict"] = ( + non_tensordict if non_tensordict is not None else {} + ) + # since we aren't calling the dataclass init method, we need to manually check + # whether a __post_init__ method has been defined and invoke it if so + if hasattr(tc, "__post_init__"): + tc.__post_init__() + else: + tddict = tensordict.to_dict() + return cls(**tddict, **non_tensordict) return tc return wrapper @@ -675,8 +692,12 @@ def _setstate(self, state: dict[str, Any]) -> None: # noqa: D417 def _getattr(self, item: str) -> Any: # if not item.startswith("__"): - __dict__ = self.__dict__ - _non_tensordict = __dict__.get("_non_tensordict") + if not torch.compiler.is_dynamo_compiling(): + __dict__ = self.__dict__ + _non_tensordict = __dict__.get("_non_tensordict") + else: + _non_tensordict = self._non_tensordict + if _non_tensordict is not None: out = _non_tensordict.get(item, NO_DEFAULT) if out is not NO_DEFAULT: @@ -687,7 +708,10 @@ def _getattr(self, item: str) -> Any: ): return _from_shared_nontensor(out) return out - _tensordict = __dict__.get("_tensordict") + if not torch.compiler.is_dynamo_compiling(): + _tensordict = __dict__.get("_tensordict") + else: + _tensordict = self._tensordict if _tensordict is not None: out = _tensordict._get_str(item, default=None) if out is not None: @@ -700,7 +724,13 @@ def _getattr(self, item: str) -> Any: raise AttributeError -SET_ATTRIBUTES = ("batch_size", "device", "_locked_tensordicts", "names") +SET_ATTRIBUTES = ( + "batch_size", + "device", + "_locked_tensordicts", + "names", + "_is_initialized", +) def _setattr_wrapper(setattr_: Callable, expected_keys: set[str]) -> Callable: @@ -713,13 +743,21 @@ def wrapper(self, key: str, value: Any) -> None: # noqa: D417 value (any): the value to set for the attribute """ - __dict__ = self.__dict__ - if ( - "_tensordict" not in __dict__ - or "_non_tensordict" not in __dict__ - or key in SET_ATTRIBUTES - ): - return setattr_(self, key, value) + if not torch.compiler.is_dynamo_compiling(): + __dict__ = self.__dict__ + if ( + "_tensordict" not in __dict__ + or "_non_tensordict" not in __dict__ + or key in SET_ATTRIBUTES + ): + return setattr_(self, key, value) + else: + # Pass? + if key in SET_ATTRIBUTES: + # assert getattr(self, "_is_initialized", False) + return setattr_(self, key, value) + # if not hasattr(self, "_tensordict") or not hasattr(self, "_non_tensordict") or key in SET_ATTRIBUTES: + # return setattr_(self, key, value) out = self.set(key, value) if out is not self: diff --git a/test/test_compile.py b/test/test_compile.py index af9277b7a..4706899fd 100644 --- a/test/test_compile.py +++ b/test/test_compile.py @@ -3,13 +3,13 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import argparse +from typing import Any import pytest import torch -from tensordict import assert_close, TensorDict, tensorclass +from tensordict import assert_close, tensorclass, TensorDict from tensordict.nn import TensorDictModule as Mod, TensorDictSequential as Seq -from typing import Any class TestTD: @@ -111,20 +111,21 @@ def clone(td: TensorDict): else: assert clone_c(data)["a", "b"] is data["a", "b"] + @tensorclass class MyClass: a: "MyClass" - b: Any=None - c: Any=None + b: Any = None + c: Any = None -class TestTC: +class TestTC: def test_tensor_output(self): def add_one(td): return td.a.b + 1 add_one_c = torch.compile(add_one, fullgraph=True) - data = MyClass(MyClass(a=None, b=0)) + data = MyClass(MyClass(a=None, b=torch.zeros(()))) assert add_one(data) == 1 assert add_one_c(data) == 1 assert add_one_c(data + 1) == 2 @@ -135,7 +136,7 @@ def add_one(td): return td add_one_c = torch.compile(add_one, fullgraph=True) - data = MyClass.from_tensordict(TensorDict({"a": {"b": 0}})) + data = MyClass(a=MyClass(a=None, b=torch.zeros(()))) assert add_one(data.clone()).a.c == 1 assert add_one_c(data.clone()).a.c == 1 assert add_one_c(data) is data @@ -158,14 +159,16 @@ def add_one(td): return td[0] + 1 add_one_c = torch.compile(add_one, fullgraph=True) - data = TensorDict({"a": {"b": torch.arange(3)}}, [3]) + data = MyClass( + a=MyClass(a=None, b=torch.arange(3), batch_size=[3]), batch_size=[3] + ) if index_type == "int": - assert (add_one(data)["a", "b"] == 1).all() - assert (add_one_c(data)["a", "b"] == 1).all() + assert (add_one(data).a.b == 1).all() + assert (add_one_c(data).a.b == 1).all() assert add_one_c(data).shape == torch.Size([]) else: - assert (add_one(data)["a", "b"] == torch.arange(1, 3)).all() - assert (add_one_c(data)["a", "b"] == torch.arange(1, 3)).all() + assert (add_one(data).a.b == torch.arange(1, 3)).all() + assert (add_one_c(data).a.b == torch.arange(1, 3)).all() assert add_one_c(data).shape == torch.Size([2]) def test_stack(self): @@ -174,8 +177,12 @@ def stack_tds(td0, td1): # return torch.stack([td0, td1]) stack_tds_c = torch.compile(stack_tds, fullgraph=True) - data0 = TensorDict({"a": {"b": torch.arange(3)}}, [3]) - data1 = TensorDict({"a": {"b": torch.arange(3)}}, [3]) + data0 = MyClass( + a=MyClass(a=None, b=torch.arange(3), batch_size=[3]), batch_size=[3] + ) + data1 = MyClass( + a=MyClass(a=None, b=torch.arange(3, 6), batch_size=[3]), batch_size=[3] + ) assert (stack_tds(data0, data1) == stack_tds_c(data0, data1)).all() def test_cat(self): @@ -183,8 +190,12 @@ def cat_tds(td0, td1): return TensorDict.cat([td0, td1]) cat_tds_c = torch.compile(cat_tds, fullgraph=True) - data0 = TensorDict({"a": {"b": torch.arange(3)}}, [3]) - data1 = TensorDict({"a": {"b": torch.arange(3)}}, [3]) + data0 = MyClass( + a=MyClass(a=None, b=torch.arange(3), batch_size=[3]), batch_size=[3] + ) + data1 = MyClass( + a=MyClass(a=None, b=torch.arange(3, 6), batch_size=[3]), batch_size=[3] + ) assert (cat_tds(data0, data1) == cat_tds_c(data0, data1)).all() def test_reshape(self): @@ -192,7 +203,9 @@ def reshape(td): return td.reshape(2, 2) reshape_c = torch.compile(reshape, fullgraph=True) - data = TensorDict({"a": {"b": torch.arange(4)}}, [4]) + data = MyClass( + a=MyClass(a=None, b=torch.arange(4), batch_size=[4]), batch_size=[4] + ) assert (reshape(data) == reshape_c(data)).all() def test_unbind(self): @@ -200,7 +213,9 @@ def unbind(td): return td.unbind(0) unbind_c = torch.compile(unbind, fullgraph=True) - data = TensorDict({"a": {"b": torch.arange(4)}}, [4]) + data = MyClass( + a=MyClass(a=None, b=torch.arange(4), batch_size=[4]), batch_size=[4] + ) assert (unbind(data)[-1] == unbind_c(data)[-1]).all() @pytest.mark.parametrize("recurse", [True, False]) @@ -218,7 +233,6 @@ def clone(td: TensorDict): assert clone_c(data)["a", "b"] is data["a", "b"] - class TestNN: def test_func(self): td = TensorDict({"a": 0}) From 8519835dde4f0146a20826e4bc313a2c26cf69f4 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Sat, 25 May 2024 15:30:03 -0400 Subject: [PATCH 14/28] amend --- benchmarks/compile/compile_td_test.py | 228 ++++++++++++++++++ benchmarks/compile/tensordict_nn_test.py | 134 ++++++++++ .../tensorclass/test_torch_functions.py | 6 + tensordict/_td.py | 38 +-- tensordict/base.py | 5 +- tensordict/nn/common.py | 81 +++++-- tensordict/tensorclass.py | 37 ++- test/test_compile.py | 5 +- 8 files changed, 465 insertions(+), 69 deletions(-) create mode 100644 benchmarks/compile/compile_td_test.py create mode 100644 benchmarks/compile/tensordict_nn_test.py diff --git a/benchmarks/compile/compile_td_test.py b/benchmarks/compile/compile_td_test.py new file mode 100644 index 000000000..e5afb5241 --- /dev/null +++ b/benchmarks/compile/compile_td_test.py @@ -0,0 +1,228 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +import argparse + +import pytest +import torch +from tensordict import LazyStackedTensorDict, TensorDict +from torch.utils._pytree import tree_map + + +# Functions +def add_one(td): + return td + 1 + + +def add_one_pytree(td): + return tree_map(lambda x: x + 1, td) + + +def copy(td): + return td.copy() + + +def copy_pytree(td): + return tree_map(lambda x: x, td) + + +def assign_and_add(td, k): + for i in range(k, k + 100): + td[str(i)] = i + return td + 1 + + +def assign_and_add_pytree(td, k, device): + for i in range(k, k + 100): + td[str(i)] = torch.tensor(i, device=device) + return tree_map(lambda x: x + 1, td) + + +def assign_and_add_stack(td, k): + for i in range(k, k + 100): + td[str(i)] = torch.full((2,), i, device=td.device) + return td + 1 + + +def index(td, idx): + return td[idx] + + +def index_pytree(td, idx): + return tree_map(lambda x: x[idx], td) + + +def get_nested_td(): + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + d = {} + _d = d + for i in range(10): + _d["a"] = torch.ones((), device=device) + _d[str(i)] = {} + _d = _d[str(i)] + _d["a"] = torch.ones((), device=device) + return TensorDict(d, device=device) + + +def get_flat_td(): + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + return TensorDict( + {str(i): torch.full((), i, device=device) for i in range(50)}, device=device + ) + + +# Tests runtime of a simple arithmetic op over a highly nested tensordict +@pytest.mark.parametrize("mode", ["compile", "eager"]) +@pytest.mark.parametrize("dict_type", ["tensordict", "dict"]) +def test_add_one_nested(mode, dict_type, benchmark): + if dict_type == "tensordict": + if mode == "compile": + func = torch.compile(add_one, fullgraph=True) + else: + func = add_one + td = get_nested_td() + else: + if mode == "compile": + func = torch.compile(add_one_pytree, fullgraph=True) + else: + func = add_one_pytree + td = get_nested_td().to_dict() + func(td) + benchmark(func, td) + + +# Tests the speed of copying a nested tensordict +@pytest.mark.parametrize("mode", ["compile", "eager"]) +@pytest.mark.parametrize("dict_type", ["tensordict", "dict"]) +def test_copy_nested(mode, dict_type, benchmark): + if dict_type == "tensordict": + if mode == "compile": + func = torch.compile(copy, fullgraph=True) + else: + func = copy + td = get_nested_td() + else: + if mode == "compile": + func = torch.compile(copy_pytree, fullgraph=True) + else: + func = copy_pytree + td = get_nested_td().to_dict() + func(td) + benchmark(func, td) + + +# Tests runtime of a simple arithmetic op over a flat tensordict +@pytest.mark.parametrize("mode", ["compile", "eager"]) +@pytest.mark.parametrize("dict_type", ["tensordict", "dict"]) +def test_add_one_flat(mode, dict_type, benchmark): + if dict_type == "tensordict": + if mode == "compile": + func = torch.compile(add_one, fullgraph=True) + else: + func = add_one + td = get_flat_td() + else: + if mode == "compile": + func = torch.compile(add_one_pytree, fullgraph=True) + else: + func = add_one_pytree + td = get_flat_td().to_dict() + func(td) + benchmark(func, td) + + +# Tests the speed of copying a flat tensordict +@pytest.mark.parametrize("mode", ["compile", "eager"]) +@pytest.mark.parametrize("dict_type", ["tensordict", "dict"]) +def test_copy_flat(mode, dict_type, benchmark): + if dict_type == "tensordict": + if mode == "compile": + func = torch.compile(copy, fullgraph=True) + else: + func = copy + td = get_flat_td() + else: + if mode == "compile": + func = torch.compile(copy_pytree, fullgraph=True) + else: + func = copy_pytree + td = get_flat_td().to_dict() + func(td) + benchmark(func, td) + + +# Tests the speed of assigning entries to an empty tensordict +@pytest.mark.parametrize("mode", ["compile", "eager"]) +@pytest.mark.parametrize("dict_type", ["tensordict", "dict"]) +def test_assign_and_add(mode, dict_type, benchmark): + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + td = TensorDict(device=device) + if dict_type == "tensordict": + if mode == "compile": + func = torch.compile(assign_and_add, fullgraph=True) + else: + func = assign_and_add + kwargs = {} + else: + if mode == "compile": + func = torch.compile(assign_and_add_pytree, fullgraph=True) + else: + func = assign_and_add_pytree + td = td.to_dict() + kwargs = {"device": device} + func(td, 5, **kwargs) + benchmark(func, td, 5, **kwargs) + + +# Tests the speed of assigning entries to a lazy stacked tensordict + + +@pytest.mark.parametrize("mode", ["compile", "eager"]) +def test_assign_and_add_stack(mode, benchmark): + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + td = LazyStackedTensorDict(TensorDict(device=device), TensorDict(device=device)) + if mode == "compile": + func = torch.compile(assign_and_add_stack, fullgraph=True) + else: + func = assign_and_add_stack + kwargs = {} + func(td, 5, **kwargs) + benchmark(func, td, 5, **kwargs) + + +# Tests indexing speed +@pytest.mark.parametrize("mode", ["compile", "eager"]) +@pytest.mark.parametrize("dict_type", ["tensordict", "dict"]) +@pytest.mark.parametrize("index_type", ["tensor", "slice", "int"]) +def test_compile_indexing(mode, dict_type, index_type, benchmark): + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + td = TensorDict( + {"a": torch.arange(100), "b": {"c": torch.arange(100)}}, + batch_size=[100], + device=device, + ) + if dict_type == "tensordict": + if mode == "compile": + func = torch.compile(index, fullgraph=True) + else: + func = index + else: + if mode == "compile": + func = torch.compile(index_pytree, fullgraph=True) + else: + func = index_pytree + td = td.to_dict() + if index_type == int: + idx = 5 + else: + idx = slice(None, None, 2) + if index_type == "tensor": + idx = torch.tensor(range(*idx.indices(10))) + func(td, idx) + benchmark(func, td, idx) + + +if __name__ == "__main__": + args, unknown = argparse.ArgumentParser().parse_known_args() + pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown) diff --git a/benchmarks/compile/tensordict_nn_test.py b/benchmarks/compile/tensordict_nn_test.py new file mode 100644 index 000000000..f0d597078 --- /dev/null +++ b/benchmarks/compile/tensordict_nn_test.py @@ -0,0 +1,134 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +import argparse + +import pytest +import torch +from tensordict import TensorDict, TensorDictParams + +from tensordict.nn import TensorDictModule as Mod, TensorDictSequential as Seq + + +def mlp(device, depth=2, num_cells=32, feature_dim=3): + return torch.nn.Sequential( + torch.nn.Linear(feature_dim, num_cells, device=device), + torch.nn.ReLU(), + *[ + torch.nn.Sequential( + torch.nn.Linear(num_cells, num_cells, device=device), torch.nn.ReLU() + ) + for _ in range(depth) + ], + torch.nn.Linear(num_cells, 4, device=device), + ) + + +@pytest.mark.parametrize("mode", ["eager", "compile"]) +def test_mod_add(mode, benchmark): + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + td = TensorDict({"a": 0}, device=device) + module = Mod(lambda x: x + 1, in_keys=["a"], out_keys=[("c", "d")]) + if mode == "compile": + module = torch.compile(module, fullgraph=True) + module(td) + benchmark(module, td) + + +@pytest.mark.parametrize("mode", ["eager", "compile"]) +def test_mod_wrap(mode, benchmark): + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + net = mlp(device) + td = TensorDict({"a": torch.zeros(32, 3, device=device)}, device=device) + module = Mod(net, in_keys=["a"], out_keys=[("c", "d")]) + if mode == "compile": + module = torch.compile(module, fullgraph=True) + module(td) + benchmark(module, td) + + +@pytest.mark.parametrize("mode", ["eager", "compile"]) +def test_seq_add(mode, benchmark): + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + td = TensorDict({"a": 0}, device=device) + + def delhidden(td): + del td["hidden"] + return td + + module = Seq( + lambda td: td.copy(), + Mod(lambda x: x + 1, in_keys=["a"], out_keys=["hidden"]), + Mod(lambda x: x + 1, in_keys=["hidden"], out_keys=[("c", "d")]), + delhidden, + ) + if mode == "compile": + module = torch.compile(module, fullgraph=True) + module(td) + benchmark(module, td) + + +@pytest.mark.parametrize("mode", ["eager", "compile"]) +def test_seq_wrap(mode, benchmark): + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + net = mlp(device) + td = TensorDict({"a": torch.zeros(32, 3, device=device)}, device=device) + + def delhidden(td): + del td["hidden"] + return td + + module = Seq( + lambda td: td.copy(), + *[ + Mod( + layer, + in_keys=["a" if i == 0 else "hidden"], + out_keys=["hidden" if i < len(net) - 1 else ("c", "d")], + ) + for i, layer in enumerate(net) + ], + delhidden, + ) + if mode == "compile": + module = torch.compile(module, fullgraph=True) + module(td) + benchmark(module, td) + + +@pytest.mark.parametrize("mode", ["eager", "compile"]) +@pytest.mark.parametrize("functional", [False, True]) +def test_func_call_runtime(mode, functional, benchmark): + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + module = mlp(device=device, depth=10, num_cells=16, feature_dim=16) + # module = torch.nn.Transformer(16, dim_feedforward=64, device=device) + if functional: + td = TensorDict.from_module(module) + td = TensorDictParams(td.clone()) + + def call(x, td): + # with needs registering + params = td.to_module(module, return_swap=True) + result = module(x) + params.to_module(module, return_swap=False) + return result + + else: + call = module + + if mode == "compile": + call = torch.compile(call, fullgraph=True) + + x = torch.randn(2, 2, 16) + if functional: + call(x, td) + benchmark(call, x, td) + else: + call(x) + benchmark(call, x) + + +if __name__ == "__main__": + args, unknown = argparse.ArgumentParser().parse_known_args() + pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown) diff --git a/benchmarks/tensorclass/test_torch_functions.py b/benchmarks/tensorclass/test_torch_functions.py index 0dc3c560c..71d69078a 100644 --- a/benchmarks/tensorclass/test_torch_functions.py +++ b/benchmarks/tensorclass/test_torch_functions.py @@ -2,6 +2,7 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +import argparse import pytest import torch @@ -94,3 +95,8 @@ def test_stack(benchmark, tc): def test_cat(benchmark, tc): benchmark(torch.cat, [tc] * 3, 0) + + +if __name__ == "__main__": + args, unknown = argparse.ArgumentParser().parse_known_args() + pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown) diff --git a/tensordict/_td.py b/tensordict/_td.py index 99792b9b0..8f054305b 100644 --- a/tensordict/_td.py +++ b/tensordict/_td.py @@ -368,6 +368,7 @@ def _to_module( use_state_dict: bool = False, non_blocking: bool = False, ): + is_dynamo = torch.compiler.is_dynamo_compiling() if not use_state_dict and isinstance(module, TensorDictBase): if return_swap: @@ -381,7 +382,7 @@ def _to_module( hooks = memo["hooks"] if return_swap: _swap = {} - if not torch.compiler.is_dynamo_compiling(): + if not is_dynamo: memo[weakref.ref(module)] = _swap if use_state_dict: @@ -423,10 +424,7 @@ def convert_type(x, y): inplace = bool(inplace) # we use __dict__ directly to avoid the getattr/setattr overhead whenever we can - if ( - module.__class__.__setattr__ is __base__setattr__ - and not torch.compiler.is_dynamo_compiling() - ): + if not is_dynamo and module.__class__.__setattr__ is __base__setattr__: __dict__ = module.__dict__ else: __dict__ = None @@ -459,19 +457,10 @@ def convert_type(x, y): child = __dict__["_modules"][key] else: child = getattr(module, key) - if not torch.compiler.is_dynamo_compiling(): + if not is_dynamo: local_out = memo.get(weakref.ref(child), NO_DEFAULT) - if torch.compiler.is_dynamo_compiling() or local_out is NO_DEFAULT: - # if isinstance(child, TensorDictBase): - # # then child is a TensorDictParams - # from tensordict.nn import TensorDictParams - # - # local_out = child - # if not isinstance(value, TensorDictParams): - # value = TensorDictParams(value, no_convert=True) - # __dict__["_modules"][key] = value - # else: + if is_dynamo or local_out is NO_DEFAULT: local_out = value._to_module( child, inplace=inplace, @@ -484,6 +473,7 @@ def convert_type(x, y): if return_swap: _swap[key] = local_out + if return_swap: if isinstance(swap_dest, dict): return _swap @@ -2853,10 +2843,10 @@ def items( is_leaf = _default_is_leaf if is_leaf is None else is_leaf def fast_iter(): - for k, val in self._tensordict.items(): + for key, val in self._tensordict.items(): if not is_leaf(val.__class__): yield from ( - ((k, *((_key,) if isinstance(_key, str) else _key)), _val) + ((key, *((_key,) if isinstance(_key, str) else _key)), _val) for _key, _val in val.items( include_nested=include_nested, leaves_only=leaves_only, @@ -2864,7 +2854,7 @@ def fast_iter(): ) ) else: - yield k, val + yield key, val return fast_iter() else: @@ -4138,3 +4128,13 @@ def _update_metadata(*, metadata, key, value, is_collection): metadata[key] = { "type": type(value).__name__, } + + +# @torch._dynamo.on_compile_start +# def _check_inline_inbuilt_module(): +# try: +# assert str2bool(environ["TORCHDYNAMO_INLINE_INBUILT_NN_MODULES"]) +# except (AssertionError, KeyError): +# raise RuntimeError( +# "set TORCHDYNAMO_INLINE_INBUILT_NN_MODULES=1 to use functional calls with tensordict." +# ) diff --git a/tensordict/base.py b/tensordict/base.py index edf571909..79e3b6462 100644 --- a/tensordict/base.py +++ b/tensordict/base.py @@ -7419,9 +7419,8 @@ def _register_tensor_class(cls): def _is_tensor_collection(datatype): - try: - out = _TENSOR_COLLECTION_MEMO[datatype] - except KeyError: + out = _TENSOR_COLLECTION_MEMO.get(datatype) + if out is None: if issubclass(datatype, TensorDictBase): out = True elif _is_tensorclass(datatype): diff --git a/tensordict/nn/common.py b/tensordict/nn/common.py index 3d09eacd1..ef274044d 100644 --- a/tensordict/nn/common.py +++ b/tensordict/nn/common.py @@ -1248,27 +1248,44 @@ def __repr__(self) -> str: return f"{self.__class__.__name__}(\n{fields})" def __getattr__(self, name: str) -> Any: - __dict__ = self.__dict__ - _parameters = __dict__.get("_parameters") - if _parameters: - # A param can be None so we use False instead to check if the key - # is in the _parameters dict once and only once. - # We use False but any non-None, non-Parameter value would do. - # The alternative `if name in _parameters: return _parameters[name]` - # accesses the value of `name` twice when only one is required + if not torch.compiler.is_dynamo_compiling(): + __dict__ = self.__dict__ + _parameters = __dict__.get("_parameters") + if _parameters: + # A param can be None so we use False instead to check if the key + # is in the _parameters dict once and only once. + # We use False but any non-None, non-Parameter value would do. + # The alternative `if name in _parameters: return _parameters[name]` + # accesses the value of `name` twice when only one is required + result = _parameters.get(name, False) + if result is not False: + return result + _buffers = __dict__.get("_buffers") + if _buffers: + result = _buffers.get(name, False) + if result is not False: + return result + _modules = __dict__.get("_modules") + if _modules: + result = _modules.get(name, False) + if result is not False: + return result + # TODO: find a way to make this check work with dynamo + # elif hasattr(self, "_parameters"): + else: + _parameters = self._parameters result = _parameters.get(name, False) if result is not False: return result - _buffers = __dict__.get("_buffers") - if _buffers: + _buffers = self._buffers result = _buffers.get(name, False) if result is not False: return result - _modules = __dict__.get("_modules") - if _modules: + _modules = self._modules result = _modules.get(name, False) if result is not False: return result + if not name.startswith("_"): # no fallback for private attributes return getattr(super().__getattr__("module"), name) @@ -1308,24 +1325,40 @@ def __init__(self, td_module: TensorDictModule) -> None: self.register_forward_hook(self.td_module._forward_hooks[pre_hook]) def __getattr__(self, name: str) -> Any: - __dict__ = self.__dict__ - _parameters = __dict__.get("_parameters") - if _parameters: - # A param can be None so we use False instead to check if the key - # is in the _parameters dict once and only once. - # We use False but any non-None, non-Parameter value would do. - # The alternative `if name in _parameters: return _parameters[name]` - # accesses the value of `name` twice when only one is required + if not torch.compiler.is_dynamo_compiling(): + __dict__ = self.__dict__ + _parameters = __dict__.get("_parameters") + if _parameters: + # A param can be None so we use False instead to check if the key + # is in the _parameters dict once and only once. + # We use False but any non-None, non-Parameter value would do. + # The alternative `if name in _parameters: return _parameters[name]` + # accesses the value of `name` twice when only one is required + result = _parameters.get(name, False) + if result is not False: + return result + _buffers = __dict__.get("_buffers") + if _buffers: + result = _buffers.get(name, False) + if result is not False: + return result + _modules = __dict__.get("_modules") + if _modules: + result = _modules.get(name, False) + if result is not False: + return result + # TODO: find a way to make this check work with dynamo + # elif hasattr(self, "_parameters"): + else: + _parameters = self._parameters result = _parameters.get(name, False) if result is not False: return result - _buffers = __dict__.get("_buffers") - if _buffers: + _buffers = self._buffers result = _buffers.get(name, False) if result is not False: return result - _modules = __dict__.get("_modules") - if _modules: + _modules = self._modules result = _modules.get(name, False) if result is not False: return result diff --git a/tensordict/tensorclass.py b/tensordict/tensorclass.py index aa3e9c09e..6154f83ae 100644 --- a/tensordict/tensorclass.py +++ b/tensordict/tensorclass.py @@ -388,7 +388,9 @@ def wrapper( kwargs[key] = value else: if args: - raise RuntimeError("dynamo doesn't support arguments when building a tensorclass, pass the keyword explicitly.") + raise RuntimeError( + "dynamo doesn't support arguments when building a tensorclass, pass the keyword explicitly." + ) if batch_size is None: batch_size = torch.Size([]) @@ -413,14 +415,17 @@ def wrapper( f"""{", ".join(f"'{name}'" for name in missing_params)}""" ) - super(type(self), self).__setattr__('_tensordict', TensorDict( - {}, - batch_size=torch.Size(batch_size), - device=device, - names=names, - _run_checks=False, - )) - super(type(self), self).__setattr__('_non_tensordict', {}) + super(type(self), self).__setattr__( + "_tensordict", + TensorDict( + {}, + batch_size=torch.Size(batch_size), + device=device, + names=names, + _run_checks=False, + ), + ) + super(type(self), self).__setattr__("_non_tensordict", {}) self._is_initialized = True __init__(self, **kwargs) @@ -709,18 +714,8 @@ def _getattr(self, item: str) -> Any: return _from_shared_nontensor(out) return out if not torch.compiler.is_dynamo_compiling(): - _non_tensordict = __dict__.get("_non_tensordict") - if _non_tensordict is not None: - out = _non_tensordict.get(item, NO_DEFAULT) - if out is not NO_DEFAULT: - if ( - isinstance(self, NonTensorData) - and item == "data" - and (self._is_shared or self._is_memmap) - ): - return _from_shared_nontensor(out) - return out - _tensordict = __dict__.get("_tensordict")else: + _tensordict = __dict__.get("_tensordict") + else: _tensordict = self._tensordict if _tensordict is not None: out = _tensordict._get_str(item, default=None) diff --git a/test/test_compile.py b/test/test_compile.py index 4706899fd..26c0156f8 100644 --- a/test/test_compile.py +++ b/test/test_compile.py @@ -8,7 +8,8 @@ import pytest import torch -from tensordict import assert_close, tensorclass, TensorDict + +from tensordict import assert_close, tensorclass, TensorDict, TensorDictParams from tensordict.nn import TensorDictModule as Mod, TensorDictSequential as Seq @@ -300,7 +301,7 @@ def forward(self, x): module.append(MessUpParams()) td = TensorDict.from_module(module) - td_zero = td.clone() + td_zero = TensorDictParams(td.data.clone()) td_zero.zero_() def call(x, td): From e1428b9475cbe0292ed1b2753ad4109068e0b395 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Sat, 25 May 2024 15:32:50 -0400 Subject: [PATCH 15/28] amend --- benchmarks/compile/compile_td_test.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/benchmarks/compile/compile_td_test.py b/benchmarks/compile/compile_td_test.py index e5afb5241..4f23eeb2a 100644 --- a/benchmarks/compile/compile_td_test.py +++ b/benchmarks/compile/compile_td_test.py @@ -75,7 +75,7 @@ def get_flat_td(): # Tests runtime of a simple arithmetic op over a highly nested tensordict @pytest.mark.parametrize("mode", ["compile", "eager"]) @pytest.mark.parametrize("dict_type", ["tensordict", "dict"]) -def test_add_one_nested(mode, dict_type, benchmark): +def test_compile_add_one_nested(mode, dict_type, benchmark): if dict_type == "tensordict": if mode == "compile": func = torch.compile(add_one, fullgraph=True) @@ -95,7 +95,7 @@ def test_add_one_nested(mode, dict_type, benchmark): # Tests the speed of copying a nested tensordict @pytest.mark.parametrize("mode", ["compile", "eager"]) @pytest.mark.parametrize("dict_type", ["tensordict", "dict"]) -def test_copy_nested(mode, dict_type, benchmark): +def test_compile_copy_nested(mode, dict_type, benchmark): if dict_type == "tensordict": if mode == "compile": func = torch.compile(copy, fullgraph=True) @@ -115,7 +115,7 @@ def test_copy_nested(mode, dict_type, benchmark): # Tests runtime of a simple arithmetic op over a flat tensordict @pytest.mark.parametrize("mode", ["compile", "eager"]) @pytest.mark.parametrize("dict_type", ["tensordict", "dict"]) -def test_add_one_flat(mode, dict_type, benchmark): +def test_compile_add_one_flat(mode, dict_type, benchmark): if dict_type == "tensordict": if mode == "compile": func = torch.compile(add_one, fullgraph=True) @@ -135,7 +135,7 @@ def test_add_one_flat(mode, dict_type, benchmark): # Tests the speed of copying a flat tensordict @pytest.mark.parametrize("mode", ["compile", "eager"]) @pytest.mark.parametrize("dict_type", ["tensordict", "dict"]) -def test_copy_flat(mode, dict_type, benchmark): +def test_compile_copy_flat(mode, dict_type, benchmark): if dict_type == "tensordict": if mode == "compile": func = torch.compile(copy, fullgraph=True) @@ -155,7 +155,7 @@ def test_copy_flat(mode, dict_type, benchmark): # Tests the speed of assigning entries to an empty tensordict @pytest.mark.parametrize("mode", ["compile", "eager"]) @pytest.mark.parametrize("dict_type", ["tensordict", "dict"]) -def test_assign_and_add(mode, dict_type, benchmark): +def test_compile_assign_and_add(mode, dict_type, benchmark): device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") td = TensorDict(device=device) if dict_type == "tensordict": @@ -179,7 +179,7 @@ def test_assign_and_add(mode, dict_type, benchmark): @pytest.mark.parametrize("mode", ["compile", "eager"]) -def test_assign_and_add_stack(mode, benchmark): +def test_compile_assign_and_add_stack(mode, benchmark): device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") td = LazyStackedTensorDict(TensorDict(device=device), TensorDict(device=device)) if mode == "compile": From 81eaac4894314041433908cbfea54edf7feff293 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Sat, 25 May 2024 18:37:00 -0400 Subject: [PATCH 16/28] amend --- tensordict/_td.py | 51 +++++++++++++++++---- tensordict/base.py | 33 ++++++-------- tensordict/tensorclass.py | 77 ++++++++++++++++++------------- tensordict/utils.py | 37 ++++++++------- test/test_compile.py | 96 +++++++++++++++++++++++++++++---------- 5 files changed, 196 insertions(+), 98 deletions(-) diff --git a/tensordict/_td.py b/tensordict/_td.py index 8f054305b..13f09145b 100644 --- a/tensordict/_td.py +++ b/tensordict/_td.py @@ -2841,20 +2841,55 @@ def items( return self._tensordict.items() elif include_nested and leaves_only: is_leaf = _default_is_leaf if is_leaf is None else is_leaf + result = [] + if torch.compiler.is_dynamo_compiling(): - def fast_iter(): - for key, val in self._tensordict.items(): - if not is_leaf(val.__class__): - yield from ( - ((key, *((_key,) if isinstance(_key, str) else _key)), _val) + def fast_iter(): + for key, val in self._tensordict.items(): + if not is_leaf(val.__class__): for _key, _val in val.items( include_nested=include_nested, leaves_only=leaves_only, is_leaf=is_leaf, + ): + result.append( + ( + ( + key, + *( + (_key,) + if isinstance(_key, str) + else _key + ), + ), + _val, + ) + ) + else: + result.append((key, val)) + return result + + else: + # dynamo doesn't like generators + def fast_iter(): + for key, val in self._tensordict.items(): + if not is_leaf(val.__class__): + yield from ( + ( + ( + key, + *((_key,) if isinstance(_key, str) else _key), + ), + _val, + ) + for _key, _val in val.items( + include_nested=include_nested, + leaves_only=leaves_only, + is_leaf=is_leaf, + ) ) - ) - else: - yield key, val + else: + yield (key, val) return fast_iter() else: diff --git a/tensordict/base.py b/tensordict/base.py index 79e3b6462..94332e2a5 100644 --- a/tensordict/base.py +++ b/tensordict/base.py @@ -9,6 +9,7 @@ import collections import concurrent.futures import contextlib +import enum import importlib import json import numbers @@ -80,17 +81,11 @@ # NO_DEFAULT is used as a placeholder whenever the default is not provided. # Using None is not an option since `td.get(key, default=None)` is a valid usage. -class _NoDefault: - def __new__(cls): - if not hasattr(cls, "instance"): - cls.instance = super(_NoDefault, cls).__new__(cls) - return cls.instance - - def __bool__(self): - return False +class _NoDefault(enum.IntEnum): + ZERO = 0 -NO_DEFAULT = "-TO_REPLACE-" # _NoDefault() +NO_DEFAULT = _NoDefault.ZERO class _NestedTensorsAsLists: @@ -3799,16 +3794,13 @@ def _items_list( include_nested: bool = False, leaves_only: bool = False, ) -> Tuple[List, List]: - return tuple( - list(key_or_val) - for key_or_val in zip( - *self.items( - include_nested=include_nested, - leaves_only=leaves_only, - is_leaf=_NESTED_TENSORS_AS_LISTS, - ) - ) + items = self.items( + include_nested=include_nested, + leaves_only=leaves_only, + is_leaf=_NESTED_TENSORS_AS_LISTS, ) + keys, vals = zip(*items) + return list(keys), list(vals) @cache # noqa: B019 def _grad(self): @@ -3876,7 +3868,7 @@ def pop(self, key: NestedKey, default: Any = NO_DEFAULT) -> CompatibleType: self.del_(key) except KeyError as err: # if default provided, 'out' value will return, else raise error - if default == NO_DEFAULT: + if default is NO_DEFAULT: raise KeyError( f"You are trying to pop key `{key}` which is not in dict " f"without providing default value." @@ -5694,7 +5686,8 @@ def cosh_(self) -> T: return self def add(self, other: TensorDictBase | float, alpha: float | None = None): - keys, vals = self._items_list(True, True) + keys_vals = self._items_list(True, True) + keys, vals = keys_vals if _is_tensor_collection(type(other)): other_val = other._values_list(True, True) else: diff --git a/tensordict/tensorclass.py b/tensordict/tensorclass.py index 6154f83ae..5efddbfc4 100644 --- a/tensordict/tensorclass.py +++ b/tensordict/tensorclass.py @@ -44,6 +44,7 @@ from tensordict.utils import ( _get_repr, _is_json_serializable, + _is_tensorclass, _LOCK_ERROR, DeviceType, IndexType, @@ -71,22 +72,22 @@ def __subclasscheck__(self, subclass): _CLEAR_METADATA = {"all", "any"} # torch functions where we can wrap the corresponding TensorDict version _TD_PASS_THROUGH = { - torch.unbind, - torch.full_like, - torch.zeros_like, - torch.ones_like, - torch.rand_like, - torch.empty_like, - torch.randn_like, - torch.clone, - torch.squeeze, - torch.unsqueeze, - torch.split, - torch.permute, - torch.split, - torch.stack, - torch.cat, - torch.gather, + torch.unbind: True, + torch.full_like: True, + torch.zeros_like: True, + torch.ones_like: True, + torch.rand_like: True, + torch.empty_like: True, + torch.randn_like: True, + torch.clone: True, + torch.squeeze: True, + torch.unsqueeze: True, + torch.split: True, + torch.permute: True, + torch.split: True, + torch.stack: True, + torch.cat: True, + torch.gather: True, } @@ -175,9 +176,9 @@ def __torch_function__( # if not torch.compiler.is_dynamo_compiling() and (func not in _TD_PASS_THROUGH or not all( # issubclass(t, (Tensor, cls)) for t in types # )): - if func not in _TD_PASS_THROUGH or not all( + if not all( issubclass(t, (Tensor, cls)) for t in types - ): + ) or not _TD_PASS_THROUGH.get(func): return NotImplemented if kwargs is None: @@ -204,9 +205,9 @@ def __torch_function__( _is_non_tensor = getattr(cls, "_is_non_tensor", False) cls = dataclass(cls) - expected_keys = set(cls.__dataclass_fields__) + expected_keys = cls.__expected_keys__ = set(cls.__dataclass_fields__) - for attr in cls.__dataclass_fields__: + for attr in expected_keys: if attr in dir(TensorDict) and attr not in ("_is_non_tensor", "data"): raise AttributeError( f"Attribute name {attr} can't be used with @tensorclass" @@ -340,9 +341,20 @@ def __torch_function__( def _arg_to_tensordict(arg): # if arg is a tensorclass or sequence of tensorclasses, extract the underlying # tensordicts and return those instead - if is_tensorclass(arg): + + # since arg can be anything (e.g. callable etc) we can't ue pytree + # def convert(x): + # if _is_tensorclass(type(x)): + # return x._tensordict + # return x + # return torch.utils._pytree.tree_map(convert, arg) + if callable(arg): + return arg + if _is_tensorclass(type(arg)): return arg._tensordict - elif isinstance(arg, (tuple, list)) and all(is_tensorclass(item) for item in arg): + elif isinstance(arg, (tuple, list)) and all( + _is_tensorclass(type(item)) for item in arg + ): return arg.__class__(item._tensordict for item in arg) return arg @@ -350,7 +362,7 @@ def _arg_to_tensordict(arg): def _from_tensordict_with_copy(tc, tensordict): # creates a new tensorclass with the same type as tc, and a copy of the # non_tensordict data - return tc._from_tensordict( + return type(tc)._from_tensordict( tensordict=tensordict, non_tensordict=dict(tc._non_tensordict) ) @@ -358,7 +370,7 @@ def _from_tensordict_with_copy(tc, tensordict): def _from_tensordict_with_none(tc, tensordict): # creates a new tensorclass with the same type as tc, and all non_tensordict entries # set to None - return tc._from_tensordict( + return type(tc)._from_tensordict( tensordict=tensordict, non_tensordict={key: None for key in tc._non_tensordict}, ) @@ -426,7 +438,7 @@ def wrapper( ), ) super(type(self), self).__setattr__("_non_tensordict", {}) - self._is_initialized = True + super(type(self), self).__setattr__("_is_initialized", True) __init__(self, **kwargs) @@ -775,7 +787,6 @@ def wrapper(self, key: str, value: Any) -> None: # noqa: D417 def _wrap_method(self, attr, func): - @functools.wraps(func) def wrapped_func(*args, **kwargs): args = tuple(_arg_to_tensordict(arg) for arg in args) kwargs = {key: _arg_to_tensordict(value) for key, value in kwargs.items()} @@ -787,13 +798,16 @@ def wrapped_func(*args, **kwargs): elif attr in _CLEAR_METADATA: # this is an attribute where copying the metadata makes no sense, e.g. # .all or .any, so we replace all values with None - return self._from_tensordict( + return type(self)._from_tensordict( res, {k: None for k in self._non_tensordict} ) # create a new tensorclass from res and copy the metadata from self - return self._from_tensordict(res, copy(self._non_tensordict)) + return type(self)._from_tensordict(res, dict(self._non_tensordict)) return res + if not torch.compiler.is_dynamo_compiling(): + wrapped_func = functools.wraps(func)(wrapped_func) + return wrapped_func @@ -923,7 +937,8 @@ def _getitem(self, item: NestedKey) -> Any: isinstance(item, tuple) and all(isinstance(_item, str) for _item in item) ): raise ValueError(f"Invalid indexing arguments: {item}.") - tensor_res = self._tensordict[item] + # tensor_res = super(type(self), self).__getattribute__("_tensordict")[item] + tensor_res = self.__dict__["_tensordict"][item] return _from_tensordict_with_copy(self, tensor_res) # device=res.device) @@ -1147,7 +1162,7 @@ def _set( if key in ("batch_size", "names", "device"): # handled by setattr return - expected_keys = cls.__dataclass_fields__ + expected_keys = cls.__expected_keys__ if key not in expected_keys: raise AttributeError( f"Cannot set the attribute '{key}', expected attributes are {expected_keys}." @@ -1642,7 +1657,7 @@ def _unbind(self, dim: int): """ return tuple( - self._from_tensordict(td, non_tensordict=copy(self._non_tensordict)) + type(self)._from_tensordict(td, non_tensordict=copy(self._non_tensordict)) for td in self._tensordict.unbind(dim) ) diff --git a/tensordict/utils.py b/tensordict/utils.py index 90ebb1683..39d11359a 100644 --- a/tensordict/utils.py +++ b/tensordict/utils.py @@ -2280,28 +2280,33 @@ def _unravel_key_to_tuple(key): def _slice_indices(index: slice, len: int): """A pure python implementation of slice.indices(len) since torch.compile doesn't recognise it.""" - start = index.start - if start is None: - start = 0 - stop = index.stop - if stop is None: - stop = len step = index.step if step is None: step = 1 elif step == 0: raise ValueError("Step cannot be zero.") - if step > 0: - # Forward direction - lower = start - upper = stop - step + 1 - else: - # Backward direction - lower = start + step - 1 - upper = stop - upper = min(upper, len) - return lower, upper, step + start = index.start + stop = index.stop + if start is None: + if step > 0: + start = 0 + else: + start = len - 1 + elif start < 0: + start = len + start + + if stop is None: + if step > 0: + stop = len + else: + stop = -1 + elif stop > 0: + stop = min(len, stop) + elif step < 0 or (step > 0 and start >= 0): + stop = len + stop + + return start, stop, step assert_allclose_td = assert_close diff --git a/test/test_compile.py b/test/test_compile.py index 26c0156f8..c3e49d9b3 100644 --- a/test/test_compile.py +++ b/test/test_compile.py @@ -98,6 +98,22 @@ def unbind(td): data = TensorDict({"a": {"b": torch.arange(4)}}, [4]) assert (unbind(data)[-1] == unbind_c(data)[-1]).all() + def test_items(self): + def items(td): + keys, vals = zip(*td.items(True, True)) + return keys, vals + + items_c = torch.compile(items, fullgraph=True) + data = TensorDict({"a": {"b": torch.arange(4)}}, [4]) + keys, vals = items(data) + keys_c, vals_c = items_c(data) + + def assert_eq(x, y): + assert (x == y).all() + + assert keys == keys_c + torch.utils._pytree.tree_map(assert_eq, vals, vals_c) + @pytest.mark.parametrize("recurse", [True, False]) def test_clone(self, recurse): def clone(td: TensorDict): @@ -121,7 +137,7 @@ class MyClass: class TestTC: - def test_tensor_output(self): + def test_tc_tensor_output(self): def add_one(td): return td.a.b + 1 @@ -131,7 +147,23 @@ def add_one(td): assert add_one_c(data) == 1 assert add_one_c(data + 1) == 2 - def test_td_output(self): + def test_tc_items(self): + def items(td): + keys, vals = zip(*td.items(True, True)) + return keys, vals + + items_c = torch.compile(items, fullgraph=True) + data = MyClass(MyClass(a=None, b=torch.zeros(()))) + keys, vals = items(data) + keys_c, vals_c = items_c(data) + + def assert_eq(x, y): + assert (x == y).all() + + assert keys == keys_c + torch.utils._pytree.tree_map(assert_eq, vals, vals_c) + + def test_tc_output(self): def add_one(td): td.a.c = td.a.b + 1 return td @@ -142,37 +174,53 @@ def add_one(td): assert add_one_c(data.clone()).a.c == 1 assert add_one_c(data) is data + def test_tc_arithmetic(self): + def add_one(td): + return td + 1 + + add_one_c = torch.compile(add_one, fullgraph=True) + data = MyClass(a=MyClass(a=None, b=torch.zeros(()))) + assert add_one(data.clone()).a.b == 1 + assert add_one_c(data.clone()).a.b == 1 + assert add_one_c(data) is data + @pytest.mark.parametrize("index_type", ["slice", "tensor", "int"]) - def test_td_index(self, index_type): + def test_tc_index(self, index_type): if index_type == "slice": - def add_one(td): - return td[:2] + 1 + def index(td): + return td[:2] elif index_type == "tensor": - def add_one(td): - return td[torch.tensor([0, 1])] + 1 + def index(td): + return td[torch.tensor([0, 1])] elif index_type == "int": - def add_one(td): - return td[0] + 1 + def index(td): + return td[0] - add_one_c = torch.compile(add_one, fullgraph=True) + index_c = torch.compile(index, fullgraph=True) data = MyClass( a=MyClass(a=None, b=torch.arange(3), batch_size=[3]), batch_size=[3] ) if index_type == "int": - assert (add_one(data).a.b == 1).all() - assert (add_one_c(data).a.b == 1).all() - assert add_one_c(data).shape == torch.Size([]) + assert (index(data).a.b == 0).all() + assert isinstance(index_c(data), MyClass) + assert isinstance(index_c(data).a, MyClass) + assert (index_c(data).a.b == 0).all() + assert index_c(data).shape == torch.Size([]) else: - assert (add_one(data).a.b == torch.arange(1, 3)).all() - assert (add_one_c(data).a.b == torch.arange(1, 3)).all() - assert add_one_c(data).shape == torch.Size([2]) - - def test_stack(self): + assert (index(data).a.b == torch.arange(0, 2)).all() + assert isinstance(index(data), MyClass) + assert isinstance(index(data).a, MyClass) + assert isinstance(index_c(data), MyClass) + assert isinstance(index_c(data).a, MyClass) + assert (index_c(data).a.b == torch.arange(0, 2)).all() + assert index_c(data).shape == torch.Size([2]) + + def test_tc_stack(self): def stack_tds(td0, td1): return TensorDict.stack([td0, td1]) # return torch.stack([td0, td1]) @@ -186,7 +234,7 @@ def stack_tds(td0, td1): ) assert (stack_tds(data0, data1) == stack_tds_c(data0, data1)).all() - def test_cat(self): + def test_tc_cat(self): def cat_tds(td0, td1): return TensorDict.cat([td0, td1]) @@ -199,7 +247,7 @@ def cat_tds(td0, td1): ) assert (cat_tds(data0, data1) == cat_tds_c(data0, data1)).all() - def test_reshape(self): + def test_tc_reshape(self): def reshape(td): return td.reshape(2, 2) @@ -209,7 +257,7 @@ def reshape(td): ) assert (reshape(data) == reshape_c(data)).all() - def test_unbind(self): + def test_tc_unbind(self): def unbind(td): return td.unbind(0) @@ -220,12 +268,14 @@ def unbind(td): assert (unbind(data)[-1] == unbind_c(data)[-1]).all() @pytest.mark.parametrize("recurse", [True, False]) - def test_clone(self, recurse): + def test_tc_clone(self, recurse): def clone(td: TensorDict): return td.clone(recurse=recurse) clone_c = torch.compile(clone, fullgraph=True) - data = TensorDict({"a": {"b": 0, "c": 1}}) + data = MyClass( + a=MyClass(a=None, b=torch.arange(4), batch_size=[4]), batch_size=[4] + ) assert_close(clone_c(data), clone(data)) assert clone_c(data) is not data if recurse: From 2d8bef5ae1987971315401beea6c884d71c064b8 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Fri, 31 May 2024 18:02:03 +0100 Subject: [PATCH 17/28] amend --- tensordict/_td.py | 3 +- tensordict/_torch_func.py | 31 +++++++++++++----- tensordict/base.py | 2 +- tensordict/tensorclass.py | 68 +++++++++++++++++++++++++-------------- tensordict/utils.py | 34 +++++++++++++++++--- test/test_compile.py | 68 ++++++++++++++++++++++++++------------- 6 files changed, 145 insertions(+), 61 deletions(-) diff --git a/tensordict/_td.py b/tensordict/_td.py index 788fa73fa..f022232c0 100644 --- a/tensordict/_td.py +++ b/tensordict/_td.py @@ -263,7 +263,8 @@ def __init__( f"sub-type or a dictionary, found type(source)={type(source)}." ) self._batch_size = self._parse_batch_size(source, batch_size) - self.names = names + # TODO: this breaks when stacking tensorclasses with dynamo + # self.names = names for key, value in source.items(): self.set(key, value, non_blocking=sub_non_blocking) diff --git a/tensordict/_torch_func.py b/tensordict/_torch_func.py index df3e1543a..02132d4d2 100644 --- a/tensordict/_torch_func.py +++ b/tensordict/_torch_func.py @@ -404,11 +404,16 @@ def _stack( ) -> T: if not list_of_tensordicts: raise RuntimeError("list_of_tensordicts cannot be empty") + is_tc = False + if any(is_tensorclass(td) for td in list_of_tensordicts): + is_tc = True + if all(is_non_tensor(td) for td in list_of_tensordicts): + from tensordict.tensorclass import NonTensorData - if all(is_non_tensor(td) for td in list_of_tensordicts): - from tensordict.tensorclass import NonTensorData - - return NonTensorData._stack_non_tensor(list_of_tensordicts, dim=dim) + return NonTensorData._stack_non_tensor(list_of_tensordicts, dim=dim) + else: + tc_type = type(list_of_tensordicts[0]) + list_of_tensordicts = [tc.to_tensordict() for tc in list_of_tensordicts] batch_size = list_of_tensordicts[0].batch_size if dim < 0: @@ -507,9 +512,12 @@ def _stack( tensor = _tensordict._get_str(key, default=NO_DEFAULT) if is_tensor is None: tensor_cls = type(tensor) - is_tensor = ( - not _is_tensor_collection(tensor_cls) - ) or is_tensorclass(tensor_cls) + # is_tensor = ( + # not _is_tensor_collection(tensor_cls) + # ) or is_tensorclass(tensor_cls) + # TODO: make sense of this, dynamo cannot pass through stack (and it's unsafe) + # only tensors should be tensors + is_tensor = not _is_tensor_collection(tensor_cls) if is_not_init is None: is_not_init = isinstance(tensor, UninitializedTensorMixin) if not is_not_init: @@ -548,7 +556,7 @@ def stack_fn(key, values, is_not_init, is_tensor): for key, (values, is_not_init, is_tensor) in out.items() } - return TensorDict( + result = TensorDict( out, batch_size=LazyStackedTensorDict._compute_batch_size( batch_size, dim, len(list_of_tensordicts) @@ -556,6 +564,13 @@ def stack_fn(key, values, is_not_init, is_tensor): device=device, _run_checks=False, ) + if is_tc: + return tc_type( + **dict(result.items()), + batch_size=result.batch_size, + device=result.device, + ) + return result else: out = LazyStackedTensorDict( *list_of_tensordicts, diff --git a/tensordict/base.py b/tensordict/base.py index eb6d9ac7a..656feee28 100644 --- a/tensordict/base.py +++ b/tensordict/base.py @@ -3008,7 +3008,7 @@ def _set_non_tensor(self, key: NestedKey, value: Any): self._set_str( key, NonTensorData( - value, + data=value, batch_size=self.batch_size, device=self.device, names=self.names if self._has_names() else None, diff --git a/tensordict/tensorclass.py b/tensordict/tensorclass.py index a368edc87..86d8923c3 100644 --- a/tensordict/tensorclass.py +++ b/tensordict/tensorclass.py @@ -330,12 +330,9 @@ def __torch_function__( args: tuple[Any, ...] = (), kwargs: dict[str, Any] | None = None, ) -> Callable: - # if not torch.compiler.is_dynamo_compiling() and (func not in _TD_PASS_THROUGH or not all( - # issubclass(t, (Tensor, cls, TensorDictBase)) for t in types - # )): - if not all( - issubclass(t, (Tensor, cls)) for t in types - ) or not _TD_PASS_THROUGH.get(func): + if func not in _TD_PASS_THROUGH or not all( + issubclass(t, (Tensor, cls, TensorDictBase)) for t in types + ): return NotImplemented if kwargs is None: @@ -390,7 +387,8 @@ def __torch_function__( cls.__getitem__ = _getitem cls.__getitems__ = _getitem cls.__setitem__ = _setitem - cls.__repr__ = _repr + if not hasattr(cls, "__repr__"): + cls.__repr__ = _repr cls.__len__ = _len cls.__eq__ = _eq cls.__ne__ = _ne @@ -510,7 +508,7 @@ def _arg_to_tensordict(arg): # if arg is a tensorclass or sequence of tensorclasses, extract the underlying # tensordicts and return those instead - # since arg can be anything (e.g. callable etc) we can't ue pytree + # since arg can be anything (e.g. callable etc) we can't use pytree # def convert(x): # if _is_tensorclass(type(x)): # return x._tensordict @@ -608,6 +606,11 @@ def wrapper( super(type(self), self).__setattr__("_non_tensordict", {}) super(type(self), self).__setattr__("_is_initialized", True) + # convert the non tensor data in a regular data + kwargs = { + key: value.data if is_non_tensor(value) else value + for key, value in kwargs.items() + } __init__(self, **kwargs) new_params = [ @@ -765,10 +768,16 @@ def wrapper(cls, tensordict, non_tensordict=None): # noqa: D417 # whether a __post_init__ method has been defined and invoke it if so if hasattr(tc, "__post_init__"): tc.__post_init__() + return tc else: - tddict = tensordict.to_dict() - return cls(**tddict, **non_tensordict) - return tc + # TODO: things that did NOT work: **tensordict, dict(tensordict) + return cls( + **dict(tensordict.items()), + **non_tensordict, + batch_size=tensordict.batch_size, + device=tensordict.device, + names=tensordict.names, + ) return wrapper @@ -956,7 +965,11 @@ def wrapper(self, key: str, value: Any) -> None: # noqa: D417 def _wrap_td_method(funcname, *, copy_non_tensor=False): def wrapped_func(self, *args, **kwargs): - td = super(type(self), self).__getattribute__("_tensordict") + if not torch.compiler.is_dynamo_compiling(): + td = super(type(self), self).__getattribute__("_tensordict") + else: + td = self._tensordict + result = getattr(td, funcname)(*args, **kwargs) def check_out(kwargs, result): @@ -969,13 +982,19 @@ def check_out(kwargs, result): if isinstance(result, TensorDictBase) and not check_out(kwargs, result): if result is td: return self - nontd = super(type(self), self).__getattribute__("_non_tensordict") + if not torch.compiler.is_dynamo_compiling(): + nontd = super(type(self), self).__getattribute__("_non_tensordict") + else: + nontd = self._non_tensordict if copy_non_tensor: # use tree_map to copy nontd = tree_map(lambda x: x, nontd) - return super(type(self), self).__getattribute__("_from_tensordict")( - result, nontd - ) + if not torch.compiler.is_dynamo_compiling(): + return super(type(self), self).__getattribute__("_from_tensordict")( + result, nontd + ) + else: + return self._from_tensordict(result, nontd) return result return wrapped_func @@ -1861,8 +1880,9 @@ def _unbind(self, dim: int): Resulting tensorclass instances will share the storage of the initial tensorclass instance. """ + # TODO: dynamo doesn't like copy, using dict instead return tuple( - type(self)._from_tensordict(td, non_tensordict=copy(self._non_tensordict)) + type(self)._from_tensordict(td, non_tensordict=dict(self._non_tensordict)) for td in self._tensordict.unbind(dim) ) @@ -2120,6 +2140,12 @@ class NonTensorData: _is_non_tensor: bool = True + def __repr__(self): + data_str = str(self.data) + if len(data_str) > 200: + data_str = data_str[:20] + " ... " + data_str[-20:] + return f"{type(self).__name__}(data={data_str}, batch_size={self.batch_size}, device={self.device})" + def __post_init__(self): if is_non_tensor(self.data): data = getattr(self.data, "data", None) @@ -2129,14 +2155,6 @@ def __post_init__(self): self._non_tensordict["data"] = data assert self._tensordict.is_empty(), self._tensordict - def __repr__(self): - data_str = str(self.data) - if len(data_str) > 200: - data_str = data_str[:20] + " ... " + data_str[-20:] - return f"{type(self).__name__}(data={data_str}, batch_size={self.batch_size}, device={self.device})" - - self.__class__.__repr__ = __repr__ - old_eq = self.__class__.__eq__ if old_eq is _eq: global NONTENSOR_HANDLED_FUNCTIONS diff --git a/tensordict/utils.py b/tensordict/utils.py index 111045309..f46d2137f 100644 --- a/tensordict/utils.py +++ b/tensordict/utils.py @@ -49,10 +49,10 @@ from packaging.version import parse from tensordict._contextlib import _DecoratorContextManager from tensordict._tensordict import ( # noqa: F401 - # _unravel_key_to_tuple, - unravel_key, - unravel_key_list, - unravel_keys, + _unravel_key_to_tuple as _unravel_key_to_tuple_cpp, + unravel_key as unravel_key_cpp, + unravel_key_list as unravel_key_list_cpp, + unravel_keys as unravel_keys_cpp, ) from torch import Tensor @@ -2270,6 +2270,8 @@ def __missing__(self, key): def _unravel_key_to_tuple(key): + if not torch.compiler.is_dynamo_compiling(): + return _unravel_key_to_tuple_cpp(key) if isinstance(key, str): return (key,) if not isinstance(key, tuple): @@ -2277,6 +2279,30 @@ def _unravel_key_to_tuple(key): return tuple(subk for k in key for subk in _unravel_key_to_tuple(k)) +def unravel_key(key): + if not torch.compiler.is_dynamo_compiling(): + return unravel_key_cpp(key) + if isinstance(key, str): + return key + if isinstance(key, tuple): + if len(key) == 1: + return unravel_key(key[0]) + return tuple(unravel_key(_key) for _key in key) + raise ValueError("the key must be a str or a tuple of str") + + +def unravel_keys(*keys): + if not torch.compiler.is_dynamo_compiling(): + return unravel_keys_cpp(*keys) + return tuple(unravel_key(key) for key in keys) + + +def unravel_key_list(keys): + if not torch.compiler.is_dynamo_compiling(): + return unravel_key_list_cpp(keys) + return [unravel_key(key) for key in keys] + + def _slice_indices(index: slice, len: int): """A pure python implementation of slice.indices(len) since torch.compile doesn't recognise it.""" step = index.step diff --git a/test/test_compile.py b/test/test_compile.py index c3e49d9b3..cc4544b85 100644 --- a/test/test_compile.py +++ b/test/test_compile.py @@ -65,8 +65,8 @@ def add_one(td): def test_stack(self): def stack_tds(td0, td1): - return TensorDict.stack([td0, td1]) - # return torch.stack([td0, td1]) + # return TensorDict.stack([td0, td1]) + return torch.stack([td0, td1]) stack_tds_c = torch.compile(stack_tds, fullgraph=True) data0 = TensorDict({"a": {"b": torch.arange(3)}}, [3]) @@ -178,11 +178,20 @@ def test_tc_arithmetic(self): def add_one(td): return td + 1 - add_one_c = torch.compile(add_one, fullgraph=True) data = MyClass(a=MyClass(a=None, b=torch.zeros(()))) - assert add_one(data.clone()).a.b == 1 - assert add_one_c(data.clone()).a.b == 1 - assert add_one_c(data) is data + + eager = add_one(data.clone()) + + add_one_c = torch.compile(add_one, fullgraph=True) + compiled = add_one_c(data.clone()) + + assert isinstance(eager.a, MyClass) + assert eager.a.b == 1 + + assert isinstance(compiled.a, MyClass) + # TODO: breaks because a is not cast to a MyClass but is a dict + assert compiled.a.b == 1 + assert add_one_c(data) is not data @pytest.mark.parametrize("index_type", ["slice", "tensor", "int"]) def test_tc_index(self, index_type): @@ -205,34 +214,49 @@ def index(td): data = MyClass( a=MyClass(a=None, b=torch.arange(3), batch_size=[3]), batch_size=[3] ) + + indexed_data_eager = index(data) + indexed_data_compile = index_c(data) if index_type == "int": - assert (index(data).a.b == 0).all() - assert isinstance(index_c(data), MyClass) - assert isinstance(index_c(data).a, MyClass) - assert (index_c(data).a.b == 0).all() - assert index_c(data).shape == torch.Size([]) + assert (indexed_data_eager.a.b == 0).all() + assert (indexed_data_compile.a.b == 0).all() + + assert isinstance(indexed_data_eager, MyClass) + assert isinstance(indexed_data_compile, MyClass) + + assert isinstance(indexed_data_eager.a, MyClass) + assert isinstance(indexed_data_compile.a, MyClass) + + assert indexed_data_eager.shape == torch.Size([]) + assert indexed_data_compile.shape == torch.Size([]) + else: - assert (index(data).a.b == torch.arange(0, 2)).all() - assert isinstance(index(data), MyClass) - assert isinstance(index(data).a, MyClass) - assert isinstance(index_c(data), MyClass) - assert isinstance(index_c(data).a, MyClass) - assert (index_c(data).a.b == torch.arange(0, 2)).all() - assert index_c(data).shape == torch.Size([2]) + assert (indexed_data_eager.a.b == torch.arange(0, 2)).all() + assert (indexed_data_compile.a.b == torch.arange(0, 2)).all() + assert isinstance(indexed_data_eager, MyClass) + assert isinstance(indexed_data_compile, MyClass) + assert isinstance(indexed_data_eager.a, MyClass) + assert isinstance(indexed_data_compile.a, MyClass) + assert indexed_data_eager.shape == torch.Size([2]) + assert indexed_data_compile.shape == torch.Size([2]) def test_tc_stack(self): def stack_tds(td0, td1): return TensorDict.stack([td0, td1]) # return torch.stack([td0, td1]) - stack_tds_c = torch.compile(stack_tds, fullgraph=True) data0 = MyClass( a=MyClass(a=None, b=torch.arange(3), batch_size=[3]), batch_size=[3] ) data1 = MyClass( a=MyClass(a=None, b=torch.arange(3, 6), batch_size=[3]), batch_size=[3] ) - assert (stack_tds(data0, data1) == stack_tds_c(data0, data1)).all() + stack_eager = stack_tds(data0, data1) + + stack_tds_c = torch.compile(stack_tds, fullgraph=True) + stack_compile = stack_tds_c(data0, data1) + + assert (stack_eager == stack_compile).all() def test_tc_cat(self): def cat_tds(td0, td1): @@ -279,9 +303,9 @@ def clone(td: TensorDict): assert_close(clone_c(data), clone(data)) assert clone_c(data) is not data if recurse: - assert clone_c(data)["a", "b"] is not data["a", "b"] + assert clone_c(data).a.b is not data.a.b else: - assert clone_c(data)["a", "b"] is data["a", "b"] + assert clone_c(data).a.b is data.a.b class TestNN: From 19f95458c011412b1496abd0a0a051736637fb22 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Fri, 31 May 2024 18:02:59 +0100 Subject: [PATCH 18/28] amend --- tensordict/_td.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tensordict/_td.py b/tensordict/_td.py index f022232c0..e29aa4725 100644 --- a/tensordict/_td.py +++ b/tensordict/_td.py @@ -264,7 +264,8 @@ def __init__( ) self._batch_size = self._parse_batch_size(source, batch_size) # TODO: this breaks when stacking tensorclasses with dynamo - # self.names = names + if not torch.compiler.is_dynamo_compiling(): + self.names = names for key, value in source.items(): self.set(key, value, non_blocking=sub_non_blocking) From c17c7b1c948c9c85287ca15de0733545c227da9a Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Fri, 31 May 2024 18:15:03 +0100 Subject: [PATCH 19/28] amend --- tensordict/tensorclass.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tensordict/tensorclass.py b/tensordict/tensorclass.py index 86d8923c3..4e833bb91 100644 --- a/tensordict/tensorclass.py +++ b/tensordict/tensorclass.py @@ -387,7 +387,7 @@ def __torch_function__( cls.__getitem__ = _getitem cls.__getitems__ = _getitem cls.__setitem__ = _setitem - if not hasattr(cls, "__repr__"): + if not _is_non_tensor: cls.__repr__ = _repr cls.__len__ = _len cls.__eq__ = _eq @@ -2155,6 +2155,8 @@ def __post_init__(self): self._non_tensordict["data"] = data assert self._tensordict.is_empty(), self._tensordict + # TODO: this will probably fail with dynamo at some point, + it's terrible. + # Make sure it's patched properly at init time old_eq = self.__class__.__eq__ if old_eq is _eq: global NONTENSOR_HANDLED_FUNCTIONS From fcc9131784254bdf46bf293656fe07883fb538e9 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Tue, 4 Jun 2024 12:24:53 +0100 Subject: [PATCH 20/28] amend --- benchmarks/compile/compile_td_test.py | 74 +++++++++- tensordict/tensorclass.py | 189 ++++++++++++++------------ tensordict/utils.py | 13 ++ test/test_compile.py | 19 +++ 4 files changed, 203 insertions(+), 92 deletions(-) diff --git a/benchmarks/compile/compile_td_test.py b/benchmarks/compile/compile_td_test.py index 4f23eeb2a..f3c9eff9d 100644 --- a/benchmarks/compile/compile_td_test.py +++ b/benchmarks/compile/compile_td_test.py @@ -6,10 +6,20 @@ import pytest import torch -from tensordict import LazyStackedTensorDict, TensorDict +from tensordict import LazyStackedTensorDict, tensorclass, TensorDict from torch.utils._pytree import tree_map +@tensorclass +class MyTensorClass: + a: torch.Tensor + b: torch.Tensor + c: torch.Tensor + d: torch.Tensor + e: torch.Tensor + f: torch.Tensor + + # Functions def add_one(td): return td + 1 @@ -19,6 +29,14 @@ def add_one_pytree(td): return tree_map(lambda x: x + 1, td) +def add_self(td): + return td + td + + +def add_self_pytree(td): + return tree_map(lambda x: x + x, td) + + def copy(td): return td.copy() @@ -72,6 +90,19 @@ def get_flat_td(): ) +def get_flat_tc(): + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + return MyTensorClass( + a=torch.ones((15,), device=device), + b=torch.ones((15,), device=device), + c=torch.ones((15,), device=device), + d=torch.ones((15,), device=device), + e=torch.ones((15,), device=device), + f=torch.ones((15,), device=device), + device=device, + ) + + # Tests runtime of a simple arithmetic op over a highly nested tensordict @pytest.mark.parametrize("mode", ["compile", "eager"]) @pytest.mark.parametrize("dict_type", ["tensordict", "dict"]) @@ -114,7 +145,7 @@ def test_compile_copy_nested(mode, dict_type, benchmark): # Tests runtime of a simple arithmetic op over a flat tensordict @pytest.mark.parametrize("mode", ["compile", "eager"]) -@pytest.mark.parametrize("dict_type", ["tensordict", "dict"]) +@pytest.mark.parametrize("dict_type", ["tensordict", "tensorclass", "dict"]) def test_compile_add_one_flat(mode, dict_type, benchmark): if dict_type == "tensordict": if mode == "compile": @@ -122,6 +153,12 @@ def test_compile_add_one_flat(mode, dict_type, benchmark): else: func = add_one td = get_flat_td() + elif dict_type == "tensorclass": + if mode == "compile": + func = torch.compile(add_one, fullgraph=True) + else: + func = add_one + td = get_flat_tc() else: if mode == "compile": func = torch.compile(add_one_pytree, fullgraph=True) @@ -132,6 +169,31 @@ def test_compile_add_one_flat(mode, dict_type, benchmark): benchmark(func, td) +@pytest.mark.parametrize("mode", ["eager", "compile"]) +@pytest.mark.parametrize("dict_type", ["tensordict", "tensorclass", "dict"]) +def test_compile_add_self_flat(mode, dict_type, benchmark): + if dict_type == "tensordict": + if mode == "compile": + func = torch.compile(add_self, fullgraph=True) + else: + func = add_self + td = get_flat_td() + elif dict_type == "tensorclass": + if mode == "compile": + func = torch.compile(add_self, fullgraph=True) + else: + func = add_self + td = get_flat_tc() + else: + if mode == "compile": + func = torch.compile(add_self_pytree, fullgraph=True) + else: + func = add_self_pytree + td = get_flat_td().to_dict() + func(td) + benchmark(func, td) + + # Tests the speed of copying a flat tensordict @pytest.mark.parametrize("mode", ["compile", "eager"]) @pytest.mark.parametrize("dict_type", ["tensordict", "dict"]) @@ -142,6 +204,12 @@ def test_compile_copy_flat(mode, dict_type, benchmark): else: func = copy td = get_flat_td() + elif dict_type == "tensorclass": + if mode == "compile": + func = torch.compile(copy, fullgraph=True) + else: + func = copy + td = get_flat_tc() else: if mode == "compile": func = torch.compile(copy_pytree, fullgraph=True) @@ -193,7 +261,7 @@ def test_compile_assign_and_add_stack(mode, benchmark): # Tests indexing speed @pytest.mark.parametrize("mode", ["compile", "eager"]) -@pytest.mark.parametrize("dict_type", ["tensordict", "dict"]) +@pytest.mark.parametrize("dict_type", ["tensordict", "tensorclass", "dict"]) @pytest.mark.parametrize("index_type", ["tensor", "slice", "int"]) def test_compile_indexing(mode, dict_type, index_type, benchmark): device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") diff --git a/tensordict/tensorclass.py b/tensordict/tensorclass.py index 4e833bb91..d1ae72d37 100644 --- a/tensordict/tensorclass.py +++ b/tensordict/tensorclass.py @@ -92,8 +92,15 @@ def __subclasscheck__(self, subclass): } # Methods to be executed from tensordict, any ref to self means 'tensorclass' _METHOD_FROM_TD = [ - "gather", "replace", + "gather", +] +# Methods to apply to the nested TD without wrapping the output +_NOWRAP_OUTPUT_METHOD_FROM_TD = [ + "_values_list", + "keys", + "items", + "values", ] # Methods to be executed from tensordict, any ref to self means 'self._tensordict' _FALLBACK_METHOD_FROM_TD = [ @@ -109,68 +116,109 @@ def __subclasscheck__(self, subclass): "__sub__", "__truediv__", "_add_batch_dim", - "apply", "_apply_nest", - "_fast_apply", - "apply_", - "named_apply", "_check_unlock", - "unsqueeze", - "squeeze", "_erase_names", # TODO: must be specialized "_exclude", # TODO: must be specialized + "_fast_apply", "_get_str", "_get_tuple", - "_set_at_tuple", "_has_names", "_propagate_lock", "_propagate_unlock", "_remove_batch_dim", - "is_memmap", - "is_shared", "_select", # TODO: must be specialized + "_set_at_tuple", "_set_str", "_set_tuple", + "abs", + "abs_", + "acos", + "acos_", + "add", + "add_", + "addcdiv", + "addcdiv_", + "addcmul", + "addcmul_", "all", "any", + "apply", + "apply_", + "asin", + "asin_", + "atan", + "atan_", + "ceil", + "ceil_", + "clamp_max", + "clamp_max_", + "clamp_min", + "clamp_min_", + "copy_", + "cos", + "cos_", + "cosh", + "cosh_", + "div", + "div_", "empty", + "erf", + "erf_", + "erfc", + "erfc_", "exclude", + "exp", + "exp_", "expand", "expand_as", + "expm1", + "expm1_", + "flatten", + "floor", + "floor_", + "frac", + "frac_", "is_empty", + "is_memmap", "is_shared", - "items", - "keys", + "is_shared", + "lerp", + "lerp_", + "lgamma", + "lgamma_", "lock_", + "log", + "log10", + "log10_", + "log1p", + "log1p_", + "log2", + "log2_", + "log_", "masked_fill", "masked_fill_", - "permute", - "flatten", - "unflatten", - "ndimension", - "rename_", # TODO: must be specialized - "reshape", - "select", - "to", - "transpose", - "unlock_", - "values", - "view", - "zero_", - "add", - "add_", + "maximum", + "maximum_", + "minimum", + "minimum_", "mul", "mul_", - "abs", - "abs_", - "acos", - "acos_", - "exp", - "exp_", + "named_apply", + "ndimension", "neg", "neg_", + "norm", + "permute", + "pow", + "pow_", "reciprocal", "reciprocal_", + "rename_", # TODO: must be specialized + "reshape", + "round", + "round_", + "select", "sigmoid", "sigmoid_", "sign", @@ -179,67 +227,24 @@ def __subclasscheck__(self, subclass): "sin_", "sinh", "sinh_", + "sqrt", + "sqrt_", + "squeeze", + "sub", + "sub_", "tan", "tan_", "tanh", "tanh_", + "to", + "transpose", "trunc", "trunc_", - "norm", - "lgamma", - "lgamma_", - "frac", - "frac_", - "expm1", - "expm1_", - "log", - "log_", - "log10", - "log10_", - "log1p", - "log1p_", - "log2", - "log2_", - "ceil", - "ceil_", - "floor", - "floor_", - "round", - "round_", - "erf", - "erf_", - "erfc", - "erfc_", - "asin", - "asin_", - "atan", - "atan_", - "cos", - "cos_", - "cosh", - "cosh_", - "lerp", - "lerp_", - "addcdiv", - "addcdiv_", - "addcmul", - "addcmul_", - "sub", - "sub_", - "maximum_", - "maximum", - "minimum_", - "minimum", - "clamp_max_", - "clamp_max", - "clamp_min_", - "clamp_min", - "pow", - "pow_", - "div", - "div_", - "sqrt", - "sqrt_", + "unflatten", + "unlock_", + "unsqueeze", + "view", + "zero_", ] _FALLBACK_METHOD_FROM_TD_COPY = [ "_clone", # TODO: must be specialized @@ -435,6 +440,9 @@ def __torch_function__( for method_name in _FALLBACK_METHOD_FROM_TD: if not hasattr(cls, method_name): setattr(cls, method_name, _wrap_td_method(method_name)) + for method_name in _NOWRAP_OUTPUT_METHOD_FROM_TD: + if not hasattr(cls, method_name): + setattr(cls, method_name, _wrap_td_method(method_name, no_wrap=True)) for method_name in _FALLBACK_METHOD_FROM_TD_COPY: if not hasattr(cls, method_name): setattr( @@ -963,7 +971,7 @@ def wrapper(self, key: str, value: Any) -> None: # noqa: D417 return wrapper -def _wrap_td_method(funcname, *, copy_non_tensor=False): +def _wrap_td_method(funcname, *, copy_non_tensor=False, no_wrap=False): def wrapped_func(self, *args, **kwargs): if not torch.compiler.is_dynamo_compiling(): td = super(type(self), self).__getattribute__("_tensordict") @@ -972,6 +980,9 @@ def wrapped_func(self, *args, **kwargs): result = getattr(td, funcname)(*args, **kwargs) + if no_wrap: + return result + def check_out(kwargs, result): out = kwargs.get("out") if out is result: diff --git a/tensordict/utils.py b/tensordict/utils.py index f46d2137f..a7a0526ab 100644 --- a/tensordict/utils.py +++ b/tensordict/utils.py @@ -2280,6 +2280,17 @@ def _unravel_key_to_tuple(key): def unravel_key(key): + """Unravel a nested key. + + Examples: + >>> unravel_key("a") + "a" + >>> unravel_key(("a",)) + "a" + >>> unravel_key((("a", ("b",)))) + ("a", "b") + + """ if not torch.compiler.is_dynamo_compiling(): return unravel_key_cpp(key) if isinstance(key, str): @@ -2292,12 +2303,14 @@ def unravel_key(key): def unravel_keys(*keys): + """Unravels a sequence of keys.""" if not torch.compiler.is_dynamo_compiling(): return unravel_keys_cpp(*keys) return tuple(unravel_key(key) for key in keys) def unravel_key_list(keys): + """Unravels a list of keys.""" if not torch.compiler.is_dynamo_compiling(): return unravel_key_list_cpp(keys) return [unravel_key(key) for key in keys] diff --git a/test/test_compile.py b/test/test_compile.py index cc4544b85..25d26d703 100644 --- a/test/test_compile.py +++ b/test/test_compile.py @@ -193,6 +193,25 @@ def add_one(td): assert compiled.a.b == 1 assert add_one_c(data) is not data + def test_tc_arithmetic_other_tc(self): + def add_self(td): + return td + td + + data = MyClass(a=MyClass(a=None, b=torch.ones(()))) + + eager = add_self(data.clone()) + + add_self_c = torch.compile(add_self, fullgraph=True) + compiled = add_self_c(data.clone()) + + assert isinstance(eager.a, MyClass) + assert eager.a.b == 2 + + assert isinstance(compiled.a, MyClass) + # TODO: breaks because a is not cast to a MyClass but is a dict + assert compiled.a.b == 2 + assert add_self_c(data) is not data + @pytest.mark.parametrize("index_type", ["slice", "tensor", "int"]) def test_tc_index(self, index_type): if index_type == "slice": From ba5c247e77d18fb3640184e9449d93e5b3052ffe Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Tue, 4 Jun 2024 14:07:02 +0100 Subject: [PATCH 21/28] amend --- .github/workflows/benchmarks.yml | 4 +- .github/workflows/benchmarks_pr.yml | 4 +- tensordict/_td.py | 3 +- tensordict/_torch_func.py | 18 ++++---- tensordict/functional.py | 3 +- tensordict/tensorclass.py | 18 ++++++-- tensordict/utils.py | 69 +++++++++++++++-------------- test/test_compile.py | 4 +- test/test_tensordict.py | 6 ++- 9 files changed, 71 insertions(+), 58 deletions(-) diff --git a/.github/workflows/benchmarks.yml b/.github/workflows/benchmarks.yml index d7a8ab9c9..fa7257807 100644 --- a/.github/workflows/benchmarks.yml +++ b/.github/workflows/benchmarks.yml @@ -34,7 +34,7 @@ jobs: - name: Run benchmarks run: | cd benchmarks/ - python -m pytest -vvv --rank 0 --benchmark-json output.json + TORCHDYNAMO_INLINE_INBUILT_NN_MODULES=1 python -m pytest -vvv --rank 0 --benchmark-json output.json - name: Store benchmark results uses: benchmark-action/github-action-benchmark@v1 if: ${{ github.ref == 'refs/heads/main' || github.event_name == 'workflow_dispatch' }} @@ -114,7 +114,7 @@ jobs: - name: Run benchmarks run: | cd benchmarks/ - python -m pytest -vvv --rank 0 --benchmark-json output.json + TORCHDYNAMO_INLINE_INBUILT_NN_MODULES=1 python -m pytest -vvv --rank 0 --benchmark-json output.json - name: Store benchmark results uses: benchmark-action/github-action-benchmark@v1 if: ${{ github.ref == 'refs/heads/main' || github.event_name == 'workflow_dispatch' }} diff --git a/.github/workflows/benchmarks_pr.yml b/.github/workflows/benchmarks_pr.yml index 5f5cf7e5b..c7eebbea2 100644 --- a/.github/workflows/benchmarks_pr.yml +++ b/.github/workflows/benchmarks_pr.yml @@ -41,7 +41,7 @@ jobs: - name: Run benchmarks run: | cd benchmarks/ - RUN_BENCHMARK="pytest -vvv --rank 0 --benchmark-json " + RUN_BENCHMARK="TORCHDYNAMO_INLINE_INBUILT_NN_MODULES=1 pytest -vvv --rank 0 --benchmark-json " git checkout ${{ github.event.pull_request.base.sha }} $RUN_BENCHMARK ${{ env.BASELINE_JSON }} git checkout ${{ github.event.pull_request.head.sha }} @@ -128,7 +128,7 @@ jobs: - name: Run benchmarks run: | cd benchmarks/ - RUN_BENCHMARK="pytest -vvv --rank 0 --benchmark-json " + RUN_BENCHMARK="TORCHDYNAMO_INLINE_INBUILT_NN_MODULES=1 pytest -vvv --rank 0 --benchmark-json " git checkout ${{ github.event.pull_request.base.sha }} $RUN_BENCHMARK ${{ env.BASELINE_JSON }} git checkout ${{ github.event.pull_request.head.sha }} diff --git a/tensordict/_td.py b/tensordict/_td.py index e6f40c59c..9bbdd2ccf 100644 --- a/tensordict/_td.py +++ b/tensordict/_td.py @@ -68,6 +68,7 @@ _set_max_batch_size, _shape, _STRDTYPE2DTYPE, + _StringKeys, _StringOnlyDict, _sub_index, _unravel_key_to_tuple, @@ -2795,7 +2796,7 @@ def keys( is_leaf: Callable[[Type], bool] | None = None, ) -> _TensorDictKeysView: if not include_nested and not leaves_only: - return self._tensordict.keys() + return _StringKeys(self._tensordict.keys()) else: return self._nested_keys( include_nested=include_nested, leaves_only=leaves_only, is_leaf=is_leaf diff --git a/tensordict/_torch_func.py b/tensordict/_torch_func.py index 02132d4d2..d63b33303 100644 --- a/tensordict/_torch_func.py +++ b/tensordict/_torch_func.py @@ -404,16 +404,14 @@ def _stack( ) -> T: if not list_of_tensordicts: raise RuntimeError("list_of_tensordicts cannot be empty") - is_tc = False - if any(is_tensorclass(td) for td in list_of_tensordicts): - is_tc = True - if all(is_non_tensor(td) for td in list_of_tensordicts): - from tensordict.tensorclass import NonTensorData - - return NonTensorData._stack_non_tensor(list_of_tensordicts, dim=dim) - else: - tc_type = type(list_of_tensordicts[0]) - list_of_tensordicts = [tc.to_tensordict() for tc in list_of_tensordicts] + is_tc = any(is_tensorclass(td) for td in list_of_tensordicts) + if all(is_non_tensor(td) for td in list_of_tensordicts): + from tensordict.tensorclass import NonTensorData + + return NonTensorData._stack_non_tensor(list_of_tensordicts, dim=dim) + elif is_tc: + tc_type = type(list_of_tensordicts[0]) + list_of_tensordicts = [tc.to_tensordict() for tc in list_of_tensordicts] batch_size = list_of_tensordicts[0].batch_size if dim < 0: diff --git a/tensordict/functional.py b/tensordict/functional.py index dc6dfd4eb..5b7055ed8 100644 --- a/tensordict/functional.py +++ b/tensordict/functional.py @@ -331,8 +331,7 @@ def dense_stack_tds( shape = list(td_list[0].shape) shape.insert(dim, len(td_list)) - out = td_list[0].unsqueeze(dim).expand(shape).clone() - return torch.stack(td_list, dim=dim, out=out) + return TensorDict.maybe_dense_stack(td_list, dim=dim) def make_tensordict( diff --git a/tensordict/tensorclass.py b/tensordict/tensorclass.py index d1ae72d37..b1f1593db 100644 --- a/tensordict/tensorclass.py +++ b/tensordict/tensorclass.py @@ -97,9 +97,13 @@ def __subclasscheck__(self, subclass): ] # Methods to apply to the nested TD without wrapping the output _NOWRAP_OUTPUT_METHOD_FROM_TD = [ + "_get_at_str", + "_get_at_tuple", + "_get_str", + "_get_tuple", "_values_list", - "keys", "items", + "keys", "values", ] # Methods to be executed from tensordict, any ref to self means 'self._tensordict' @@ -121,8 +125,7 @@ def __subclasscheck__(self, subclass): "_erase_names", # TODO: must be specialized "_exclude", # TODO: must be specialized "_fast_apply", - "_get_str", - "_get_tuple", + "_get_sub_tensordict", "_has_names", "_propagate_lock", "_propagate_unlock", @@ -155,6 +158,7 @@ def __subclasscheck__(self, subclass): "clamp_max_", "clamp_min", "clamp_min_", + "contiguous", "copy_", "cos", "cos_", @@ -244,6 +248,7 @@ def __subclasscheck__(self, subclass): "unlock_", "unsqueeze", "view", + "where", "zero_", ] _FALLBACK_METHOD_FROM_TD_COPY = [ @@ -748,7 +753,12 @@ def wrapper(cls, tensordict, non_tensordict=None): # noqa: D417 ) # Validating non-tensor keys and for key clash - tensor_keys = set(tensordict.keys()) + + # TODO: compile doesn't like set() over an arbitrary object + if torch.compiler.is_dynamo_compiling(): + tensor_keys = {k for k in tensordict.keys()} # C416 + else: + tensor_keys = set(tensordict.keys()) if non_tensordict is not None: for key in list(non_tensordict.keys()): if key not in expected_keys: diff --git a/tensordict/utils.py b/tensordict/utils.py index a7a0526ab..5ec32e5eb 100644 --- a/tensordict/utils.py +++ b/tensordict/utils.py @@ -1113,7 +1113,26 @@ def new_fun(self, *args, **kwargs): class _StringKeys(KeysView): - """A key view where contains is restricted to strings.""" + """A key view where contains is restricted to strings. + + Saving the keys as an attribute is 25% faster than just subclassing KeysView. + + """ + + def __init__(self, keys): + self.keys = keys + + def __getitem__(self, key): + return self.keys.__getitem__(key) + + def __iter__(self): + yield from self.keys + + def __repr__(self): + return f"{type(self)}({self.keys})" + + def __len__(self): + return len(self.keys) def __contains__(self, item): if not isinstance(item, str): @@ -1127,35 +1146,10 @@ def __contains__(self, item): raise TypeError(_NON_STR_KEY_TUPLE_ERR) else: item = unravel_item[0] - return super().__contains__(item) + return self.keys.__contains__(item) _StringOnlyDict = dict -# class _StringOnlyDict(dict): -# """A dict class where contains is restricted to strings.""" -# -# # kept here for debugging -# # def __setitem__(self, key, value): -# # if not isinstance(key, str): -# # raise RuntimeError -# # return super().__setitem__(key, value) -# -# def __contains__(self, item): -# if not isinstance(item, str): -# try: -# unravel_item = _unravel_key_to_tuple(item) -# if not unravel_item: # catch errors during unravel -# raise TypeError -# except Exception: -# raise TypeError(_NON_STR_KEY_ERR) -# if len(unravel_item) > 1: -# raise TypeError(_NON_STR_KEY_TUPLE_ERR) -# else: -# item = unravel_item[0] -# return super().__contains__(item) -# -# def keys(self): -# return _StringKeys(self) def lock_blocked(func): @@ -1531,13 +1525,16 @@ def _check_keys( if not len(list_of_tensordicts): return set() - keys: set[str] = set( - list_of_tensordicts[0].keys( - include_nested=include_nested, - leaves_only=leaves_only, - is_leaf=_is_leaf_nontensor, - ) + keys = list_of_tensordicts[0].keys( + include_nested=include_nested, + leaves_only=leaves_only, + is_leaf=_is_leaf_nontensor, ) + # TODO: compile doesn't like set() over an arbitrary object + if torch.compiler.is_dynamo_compiling(): + keys = {k for k in keys} # C416 + else: + keys: set[str] = set(keys) for td in list_of_tensordicts[1:]: k = td.keys( include_nested=include_nested, @@ -1547,7 +1544,11 @@ def _check_keys( if not strict: keys = keys.intersection(k) else: - if set(k) != keys: + if torch.compiler.is_dynamo_compiling(): + k = {v for v in k} # C416 + else: + k = set(k) + if k != keys: raise KeyError( f"got keys {keys} and {set(td.keys())} which are incompatible" ) diff --git a/test/test_compile.py b/test/test_compile.py index 25d26d703..61462aca8 100644 --- a/test/test_compile.py +++ b/test/test_compile.py @@ -65,8 +65,8 @@ def add_one(td): def test_stack(self): def stack_tds(td0, td1): - # return TensorDict.stack([td0, td1]) - return torch.stack([td0, td1]) + return TensorDict.stack([td0, td1]) + # return torch.stack([td0, td1]) stack_tds_c = torch.compile(stack_tds, fullgraph=True) data0 = TensorDict({"a": {"b": torch.arange(3)}}, [3]) diff --git a/test/test_tensordict.py b/test/test_tensordict.py index da03791f8..46f2b0167 100644 --- a/test/test_tensordict.py +++ b/test/test_tensordict.py @@ -8977,7 +8977,11 @@ def test_memmap_stack(self, tmpdir, json_serializable, device): def test_memmap_stack_updates(self, tmpdir): data = torch.stack([NonTensorData(data=0), NonTensorData(data=1)], 0) - data = torch.stack([data] * 3).clone() + assert is_non_tensor(data) + data = torch.stack([data] * 3) + assert is_non_tensor(data) + data = data.clone() + assert is_non_tensor(data) data.memmap_(tmpdir) data_recon = TensorDict.load_memmap(tmpdir) assert data.tolist() == data_recon.tolist() From 181da2a542e08e1da93b34d36e82b54cbb03321d Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Tue, 4 Jun 2024 14:34:50 +0100 Subject: [PATCH 22/28] amend --- .github/unittest/linux/scripts/run_test.sh | 1 + .../unittest/linux_torchrec/scripts/run_test.sh | 1 + .../unittest/rl_linux_optdeps/scripts/run_test.sh | 1 + benchmarks/compile/compile_td_test.py | 15 ++++++++------- 4 files changed, 11 insertions(+), 7 deletions(-) diff --git a/.github/unittest/linux/scripts/run_test.sh b/.github/unittest/linux/scripts/run_test.sh index 9fc821efb..178775afd 100755 --- a/.github/unittest/linux/scripts/run_test.sh +++ b/.github/unittest/linux/scripts/run_test.sh @@ -17,6 +17,7 @@ lib_dir="${env_dir}/lib" # solves ImportError: /lib64/libstdc++.so.6: version `GLIBCXX_3.4.21' not found export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$lib_dir export MKL_THREADING_LAYER=GNU +export TORCHDYNAMO_INLINE_INBUILT_NN_MODULES=1 coverage run -m pytest test/smoke_test.py -v --durations 20 coverage run -m pytest --instafail -v --durations 20 diff --git a/.github/unittest/linux_torchrec/scripts/run_test.sh b/.github/unittest/linux_torchrec/scripts/run_test.sh index 9fc821efb..178775afd 100755 --- a/.github/unittest/linux_torchrec/scripts/run_test.sh +++ b/.github/unittest/linux_torchrec/scripts/run_test.sh @@ -17,6 +17,7 @@ lib_dir="${env_dir}/lib" # solves ImportError: /lib64/libstdc++.so.6: version `GLIBCXX_3.4.21' not found export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$lib_dir export MKL_THREADING_LAYER=GNU +export TORCHDYNAMO_INLINE_INBUILT_NN_MODULES=1 coverage run -m pytest test/smoke_test.py -v --durations 20 coverage run -m pytest --instafail -v --durations 20 diff --git a/.github/unittest/rl_linux_optdeps/scripts/run_test.sh b/.github/unittest/rl_linux_optdeps/scripts/run_test.sh index 1f2fc9ae6..04d81bb06 100755 --- a/.github/unittest/rl_linux_optdeps/scripts/run_test.sh +++ b/.github/unittest/rl_linux_optdeps/scripts/run_test.sh @@ -16,6 +16,7 @@ git config --global --add safe.directory '*' root_dir="$(git rev-parse --show-toplevel)" export MKL_THREADING_LAYER=GNU export CKPT_BACKEND=torch +export TORCHDYNAMO_INLINE_INBUILT_NN_MODULES=1 #MUJOCO_GL=glfw pytest --cov=torchrl --junitxml=test-results/junit.xml -v --durations 20 MUJOCO_GL=egl python -m pytest rl/test --instafail -v --durations 20 --ignore rl/test/test_distributed.py diff --git a/benchmarks/compile/compile_td_test.py b/benchmarks/compile/compile_td_test.py index f3c9eff9d..e9359c859 100644 --- a/benchmarks/compile/compile_td_test.py +++ b/benchmarks/compile/compile_td_test.py @@ -105,7 +105,7 @@ def get_flat_tc(): # Tests runtime of a simple arithmetic op over a highly nested tensordict @pytest.mark.parametrize("mode", ["compile", "eager"]) -@pytest.mark.parametrize("dict_type", ["tensordict", "dict"]) +@pytest.mark.parametrize("dict_type", ["tensordict", "pytree"]) def test_compile_add_one_nested(mode, dict_type, benchmark): if dict_type == "tensordict": if mode == "compile": @@ -125,7 +125,7 @@ def test_compile_add_one_nested(mode, dict_type, benchmark): # Tests the speed of copying a nested tensordict @pytest.mark.parametrize("mode", ["compile", "eager"]) -@pytest.mark.parametrize("dict_type", ["tensordict", "dict"]) +@pytest.mark.parametrize("dict_type", ["tensordict", "pytree"]) def test_compile_copy_nested(mode, dict_type, benchmark): if dict_type == "tensordict": if mode == "compile": @@ -145,7 +145,7 @@ def test_compile_copy_nested(mode, dict_type, benchmark): # Tests runtime of a simple arithmetic op over a flat tensordict @pytest.mark.parametrize("mode", ["compile", "eager"]) -@pytest.mark.parametrize("dict_type", ["tensordict", "tensorclass", "dict"]) +@pytest.mark.parametrize("dict_type", ["tensordict", "tensorclass", "pytree"]) def test_compile_add_one_flat(mode, dict_type, benchmark): if dict_type == "tensordict": if mode == "compile": @@ -170,7 +170,7 @@ def test_compile_add_one_flat(mode, dict_type, benchmark): @pytest.mark.parametrize("mode", ["eager", "compile"]) -@pytest.mark.parametrize("dict_type", ["tensordict", "tensorclass", "dict"]) +@pytest.mark.parametrize("dict_type", ["tensordict", "tensorclass", "pytree"]) def test_compile_add_self_flat(mode, dict_type, benchmark): if dict_type == "tensordict": if mode == "compile": @@ -196,7 +196,7 @@ def test_compile_add_self_flat(mode, dict_type, benchmark): # Tests the speed of copying a flat tensordict @pytest.mark.parametrize("mode", ["compile", "eager"]) -@pytest.mark.parametrize("dict_type", ["tensordict", "dict"]) +@pytest.mark.parametrize("dict_type", ["tensordict", "pytree"]) def test_compile_copy_flat(mode, dict_type, benchmark): if dict_type == "tensordict": if mode == "compile": @@ -222,7 +222,7 @@ def test_compile_copy_flat(mode, dict_type, benchmark): # Tests the speed of assigning entries to an empty tensordict @pytest.mark.parametrize("mode", ["compile", "eager"]) -@pytest.mark.parametrize("dict_type", ["tensordict", "dict"]) +@pytest.mark.parametrize("dict_type", ["tensordict", "pytree"]) def test_compile_assign_and_add(mode, dict_type, benchmark): device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") td = TensorDict(device=device) @@ -261,7 +261,7 @@ def test_compile_assign_and_add_stack(mode, benchmark): # Tests indexing speed @pytest.mark.parametrize("mode", ["compile", "eager"]) -@pytest.mark.parametrize("dict_type", ["tensordict", "tensorclass", "dict"]) +@pytest.mark.parametrize("dict_type", ["tensordict", "tensorclass", "pytree"]) @pytest.mark.parametrize("index_type", ["tensor", "slice", "int"]) def test_compile_indexing(mode, dict_type, index_type, benchmark): device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") @@ -287,6 +287,7 @@ def test_compile_indexing(mode, dict_type, index_type, benchmark): idx = slice(None, None, 2) if index_type == "tensor": idx = torch.tensor(range(*idx.indices(10))) + func(td, idx) benchmark(func, td, idx) From 3b74b7ec49ef6c1ff7a3b21253cd7dbf9bfa3d4a Mon Sep 17 00:00:00 2001 From: vmoens Date: Tue, 4 Jun 2024 10:55:48 -0700 Subject: [PATCH 23/28] amend --- tensordict/base.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tensordict/base.py b/tensordict/base.py index 656feee28..dd2f6eb83 100644 --- a/tensordict/base.py +++ b/tensordict/base.py @@ -7465,7 +7465,8 @@ def to(self, *args, **kwargs) -> T: def _sync_all(self): if _has_cuda: - if torch.cuda.is_initialized(): + # TODO: dynamo doesn't like torch.cuda.is_initialized + if torch.compiler.is_dynamo_compiling() or torch.cuda.is_initialized(): torch.cuda.synchronize() elif _has_mps: torch.mps.synchronize() From 4cda3a932f6db7a97c60328d75e33803f6e877e6 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Wed, 19 Jun 2024 15:18:44 +0100 Subject: [PATCH 24/28] amend --- tensordict/nn/utils.py | 9 ++++++- tensordict/tensorclass.py | 57 +++++++++++++++++++-------------------- tensordict/utils.py | 4 +-- 3 files changed, 37 insertions(+), 33 deletions(-) diff --git a/tensordict/nn/utils.py b/tensordict/nn/utils.py index ea9df8f25..cb485ff97 100644 --- a/tensordict/nn/utils.py +++ b/tensordict/nn/utils.py @@ -263,6 +263,8 @@ def wrapper(_self, tensordict, *args: Any, **kwargs: Any) -> Any: return super().__call__(wrapper) def __enter__(self) -> None: + if self.mode and torch.compiler.is_dynamo_compiling(): + raise RuntimeError("skip_existing is not compatible with TorchDynamo.") global _SKIP_EXISTING self.prev = _SKIP_EXISTING if self.mode is not None: @@ -285,7 +287,6 @@ class _set_skip_existing_None(set_skip_existing): """ def __call__(self, func: Callable): - self._called = True # sanity check @@ -302,6 +303,10 @@ def __call__(self, func: Callable): @functools.wraps(func) def wrapper(_self, tensordict, *args: Any, **kwargs: Any) -> Any: + if skip_existing() and torch.compiler.is_dynamo_compiling(): + raise RuntimeError( + "skip_existing is not compatible with torch.compile." + ) in_keys = getattr(_self, self.in_key_attr) out_keys = getattr(_self, self.out_key_attr) # we use skip_existing to allow users to override the mode internally @@ -311,6 +316,8 @@ def wrapper(_self, tensordict, *args: Any, **kwargs: Any) -> Any: and not any(key in out_keys for key in in_keys) ): return tensordict + if torch.compiler.is_dynamo_compiling(): + return func(_self, tensordict, *args, **kwargs) global _SKIP_EXISTING self.prev = _SKIP_EXISTING try: diff --git a/tensordict/tensorclass.py b/tensordict/tensorclass.py index 048508076..543fc49b1 100644 --- a/tensordict/tensorclass.py +++ b/tensordict/tensorclass.py @@ -100,14 +100,20 @@ def __subclasscheck__(self, subclass): "_get_at_str", "_get_at_tuple", "_get_str", - "_get_sub_tensordict", + # "_get_sub_tensordict", + "ndimension", "_get_tuple", "_values_list", "is_memmap", + "_propagate_lock", + "_propagate_unlock", "is_shared", + "is_empty", "items", "keys", "ndimension", + "_has_names", + "_check_unlock", "numel", "values", ] @@ -127,23 +133,13 @@ def __subclasscheck__(self, subclass): "__truediv__", "_add_batch_dim", "_apply_nest", - "_check_unlock", "_erase_names", # TODO: must be specialized "_exclude", # TODO: must be specialized "_fast_apply", - "_fast_apply", "_get_sub_tensordict", - "_has_names", - "_propagate_lock", - "_propagate_unlock", "_remove_batch_dim", "_select", # TODO: must be specialized - # "_set_at_str", - "_set_at_tuple", "_set_at_tuple", - # _set_str needs a special treatment to catch keys that are already in - # non tensor data - # "_set_at_tuple", "_set_str", "_set_tuple", "abs", @@ -197,15 +193,6 @@ def __subclasscheck__(self, subclass): "floor_", "frac", "frac_", - "is_empty", - "is_memmap", - "is_shared", - "is_shared", - "is_shared", - "lerp", - "lerp_", - "lgamma", - "lgamma_", "lerp", "lerp_", "lgamma", @@ -228,7 +215,6 @@ def __subclasscheck__(self, subclass): "mul", "mul_", "named_apply", - "ndimension", "neg", "neg_", "norm", @@ -266,11 +252,19 @@ def __subclasscheck__(self, subclass): "unflatten", "unlock_", "unsqueeze", - "values", "view", "where", "zero_", + # "_set_at_str", + # "_set_at_tuple", + # _set_str needs a special treatment to catch keys that are already in + # non tensor data ] +assert not any( + v in _NOWRAP_OUTPUT_METHOD_FROM_TD for v in _FALLBACK_METHOD_FROM_TD +), set(_NOWRAP_OUTPUT_METHOD_FROM_TD).intersection(_FALLBACK_METHOD_FROM_TD) +assert len(set(_FALLBACK_METHOD_FROM_TD)) == len(_FALLBACK_METHOD_FROM_TD) + _FALLBACK_METHOD_FROM_TD_COPY = [ "_clone", # TODO: must be specialized "clone", # TODO: must be specialized @@ -780,7 +774,7 @@ def wrapper(cls, tensordict, non_tensordict=None): # noqa: D417 # TODO: compile doesn't like set() over an arbitrary object if torch.compiler.is_dynamo_compiling(): - tensor_keys = {k for k in tensordict.keys()} # C416 + tensor_keys = {k for k in tensordict.keys()} # noqa: C416 else: tensor_keys = set(tensordict.keys()) if non_tensordict is not None: @@ -985,16 +979,18 @@ def wrapper(self, key: str, value: Any) -> None: # noqa: D417 "_tensordict" not in __dict__ or "_non_tensordict" not in __dict__ or key in SET_ATTRIBUTES - or key in self.__class__.__dict__ - # if we ever decide to allow anything to be written in a tc - # or key not in self.__dataclass_fields__): + or key in self.__class__.__dict__ + ): + # if we ever decide to allow anything to be written in a tc + # or key not in self.__dataclass_fields__): return setattr_(self, key, value) else: # Pass? if key in SET_ATTRIBUTES: # assert getattr(self, "_is_initialized", False) return setattr_(self, key, value) - # if not hasattr(self, "_tensordict") or not hasattr(self, "_non_tensordict") or key in SET_ATTRIBUTES: + # TODO: compile doesn't support property checks + # if type(self).__dict__.get(key) is not None: # return setattr_(self, key, value) out = self.set(key, value) @@ -1026,9 +1022,10 @@ def check_out(kwargs, result): return True return False - if isinstance(result, TensorDictBase) and not check_out(kwargs, result): - if result is td: - return self + if result is td: + return self + # TODO: this subclass check is just to make dynamo happy + if issubclass(type(result), TensorDictBase) and not check_out(kwargs, result): if not torch.compiler.is_dynamo_compiling(): nontd = super(type(self), self).__getattribute__("_non_tensordict") else: diff --git a/tensordict/utils.py b/tensordict/utils.py index 03951ded5..f90de7a6c 100644 --- a/tensordict/utils.py +++ b/tensordict/utils.py @@ -1547,7 +1547,7 @@ def _check_keys( ) # TODO: compile doesn't like set() over an arbitrary object if torch.compiler.is_dynamo_compiling(): - keys = {k for k in keys} # C416 + keys = {k for k in keys} # noqa: C416 else: keys: set[str] = set(keys) for td in list_of_tensordicts[1:]: @@ -1560,7 +1560,7 @@ def _check_keys( keys = keys.intersection(k) else: if torch.compiler.is_dynamo_compiling(): - k = {v for v in k} # C416 + k = {v for v in k} # noqa: C416 else: k = set(k) if k != keys: From b5cb40aaeb824ac632c0a2f178430105513525d1 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Thu, 27 Jun 2024 15:06:22 +0100 Subject: [PATCH 25/28] amend --- tensordict/_td.py | 37 +++++++++---------------------------- tensordict/base.py | 3 +-- tensordict/tensorclass.py | 34 +++++++++++++++------------------- test/test_compile.py | 22 ++++++++++++++++++++-- 4 files changed, 45 insertions(+), 51 deletions(-) diff --git a/tensordict/_td.py b/tensordict/_td.py index ea66f6d48..e5df6d0b1 100644 --- a/tensordict/_td.py +++ b/tensordict/_td.py @@ -428,13 +428,16 @@ def convert_type(x, y): inplace = bool(inplace) # we use __dict__ directly to avoid the getattr/setattr overhead whenever we can - if not is_dynamo and module.__class__.__setattr__ is __base__setattr__: + if module.__class__.__setattr__ is __base__setattr__: __dict__ = module.__dict__ else: __dict__ = None for key, value in input.items(): if isinstance(value, (Tensor, ftdim.Tensor)): + # For Dynamo, we use regular set/delattr as we're not + # much afraid by overhead (and dynamo doesn't like those + # hacks we're doing). if __dict__ is not None: # if setattr is the native nn.Module.setattr, we can rely on _set_tensor_dict local_out = _set_tensor_dict( @@ -460,7 +463,8 @@ def convert_type(x, y): if __dict__ is not None: child = __dict__["_modules"][key] else: - child = getattr(module, key) + child = module._modules.get(key) + if not is_dynamo: local_out = memo.get(weakref.ref(child), NO_DEFAULT) @@ -2662,6 +2666,7 @@ def _clone(self, recurse: bool = True) -> T: source={key: _clone_value(value, recurse) for key, value in self.items()}, batch_size=self.batch_size, device=self.device, + # TODO: Dynamo doesn't like `copy` names=list(self._td_dim_names) if self._has_names() else None, _run_checks=False, ) @@ -2926,7 +2931,8 @@ def values( if not include_nested and not leaves_only: return self._tensordict.values() else: - return super().values( + return TensorDictBase.values( + self, include_nested=include_nested, leaves_only=leaves_only, is_leaf=is_leaf, @@ -3845,33 +3851,18 @@ def __init__( is_leaf = _default_is_leaf self.is_leaf = is_leaf - # self.items = [] - # self._iter() - # - # def __getitem__(self, i): - # return self.items[i] - def __iter__(self) -> Iterable[str] | Iterable[tuple[str, ...]]: - # yield from self.items - # - # def _iter(self): if not self.include_nested: if self.leaves_only: for key in self._keys(): target_class = self.tensordict.entry_class(key) if _is_tensor_collection(target_class): continue - # self.items.append(key) yield key else: yield from self._keys() - # self.items += self._keys() else: yield from ( - # key if len(key) > 1 else key[0] - # for key in self._iter_helper(self.tensordict) - # ) - # self.items += ( key if len(key) > 1 else key[0] for key in self._iter_helper(self.tensordict) ) @@ -4187,13 +4178,3 @@ def _update_metadata(*, metadata, key, value, is_collection): metadata[key] = { "type": type(value).__name__, } - - -# @torch._dynamo.on_compile_start -# def _check_inline_inbuilt_module(): -# try: -# assert str2bool(environ["TORCHDYNAMO_INLINE_INBUILT_NN_MODULES"]) -# except (AssertionError, KeyError): -# raise RuntimeError( -# "set TORCHDYNAMO_INLINE_INBUILT_NN_MODULES=1 to use functional calls with tensordict." -# ) diff --git a/tensordict/base.py b/tensordict/base.py index 189f6df1d..b9076ea3b 100644 --- a/tensordict/base.py +++ b/tensordict/base.py @@ -6036,8 +6036,7 @@ def cosh_(self) -> T: return self def add(self, other: TensorDictBase | float, alpha: float | None = None): - keys_vals = self._items_list(True, True) - keys, vals = keys_vals + keys, vals = self._items_list(True, True) if _is_tensor_collection(type(other)): other_val = other._values_list(True, True) else: diff --git a/tensordict/tensorclass.py b/tensordict/tensorclass.py index df47c073a..4f41ad857 100644 --- a/tensordict/tensorclass.py +++ b/tensordict/tensorclass.py @@ -92,26 +92,28 @@ def __subclasscheck__(self, subclass): } # Methods to be executed from tensordict, any ref to self means 'tensorclass' _METHOD_FROM_TD = [ + "_check_unlock", "_default_get", "_get_at_str", "_get_at_tuple", "_get_str", - # "_get_sub_tensordict", - "ndimension", "_get_tuple", - "_values_list", - "is_memmap", + "_has_names", "_propagate_lock", "_propagate_unlock", - "is_shared", + "_values_list", + "gather", "is_empty", + "is_memmap", + "is_shared", "items", "keys", "ndimension", - "_has_names", - "_check_unlock", + "ndimension", "numel", + "replace", "values", + # "_get_sub_tensordict", ] # Methods to be executed from tensordict, any ref to self means 'self._tensordict' @@ -189,15 +191,9 @@ def __subclasscheck__(self, subclass): "floor_", "frac", "frac_", - "is_empty", - "is_memmap", - "is_shared", - "is_shared", "isfinite", "isnan", "isreal", - "items", - "keys", "lerp", "lerp_", "lgamma", @@ -265,9 +261,9 @@ def __subclasscheck__(self, subclass): # _set_str needs a special treatment to catch keys that are already in # non tensor data ] -assert not any( - v in _NOWRAP_OUTPUT_METHOD_FROM_TD for v in _FALLBACK_METHOD_FROM_TD -), set(_NOWRAP_OUTPUT_METHOD_FROM_TD).intersection(_FALLBACK_METHOD_FROM_TD) +assert not any(v in _METHOD_FROM_TD for v in _FALLBACK_METHOD_FROM_TD), set( + _METHOD_FROM_TD +).intersection(_FALLBACK_METHOD_FROM_TD) assert len(set(_FALLBACK_METHOD_FROM_TD)) == len(_FALLBACK_METHOD_FROM_TD) _FALLBACK_METHOD_FROM_TD_COPY = [ @@ -468,9 +464,9 @@ def __torch_function__( for method_name in _FALLBACK_METHOD_FROM_TD: if not hasattr(cls, method_name): setattr(cls, method_name, _wrap_td_method(method_name)) - for method_name in _NOWRAP_OUTPUT_METHOD_FROM_TD: - if not hasattr(cls, method_name): - setattr(cls, method_name, _wrap_td_method(method_name, no_wrap=True)) + # for method_name in _NOWRAP_OUTPUT_METHOD_FROM_TD: + # if not hasattr(cls, method_name): + # setattr(cls, method_name, _wrap_td_method(method_name, no_wrap=True)) for method_name in _FALLBACK_METHOD_FROM_TD_COPY: if not hasattr(cls, method_name): setattr( diff --git a/test/test_compile.py b/test/test_compile.py index 61462aca8..8ed2197b5 100644 --- a/test/test_compile.py +++ b/test/test_compile.py @@ -373,9 +373,9 @@ def remove_hidden(td): assert module_compile(td) is not td -class TestFunc: +class TestFunctional: @pytest.mark.parametrize("modif_param", [False, True]) - def test_func(self, modif_param): + def test_functional(self, modif_param): class MessUpParams(torch.nn.Module): def __init__(self): super().__init__() @@ -393,6 +393,7 @@ def forward(self, x): if modif_param: module.append(MessUpParams()) + orig_params = list(module.parameters()) td = TensorDict.from_module(module) td_zero = TensorDictParams(td.data.clone()) td_zero.zero_() @@ -402,6 +403,10 @@ def call(x, td): # with td.to_module(module): # return module(x) + # sd = module.state_dict() + # module.load_state_dict(td.flatten_keys().to_dict()) + # result = module(x) + # module.load_state_dict(sd) params = td.to_module(module, return_swap=True) result = module(x) params.to_module(module, return_swap=True, swap_dest=td) @@ -410,14 +415,27 @@ def call(x, td): call_compile = torch.compile(call, fullgraph=True) # , backend="eager") x = torch.randn(2, 3) assert (call(x, td_zero) == 0).all() + assert ( + p_new is p_orig for p_new, p_orig in zip(module.parameters(), orig_params) + ) assert (call(x, td_zero) == 0).all() + assert ( + p_new is p_orig for p_new, p_orig in zip(module.parameters(), orig_params) + ) if modif_param: assert td_zero["3", "param"] == 2 else: assert (td_zero == 0).all() # torch.testing.assert_close(call_compile(x, td_zero), module(x)) + call_compile(x, td_zero) assert (call_compile(x, td_zero) == 0).all() + assert ( + p_new is p_orig for p_new, p_orig in zip(module.parameters(), orig_params) + ) assert (call_compile(x, td_zero) == 0).all() + assert ( + p_new is p_orig for p_new, p_orig in zip(module.parameters(), orig_params) + ) if modif_param: assert td_zero["3", "param"] == 4 else: From 7252fcfa4719c1109187187b2290f71418b57ff1 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Thu, 27 Jun 2024 16:29:04 +0100 Subject: [PATCH 26/28] amend --- tensordict/_td.py | 3 +-- tensordict/tensorclass.py | 10 +++------- test/test_compile.py | 30 +++++++++++++++++------------- 3 files changed, 21 insertions(+), 22 deletions(-) diff --git a/tensordict/_td.py b/tensordict/_td.py index e5df6d0b1..4cb89dc29 100644 --- a/tensordict/_td.py +++ b/tensordict/_td.py @@ -2666,8 +2666,7 @@ def _clone(self, recurse: bool = True) -> T: source={key: _clone_value(value, recurse) for key, value in self.items()}, batch_size=self.batch_size, device=self.device, - # TODO: Dynamo doesn't like `copy` - names=list(self._td_dim_names) if self._has_names() else None, + names=copy(self._td_dim_names) if self._has_names() else None, _run_checks=False, ) # If this is uncommented, a shallow copy of a shared/memmap will be shared and locked too diff --git a/tensordict/tensorclass.py b/tensordict/tensorclass.py index 4f41ad857..55842e2ca 100644 --- a/tensordict/tensorclass.py +++ b/tensordict/tensorclass.py @@ -937,11 +937,8 @@ def _setstate(self, state: dict[str, Any]) -> None: # noqa: D417 def _getattr(self, item: str) -> Any: # if not item.startswith("__"): - if not torch.compiler.is_dynamo_compiling(): - __dict__ = self.__dict__ - _non_tensordict = __dict__.get("_non_tensordict") - else: - _non_tensordict = self._non_tensordict + __dict__ = self.__dict__ + _non_tensordict = __dict__.get("_non_tensordict") if _non_tensordict is not None: out = _non_tensordict.get(item, NO_DEFAULT) @@ -1039,8 +1036,7 @@ def check_out(kwargs, result): if result is td: return self - # TODO: this subclass check is just to make dynamo happy - if issubclass(type(result), TensorDictBase) and not check_out(kwargs, result): + if isinstance(result, TensorDictBase) and not check_out(kwargs, result): if not torch.compiler.is_dynamo_compiling(): nontd = super(type(self), self).__getattribute__("_non_tensordict") else: diff --git a/test/test_compile.py b/test/test_compile.py index 8ed2197b5..8a4bac77e 100644 --- a/test/test_compile.py +++ b/test/test_compile.py @@ -376,6 +376,8 @@ def remove_hidden(td): class TestFunctional: @pytest.mark.parametrize("modif_param", [False, True]) def test_functional(self, modif_param): + + # TODO: UNTESTED class MessUpParams(torch.nn.Module): def __init__(self): super().__init__() @@ -383,7 +385,7 @@ def __init__(self): def forward(self, x): self.param.data.add_(1) - return x + return x * 1 module = torch.nn.Sequential( torch.nn.Linear(3, 4), @@ -403,10 +405,6 @@ def call(x, td): # with td.to_module(module): # return module(x) - # sd = module.state_dict() - # module.load_state_dict(td.flatten_keys().to_dict()) - # result = module(x) - # module.load_state_dict(sd) params = td.to_module(module, return_swap=True) result = module(x) params.to_module(module, return_swap=True, swap_dest=td) @@ -415,26 +413,32 @@ def call(x, td): call_compile = torch.compile(call, fullgraph=True) # , backend="eager") x = torch.randn(2, 3) assert (call(x, td_zero) == 0).all() - assert ( - p_new is p_orig for p_new, p_orig in zip(module.parameters(), orig_params) + assert all( + p_new is p_orig + for p_new, p_orig in zip(module.parameters(), orig_params, strict=True) ) assert (call(x, td_zero) == 0).all() - assert ( - p_new is p_orig for p_new, p_orig in zip(module.parameters(), orig_params) + assert all( + p_new is p_orig + for p_new, p_orig in zip(module.parameters(), orig_params, strict=True) ) if modif_param: assert td_zero["3", "param"] == 2 else: assert (td_zero == 0).all() # torch.testing.assert_close(call_compile(x, td_zero), module(x)) + + td.to_module(module) call_compile(x, td_zero) assert (call_compile(x, td_zero) == 0).all() - assert ( - p_new is p_orig for p_new, p_orig in zip(module.parameters(), orig_params) + assert all( + p_new is p_orig + for p_new, p_orig in zip(module.parameters(), orig_params, strict=True) ) assert (call_compile(x, td_zero) == 0).all() - assert ( - p_new is p_orig for p_new, p_orig in zip(module.parameters(), orig_params) + assert all( + p_new is p_orig + for p_new, p_orig in zip(module.parameters(), orig_params, strict=True) ) if modif_param: assert td_zero["3", "param"] == 4 From dd3560dd9c09f5e77ea6849f3ab8ad3578dd9087 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Sat, 29 Jun 2024 12:51:37 +0100 Subject: [PATCH 27/28] amend --- tensordict/tensorclass.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/tensordict/tensorclass.py b/tensordict/tensorclass.py index 55842e2ca..ef127f86e 100644 --- a/tensordict/tensorclass.py +++ b/tensordict/tensorclass.py @@ -92,12 +92,15 @@ def __subclasscheck__(self, subclass): } # Methods to be executed from tensordict, any ref to self means 'tensorclass' _METHOD_FROM_TD = [ - "_check_unlock", +] +# Methods to be executed from tensordict, any ref to self means 'self._tensordict' +_FALLBACK_METHOD_FROM_TD_NOWRAP = [ "_default_get", "_get_at_str", "_get_at_tuple", "_get_str", "_get_tuple", + "_check_unlock", "_has_names", "_propagate_lock", "_propagate_unlock", @@ -464,9 +467,9 @@ def __torch_function__( for method_name in _FALLBACK_METHOD_FROM_TD: if not hasattr(cls, method_name): setattr(cls, method_name, _wrap_td_method(method_name)) - # for method_name in _NOWRAP_OUTPUT_METHOD_FROM_TD: - # if not hasattr(cls, method_name): - # setattr(cls, method_name, _wrap_td_method(method_name, no_wrap=True)) + for method_name in _FALLBACK_METHOD_FROM_TD_NOWRAP: + if not hasattr(cls, method_name): + setattr(cls, method_name, _wrap_td_method(method_name, no_wrap=True)) for method_name in _FALLBACK_METHOD_FROM_TD_COPY: if not hasattr(cls, method_name): setattr( From 8fad24d32fff13efd6343deb9bccd99634b7919f Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Tue, 2 Jul 2024 11:52:17 +0100 Subject: [PATCH 28/28] amend --- tensordict/_td.py | 6 ++++-- tensordict/tensorclass.py | 14 +++++++------- 2 files changed, 11 insertions(+), 9 deletions(-) diff --git a/tensordict/_td.py b/tensordict/_td.py index 453acf705..4ed10212c 100644 --- a/tensordict/_td.py +++ b/tensordict/_td.py @@ -252,8 +252,8 @@ def __init__( ) self._batch_size = self._parse_batch_size(source, batch_size) # TODO: this breaks when stacking tensorclasses with dynamo - if not torch.compiler.is_dynamo_compiling(): - self.names = names + if not torch.compiler.is_dynamo_compiling(): + self.names = names for key, value in source.items(): self.set(key, value, non_blocking=sub_non_blocking) @@ -273,6 +273,8 @@ def _new_unsafe( lock: bool = False, nested: bool = True, ) -> TensorDict: + if torch.compiler.is_dynamo_compiling(): + return TensorDict(source, batch_size=batch_size, device=device, names=names, non_blocking=non_blocking, lock=lock) self = cls.__new__(cls) sub_non_blocking = False if device is not None: diff --git a/tensordict/tensorclass.py b/tensordict/tensorclass.py index 854613974..beafc5a8e 100644 --- a/tensordict/tensorclass.py +++ b/tensordict/tensorclass.py @@ -91,8 +91,7 @@ def __subclasscheck__(self, subclass): torch.gather: True, } # Methods to be executed from tensordict, any ref to self means 'tensorclass' -_METHOD_FROM_TD = [ -] +_METHOD_FROM_TD = [] # Methods to be executed from tensordict, any ref to self means 'self._tensordict' _FALLBACK_METHOD_FROM_TD_NOWRAP = [ "_check_unlock", @@ -633,11 +632,12 @@ def wrapper( super(type(self), self).__setattr__( "_tensordict", TensorDict._new_unsafe( - {}, - batch_size=torch.Size(batch_size), - device=device, - names=names, - )) + {}, + batch_size=torch.Size(batch_size), + device=device, + names=names, + ), + ) super(type(self), self).__setattr__("_non_tensordict", {}) super(type(self), self).__setattr__("_is_initialized", True)