From b46f79b86fd0d154afbcc4e2908eeee30cbaea5f Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Thu, 15 Feb 2024 17:16:39 +0000 Subject: [PATCH 01/32] [Feature] `from_modules` method for MOE / ensemble learning (#677) --- tensordict/base.py | 136 +++++++++++++++++++++++++++++++++++++++- test/test_tensordict.py | 37 +++++++++++ 2 files changed, 172 insertions(+), 1 deletion(-) diff --git a/tensordict/base.py b/tensordict/base.py index 51c734dd1..a0fcf9637 100644 --- a/tensordict/base.py +++ b/tensordict/base.py @@ -52,6 +52,7 @@ _td_fields, _unravel_key_to_tuple, as_decorator, + Buffer, cache, convert_ellipsis_to_idx, DeviceType, @@ -64,6 +65,7 @@ lock_blocked, NestedKey, prod, + set_lazy_legacy, TensorDictFuture, unravel_key, unravel_key_list, @@ -376,6 +378,7 @@ def from_module( """Copies the params and buffers of a module in a tensordict. Args: + module (nn.Module): the module to get the parameters from. as_module (bool, optional): if ``True``, a :class:`~tensordict.nn.TensorDictParams` instance will be returned which can be used to store parameters within a :class:`torch.nn.Module`. Defaults to ``False``. @@ -385,7 +388,7 @@ def from_module( module will be used and unflattened into a TensorDict with the tree structure of the model. Defaults to ``False``. .. note:: - This is particularily useful when state-dict hooks have to be + This is particularly useful when state-dict hooks have to be used. Examples: @@ -405,6 +408,137 @@ def from_module( """ ... + @classmethod + def from_modules( + cls, + *modules, + as_module: bool = False, + lock: bool = True, + use_state_dict: bool = False, + lazy_stack: bool = False, + ): + """Retrieves the parameters of several modules for ensebmle learning/feature of expects applications through vmap. + + Args: + modules (sequence of nn.Module): the modules to get the parameters from. + If the modules differ in their structure, a lazy stack is needed + (see the ``lazy_stack`` argument below). + + Keyword Args: + as_module (bool, optional): if ``True``, a :class:`~tensordict.nn.TensorDictParams` + instance will be returned which can be used to store parameters + within a :class:`torch.nn.Module`. Defaults to ``False``. + lock (bool, optional): if ``True``, the resulting tensordict will be locked. + Defaults to ``True``. + use_state_dict (bool, optional): if ``True``, the state-dict from the + module will be used and unflattened into a TensorDict with + the tree structure of the model. Defaults to ``False``. + .. note:: + This is particularly useful when state-dict hooks have to be + used. + lazy_stack (bool, optional): whether parameters should be densly or + lazily stacked. Defaults to ``False`` (dense stack). + + .. note:: ``lazy_stack`` and ``as_module`` are exclusive features. + + .. warning:: + There is a crucial difference between lazy and non-lazy outputs + in that non-lazy output will reinstantiate parameters with the + desired batch-size, while ``lazy_stack`` will just represent + the parameters as lazily stacked. This means that whilst the + original parameters can safely be passed to an optimizer + when ``lazy_stack=True``, the new parameters need to be passed + when it is set to ``True``. + + .. warning:: + Whilst it can be tempting to use a lazy stack to keep the + orignal parameter references, remember that lazy stack + perform a stack each time :meth:`~.get` is called. This will + require memory (N times the size of the parameters, more if a + graph is built) and time to be computed. + It also means that the optimizer(s) will contain more + parameters, and operations like :meth:`~torch.optim.Optimizer.step` + or :meth:`~torch.optim.Optimizer.zero_grad` will take longer + to be executed. In general, ``lazy_stack`` should be reserved + to very few use cases. + + Examples: + >>> from torch import nn + >>> from tensordict import TensorDict + >>> torch.manual_seed(0) + >>> empty_module = nn.Linear(3, 4, device="meta") + >>> n_models = 2 + >>> modules = [nn.Linear(3, 4) for _ in range(n_models)] + >>> params = TensorDict.from_modules(*modules) + >>> print(params) + TensorDict( + fields={ + bias: Parameter(shape=torch.Size([2, 4]), device=cpu, dtype=torch.float32, is_shared=False), + weight: Parameter(shape=torch.Size([2, 4, 3]), device=cpu, dtype=torch.float32, is_shared=False)}, + batch_size=torch.Size([2]), + device=None, + is_shared=False) + >>> # example of batch execution + >>> def exec_module(params, x): + ... with params.to_module(empty_module): + ... return empty_module(x) + >>> x = torch.randn(3) + >>> y = torch.vmap(exec_module, (0, None))(params, x) + >>> assert y.shape == (n_models, 4) + >>> # since lazy_stack = False, backprop leaves the original params untouched + >>> y.sum().backward() + >>> assert params["weight"].grad.norm() > 0 + >>> assert modules[0].weight.grad is None + + With ``lazy_stack=True``, things are slightly different: + + >>> params = TensorDict.from_modules(*modules, lazy_stack=True) + >>> print(params) + LazyStackedTensorDict( + fields={ + bias: Tensor(shape=torch.Size([2, 4]), device=cpu, dtype=torch.float32, is_shared=False), + weight: Tensor(shape=torch.Size([2, 4, 3]), device=cpu, dtype=torch.float32, is_shared=False)}, + exclusive_fields={ + }, + batch_size=torch.Size([2]), + device=None, + is_shared=False, + stack_dim=0) + >>> # example of batch execution + >>> y = torch.vmap(exec_module, (0, None))(params, x) + >>> assert y.shape == (n_models, 4) + >>> y.sum().backward() + >>> assert modules[0].weight.grad is not None + + + """ + param_list = [ + cls.from_module(module, use_state_dict=use_state_dict) for module in modules + ] + if lazy_stack: + from tensordict._lazy import LazyStackedTensorDict + + params = LazyStackedTensorDict.lazy_stack(param_list) + else: + with set_lazy_legacy(False), torch.no_grad(): + params = torch.stack(param_list) + + # Make sure params are params, buffers are buffers + def make_param(param, orig_param): + if isinstance(orig_param, nn.Parameter): + return nn.Parameter(param.detach(), orig_param.requires_grad) + return Buffer(param) + + params = params.apply(make_param, param_list[0]) + + if as_module: + from tensordict.nn import TensorDictParams + + params = TensorDictParams(params, no_convert=True) + if lock: + params.lock_() + return params + @as_decorator() def to_module( self, diff --git a/test/test_tensordict.py b/test/test_tensordict.py index ee404b2df..a9d3e1ea7 100644 --- a/test/test_tensordict.py +++ b/test/test_tensordict.py @@ -714,6 +714,43 @@ def remover(module, *args, **kwargs): sd = net.state_dict() assert_allclose_td(params_sd.flatten_keys("."), TensorDict(sd, [])) + @pytest.mark.parametrize("as_module", [False, True]) + @pytest.mark.parametrize("lazy_stack", [False, True]) + def test_from_modules(self, as_module, lazy_stack): + empty_module = nn.Linear(3, 4, device="meta") + modules = [nn.Linear(3, 4) for _ in range(3)] + if as_module and lazy_stack: + with pytest.raises(RuntimeError, match="within a TensorDictParams"): + params = TensorDict.from_modules( + *modules, as_module=as_module, lazy_stack=lazy_stack + ) + return + + params = TensorDict.from_modules( + *modules, as_module=as_module, lazy_stack=lazy_stack + ) + + def vmap_module(params, x): + with params.to_module(empty_module): + return empty_module(x) + + x = torch.zeros(3) + y = torch.vmap(vmap_module, (0, None))(params, x) + y.sum().backward() + if lazy_stack: + leaves = [] + + def get_leaf(leaf): + leaves.append(leaf) + + params.apply(get_leaf) + assert all(param.grad is not None for param in leaves) + assert all(param.grad is None for param in params.values(True, True)) + else: + for p in modules[0].parameters(): + assert p.grad is None + assert all(param.grad is not None for param in params.values(True, True)) + @pytest.mark.parametrize( "idx", [ From 22db333dcc15f48482d16889d865ca22cabf0ff6 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Sat, 17 Feb 2024 20:57:07 +0000 Subject: [PATCH 02/32] [BugFix, Refactor] More reliable Sequential.get_dist (#678) --- tensordict/nn/probabilistic.py | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/tensordict/nn/probabilistic.py b/tensordict/nn/probabilistic.py index bb422cc02..510610d9c 100644 --- a/tensordict/nn/probabilistic.py +++ b/tensordict/nn/probabilistic.py @@ -546,8 +546,22 @@ def log_prob( return self.module[-1].log_prob(tensordict_out) def build_dist_from_params(self, tensordict: TensorDictBase) -> D.Distribution: - """Construct a distribution from the input parameters. Other modules in the sequence are not evaluated.""" - return self.module[-1].get_dist(tensordict) + """Construct a distribution from the input parameters. Other modules in the sequence are not evaluated. + + This method will look for the last ProbabilisticTensorDictModule contained in the + sequence and use it to build the distribution. + + """ + dest_module = None + for module in reversed(list(self.modules())): + if isinstance(module, ProbabilisticTensorDictModule): + dest_module = module + break + if dest_module is None: + raise RuntimeError( + "Could not find any ProbabilisticTensorDictModule in the sequence." + ) + return dest_module.get_dist(tensordict) @dispatch(auto_batch_size=False) @set_skip_existing(None) From b43baa8689e417d7315456bc6de3eb6bcc6107fc Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 19 Feb 2024 10:17:05 +0000 Subject: [PATCH 03/32] [Refactor] Set lazy_legacy to False by default (#680) --- tensordict/_torch_func.py | 2 +- tensordict/base.py | 10 +++++----- tensordict/utils.py | 2 +- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/tensordict/_torch_func.py b/tensordict/_torch_func.py index 087b49534..041a6a491 100644 --- a/tensordict/_torch_func.py +++ b/tensordict/_torch_func.py @@ -376,7 +376,7 @@ def _stack( From v0.4 onward, a dense stack will be returned and to build a lazy stack, an explicit call to LazyStackedTensorDict.lazy_stack will be required. To silence this warning, choose one of the following options: -- set the LAZY_LEGACY_OP environment variable to 'False' (recommended) or 'True' depending on +- set the LAZY_LEGACY_OP environment variable to 'False' (recommended, default) or 'True' depending on the behaviour you want to use. Another way to achieve this is to call `tensordict.set_lazy_legacy(False).set()` at the beginning of your script. - set the decorator/context manager `tensordict.set_lazy_legacy(False)` (recommended) around diff --git a/tensordict/base.py b/tensordict/base.py index a0fcf9637..8954c0d10 100644 --- a/tensordict/base.py +++ b/tensordict/base.py @@ -840,7 +840,7 @@ def unsqueeze(self, *args, **kwargs): version of the unsqueezed tensordict. Up until v0.3 included, a lazy unsqueezed tensordict was returned. From v0.4 onward, a dense unsqueeze will be returned. To silence this warning, choose one of the following options: -- set the LAZY_LEGACY_OP environment variable to 'False' (recommended) or 'True' depending on +- set the LAZY_LEGACY_OP environment variable to 'False' (recommended, default) or 'True' depending on the behaviour you want to use. Another way to achieve this is to call `tensordict.set_lazy_legacy(False).set()` at the beginning of your script. - set the decorator/context manager `tensordict.set_lazy_legacy(False)` (recommended) around @@ -929,7 +929,7 @@ def squeeze(self, *args, **kwargs): version of the squeezed tensordict. Up until v0.3 included, a lazy squeezed tensordict was returned. From v0.4 onward, a dense squeeze will be returned. To silence this warning, choose one of the following options: -- set the LAZY_LEGACY_OP environment variable to 'False' (recommended) or 'True' depending on +- set the LAZY_LEGACY_OP environment variable to 'False' (recommended, default) or 'True' depending on the behaviour you want to use. Another way to achieve this is to call `tensordict.set_lazy_legacy(False).set()` at the beginning of your script. - set the decorator/context manager `tensordict.set_lazy_legacy(False)` (recommended) around @@ -1141,7 +1141,7 @@ def view( version of the viewed tensordict. Up until v0.3 included, a lazy view of the tensordict was returned. From v0.4 onward, a proper view will be returned. To silence this warning, choose one of the following options: -- set the LAZY_LEGACY_OP environment variable to 'False' (recommended) or 'True' depending on +- set the LAZY_LEGACY_OP environment variable to 'False' (recommended, default) or 'True' depending on the behaviour you want to use. Another way to achieve this is to call `tensordict.set_lazy_legacy(False).set()` at the beginning of your script. - set the decorator/context manager `tensordict.set_lazy_legacy(False)` (recommended) around @@ -1210,7 +1210,7 @@ def transpose(self, dim0, dim1): version of the transposed tensordict. Up until v0.3 included, a lazy transpose of the tensordict was returned. From v0.4 onward, a proper transpose will be returned. To silence this warning, choose one of the following options: -- set the LAZY_LEGACY_OP environment variable to 'False' (recommended) or 'True' depending on +- set the LAZY_LEGACY_OP environment variable to 'False' (recommended, default) or 'True' depending on the behaviour you want to use. Another way to achieve this is to call `tensordict.set_lazy_legacy(False).set()` at the beginning of your script. - set the decorator/context manager `tensordict.set_lazy_legacy(False)` (recommended) around @@ -1325,7 +1325,7 @@ def permute(self, *args, **kwargs): version of the permuted tensordict. Up until v0.3 included, a lazy permute of the tensordict was returned. From v0.4 onward, a proper permute will be returned. To silence this warning, choose one of the following options: -- set the LAZY_LEGACY_OP environment variable to 'False' (recommended) or 'True' depending on +- set the LAZY_LEGACY_OP environment variable to 'False' (recommended, default) or 'True' depending on the behaviour you want to use. Another way to achieve this is to call `tensordict.set_lazy_legacy(False).set()` at the beginning of your script. - set the decorator/context manager `tensordict.set_lazy_legacy(False)` (recommended) around diff --git a/tensordict/utils.py b/tensordict/utils.py index 1384b62e4..37321a357 100644 --- a/tensordict/utils.py +++ b/tensordict/utils.py @@ -1743,7 +1743,7 @@ def _getitem_batch_size(batch_size, index): # Lazy classes control (legacy feature) -_DEFAULT_LAZY_OP = True +_DEFAULT_LAZY_OP = False _LAZY_OP = os.environ.get("LAZY_LEGACY_OP", None) From 84857557a6e17f0f7b696dd7d7888471dd13e926 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 19 Feb 2024 14:17:46 +0000 Subject: [PATCH 04/32] [BugFix, Performance] Fewer imports at root (#682) --- tensordict/memmap.py | 14 +------------- tensordict/persistent.py | 25 ++++++++++++++++--------- 2 files changed, 17 insertions(+), 22 deletions(-) diff --git a/tensordict/memmap.py b/tensordict/memmap.py index ff948e1c7..c5b2489a9 100644 --- a/tensordict/memmap.py +++ b/tensordict/memmap.py @@ -22,20 +22,8 @@ from tensordict.utils import implement_for -from torch import distributed as dist - from torch.multiprocessing.reductions import ForkingPickler -try: - if dist.is_available(): - from torch.distributed._tensor.api import DTensor - else: - raise ImportError -except ImportError: - - class DTensor(torch.Tensor): # noqa: D101 - ... - class MemoryMappedTensor(torch.Tensor): """A Memory-mapped Tensor. @@ -204,7 +192,7 @@ def from_tensor( out.index = None out.parent_shape = input.shape if copy_data: - if isinstance(input, DTensor): + if hasattr(input, "full_tensor"): input = input.full_tensor() out.copy_(input) return out diff --git a/tensordict/persistent.py b/tensordict/persistent.py index 0c605074a..09c0fb176 100644 --- a/tensordict/persistent.py +++ b/tensordict/persistent.py @@ -46,14 +46,7 @@ ) from torch import multiprocessing as mp -H5_ERR = None -try: - import h5py - - _has_h5 = True -except ModuleNotFoundError as err: - H5_ERR = err - _has_h5 = False +_has_h5 = importlib.util.find_spec("h5py", None) is not None class _Visitor: @@ -151,7 +144,9 @@ def __init__( self._locked_tensordicts = [] self._lock_id = set() if not _has_h5: - raise ModuleNotFoundError("Could not load h5py.") from H5_ERR + raise ModuleNotFoundError("Could not load h5py.") + import h5py + super().__init__() self.filename = filename self.mode = mode @@ -213,6 +208,8 @@ def from_dict(cls, input_dict, filename, batch_size=None, device=None, **kwargs) A :class:`PersitentTensorDict` instance linked to the newly created file. """ + import h5py + file = h5py.File(filename, "w", locking=cls.LOCKING) _has_batch_size = True if batch_size is None: @@ -260,6 +257,8 @@ def _get_array(self, key, default=NO_DEFAULT): raise KeyError(f"key {key} not found in PersistentTensorDict {self}") def _process_array(self, key, array): + import h5py + if isinstance(array, (h5py.Dataset,)): if self.device is not None: device = self.device @@ -294,6 +293,8 @@ def get(self, key, default=NO_DEFAULT): def get_at( self, key: NestedKey, idx: IndexType, default: CompatibleType = NO_DEFAULT ) -> CompatibleType: + import h5py + array = self._get_array(key, default) if isinstance(array, (h5py.Dataset,)): if self.device is not None: @@ -339,6 +340,8 @@ def _get_metadata(self, key): This method avoids creating a tensor from scratch, and just reads the metadata of the array. """ + import h5py + array = self._get_array(key) if ( isinstance(array, (h5py.Dataset,)) @@ -1048,6 +1051,8 @@ def _set_metadata(self, orig_metadata_container: PersistentTensorDict): self._nested_tensordicts[key]._set_metadata(td) def _clone(self, recurse: bool = True, newfile=None) -> PersistentTensorDict: + import h5py + if recurse: # this should clone the h5 to a new location indicated by newfile if newfile is None: @@ -1106,6 +1111,8 @@ def __getstate__(self): return state def __setstate__(self, state): + import h5py + state["file"] = h5py.File( state["filename"], mode=state["mode"], locking=self.LOCKING ) From 3594dd0bfbebd5bc6d8dba84d9b3c92db03c1aef Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Tue, 20 Feb 2024 17:17:05 +0000 Subject: [PATCH 05/32] [BugFix] Fix lazy params init (#681) --- tensordict/_td.py | 38 +++++++++++++++----- tensordict/_torch_func.py | 74 ++++++++++++++++++++++++++++++++++++--- tensordict/base.py | 14 ++++++-- tensordict/utils.py | 58 ++++++++++++++++++++++++++++-- test/test_tensordict.py | 52 +++++++++++++++++++++++++-- 5 files changed, 217 insertions(+), 19 deletions(-) diff --git a/tensordict/_td.py b/tensordict/_td.py index df0660191..24e909dab 100644 --- a/tensordict/_td.py +++ b/tensordict/_td.py @@ -38,6 +38,9 @@ from tensordict.memmap import MemoryMappedTensor from tensordict.memmap_deprec import MemmapTensor as _MemmapTensor from tensordict.utils import ( + _add_batch_dim_pre_hook, + _BatchedUninitializedBuffer, + _BatchedUninitializedParameter, _clone_value, _expand_to_match_shape, _get_item, @@ -750,15 +753,26 @@ def make_result(): @cache # noqa: B019 def _add_batch_dim(self, *, in_dim, vmap_level): td = self + + def _add_batch_dim_wrapper(key, value): + if is_tensor_collection(value): + return value._add_batch_dim(in_dim=in_dim, vmap_level=vmap_level) + + if isinstance( + value, (_BatchedUninitializedParameter, _BatchedUninitializedBuffer) + ): + value.in_dim = in_dim + value.vmap_level = vmap_level + return value + return _add_batch_dim(value, in_dim, vmap_level) + out = TensorDict( - { - key: value._add_batch_dim(in_dim=in_dim, vmap_level=vmap_level) - if is_tensor_collection(value) - else _add_batch_dim(value, in_dim, vmap_level) - for key, value in td.items() - }, - batch_size=[b for i, b in enumerate(td.batch_size) if i != in_dim], + {key: _add_batch_dim_wrapper(key, value) for key, value in td.items()}, + batch_size=torch.Size( + [b for i, b in enumerate(td.batch_size) if i != in_dim] + ), names=[name for i, name in enumerate(td.names) if i != in_dim], + _run_checks=False, ) return out @@ -3040,7 +3054,7 @@ def __repr__(self): def _set_tensor_dict( # noqa: F811 module_dict, hooks, - module, + module: torch.nn.Module, name: str, tensor: torch.Tensor, inplace: bool, @@ -3066,6 +3080,14 @@ def _set_tensor_dict( # noqa: F811 if output is not None: tensor = output module_dict["_parameters"][name] = tensor + + if isinstance( + tensor, (_BatchedUninitializedParameter, _BatchedUninitializedBuffer) + ): + module.register_forward_pre_hook( + _add_batch_dim_pre_hook(), with_kwargs=True + ) + elif was_buffer and isinstance(tensor, torch.Tensor): module_dict["_buffers"][name] = tensor else: diff --git a/tensordict/_torch_func.py b/tensordict/_torch_func.py index 041a6a491..60eb9ec6a 100644 --- a/tensordict/_torch_func.py +++ b/tensordict/_torch_func.py @@ -11,7 +11,6 @@ from typing import Any, Callable, Sequence, TypeVar import torch - from tensordict._lazy import LazyStackedTensorDict from tensordict._td import TensorDict from tensordict.base import _is_leaf_nontensor, NO_DEFAULT, TensorDictBase @@ -23,10 +22,16 @@ lazy_legacy, set_lazy_legacy, ) -from torch import Tensor +from torch import _C, Tensor +from torch.nn.parameter import ( + UninitializedBuffer, + UninitializedParameter, + UninitializedTensorMixin, +) TD_HANDLED_FUNCTIONS: dict[Callable, Callable] = {} LAZY_TD_HANDLED_FUNCTIONS: dict[Callable, Callable] = {} +UNINIT_TENSOR_FUNCTIONS: dict[Callable, Callable] = {} T = TypeVar("T", bound="TensorDictBase") @@ -52,6 +57,19 @@ def decorator(func: Callable) -> Callable: return decorator +def implements_for_uninit_param( + torch_function: Callable, +) -> Callable[[Callable], Callable]: + """Register a torch function override for UninitializedTensorMixin.""" + + @functools.wraps(torch_function) + def decorator(func: Callable) -> Callable: + UNINIT_TENSOR_FUNCTIONS[torch_function] = func + return func + + return decorator + + @implements_for_td(torch.unbind) def _unbind(td: T, *args: Any, **kwargs: Any) -> tuple[T, ...]: return td.unbind(*args, **kwargs) @@ -351,7 +369,7 @@ def _stack( from tensordict.tensorclass import NonTensorData - if all(isinstance(tensordict, NonTensorData) for tensordict in list_of_tensordicts): + if all(isinstance(td, NonTensorData) for td in list_of_tensordicts): return NonTensorData._stack_non_tensor(list_of_tensordicts, dim=dim) batch_size = list_of_tensordicts[0].batch_size @@ -409,7 +427,9 @@ def _stack( tensor_shape = None for _tensordict in list_of_tensordicts: tensor = _tensordict._get_str(key, default=NO_DEFAULT) - if tensor_shape is None: + if isinstance(tensor, UninitializedTensorMixin): + pass + elif tensor_shape is None: tensor_shape = tensor.shape elif tensor.shape != tensor_shape: with set_lazy_legacy(True): @@ -498,3 +518,49 @@ def where(condition, input, other, *, out=None): "Cannot use a persistent tensordict as output of torch.where." ) return input.where(condition, other, out=out) + + +def __torch_function__( + cls, + func: Callable, + types: tuple[type, ...], + args: tuple[Any, ...] = (), + kwargs: dict[str, Any] | None = None, +) -> Callable: + if kwargs is None: + kwargs = {} + fnc_uninit = UNINIT_TENSOR_FUNCTIONS.get(func, None) + if fnc_uninit is not None: + return fnc_uninit(*args, **kwargs) + with _C.DisableTorchFunctionSubclass(): + return func(*args, **kwargs) + + +UninitializedTensorMixin.__torch_function__ = classmethod(__torch_function__) + + +@implements_for_uninit_param(torch.stack) +def _stack_uninit_params(list_of_params, dim=0, out=None): + if out is not None: + raise NotImplementedError + if dim > 0: + raise NotImplementedError + from tensordict.utils import ( + _BatchedUninitializedBuffer, + _BatchedUninitializedParameter, + ) + + if isinstance(list_of_params[0], UninitializedParameter): + out = _BatchedUninitializedParameter( + requires_grad=list_of_params[0].requires_grad, + device=list_of_params[0].device, + dtype=list_of_params[0].dtype, + ) + elif isinstance(list_of_params[0], UninitializedBuffer): + out = _BatchedUninitializedBuffer( + requires_grad=list_of_params[0].requires_grad, + device=list_of_params[0].device, + dtype=list_of_params[0].dtype, + ) + out.batch_size = torch.Size([len(list_of_params)]) + return out diff --git a/tensordict/base.py b/tensordict/base.py index 8954c0d10..9203b6ba3 100644 --- a/tensordict/base.py +++ b/tensordict/base.py @@ -71,6 +71,7 @@ unravel_key_list, ) from torch import distributed as dist, multiprocessing as mp, nn, Tensor +from torch.nn.parameter import UninitializedTensorMixin from torch.utils._pytree import tree_map # NO_DEFAULT is used as a placeholder whenever the default is not provided. @@ -518,6 +519,14 @@ def from_modules( if lazy_stack: from tensordict._lazy import LazyStackedTensorDict + for param in param_list: + if any( + isinstance(tensor, UninitializedTensorMixin) + for tensor in param.values(True, True) + ): + raise RuntimeError( + "lasy_stack=True is not compatible with lazy modules." + ) params = LazyStackedTensorDict.lazy_stack(param_list) else: with set_lazy_legacy(False), torch.no_grad(): @@ -525,12 +534,13 @@ def from_modules( # Make sure params are params, buffers are buffers def make_param(param, orig_param): + if isinstance(param, UninitializedTensorMixin): + return param if isinstance(orig_param, nn.Parameter): return nn.Parameter(param.detach(), orig_param.requires_grad) return Buffer(param) - params = params.apply(make_param, param_list[0]) - + params = params._fast_apply(make_param, param_list[0]) if as_module: from tensordict.nn import TensorDictParams diff --git a/tensordict/utils.py b/tensordict/utils.py index 37321a357..c744a7377 100644 --- a/tensordict/utils.py +++ b/tensordict/utils.py @@ -51,9 +51,15 @@ ) from torch import Tensor from torch._C import _disabled_torch_function_impl -from torch.nn.parameter import _ParameterMeta +from torch.nn.parameter import ( + _ParameterMeta, + UninitializedBuffer, + UninitializedParameter, + UninitializedTensorMixin, +) from torch.utils.data._utils.worker import _generate_state + if TYPE_CHECKING: from tensordict.memmap_deprec import MemmapTensor as _MemmapTensor from tensordict.tensordict import TensorDictBase @@ -559,7 +565,9 @@ def _ndimension(tensor: Tensor) -> int: def _shape(tensor: Tensor) -> torch.Size: - if not isinstance(tensor, Tensor): + if isinstance(tensor, UninitializedTensorMixin): + return torch.Size([*getattr(tensor, "batch_size", ()), -1]) + elif not isinstance(tensor, Tensor): if type(tensor) is KeyedJaggedTensor: return torch.Size([len(tensor.lengths()) // len(tensor.keys())]) return tensor.shape @@ -1424,12 +1432,17 @@ def _td_fields(td: T, keys=None) -> str: if isinstance(tensor, TensorDictBase): substr = _td_fields(tensor) else: + is_shared = ( + tensor.is_shared() + if not isinstance(tensor, UninitializedTensorMixin) + else None + ) substr = _get_repr_custom( tensor.__class__, shape=shape, device=tensor.device, dtype=tensor.dtype, - is_shared=tensor.is_shared(), + is_shared=is_shared, ) strs.append(f"{key}: {substr}") @@ -2101,3 +2114,42 @@ def __setstate__(self, ob: bytes): def __call__(self, *args, **kwargs): return self.fn(*args, **kwargs) + + +class _BatchedUninitializedParameter(UninitializedParameter): + batch_size: torch.Size + in_dim: int | None = None + vmap_level: int | None = None + + def materialize(self, shape, device=None, dtype=None): + UninitializedParameter.materialize( + self, (*self.batch_size, *shape), device=device, dtype=dtype + ) + + +class _BatchedUninitializedBuffer(UninitializedBuffer): + batch_size: torch.Size + in_dim: int | None = None + vmap_level: int | None = None + + def materialize(self, shape, device=None, dtype=None): + UninitializedBuffer.materialize( + self, (*self.batch_size, *shape), device=device, dtype=dtype + ) + + +class _add_batch_dim_pre_hook: + def __call__(self, mod: torch.nn.Module, args, kwargs): + for name, param in list(mod.named_parameters(recurse=False)): + if hasattr(param, "in_dim") and hasattr(param, "vmap_level"): + from torch._C._functorch import _add_batch_dim + + param = _add_batch_dim(param, param.in_dim, param.vmap_level) + delattr(mod, name) + setattr(mod, name, param) + for key, val in list(mod._forward_pre_hooks.items()): + if val is self: + del mod._forward_pre_hooks[key] + return + else: + raise RuntimeError("did not find pre-hook") diff --git a/test/test_tensordict.py b/test/test_tensordict.py index a9d3e1ea7..7c40eb5ca 100644 --- a/test/test_tensordict.py +++ b/test/test_tensordict.py @@ -50,6 +50,7 @@ set_lazy_legacy, ) from torch import multiprocessing as mp, nn +from torch.nn.parameter import UninitializedTensorMixin try: import torchsnapshot @@ -730,12 +731,12 @@ def test_from_modules(self, as_module, lazy_stack): *modules, as_module=as_module, lazy_stack=lazy_stack ) - def vmap_module(params, x): + def exec_module(params, x): with params.to_module(empty_module): return empty_module(x) x = torch.zeros(3) - y = torch.vmap(vmap_module, (0, None))(params, x) + y = torch.vmap(exec_module, (0, None))(params, x) y.sum().backward() if lazy_stack: leaves = [] @@ -751,6 +752,53 @@ def get_leaf(leaf): assert p.grad is None assert all(param.grad is not None for param in params.values(True, True)) + @pytest.mark.parametrize("as_module", [False, True]) + @pytest.mark.parametrize("lazy_stack", [False, True]) + @pytest.mark.parametrize("device", get_available_devices()) + def test_from_modules_lazy(self, as_module, lazy_stack, device): + empty_module = nn.LazyLinear(4, device="meta") + modules = [nn.LazyLinear(4, device=device) for _ in range(3)] + if lazy_stack: + with pytest.raises( + RuntimeError, + match="lasy_stack=True is not compatible with lazy modules.", + ): + params = TensorDict.from_modules( + *modules, as_module=as_module, lazy_stack=lazy_stack + ) + return + + params = TensorDict.from_modules( + *modules, as_module=as_module, lazy_stack=lazy_stack + ) + optim = torch.optim.Adam(list(params.values(True, True))) + + def exec_module(params, x): + with params.to_module(empty_module): + return empty_module(x) + + x = torch.zeros(3, device=device) + assert isinstance(params["weight"], UninitializedTensorMixin) + y = torch.vmap(exec_module, (0, None), randomness="same")(params, x) + assert params["weight"].shape == torch.Size((3, 4, 3)) + + y.sum().backward() + optim.step() + + if lazy_stack: + leaves = [] + + def get_leaf(leaf): + leaves.append(leaf) + + params.apply(get_leaf) + assert all(param.grad is not None for param in leaves) + assert all(param.grad is None for param in params.values(True, True)) + else: + for p in modules[0].parameters(): + assert p.grad is None + assert all(param.grad is not None for param in params.values(True, True)) + @pytest.mark.parametrize( "idx", [ From 8ea4e871036d656e9e2a0ecf9fd2e81aa3c1706d Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Tue, 20 Feb 2024 20:46:44 +0000 Subject: [PATCH 06/32] [BugFix] Fix torch_function for uninit param (#683) --- tensordict/_torch_func.py | 20 ++++++++++++++++++-- 1 file changed, 18 insertions(+), 2 deletions(-) diff --git a/tensordict/_torch_func.py b/tensordict/_torch_func.py index 60eb9ec6a..424de42bb 100644 --- a/tensordict/_torch_func.py +++ b/tensordict/_torch_func.py @@ -520,6 +520,10 @@ def where(condition, input, other, *, out=None): return input.where(condition, other, out=out) +# monkey patch +__prev_torch_function__ = UninitializedTensorMixin.__torch_function__ + + def __torch_function__( cls, func: Callable, @@ -532,8 +536,20 @@ def __torch_function__( fnc_uninit = UNINIT_TENSOR_FUNCTIONS.get(func, None) if fnc_uninit is not None: return fnc_uninit(*args, **kwargs) - with _C.DisableTorchFunctionSubclass(): - return func(*args, **kwargs) + if func in cls._allowed_methods or func.__class__.__name__ == "method-wrapper": + if kwargs is None: + kwargs = {} + with _C.DisableTorchFunctionSubclass(): + return func(*args, **kwargs) + # Ideally we'd like to use this from the original __torch_function__ + # return super().__torch_function__(func, types, args, kwargs) + raise ValueError( + f"Attempted to use an uninitialized parameter in {func}. " + "This error happens when you are using a `LazyModule` or " + f"explicitly manipulating `torch.nn.parameter.{cls.__name__}` " + "objects. When using LazyModules Call `forward` with a dummy batch " + "to initialize the parameters before calling torch functions" + ) UninitializedTensorMixin.__torch_function__ = classmethod(__torch_function__) From 50f457744fa706c945ea4d110d4671b6745f626d Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Wed, 21 Feb 2024 00:14:33 +0000 Subject: [PATCH 07/32] [BugFix] Remove monkey patching of uninit params (#684) --- tensordict/_torch_func.py | 66 +++++++-------------------------------- 1 file changed, 11 insertions(+), 55 deletions(-) diff --git a/tensordict/_torch_func.py b/tensordict/_torch_func.py index 424de42bb..3584b1d33 100644 --- a/tensordict/_torch_func.py +++ b/tensordict/_torch_func.py @@ -22,7 +22,7 @@ lazy_legacy, set_lazy_legacy, ) -from torch import _C, Tensor +from torch import Tensor from torch.nn.parameter import ( UninitializedBuffer, UninitializedParameter, @@ -31,7 +31,6 @@ TD_HANDLED_FUNCTIONS: dict[Callable, Callable] = {} LAZY_TD_HANDLED_FUNCTIONS: dict[Callable, Callable] = {} -UNINIT_TENSOR_FUNCTIONS: dict[Callable, Callable] = {} T = TypeVar("T", bound="TensorDictBase") @@ -57,19 +56,6 @@ def decorator(func: Callable) -> Callable: return decorator -def implements_for_uninit_param( - torch_function: Callable, -) -> Callable[[Callable], Callable]: - """Register a torch function override for UninitializedTensorMixin.""" - - @functools.wraps(torch_function) - def decorator(func: Callable) -> Callable: - UNINIT_TENSOR_FUNCTIONS[torch_function] = func - return func - - return decorator - - @implements_for_td(torch.unbind) def _unbind(td: T, *args: Any, **kwargs: Any) -> tuple[T, ...]: return td.unbind(*args, **kwargs) @@ -424,26 +410,32 @@ def _stack( out = {} for key in keys: out[key] = [] + is_lazy = False tensor_shape = None for _tensordict in list_of_tensordicts: tensor = _tensordict._get_str(key, default=NO_DEFAULT) if isinstance(tensor, UninitializedTensorMixin): - pass + is_lazy = True elif tensor_shape is None: tensor_shape = tensor.shape elif tensor.shape != tensor_shape: with set_lazy_legacy(True): return _stack(list_of_tensordicts, dim=dim) out[key].append(tensor) + out[key] = (out[key], is_lazy) - def stack_fn(key_values): - key, values = key_values + def stack_fn(key, values, is_lazy): + if is_lazy: + return _stack_uninit_params(values, dim) with _ErrorInteceptor( key, "Attempted to stack tensors on different devices at key" ): return torch.stack(values, dim) - out = {key: stack_fn((key, value)) for key, value in out.items()} + out = { + key: stack_fn(key, values, is_lazy) + for key, (values, is_lazy) in out.items() + } return TensorDict( out, @@ -520,42 +512,6 @@ def where(condition, input, other, *, out=None): return input.where(condition, other, out=out) -# monkey patch -__prev_torch_function__ = UninitializedTensorMixin.__torch_function__ - - -def __torch_function__( - cls, - func: Callable, - types: tuple[type, ...], - args: tuple[Any, ...] = (), - kwargs: dict[str, Any] | None = None, -) -> Callable: - if kwargs is None: - kwargs = {} - fnc_uninit = UNINIT_TENSOR_FUNCTIONS.get(func, None) - if fnc_uninit is not None: - return fnc_uninit(*args, **kwargs) - if func in cls._allowed_methods or func.__class__.__name__ == "method-wrapper": - if kwargs is None: - kwargs = {} - with _C.DisableTorchFunctionSubclass(): - return func(*args, **kwargs) - # Ideally we'd like to use this from the original __torch_function__ - # return super().__torch_function__(func, types, args, kwargs) - raise ValueError( - f"Attempted to use an uninitialized parameter in {func}. " - "This error happens when you are using a `LazyModule` or " - f"explicitly manipulating `torch.nn.parameter.{cls.__name__}` " - "objects. When using LazyModules Call `forward` with a dummy batch " - "to initialize the parameters before calling torch functions" - ) - - -UninitializedTensorMixin.__torch_function__ = classmethod(__torch_function__) - - -@implements_for_uninit_param(torch.stack) def _stack_uninit_params(list_of_params, dim=0, out=None): if out is not None: raise NotImplementedError From ff2c5b0d7baa4a77b83962928eafae3d9357b4af Mon Sep 17 00:00:00 2001 From: vmoens Date: Tue, 20 Feb 2024 19:48:51 -0800 Subject: [PATCH 08/32] init --- tensordict/_lazy.py | 4 ++++ tensordict/base.py | 3 ++- tensordict/tensorclass.py | 4 ++-- 3 files changed, 8 insertions(+), 3 deletions(-) diff --git a/tensordict/_lazy.py b/tensordict/_lazy.py index b0c6ee6cc..659b04cfe 100644 --- a/tensordict/_lazy.py +++ b/tensordict/_lazy.py @@ -2441,6 +2441,10 @@ def _unsqueeze(self, dim): _to_module = TensorDict._to_module +class StackNonTensor(LazyStackedTensorDict): + """A thin wrapper aroung LazyStackedTensorDict to make stack on non-tensor data easily recognizable.""" + pass + class _CustomOpTensorDict(TensorDictBase): """Encodes lazy operations on tensors contained in a TensorDict.""" diff --git a/tensordict/base.py b/tensordict/base.py index 9203b6ba3..52b83e089 100644 --- a/tensordict/base.py +++ b/tensordict/base.py @@ -5377,9 +5377,10 @@ def _default_is_leaf(cls: Type) -> bool: def _is_leaf_nontensor(cls: Type) -> bool: from tensordict.tensorclass import NonTensorData + from tensordict._lazy import StackNonTensor if issubclass(cls, KeyedJaggedTensor): return False if _is_tensor_collection(cls): - return issubclass(cls, NonTensorData) + return issubclass(cls, (NonTensorData, StackNonTensor)) return issubclass(cls, torch.Tensor) diff --git a/tensordict/tensorclass.py b/tensordict/tensorclass.py index d2b249cb5..81f16c598 100644 --- a/tensordict/tensorclass.py +++ b/tensordict/tensorclass.py @@ -1339,9 +1339,9 @@ def _check_equal(a, b): device=first.device, ) - from tensordict._lazy import LazyStackedTensorDict + from tensordict._lazy import StackNonTensor - return LazyStackedTensorDict(*list_of_non_tensor, stack_dim=dim) + return StackNonTensor(*list_of_non_tensor, stack_dim=dim) @classmethod def __torch_function__( From 40dd77349bd1e81010d41cbd1c3cce619805111e Mon Sep 17 00:00:00 2001 From: vmoens Date: Wed, 21 Feb 2024 08:21:14 -0800 Subject: [PATCH 09/32] amend --- tensordict/_lazy.py | 16 +++++++++++++--- tensordict/_td.py | 9 ++++++--- tensordict/_torch_func.py | 31 ++++++++++++++++++++++++++++++- tensordict/tensorclass.py | 12 +++++++++++- tensordict/utils.py | 25 +++++++++++++++++++++++++ 5 files changed, 85 insertions(+), 8 deletions(-) diff --git a/tensordict/_lazy.py b/tensordict/_lazy.py index 659b04cfe..1ff73948e 100644 --- a/tensordict/_lazy.py +++ b/tensordict/_lazy.py @@ -1473,7 +1473,7 @@ def __setitem__(self, index: IndexType, value: T) -> T: elif isinstance(index, (list, range)): index = torch.tensor(index, device=self.device) - if isinstance(value, (TensorDictBase, dict)): + if is_tensor_collection(value) or isinstance(value, dict): indexed_bs = _getitem_batch_size(self.batch_size, index) if isinstance(value, dict): value = TensorDict( @@ -1500,7 +1500,11 @@ def __setitem__(self, index: IndexType, value: T) -> T: if isinteger: # this will break if the index along the stack dim is [0] or :1 or smth for i, _idx in converted_idx.items(): - self.tensordicts[i][_idx] = value + if _idx != (): + self.tensordicts[i][_idx] = value + else: + self.tensordicts[i] = value + return self if is_nd_tensor: raise RuntimeError( @@ -2443,7 +2447,13 @@ def _unsqueeze(self, dim): class StackNonTensor(LazyStackedTensorDict): """A thin wrapper aroung LazyStackedTensorDict to make stack on non-tensor data easily recognizable.""" - pass + + def tolist(self): + if self.stack_dim == 0: + return [td.tolist() for td in self.tensordicts] + else: + return [td.tolist() for td in self.unbind(0)] + class _CustomOpTensorDict(TensorDictBase): """Encodes lazy operations on tensors contained in a TensorDict.""" diff --git a/tensordict/_td.py b/tensordict/_td.py index 24e909dab..c52a3d594 100644 --- a/tensordict/_td.py +++ b/tensordict/_td.py @@ -1538,7 +1538,9 @@ def _set_at_str(self, key, value, idx, *, validated): tensor_in = _sub_index(tensor_in, idx) tensor_in.copy_(value) else: - _set_item(tensor_in, idx, value, validated=validated) + tensor_out = _set_item(tensor_in, idx, value, validated=validated) + if tensor_in is not tensor_out: + self._set_str(key, tensor_out, validated=True, inplace=False) return self @@ -2366,10 +2368,11 @@ def _set_at_str(self, key, value, idx, *, validated): ) tensor_in = _sub_index(tensor_in, idx) tensor_in.copy_(value) + tensor_out = tensor_in else: - _set_item(tensor_in, idx, value, validated=validated) + tensor_out = _set_item(tensor_in, idx, value, validated=validated) # make sure that the value is updated - self._source._set_at_str(key, tensor_in, self.idx, validated=validated) + self._source._set_at_str(key, tensor_out, self.idx, validated=validated) return self def _set_at_tuple(self, key, value, idx, *, validated): diff --git a/tensordict/_torch_func.py b/tensordict/_torch_func.py index 3584b1d33..47ad18077 100644 --- a/tensordict/_torch_func.py +++ b/tensordict/_torch_func.py @@ -161,6 +161,32 @@ def _ones_like(td: T, **kwargs: Any) -> T: return td_clone +@implements_for_td(torch.rand_like) +def _rand_like(td: T, **kwargs: Any) -> T: + td_clone = td._fast_apply(lambda x: torch.rand_like(x)) + if "device" in kwargs: + td_clone = td_clone.to(kwargs.pop("device")) + if len(kwargs): + raise RuntimeError( + f"keyword arguments {list(kwargs.keys())} are not " + f"supported with full_like with TensorDict" + ) + return td_clone + + +@implements_for_td(torch.randn_like) +def _randn_like(td: T, **kwargs: Any) -> T: + td_clone = td._fast_apply(lambda x: torch.randn_like(x)) + if "device" in kwargs: + td_clone = td_clone.to(kwargs.pop("device")) + if len(kwargs): + raise RuntimeError( + f"keyword arguments {list(kwargs.keys())} are not " + f"supported with full_like with TensorDict" + ) + return td_clone + + @implements_for_td(torch.empty_like) def _empty_like(td: T, *args, **kwargs) -> T: try: @@ -353,9 +379,12 @@ def _stack( if not list_of_tensordicts: raise RuntimeError("list_of_tensordicts cannot be empty") + from tensordict._lazy import StackNonTensor from tensordict.tensorclass import NonTensorData - if all(isinstance(td, NonTensorData) for td in list_of_tensordicts): + if all( + isinstance(td, (NonTensorData, StackNonTensor)) for td in list_of_tensordicts + ): return NonTensorData._stack_non_tensor(list_of_tensordicts, dim=dim) batch_size = list_of_tensordicts[0].batch_size diff --git a/tensordict/tensorclass.py b/tensordict/tensorclass.py index 81f16c598..38003e1bc 100644 --- a/tensordict/tensorclass.py +++ b/tensordict/tensorclass.py @@ -56,6 +56,9 @@ torch.full_like, torch.zeros_like, torch.ones_like, + torch.rand_like, + torch.empty_like, + torch.randn_like, torch.clone, torch.squeeze, torch.unsqueeze, @@ -1329,7 +1332,9 @@ def _check_equal(a, b): iseq = False return iseq - if all(_check_equal(data.data, first.data) for data in list_of_non_tensor[1:]): + if all(isinstance(data, NonTensorData) for data in list_of_non_tensor) and all( + _check_equal(data.data, first.data) for data in list_of_non_tensor[1:] + ): batch_size = list(first.batch_size) batch_size.insert(dim, len(list_of_non_tensor)) return NonTensorData( @@ -1395,3 +1400,8 @@ def _fast_apply(self, *args, **kwargs): return _wrap_method(self, "_fast_apply", self._tensordict._fast_apply)( *args, **kwargs ) + + def tolist(self): + if not self.batch_size: + return self.data + return [ntd.tolist() for ntd in self.unbind(0)] diff --git a/tensordict/utils.py b/tensordict/utils.py index c744a7377..f572c3ff4 100644 --- a/tensordict/utils.py +++ b/tensordict/utils.py @@ -656,6 +656,31 @@ def _set_item(tensor: Tensor, index: IndexType, value: Tensor, *, validated) -> elif isinstance(tensor, KeyedJaggedTensor): tensor = setitem_keyedjaggedtensor(tensor, index, value) return tensor + from tensordict._lazy import StackNonTensor + from tensordict.tensorclass import NonTensorData + + if isinstance(tensor, (NonTensorData, StackNonTensor)): + if ( + isinstance(value, NonTensorData) + and isinstance(tensor, NonTensorData) + and tensor.data == value.data + ): + return tensor + if isinstance(index, tuple): + if len(index) == 1: + index = index[0] + else: + idx = index[0] + tensor_idx = tensor[idx] + tensor_idx = _set_item(tensor_idx, index[1:], value, validated=True) + tensor = _set_item(tensor, idx, tensor_idx, validated=True) + return tensor + if isinstance(tensor, NonTensorData): + tensor = StackNonTensor(*[tensor[0]] * tensor.shape[0], stack_dim=0) + elif tensor.stack_dim != 0: + tensor = StackNonTensor(*tensor.unbind(0), stack_dim=0) + tensor[index] = value + return tensor else: tensor[index] = value return tensor From 20dc288d6dda3822383faf6c6fe0aed431a05ffa Mon Sep 17 00:00:00 2001 From: vmoens Date: Wed, 21 Feb 2024 09:31:54 -0800 Subject: [PATCH 10/32] amend --- tensordict/_lazy.py | 5 ++++- tensordict/_td.py | 10 +++++++--- tensordict/base.py | 13 +++++++++++-- tensordict/tensorclass.py | 5 +++++ test/test_tensordict.py | 37 +++++++++++++++++++++++++++++++++++++ 5 files changed, 64 insertions(+), 6 deletions(-) diff --git a/tensordict/_lazy.py b/tensordict/_lazy.py index 1ff73948e..6ff283228 100644 --- a/tensordict/_lazy.py +++ b/tensordict/_lazy.py @@ -1518,7 +1518,10 @@ def __setitem__(self, index: IndexType, value: T) -> T: converted_idx.items(), value_unbind, ): - self.tensordicts[i][_idx] = _value + if _idx != (): + self.tensordicts[i][_idx] = _value + else: + self.tensordicts[i] = _value else: # we must split, not unbind mask_unbind = split_index["individual_masks"] diff --git a/tensordict/_td.py b/tensordict/_td.py index c52a3d594..74f099faa 100644 --- a/tensordict/_td.py +++ b/tensordict/_td.py @@ -308,9 +308,10 @@ def is_empty(self): if _is_tensor_collection(type(item)): if not item.is_empty(): return False + from tensordict._lazy import StackNonTensor from tensordict.tensorclass import NonTensorData - if isinstance(item, NonTensorData): + if isinstance(item, (NonTensorData, StackNonTensor)): return False else: return False @@ -2433,10 +2434,13 @@ def get( def _get_non_tensor(self, key: NestedKey, default=NO_DEFAULT): out = super()._get_non_tensor(key, default=default) + from tensordict._lazy import StackNonTensor from tensordict.tensorclass import NonTensorData - if isinstance(out, _SubTensorDict) and isinstance(out._source, NonTensorData): - return out._source.data + if isinstance(out, _SubTensorDict) and isinstance( + out._source, (NonTensorData, StackNonTensor) + ): + return out._source return out def _get_str(self, key, default): diff --git a/tensordict/base.py b/tensordict/base.py index 52b83e089..253836712 100644 --- a/tensordict/base.py +++ b/tensordict/base.py @@ -246,6 +246,10 @@ def __getitem__(self, index: IndexType) -> T: if isinstance(result, NonTensorData): return result.data + from ._lazy import StackNonTensor + + if isinstance(result, StackNonTensor): + return result.tolist() return result if (istuple and not index) or (not istuple and index is Ellipsis): @@ -2308,10 +2312,15 @@ def _get_non_tensor(self, key: NestedKey, default=NO_DEFAULT): return subtd return subtd._get_non_tensor(key[1:], default=default) value = self._get_str(key, default=default) - from tensordict.tensorclass import NonTensorData + + from .tensorclass import NonTensorData if isinstance(value, NonTensorData): return value.data + from ._lazy import StackNonTensor + + if isinstance(value, StackNonTensor): + return value.tolist() return value def filter_non_tensor_data(self) -> T: @@ -5376,8 +5385,8 @@ def _default_is_leaf(cls: Type) -> bool: def _is_leaf_nontensor(cls: Type) -> bool: - from tensordict.tensorclass import NonTensorData from tensordict._lazy import StackNonTensor + from tensordict.tensorclass import NonTensorData if issubclass(cls, KeyedJaggedTensor): return False diff --git a/tensordict/tensorclass.py b/tensordict/tensorclass.py index 38003e1bc..b7016bcff 100644 --- a/tensordict/tensorclass.py +++ b/tensordict/tensorclass.py @@ -1402,6 +1402,11 @@ def _fast_apply(self, *args, **kwargs): ) def tolist(self): + """Converts the data in a list if the batch-size is non-empty. + + If the batch-size is empty, returns the data. + + """ if not self.batch_size: return self.data return [ntd.tolist() for ntd in self.unbind(0)] diff --git a/test/test_tensordict.py b/test/test_tensordict.py index 7c40eb5ca..adb8fae65 100644 --- a/test/test_tensordict.py +++ b/test/test_tensordict.py @@ -7785,6 +7785,43 @@ def test_stack(self, non_tensor_data): LazyStackedTensorDict, ) + def test_assign_non_tensor(self): + data = TensorDict({}, [1, 10]) + + data[0, 0] = TensorDict({"a": 0, "b": "a string!"}, []) + + assert data["b"] == "a string!" + assert data.get("b").tolist() == [["a string!"] * 10] + data[0, 1] = TensorDict({"a": 0, "b": "another string!"}, []) + assert data.get("b").tolist() == [ + ["a string!"] + ["another string!"] + ["a string!"] * 8 + ] + + data = TensorDict({}, [1, 10]) + + data[0, 0] = TensorDict({"a": 0, "b": "a string!"}, []) + + data[0, 5:] = TensorDict({"a": torch.zeros(5), "b": "another string!"}, [5]) + assert data.get("b").tolist() == [["a string!"] * 5 + ["another string!"] * 5] + + data = TensorDict({}, [1, 10]) + + data[0, 0] = TensorDict({"a": 0, "b": "a string!"}, []) + + data[0, 0::2] = TensorDict( + {"a": torch.zeros(5, dtype=torch.long), "b": "another string!"}, [5] + ) + assert data.get("b").tolist() == [["another string!", "a string!"] * 5] + + data = TensorDict({}, [1, 10]) + + data[0, 0] = TensorDict({"a": 0, "b": "a string!"}, []) + + data[0] = TensorDict( + {"a": torch.zeros(10, dtype=torch.long), "b": "another string!"}, [10] + ) + assert data.get("b").tolist() == [["another string!"] * 10] + if __name__ == "__main__": args, unknown = argparse.ArgumentParser().parse_known_args() From 87aaa6eb3dbb84116437b5d4bcb7479841777c56 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Wed, 21 Feb 2024 18:10:33 +0000 Subject: [PATCH 11/32] [BugFix] Unlock td during update in to_module (#686) --- tensordict/_td.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tensordict/_td.py b/tensordict/_td.py index 24e909dab..0a825b88e 100644 --- a/tensordict/_td.py +++ b/tensordict/_td.py @@ -330,7 +330,8 @@ def _to_module( if not use_state_dict and isinstance(module, TensorDictBase): if return_swap: swap = module.copy() - module.update(self) + with module.unlock_(): + module.update(self) return swap else: module.update(self) From 0f44acc2d3a37135ccd73ef93c38107264fdb96d Mon Sep 17 00:00:00 2001 From: vmoens Date: Wed, 21 Feb 2024 14:08:14 -0800 Subject: [PATCH 12/32] amend --- docs/source/reference/tensorclass.rst | 1 + tensordict/__init__.py | 3 +- tensordict/_lazy.py | 38 +++++++----------- tensordict/_td.py | 17 ++++---- tensordict/_torch_func.py | 5 +-- tensordict/base.py | 28 ++++++++++---- tensordict/tensorclass.py | 56 ++++++++++++++++++++++----- tensordict/utils.py | 9 ++--- 8 files changed, 100 insertions(+), 57 deletions(-) diff --git a/docs/source/reference/tensorclass.rst b/docs/source/reference/tensorclass.rst index 17518df1d..8e6e4e907 100644 --- a/docs/source/reference/tensorclass.rst +++ b/docs/source/reference/tensorclass.rst @@ -273,3 +273,4 @@ Here is an example: tensorclass NonTensorData + NonTensorStack diff --git a/tensordict/__init__.py b/tensordict/__init__.py index c71661ceb..5e6dc8761 100644 --- a/tensordict/__init__.py +++ b/tensordict/__init__.py @@ -16,7 +16,7 @@ from tensordict.memmap import MemoryMappedTensor from tensordict.memmap_deprec import is_memmap, MemmapTensor, set_transfer_ownership from tensordict.persistent import PersistentTensorDict -from tensordict.tensorclass import NonTensorData, tensorclass +from tensordict.tensorclass import NonTensorData, NonTensorStack, tensorclass from tensordict.utils import ( assert_allclose_td, is_batchedtensor, @@ -43,6 +43,7 @@ "TensorDict", "TensorDictBase", "merge_tensordicts", + "NonTensorStack", "set_transfer_ownership", "pad_sequence", "is_memmap", diff --git a/tensordict/_lazy.py b/tensordict/_lazy.py index 6ff283228..8059b6819 100644 --- a/tensordict/_lazy.py +++ b/tensordict/_lazy.py @@ -718,7 +718,7 @@ def _legacy_unsqueeze(self, dim: int) -> T: else: dim = dim - 1 stack_dim = self.stack_dim - return LazyStackedTensorDict( + return type(self)( *(tensordict.unsqueeze(dim) for tensordict in self.tensordicts), stack_dim=stack_dim, ) @@ -756,7 +756,7 @@ def _legacy_squeeze(self, dim: int | None = None) -> T: else: dim = dim - 1 stack_dim = self.stack_dim - return LazyStackedTensorDict( + return type(self)( *(tensordict.squeeze(dim) for tensordict in self.tensordicts), stack_dim=stack_dim, ) @@ -1236,7 +1236,7 @@ def contiguous(self) -> T: return out def empty(self, recurse=False) -> T: - return LazyStackedTensorDict( + return type(self)( *[td.empty(recurse=recurse) for td in self.tensordicts], stack_dim=self.stack_dim, ) @@ -1245,12 +1245,12 @@ def _clone(self, recurse: bool = True) -> T: if recurse: # This could be optimized using copy but we must be careful with # metadata (_is_shared etc) - result = LazyStackedTensorDict( + result = type(self)( *[td._clone() for td in self.tensordicts], stack_dim=self.stack_dim, ) else: - result = LazyStackedTensorDict( + result = type(self)( *[td._clone(recurse=False) for td in self.tensordicts], stack_dim=self.stack_dim, ) @@ -1274,7 +1274,7 @@ def to(self, *args, **kwargs) -> T: if device is not None and dtype is None and device == self.device: return result - return LazyStackedTensorDict( + return type(self)( *[td.to(*args, **kwargs) for td in self.tensordicts], stack_dim=self.stack_dim, hook_out=self.hook_out, @@ -1403,7 +1403,7 @@ def _apply_nest( if filter_empty and all(r is None for r in results): return if not inplace: - out = LazyStackedTensorDict( + out = type(self)( *results, stack_dim=self.stack_dim, ) @@ -1429,7 +1429,7 @@ def _select( ] if inplace: return self - result = LazyStackedTensorDict(*tensordicts, stack_dim=self.stack_dim) + result = type(self)(*tensordicts, stack_dim=self.stack_dim) return result def _exclude( @@ -1442,7 +1442,7 @@ def _exclude( if inplace: self.tensordicts = tensordicts return self - result = LazyStackedTensorDict(*tensordicts, stack_dim=self.stack_dim) + result = type(self)(*tensordicts, stack_dim=self.stack_dim) return result def __setitem__(self, index: IndexType, value: T) -> T: @@ -2336,9 +2336,9 @@ def _transpose(self, dim0, dim1): # example: shape = [5, 4, 3, 2, 1], stack_dim=1, dim0=1, dim1=4 # resulting shape: [5, 1, 3, 2, 4] if dim1 == dim0 + 1: - result = LazyStackedTensorDict(*self.tensordicts, stack_dim=dim1) + result = type(self)(*self.tensordicts, stack_dim=dim1) else: - result = LazyStackedTensorDict( + result = type(self)( *(td.transpose(dim0, dim1 - 1) for td in self.tensordicts), stack_dim=dim1, ) @@ -2346,16 +2346,16 @@ def _transpose(self, dim0, dim1): # example: shape = [5, 4, 3, 2, 1], stack_dim=3, dim0=1, dim1=3 # resulting shape: [5, 2, 3, 4, 1] if dim0 + 1 == dim1: - result = LazyStackedTensorDict(*self.tensordicts, stack_dim=dim0) + result = type(self)(*self.tensordicts, stack_dim=dim0) else: - result = LazyStackedTensorDict( + result = type(self)( *(td.transpose(dim0 + 1, dim1) for td in self.tensordicts), stack_dim=dim0, ) else: dim0 = dim0 if dim0 < self.stack_dim else dim0 - 1 dim1 = dim1 if dim1 < self.stack_dim else dim1 - 1 - result = LazyStackedTensorDict( + result = type(self)( *(td.transpose(dim0, dim1) for td in self.tensordicts), stack_dim=self.stack_dim, ) @@ -2448,16 +2448,6 @@ def _unsqueeze(self, dim): _to_module = TensorDict._to_module -class StackNonTensor(LazyStackedTensorDict): - """A thin wrapper aroung LazyStackedTensorDict to make stack on non-tensor data easily recognizable.""" - - def tolist(self): - if self.stack_dim == 0: - return [td.tolist() for td in self.tensordicts] - else: - return [td.tolist() for td in self.unbind(0)] - - class _CustomOpTensorDict(TensorDictBase): """Encodes lazy operations on tensors contained in a TensorDict.""" diff --git a/tensordict/_td.py b/tensordict/_td.py index 74f099faa..43562aa89 100644 --- a/tensordict/_td.py +++ b/tensordict/_td.py @@ -29,6 +29,7 @@ _register_tensor_class, BEST_ATTEMPT_INPLACE, CompatibleType, + is_non_tensor, is_tensor_collection, NO_DEFAULT, T, @@ -308,10 +309,9 @@ def is_empty(self): if _is_tensor_collection(type(item)): if not item.is_empty(): return False - from tensordict._lazy import StackNonTensor - from tensordict.tensorclass import NonTensorData + from tensordict.tensorclass import NonTensorData, NonTensorStack - if isinstance(item, (NonTensorData, StackNonTensor)): + if isinstance(item, (NonTensorData, NonTensorStack)): return False else: return False @@ -680,7 +680,11 @@ def make_result(): any_set = False for key, item in self.items(): - if not call_on_nested and _is_tensor_collection(item.__class__): + if ( + not call_on_nested + and _is_tensor_collection(item.__class__) + and not is_non_tensor(item) + ): if default is not NO_DEFAULT: _others = [_other._get_str(key, default=None) for _other in others] _others = [ @@ -2434,11 +2438,10 @@ def get( def _get_non_tensor(self, key: NestedKey, default=NO_DEFAULT): out = super()._get_non_tensor(key, default=default) - from tensordict._lazy import StackNonTensor - from tensordict.tensorclass import NonTensorData + from tensordict.tensorclass import NonTensorData, NonTensorStack if isinstance(out, _SubTensorDict) and isinstance( - out._source, (NonTensorData, StackNonTensor) + out._source, (NonTensorData, NonTensorStack) ): return out._source return out diff --git a/tensordict/_torch_func.py b/tensordict/_torch_func.py index 47ad18077..d27cbb1d7 100644 --- a/tensordict/_torch_func.py +++ b/tensordict/_torch_func.py @@ -379,11 +379,10 @@ def _stack( if not list_of_tensordicts: raise RuntimeError("list_of_tensordicts cannot be empty") - from tensordict._lazy import StackNonTensor - from tensordict.tensorclass import NonTensorData + from tensordict.tensorclass import NonTensorData, NonTensorStack if all( - isinstance(td, (NonTensorData, StackNonTensor)) for td in list_of_tensordicts + isinstance(td, (NonTensorData, NonTensorStack)) for td in list_of_tensordicts ): return NonTensorData._stack_non_tensor(list_of_tensordicts, dim=dim) diff --git a/tensordict/base.py b/tensordict/base.py index 253836712..3fe1ac63b 100644 --- a/tensordict/base.py +++ b/tensordict/base.py @@ -246,9 +246,9 @@ def __getitem__(self, index: IndexType) -> T: if isinstance(result, NonTensorData): return result.data - from ._lazy import StackNonTensor + from tensordict.tensorclass import NonTensorStack - if isinstance(result, StackNonTensor): + if isinstance(result, NonTensorStack): return result.tolist() return result @@ -1659,6 +1659,14 @@ def cuda(self, device: int = None) -> T: return self.to(torch.device("cuda")) return self.to(f"cuda:{device}") + @property + def is_cuda(self): + return self.device is not None and self.device.type == "cuda" + + @property + def is_cpu(self): + return self.device is not None and self.device.type == "cpu" + # Serialization functionality def state_dict( self, @@ -2317,9 +2325,9 @@ def _get_non_tensor(self, key: NestedKey, default=NO_DEFAULT): if isinstance(value, NonTensorData): return value.data - from ._lazy import StackNonTensor + from tensordict.tensorclass import NonTensorStack - if isinstance(value, StackNonTensor): + if isinstance(value, NonTensorStack): return value.tolist() return value @@ -5380,16 +5388,22 @@ def is_tensor_collection(datatype: type | Any) -> bool: return _is_tensor_collection(datatype) +def is_non_tensor(data): + """Checks if an item is a non-tensor.""" + from tensordict.tensorclass import NonTensorData, NonTensorStack + + return isinstance(data, (NonTensorData, NonTensorStack)) + + def _default_is_leaf(cls: Type) -> bool: return not _is_tensor_collection(cls) def _is_leaf_nontensor(cls: Type) -> bool: - from tensordict._lazy import StackNonTensor - from tensordict.tensorclass import NonTensorData + from tensordict.tensorclass import NonTensorData, NonTensorStack if issubclass(cls, KeyedJaggedTensor): return False if _is_tensor_collection(cls): - return issubclass(cls, (NonTensorData, StackNonTensor)) + return issubclass(cls, (NonTensorData, NonTensorStack)) return issubclass(cls, torch.Tensor) diff --git a/tensordict/tensorclass.py b/tensordict/tensorclass.py index b7016bcff..90b00ecb8 100644 --- a/tensordict/tensorclass.py +++ b/tensordict/tensorclass.py @@ -15,7 +15,7 @@ import re import sys import warnings -from copy import copy +from copy import copy, deepcopy from dataclasses import dataclass from pathlib import Path from textwrap import indent @@ -24,6 +24,7 @@ import tensordict as tensordict_lib import torch +from tensordict import LazyStackedTensorDict from tensordict._td import is_tensor_collection, NO_DEFAULT, TensorDict, TensorDictBase from tensordict._tensordict import _unravel_key_to_tuple from tensordict._torch_func import TD_HANDLED_FUNCTIONS @@ -475,7 +476,7 @@ def wrapper(self, item: str) -> Any: return wrapper -SET_ATTRIBUTES = ("batch_size", "device", "_locked_tensordicts") +SET_ATTRIBUTES = ("batch_size", "device", "_locked_tensordicts", "names") def _setattr_wrapper(setattr_: Callable, expected_keys: set[str]) -> Callable: @@ -489,12 +490,10 @@ def wrapper(self, key: str, value: Any) -> None: # noqa: D417 """ __dict__ = self.__dict__ - if ( - "_tensordict" not in __dict__ - or "_non_tensordict" not in __dict__ - or key in SET_ATTRIBUTES - ): + if "_tensordict" not in __dict__ or "_non_tensordict" not in __dict__: return setattr_(self, key, value) + if key in SET_ATTRIBUTES: + return setattr(self._tensordict, key, value) out = self.set(key, value) if out is not self: @@ -714,6 +713,9 @@ def _set(self, key: NestedKey, value: Any, inplace: bool = False): __dict__ = self.__dict__ if __dict__["_tensordict"].is_locked: raise RuntimeError(_LOCK_ERROR) + if key in ("batch_size", "names", "device"): + # handled by setattr + return expected_keys = self.__dataclass_fields__ if key not in expected_keys: raise AttributeError( @@ -1344,9 +1346,7 @@ def _check_equal(a, b): device=first.device, ) - from tensordict._lazy import StackNonTensor - - return StackNonTensor(*list_of_non_tensor, stack_dim=dim) + return NonTensorStack(*list_of_non_tensor, stack_dim=dim) @classmethod def __torch_function__( @@ -1410,3 +1410,39 @@ def tolist(self): if not self.batch_size: return self.data return [ntd.tolist() for ntd in self.unbind(0)] + + def copy_(self, src: NonTensorData | NonTensorStack, non_blocking: bool = False): + if isinstance(src, NonTensorStack): + raise RuntimeError( + "Cannot update a NonTensorData with a NonTensorStack object." + ) + if not isinstance(src, NonTensorData): + raise RuntimeError( + "NonTensorData.copy_ requires the source to be a NonTensorData object." + ) + self._non_tensordict["data"] = src.data + + def clone(self, recurse: bool = True): + if recurse: + return type(self)( + data=deepcopy(self.data), + batch_size=self.batch_size, + device=self.device, + names=self.names, + ) + return type(self)( + data=self.data, + batch_size=self.batch_size, + device=self.device, + names=self.names, + ) + + +class NonTensorStack(LazyStackedTensorDict): + """A thin wrapper aroung LazyStackedTensorDict to make stack on non-tensor data easily recognizable.""" + + def tolist(self): + if self.stack_dim == 0: + return [td.tolist() for td in self.tensordicts] + else: + return [td.tolist() for td in self.unbind(0)] diff --git a/tensordict/utils.py b/tensordict/utils.py index f572c3ff4..b02e367c8 100644 --- a/tensordict/utils.py +++ b/tensordict/utils.py @@ -656,10 +656,9 @@ def _set_item(tensor: Tensor, index: IndexType, value: Tensor, *, validated) -> elif isinstance(tensor, KeyedJaggedTensor): tensor = setitem_keyedjaggedtensor(tensor, index, value) return tensor - from tensordict._lazy import StackNonTensor - from tensordict.tensorclass import NonTensorData + from tensordict.tensorclass import NonTensorData, NonTensorStack - if isinstance(tensor, (NonTensorData, StackNonTensor)): + if isinstance(tensor, (NonTensorData, NonTensorStack)): if ( isinstance(value, NonTensorData) and isinstance(tensor, NonTensorData) @@ -676,9 +675,9 @@ def _set_item(tensor: Tensor, index: IndexType, value: Tensor, *, validated) -> tensor = _set_item(tensor, idx, tensor_idx, validated=True) return tensor if isinstance(tensor, NonTensorData): - tensor = StackNonTensor(*[tensor[0]] * tensor.shape[0], stack_dim=0) + tensor = NonTensorStack(*[tensor[0]] * tensor.shape[0], stack_dim=0) elif tensor.stack_dim != 0: - tensor = StackNonTensor(*tensor.unbind(0), stack_dim=0) + tensor = NonTensorStack(*tensor.unbind(0), stack_dim=0) tensor[index] = value return tensor else: From 7104a0b466d1b23be4b16025090d129df9575a61 Mon Sep 17 00:00:00 2001 From: vmoens Date: Wed, 21 Feb 2024 16:32:57 -0800 Subject: [PATCH 13/32] amend --- tensordict/nn/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensordict/nn/utils.py b/tensordict/nn/utils.py index 3e1d698bd..105be6538 100644 --- a/tensordict/nn/utils.py +++ b/tensordict/nn/utils.py @@ -14,7 +14,7 @@ import torch from torch import nn -AUTO_MAKE_FUNCTIONAL = strtobool(os.environ.get("AUTO_MAKE_FUNCTIONAL", "True")) +AUTO_MAKE_FUNCTIONAL = strtobool(os.environ.get("AUTO_MAKE_FUNCTIONAL", "False")) DISPATCH_TDNN_MODULES = strtobool(os.environ.get("DISPATCH_TDNN_MODULES", "True")) From 79b0c018c8fce85c416dd957580de22c35bf96d8 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Thu, 22 Feb 2024 19:57:35 +0000 Subject: [PATCH 14/32] [BugFix] Fix singleton dims in expand_as_right (#687) --- tensordict/utils.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tensordict/utils.py b/tensordict/utils.py index c744a7377..e72163b1b 100644 --- a/tensordict/utils.py +++ b/tensordict/utils.py @@ -305,7 +305,10 @@ def expand_as_right( f" tensor.ndimension()={tensor.ndimension()} and " f"dest.ndimension()={dest.ndimension()}" ) - if not (tensor.shape == dest.shape[: tensor.ndimension()]): + if any( + tensor.shape[i] != dest.shape[i] and tensor.shape[i] != 1 + for i in range(tensor.ndimension()) + ): raise RuntimeError( f"tensor shape is incompatible with dest shape, " f"got: tensor.shape={tensor.shape}, dest={dest.shape}" From 7fae5aeb4367635171b1bb65c7233b7b30f91883 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Fri, 23 Feb 2024 19:44:15 +0000 Subject: [PATCH 15/32] [BugFix] Fix to_module batch-size mismatch (#688) --- tensordict/_td.py | 13 +++++++++++-- test/test_nn.py | 18 ++++++++++++++---- 2 files changed, 25 insertions(+), 6 deletions(-) diff --git a/tensordict/_td.py b/tensordict/_td.py index 0a825b88e..331ed4234 100644 --- a/tensordict/_td.py +++ b/tensordict/_td.py @@ -330,8 +330,7 @@ def _to_module( if not use_state_dict and isinstance(module, TensorDictBase): if return_swap: swap = module.copy() - with module.unlock_(): - module.update(self) + module._param_td = getattr(self, "_param_td", self) return swap else: module.update(self) @@ -407,7 +406,17 @@ def convert_type(x, y): continue child = __dict__["_modules"][key] local_out = memo.get(id(child), NO_DEFAULT) + if local_out is NO_DEFAULT: + # if isinstance(child, TensorDictBase): + # # then child is a TensorDictParams + # from tensordict.nn import TensorDictParams + # + # local_out = child + # if not isinstance(value, TensorDictParams): + # value = TensorDictParams(value, no_convert=True) + # __dict__["_modules"][key] = value + # else: local_out = value._to_module( child, inplace=inplace, diff --git a/test/test_nn.py b/test/test_nn.py index fd5aefeff..69c7b2fcf 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -2972,7 +2972,8 @@ def test_tdparams_clone_tying(self): td_clone = td.clone() assert td_clone["c"] is td_clone["a", "b", "c"] - def test_func_on_tdparams(self): + @pytest.mark.parametrize("with_batch", [False, True]) + def test_func_on_tdparams(self, with_batch): # tdparams isn't represented in a nested way, so we must check that calling to_module on it works ok net = nn.Sequential( nn.Linear(2, 2), @@ -2987,9 +2988,13 @@ def test_func_on_tdparams(self): ), ) - params = TensorDict.from_module(net, as_module=True) + if with_batch: + params = TensorDict.from_modules(net, net, as_module=True) + params0 = params[0].expand(3).clone().apply(lambda x: x.data * 0) + else: + params = TensorDict.from_module(net, as_module=True) + params0 = params.apply(lambda x: x.data * 0) - params0 = params.apply(lambda x: x.data * 0) assert (params0 == 0).all() with params0.to_module(params): assert (params == 0).all() @@ -3002,7 +3007,12 @@ class MyModule(nn.Module): m = MyModule() m.params = params params_m = TensorDict.from_module(m, as_module=True) - params_m0 = params_m.apply(lambda x: x.data * 0) + if with_batch: + params_m0 = params_m.clone() + params_m0["params"] = params_m0["params"][0].expand(3).clone() + else: + params_m0 = params_m + params_m0 = params_m0.apply(lambda x: x.data * 0) assert (params_m0 == 0).all() with params_m0.to_module(m): assert (params_m == 0).all() From 003be12f7fae60c883c1d00e2aef2db7a4f8a138 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Fri, 23 Feb 2024 23:47:32 +0000 Subject: [PATCH 16/32] [BugFix] Fix load_state_dict for TensorDictParams (#689) --- tensordict/nn/params.py | 23 +++++++++++------------ test/test_nn.py | 36 ++++++++++++++++++++++++++++++++++++ 2 files changed, 47 insertions(+), 12 deletions(-) diff --git a/tensordict/nn/params.py b/tensordict/nn/params.py index f51f1a1ca..81a124c81 100644 --- a/tensordict/nn/params.py +++ b/tensordict/nn/params.py @@ -990,18 +990,17 @@ def _load_from_state_dict( unexpected_keys, error_msgs, ): - data = ( - TensorDict( - { - key: val - for key, val in state_dict.items() - if key.startswith(prefix) and val is not None - }, - [], - ) - .unflatten_keys(".") - .get(prefix[:-1]) - ) + data = TensorDict( + { + key: val + for key, val in state_dict.items() + if key.startswith(prefix) and val is not None + }, + [], + ).unflatten_keys(".") + prefix = tuple(key for key in prefix.split(".") if key) + if prefix: + data = data.get(prefix) self.data.load_state_dict(data) def items( diff --git a/test/test_nn.py b/test/test_nn.py index 69c7b2fcf..37d5975b0 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -8,6 +8,7 @@ import pickle import unittest import warnings +import weakref import pytest import torch @@ -3018,6 +3019,41 @@ class MyModule(nn.Module): assert (params_m == 0).all() assert not (params_m == 0).all() + def test_load_state_dict(self): + net = nn.Sequential( + nn.Linear(2, 2), + nn.Sequential( + nn.Linear(2, 2), + nn.Dropout(), + nn.BatchNorm1d(2), + nn.Sequential( + nn.Tanh(), + nn.Linear(2, 2), + ), + ), + ) + + params = TensorDict.from_module(net, as_module=True) + assert any(isinstance(p, nn.Parameter) for p in params.values(True, True)) + weakrefs = {weakref.ref(t) for t in params.values(True, True)} + + # Now with a module around it + class MyModule(nn.Module): + pass + + module = MyModule() + module.model = MyModule() + module.model.params = params + sd = module.state_dict() + sd = { + key: val * 0 if isinstance(val, torch.Tensor) else val + for key, val in sd.items() + } + module.load_state_dict(sd) + assert (params == 0).all() + assert any(isinstance(p, nn.Parameter) for p in params.values(True, True)) + assert weakrefs == {weakref.ref(t) for t in params.values(True, True)} + def test_inplace_ops(self): td = TensorDict( { From f09f20dd586ee1749d244e5ae42ee5cdc992bd77 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Sat, 24 Feb 2024 02:04:26 +0000 Subject: [PATCH 17/32] [BugFix] Fix name gathering with tensor indices (#690) --- tensordict/_td.py | 17 +++++++++++++---- tensordict/_torch_func.py | 6 ++++-- 2 files changed, 17 insertions(+), 6 deletions(-) diff --git a/tensordict/_td.py b/tensordict/_td.py index 331ed4234..c3b5d85dd 100644 --- a/tensordict/_td.py +++ b/tensordict/_td.py @@ -829,12 +829,14 @@ def _index_tensordict( raise RuntimeError( f"indexing a tensordict with td.batch_dims==0 is not permitted. Got index {index}." ) - if names is None: - names = self._get_names_idx(index) if new_batch_size is not None: batch_size = new_batch_size else: batch_size = _getitem_batch_size(batch_size, index) + + if names is None: + names = self._get_names_idx(index) + source = {} for key, item in self.items(): if isinstance(item, TensorDict): @@ -1371,14 +1373,20 @@ def is_boolean(idx): # this will convert a [None, :, :, 0, None, 0] in [None, 0, 1, None, 3] count = 0 idx_to_take = [] + no_more_tensors = False for _idx in idx_names: if _idx is None: idx_to_take.append(None) elif _is_number(_idx): count += 1 elif isinstance(_idx, (torch.Tensor, np.ndarray)): - idx_to_take.extend([count] * _idx.ndim) - count += 1 + if not no_more_tensors: + idx_to_take.extend([count] * _idx.ndim) + count += 1 + no_more_tensors = True + else: + # skip this one + count += 1 else: idx_to_take.append(count) count += 1 @@ -1392,6 +1400,7 @@ def names(self, value): self._rename_subtds(value) self._erase_names() return + value = list(value) num_none = sum(v is None for v in value) if num_none: num_none -= 1 diff --git a/tensordict/_torch_func.py b/tensordict/_torch_func.py index 3584b1d33..b960ea843 100644 --- a/tensordict/_torch_func.py +++ b/tensordict/_torch_func.py @@ -96,8 +96,10 @@ def _gather_tensor(tensor, dest=None): return out if out is None: - names = input.names if input._has_names() else None - + if len(index.shape) == input.ndim and input._has_names(): + names = input.names + else: + names = None return TensorDict( { key: _gather_tensor(value) From dbb33631f6c5b74f93d3bef0ae5f7c726e8451b6 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Sat, 24 Feb 2024 04:26:34 +0000 Subject: [PATCH 18/32] [BugFix] tensorclass as a decorator (#691) --- tensordict/tensorclass.py | 11 +++++++++++ test/test_tensorclass.py | 16 ++++++++++++++++ 2 files changed, 27 insertions(+) diff --git a/tensordict/tensorclass.py b/tensordict/tensorclass.py index d2b249cb5..2c0397422 100644 --- a/tensordict/tensorclass.py +++ b/tensordict/tensorclass.py @@ -196,6 +196,9 @@ def __torch_function__( cls.load_state_dict = _load_state_dict cls._memmap_ = _memmap_ + cls.__enter__ = __enter__ + cls.__exit__ = __exit__ + # Memmap cls.memmap_like = TensorDictBase.memmap_like cls.memmap_ = TensorDictBase.memmap_ @@ -421,6 +424,14 @@ def _load_memmap(cls, prefix: Path, metadata: dict): return cls._from_tensordict(td, non_tensordict) +def __enter__(self, *args, **kwargs): + return self._tensordict.__enter__(*args, **kwargs) + + +def __exit__(self, *args, **kwargs): + return self._tensordict.__exit__(*args, **kwargs) + + def _getstate(self) -> dict[str, Any]: """Returns a state dict which consists of tensor and non_tensor dicts for serialization. diff --git a/test/test_tensorclass.py b/test/test_tensorclass.py index b430674b8..d1195c40e 100644 --- a/test/test_tensorclass.py +++ b/test/test_tensorclass.py @@ -1785,6 +1785,22 @@ def test_to(self): assert isinstance(td.get("c")[0], self.TensorClass) +def test_decorator(): + @tensorclass + class MyClass: + X: torch.Tensor + y: Any + + obj = MyClass(X=torch.zeros(2), y="a string!", batch_size=[]) + assert not obj.is_locked + with obj.lock_(): + assert obj.is_locked + with obj.unlock_(): + assert not obj.is_locked + assert obj.is_locked + assert not obj.is_locked + + if __name__ == "__main__": args, unknown = argparse.ArgumentParser().parse_known_args() pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown) From b5f6c1784b3209d234a3c5a0fdafcbd63a1bbe81 Mon Sep 17 00:00:00 2001 From: Dimitrios Tsaras Date: Sat, 24 Feb 2024 19:51:28 -0500 Subject: [PATCH 19/32] [BugFix, Feature] `pad_sequence` refactoring (#652) Co-authored-by: Vincent Moens --- tensordict/functional.py | 136 ++++++++++++++++++++++++++------------- test/test_tensordict.py | 112 ++++++++++++++++++++++++++------ 2 files changed, 181 insertions(+), 67 deletions(-) diff --git a/tensordict/functional.py b/tensordict/functional.py index 1cbf168b7..cfb7bff4e 100644 --- a/tensordict/functional.py +++ b/tensordict/functional.py @@ -1,5 +1,7 @@ from __future__ import annotations +import warnings + from typing import Sequence import torch @@ -77,7 +79,8 @@ def pad(tensordict: T, pad_size: Sequence[int], value: float = 0.0) -> T: def pad_sequence( list_of_tensordicts: Sequence[T], - batch_first: bool = True, + batch_first: bool | None = None, + pad_dim: int = 0, padding_value: float = 0.0, out: T | None = None, device: DeviceType | None = None, @@ -87,76 +90,117 @@ def pad_sequence( Args: list_of_tensordicts (List[TensorDictBase]): the list of instances to pad and stack. - batch_first (bool, optional): the ``batch_first`` correspondant of :func:`torch.nn.utils.rnn.pad_sequence`. - Defaults to ``True``. + pad_dim (int, optional): the ``pad_dim`` indicates the dimension to pad all the keys in the tensordict. + Defaults to ``0``. padding_value (number, optional): the padding value. Defaults to ``0.0``. out (TensorDictBase, optional): if provided, the destination where the data will be written. device (device compatible type, optional): if provded, the device where the TensorDict output will be created. - return_mask (bool, optional): if ``True``, a "mask" entry will be returned. - It contains the mask of valid values in the stacked tensordict. + return_mask (bool, optional): if ``True``, a "masks" entry will be returned. + It contains a tensordict with the same structure as the stacked tensordict where every entry contains the mask of valid values with size ``torch.Size([stack_len, *new_shape])``, + where `new_shape[pad_dim] = max_seq_length` and the rest of the `new_shape` matches the previous shape of the contained tensors. Examples: >>> list_td = [ - ... TensorDict({"a": torch.zeros((3,))}, []), - ... TensorDict({"a": torch.zeros((4,))}, []), + ... TensorDict({"a": torch.zeros((3, 8)), "b": torch.zeros((6, 8))}, batch_size=[]), + ... TensorDict({"a": torch.zeros((5, 8)), "b": torch.zeros((6, 8))}, batch_size=[]), ... ] - >>> padded_td = pad_sequence(list_td) + >>> padded_td = pad_sequence(list_td, return_mask=True) >>> print(padded_td) TensorDict( fields={ - a: Tensor(shape=torch.Size([2, 4]), device=cpu, dtype=torch.float32, is_shared=False)}, - batch_size=torch.Size([]), + a: Tensor(shape=torch.Size([2, 4, 8]), device=cpu, dtype=torch.float32, is_shared=False), + b: Tensor(shape=torch.Size([2, 5, 8]), device=cpu, dtype=torch.float32, is_shared=False), + masks: TensorDict( + fields={ + a: Tensor(shape=torch.Size([2, 4]), device=cpu, dtype=torch.bool, is_shared=False), + b: Tensor(shape=torch.Size([2, 6]), device=cpu, dtype=torch.bool, is_shared=False)}, + batch_size=torch.Size([2]), + device=None, + is_shared=False)}, + batch_size=torch.Size([2]), device=None, is_shared=False) """ + if batch_first is not None: + warnings.warn( + "The batch_first argument is deprecated and will be removed in a future release. The output will always be batch_first.", + category=DeprecationWarning, + ) + if not list_of_tensordicts: raise RuntimeError("list_of_tensordicts cannot be empty") + # check that all tensordict match - if return_mask: - list_of_tensordicts = [ - td.clone(False).set("mask", torch.ones(td.shape, dtype=torch.bool)) - for td in list_of_tensordicts - ] + update_batch_size = True + max_seq_length = float("-inf") keys = _check_keys(list_of_tensordicts, leaves_only=True, include_nested=True) - shape = max(len(td) for td in list_of_tensordicts) - if shape == 0: - shape = [ - len(list_of_tensordicts), - ] - elif batch_first: - shape = [len(list_of_tensordicts), shape] - else: - shape = [shape, len(list_of_tensordicts)] - if out is None: - out = TensorDict( - {}, batch_size=torch.Size(shape), device=device, _run_checks=False - ) + tmp_list_of_tensordicts = [] + for td in list_of_tensordicts: + + if return_mask: + tmp_list_of_tensordicts.append(td.clone(False)) + for key in keys: - try: - out.set( - key, - torch.nn.utils.rnn.pad_sequence( - [td.get(key) for td in list_of_tensordicts], - batch_first=batch_first, - padding_value=padding_value, - ), + tensor_shape = td.get(key).shape + pos_pad_dim = pad_dim if pad_dim >= 0 else len(tensor_shape) + pad_dim + + # track the maximum sequence length to update batch_size accordingly + if tensor_shape[pos_pad_dim] > max_seq_length: + max_seq_length = tensor_shape[pos_pad_dim] + + # The mask should always contain the batch_size of the TensorDict + mask_shape = td.shape + + # if the pad_dim is past the batch_size of the TensorDict, we need to add the new dimension to the mask + if pos_pad_dim >= td.ndim: + mask_shape += torch.Size([tensor_shape[pos_pad_dim]]) + update_batch_size = False + + if return_mask: + tmp_list_of_tensordicts[-1].set( + ("masks", key), + torch.ones(mask_shape, dtype=torch.bool), ) - except Exception as err: - raise RuntimeError(f"pad_sequence failed for key {key}") from err - return out - else: - for key in keys: - out.set_( + if return_mask: + list_of_tensordicts = tmp_list_of_tensordicts + + keys = _check_keys(list_of_tensordicts, leaves_only=True, include_nested=True) + + old_batch_size = list(list_of_tensordicts[0].batch_size) + if update_batch_size and len(old_batch_size) > 0: + old_batch_size[pad_dim] = max_seq_length + shape = [ + len(list_of_tensordicts), + ] + old_batch_size + + if out is None: + out = list_of_tensordicts[0].empty(recurse=True).reshape(torch.Size(shape)) + + for key in keys: + try: + tensor_shape = list_of_tensordicts[0].get(key).shape + pos_pad_dim = ( + (pad_dim if pad_dim >= 0 else len(tensor_shape) + pad_dim) + if len(tensor_shape) > 1 + else 0 # handles the case when the masks are 1-dimensional + ) + out.set( key, torch.nn.utils.rnn.pad_sequence( - [td.get(key) for td in list_of_tensordicts], - batch_first=batch_first, + [ + td.get(key).transpose(0, pos_pad_dim) + for td in list_of_tensordicts + ], + batch_first=True, padding_value=padding_value, - ), + ).transpose(1, pos_pad_dim + 1), + inplace=True, ) - return out + except Exception as err: + raise RuntimeError(f"pad_sequence failed for key {key}") from err + return out def merge_tensordicts(*tensordicts: T) -> T: diff --git a/test/test_tensordict.py b/test/test_tensordict.py index 7c40eb5ca..819556a6f 100644 --- a/test/test_tensordict.py +++ b/test/test_tensordict.py @@ -996,33 +996,103 @@ def test_pad(self): assert torch.equal(padded_td["a"], expected_a) padded_td._check_batch_size() - @pytest.mark.parametrize("batch_first", [True, False]) @pytest.mark.parametrize("make_mask", [True, False]) - def test_pad_sequence(self, batch_first, make_mask): + def test_pad_sequence_pad_dim0(self, make_mask): + pad_dim = 0 list_td = [ - TensorDict({"a": torch.ones((2,)), ("b", "c"): torch.ones((2, 3))}, [2]), - TensorDict({"a": torch.ones((4,)), ("b", "c"): torch.ones((4, 3))}, [4]), + TensorDict( + {"a": torch.ones((2, 8, 8)), ("b", "c"): torch.ones((2, 3))}, [2] + ), + TensorDict( + {"a": torch.full((4, 8, 8), 2), ("b", "c"): torch.full((4, 3), 2)}, + [4], + ), ] - padded_td = pad_sequence( - list_td, batch_first=batch_first, return_mask=make_mask - ) - if batch_first: - assert padded_td.shape == torch.Size([2, 4]) - assert padded_td["a"].shape == torch.Size([2, 4]) - assert padded_td["a"][0, -1] == 0 - assert padded_td["b", "c"].shape == torch.Size([2, 4, 3]) - assert padded_td["b", "c"][0, -1, 0] == 0 + padded_td = pad_sequence(list_td, pad_dim=pad_dim, return_mask=make_mask) + assert padded_td.shape == torch.Size( + [2, 4] + ) # check the shape of the padded tensordict + assert torch.all( + padded_td["a"][0, :2, :, :] == 1 + ) # check the values of the first tensor + assert torch.all( + padded_td["a"][1, :, :, :] == 2 + ) # check the values of the second tensor + assert padded_td["a"].shape == torch.Size( + [2, 4, 8, 8] + ) # check the shape of the padded tensor + assert torch.all(padded_td["a"][0, 2:, :, :] == 0) # check the padding + assert padded_td["b", "c"].shape == torch.Size( + [2, 4, 3] + ) # check the shape of the padded tensor + assert torch.all(padded_td["b", "c"][0, 2:, :] == 0) # check the padding + if make_mask: + padded_td_without_masks = pad_sequence( + list_td, pad_dim=pad_dim, return_mask=False + ) + assert "masks" in padded_td.keys() + assert set( + padded_td_without_masks.keys(include_nested=True, leaves_only=True) + ) == set(padded_td["masks"].keys(include_nested=True, leaves_only=True)) + assert not padded_td["masks", "a"].all() + assert padded_td["masks", "a"].ndim == pad_dim + 2 + assert (padded_td["a"][padded_td["masks", "a"]] != 0).all() + assert (padded_td["a"][~padded_td["masks", "a"]] == 0).all() + assert not padded_td["masks", "b", "c"].all() + assert padded_td["masks", "b", "c"].ndim == pad_dim + 2 + assert (padded_td["b", "c"][padded_td["masks", "b", "c"]] != 0).all() + assert (padded_td["b", "c"][~padded_td["masks", "b", "c"]] == 0).all() else: - assert padded_td.shape == torch.Size([4, 2]) - assert padded_td["a"].shape == torch.Size([4, 2]) - assert padded_td["a"][-1, 0] == 0 - assert padded_td["b", "c"].shape == torch.Size([4, 2, 3]) - assert padded_td["b", "c"][-1, 0, 0] == 0 + assert "masks" not in padded_td.keys() + + @pytest.mark.parametrize("make_mask", [True, False]) + def test_pad_sequence_pad_dim1(self, make_mask): + pad_dim = 1 + list_td = [ + TensorDict( + {"a": torch.ones((6, 3, 8)), ("b", "c"): torch.ones((6, 3))}, [6] + ), + TensorDict( + {"a": torch.full((6, 5, 8), 2), ("b", "c"): torch.full((6, 7), 2)}, + [6], + ), + ] + padded_td = pad_sequence(list_td, pad_dim=pad_dim, return_mask=make_mask) + assert padded_td.shape == torch.Size( + [2, 6] + ) # check the shape of the padded tensordict + assert padded_td["a"].shape == torch.Size( + [2, 6, 5, 8] + ) # check the shape of the padded tensor + assert torch.all( + padded_td["a"][0, :, :3, :] == 1 + ) # check the values of the first tensor + assert torch.all(padded_td["a"][0, :, 3:, :] == 0) # check the padding + assert torch.all( + padded_td["a"][1, :, :, :] == 2 + ) # check the values of the second tensor + assert padded_td["b", "c"].shape == torch.Size( + [2, 6, 7] + ) # check the shape of the padded tensor + assert torch.all(padded_td["b", "c"][0, :, 3:] == 0) # check the padding if make_mask: - assert "mask" in padded_td.keys() - assert not padded_td["mask"].all() + padded_td_without_masks = pad_sequence( + list_td, pad_dim=pad_dim, return_mask=False + ) + assert "masks" in padded_td.keys() + assert set( + padded_td_without_masks.keys(include_nested=True, leaves_only=True) + ) == set(padded_td["masks"].keys(include_nested=True, leaves_only=True)) + assert not padded_td["masks", "a"].all() + assert padded_td["masks", "a"].ndim == pad_dim + 2 + assert (padded_td["a"][padded_td["masks", "a"]] != 0).all() + assert (padded_td["a"][~padded_td["masks", "a"]] == 0).all() + assert not padded_td["masks", "b", "c"].all() + assert padded_td["masks", "b", "c"].ndim == pad_dim + 2 + assert (padded_td["b", "c"][padded_td["masks", "b", "c"]] != 0).all() + assert (padded_td["b", "c"][~padded_td["masks", "b", "c"]] == 0).all() else: - assert "mask" not in padded_td.keys() + assert "masks" not in padded_td.keys() @pytest.mark.parametrize("device", get_available_devices()) def test_permute(self, device): From d1866394de41bbeb8b29c2dd450d3e4a91487933 Mon Sep 17 00:00:00 2001 From: vmoens Date: Sat, 24 Feb 2024 17:44:07 -0800 Subject: [PATCH 20/32] amend --- tensordict/_td.py | 10 ++++++---- tensordict/base.py | 5 +---- 2 files changed, 7 insertions(+), 8 deletions(-) diff --git a/tensordict/_td.py b/tensordict/_td.py index c01256e24..aeca5001a 100644 --- a/tensordict/_td.py +++ b/tensordict/_td.py @@ -309,9 +309,8 @@ def is_empty(self): if _is_tensor_collection(type(item)): if not item.is_empty(): return False - from tensordict.tensorclass import NonTensorData, NonTensorStack - if isinstance(item, (NonTensorData, NonTensorStack)): + if is_non_tensor(item): return False else: return False @@ -693,7 +692,7 @@ def make_result(): if ( not call_on_nested and _is_tensor_collection(item.__class__) - and not is_non_tensor(item) + # and not is_non_tensor(item) ): if default is not NO_DEFAULT: _others = [_other._get_str(key, default=None) for _other in others] @@ -2467,7 +2466,10 @@ def _get_non_tensor(self, key: NestedKey, default=NO_DEFAULT): def _get_str(self, key, default): if key in self.keys() and _is_tensor_collection(self.entry_class(key)): - return _SubTensorDict(self._source._get_str(key, NO_DEFAULT), self.idx) + data = self._source._get_str(key, NO_DEFAULT) + if is_non_tensor(data): + return data[self.idx] + return _SubTensorDict(data, self.idx) return self._source._get_at_str(key, self.idx, default=default) def _get_tuple(self, key, default): diff --git a/tensordict/base.py b/tensordict/base.py index 3fe1ac63b..33fa7b239 100644 --- a/tensordict/base.py +++ b/tensordict/base.py @@ -2321,12 +2321,9 @@ def _get_non_tensor(self, key: NestedKey, default=NO_DEFAULT): return subtd._get_non_tensor(key[1:], default=default) value = self._get_str(key, default=default) - from .tensorclass import NonTensorData - + from .tensorclass import NonTensorData, NonTensorStack if isinstance(value, NonTensorData): return value.data - from tensordict.tensorclass import NonTensorStack - if isinstance(value, NonTensorStack): return value.tolist() return value From 4f665201d32ec5d4d4dd70d3ebd072475fdbaa1a Mon Sep 17 00:00:00 2001 From: vmoens Date: Sun, 25 Feb 2024 15:26:58 -0800 Subject: [PATCH 21/32] amend --- tensordict/_lazy.py | 155 +++++++++++++++++++++++++++----------- tensordict/base.py | 1 + tensordict/tensorclass.py | 78 +++++++++++++++++-- tensordict/utils.py | 18 ++--- test/test_tensordict.py | 19 ++++- 5 files changed, 206 insertions(+), 65 deletions(-) diff --git a/tensordict/_lazy.py b/tensordict/_lazy.py index 8059b6819..df256f265 100644 --- a/tensordict/_lazy.py +++ b/tensordict/_lazy.py @@ -200,6 +200,10 @@ def __init__( ) _batch_size = tensordicts[0].batch_size device = tensordicts[0].device + if stack_dim > len(_batch_size): + raise RuntimeError( + f"Stack dim {stack_dim} is too big for batch size {_batch_size}." + ) for td in tensordicts[1:]: if not is_tensor_collection(td): @@ -487,9 +491,10 @@ def _split_index(self, index): isinteger = False is_nd_tensor = False cursor = 0 # the dimension cursor - selected_td_idx = range(len(self.tensordicts)) + selected_td_idx = torch.arange(len(self.tensordicts)) has_bool = False num_squash = 0 + encountered_tensor = False for i, idx in enumerate(index): # noqa: B007 cursor_incr = 1 if idx is None: @@ -509,10 +514,8 @@ def _split_index(self, index): if not isinstance(selected_td_idx, range): isinteger = True selected_td_idx = [selected_td_idx] - elif isinstance(idx, (list, range)): - selected_td_idx = idx - elif isinstance(idx, (torch.Tensor, np.ndarray)): - if idx.dtype in (np.dtype("bool"), torch.bool): + elif isinstance(idx, torch.Tensor): + if idx.dtype == torch.bool: # we mark that we need to dispatch the indices across stack idx has_bool = True # split mask along dim @@ -522,11 +525,14 @@ def _split_index(self, index): split_dim = self.stack_dim - num_single mask_loc = i else: - if isinstance(idx, np.ndarray): - idx = torch.tensor(idx) is_nd_tensor = True - selected_td_idx = range(len(idx)) - out.append(idx.unbind(0)) + if not encountered_tensor: + # num_single -= idx.ndim - 1 + encountered_tensor = True + else: + num_single += 1 + selected_td_idx = idx + # out.append(idx.unbind(0)) else: raise TypeError(f"Invalid index type: {type(idx)}.") else: @@ -537,13 +543,11 @@ def _split_index(self, index): ( ftdim.Dim, slice, - list, - range, ), ): out.append(idx) - elif isinstance(idx, (np.ndarray, torch.Tensor)): - if idx.dtype in (np.dtype("bool"), torch.bool): + elif isinstance(idx, torch.Tensor): + if idx.dtype == torch.bool: cursor_incr = idx.ndim if cursor < self.stack_dim: num_squash += cursor_incr - 1 @@ -568,7 +572,11 @@ def _split_index(self, index): # smth[torch.tensor(1)].ndim = smth.ndim-1 # smth[torch.tensor([1])].ndim = smth.ndim # smth[torch.tensor([[1]])].ndim = smth.ndim+1 - num_single -= idx.ndim - 1 + if not encountered_tensor: + num_single -= idx.ndim - 1 + encountered_tensor = True + else: + num_single += 1 out.append(idx) else: raise TypeError(f"Invalid index type: {type(idx)}.") @@ -593,20 +601,45 @@ def _split_index(self, index): elif is_nd_tensor: def isindexable(idx): - if isinstance(idx, (torch.Tensor, np.ndarray)): - if idx.dtype in (torch.bool, np.dtype("bool")): + if isinstance(idx, torch.Tensor): + if idx.dtype == torch.bool: return False return True if isinstance(idx, (tuple, list, range)): return True return False - out = tuple( - tuple(idx if not isindexable(idx) else idx[i] for idx in out) - for i in selected_td_idx - ) + def outer_list(tensor_index, tuple_index): + """Converts a tensor and a tuple to a nested list where each leaf is a (int, index) tuple where the index only points to one element.""" + if isinstance(tensor_index, torch.Tensor): + list_index = tensor_index.tolist() + else: + list_index = tensor_index + list_result = [] + + def index_tuple_index(i, convert=False): + for idx in tuple_index: + if isindexable(idx): + if convert: + yield int(idx[i]) + else: + yield idx[i] + else: + yield idx + + for i, idx in enumerate(list_index): + if isinstance(idx, int): + list_result.append( + (idx, tuple(index_tuple_index(i, convert=True))) + ) + elif isinstance(idx, list): + list_result.append(outer_list(idx, tuple(index_tuple_index(i)))) + else: + raise NotImplementedError + return list_result + return { - "index_dict": dict(enumerate(out)), + "index_dict": outer_list(selected_td_idx, out), "num_single": num_single, "isinteger": isinteger, "has_bool": has_bool, @@ -646,8 +679,19 @@ def _set_at_str(self, key, value, index, *, validated): if is_nd_tensor: unbind_dim = self.stack_dim - num_single + num_none - num_squash value_unbind = value.unbind(unbind_dim) - for idx, _value in zip(converted_idx.values(), value_unbind): - self._set_at_str(key, _value, idx, validated=validated) + + def set_at_str(converted_idx): + for i, item in enumerate(converted_idx): + if isinstance(item, list): + set_at_str(item) + else: + _value = value_unbind[i] + stack_idx, idx = item + self.tensordicts[stack_idx]._set_at_str( + key, _value, idx, validated=validated + ) + + set_at_str(converted_idx) return self elif not has_bool: unbind_dim = self.stack_dim - num_single + num_none - num_squash @@ -1460,10 +1504,12 @@ def __setitem__(self, index: IndexType, value: T) -> T: ) return - if any(isinstance(sub_index, (list, range)) for sub_index in index): + if any( + isinstance(sub_index, (list, range, np.ndarray)) for sub_index in index + ): index = tuple( - torch.tensor(sub_index, device=self.device) - if isinstance(sub_index, (list, range)) + torch.as_tensor(sub_index, device=self.device) + if isinstance(sub_index, (list, range, np.ndarray)) else sub_index for sub_index in index ) @@ -1471,7 +1517,7 @@ def __setitem__(self, index: IndexType, value: T) -> T: if index is Ellipsis or (isinstance(index, tuple) and Ellipsis in index): index = convert_ellipsis_to_idx(index, self.batch_size) elif isinstance(index, (list, range)): - index = torch.tensor(index, device=self.device) + index = torch.as_tensor(index, device=self.device) if is_tensor_collection(value) or isinstance(value, dict): indexed_bs = _getitem_batch_size(self.batch_size, index) @@ -1500,17 +1546,25 @@ def __setitem__(self, index: IndexType, value: T) -> T: if isinteger: # this will break if the index along the stack dim is [0] or :1 or smth for i, _idx in converted_idx.items(): - if _idx != (): - self.tensordicts[i][_idx] = value + if _idx == (): + self.tensordicts[i].update(value, inplace=True) else: - self.tensordicts[i] = value - + self.tensordicts[i][_idx] = value return self if is_nd_tensor: - raise RuntimeError( - "Indexing along stack dim with a non-boolean tensor is not supported yet. " - "Use SubTensorDict instead." - ) + unbind_dim = self.stack_dim - num_single + num_none - num_squash + # converted_idx is a nested list with (int, index) items + def assign(converted_idx, value=value): + value = value.unbind(unbind_dim) + for i, item in enumerate(converted_idx): + if isinstance(item, list): + assign(item) + else: + stack_item, idx = item + self.tensordicts[stack_item][idx] = value[i] + + assign(converted_idx) + return self if not has_bool: unbind_dim = self.stack_dim - num_single + num_none - num_squash value_unbind = value.unbind(unbind_dim) @@ -1518,10 +1572,10 @@ def __setitem__(self, index: IndexType, value: T) -> T: converted_idx.items(), value_unbind, ): - if _idx != (): - self.tensordicts[i][_idx] = _value + if _idx == (): + self.tensordicts[i].update(_value, inplace=True) else: - self.tensordicts[i] = _value + self.tensordicts[i][_idx] = _value else: # we must split, not unbind mask_unbind = split_index["individual_masks"] @@ -1589,11 +1643,21 @@ def __getitem__(self, index: IndexType) -> T: return torch.cat(result, cat_dim) elif is_nd_tensor: new_stack_dim = self.stack_dim - num_single + num_none - out = LazyStackedTensorDict.lazy_stack( - [self[idx] for idx in converted_idx.values()], new_stack_dim - ) - out._td_dim_name = self._td_dim_name - return out + + def recompose(converted_idx, stack_dim=new_stack_dim): + stack = [] + for item in converted_idx: + if isinstance(item, list): + stack.append(recompose(item, stack_dim=stack_dim)) + else: + stack_elt, idx = item + stack.append(self.tensordicts[stack_elt][idx]) + result = LazyStackedTensorDict.lazy_stack(stack, stack_dim) + # TODO: this produces multiple dims with the same name + result._td_dim_name = self._td_dim_name + return result + + return recompose(converted_idx) else: if isinteger: for ( @@ -1610,7 +1674,10 @@ def __getitem__(self, index: IndexType) -> T: result = [] new_stack_dim = self.stack_dim - num_single + num_none - num_squash for i, _idx in converted_idx.items(): - result.append(self.tensordicts[i][_idx]) + if _idx == (): + result.append(self.tensordicts[i]) + else: + result.append(self.tensordicts[i][_idx]) result = LazyStackedTensorDict.lazy_stack(result, new_stack_dim) result._td_dim_name = self._td_dim_name return result diff --git a/tensordict/base.py b/tensordict/base.py index 33fa7b239..e08ba35f9 100644 --- a/tensordict/base.py +++ b/tensordict/base.py @@ -2322,6 +2322,7 @@ def _get_non_tensor(self, key: NestedKey, default=NO_DEFAULT): value = self._get_str(key, default=default) from .tensorclass import NonTensorData, NonTensorStack + if isinstance(value, NonTensorData): return value.data if isinstance(value, NonTensorStack): diff --git a/tensordict/tensorclass.py b/tensordict/tensorclass.py index c468697b4..4421aa318 100644 --- a/tensordict/tensorclass.py +++ b/tensordict/tensorclass.py @@ -28,7 +28,7 @@ from tensordict._td import is_tensor_collection, NO_DEFAULT, TensorDict, TensorDictBase from tensordict._tensordict import _unravel_key_to_tuple from tensordict._torch_func import TD_HANDLED_FUNCTIONS -from tensordict.base import _ACCEPTED_CLASSES, _register_tensor_class +from tensordict.base import _ACCEPTED_CLASSES, _register_tensor_class, CompatibleType from tensordict.memmap_deprec import MemmapTensor as _MemmapTensor from tensordict.utils import ( @@ -1319,6 +1319,23 @@ def __or__(self, other): self.__class__.__or__ = __or__ + def update( + self, + input_dict_or_td: dict[str, CompatibleType] | T, + clone: bool = False, + inplace: bool = False, + *, + keys_to_update: Sequence[NestedKey] | None = None, + ) -> T: + if isinstance(input_dict_or_td, NonTensorData): + data = input_dict_or_td.data + if clone: + data = deepcopy(data) + self.data = data + elif not input_dict_or_td.is_empty(): + raise RuntimeError(f"Unexpected type {type(input_dict_or_td)}") + return self + def empty(self, recurse=False): return NonTensorData( data=self.data, @@ -1450,10 +1467,59 @@ def clone(self, recurse: bool = True): class NonTensorStack(LazyStackedTensorDict): - """A thin wrapper aroung LazyStackedTensorDict to make stack on non-tensor data easily recognizable.""" + """A thin wrapper around LazyStackedTensorDict to make stack on non-tensor data easily recognizable. + + A ``NonTensorStack`` is returned whenever :func:`~torch.stack` is called on + a list of :class:`~tensordict.NonTensorData` or ``NonTensorStack``. + + Examples: + >>> from tensordict import NonTensorData + >>> import torch + >>> data = torch.stack([ + ... torch.stack([NonTensorData(data=(i, j), batch_size=[]) for i in range(2)]) + ... for j in range(3)]) + >>> print(data) + NonTensorStack( + [[(0, 0), (1, 0)], [(0, 1), (1, 1)], [(0, 2), (1, ..., + batch_size=torch.Size([3, 2]), + device=None) + + To obtain the values stored in a ``NonTensorStack``, call :class:`~.tolist`. + + """ def tolist(self): - if self.stack_dim == 0: - return [td.tolist() for td in self.tensordicts] - else: - return [td.tolist() for td in self.unbind(0)] + """Extracts the content of a :class:`tensordict.tensorclass.NonTensorStack` in a nested list. + + Examples: + >>> from tensordict import NonTensorData + >>> import torch + >>> data = torch.stack([ + ... torch.stack([NonTensorData(data=(i, j), batch_size=[]) for i in range(2)]) + ... for j in range(3)]) + >>> data.tolist() + [[(0, 0), (1, 0)], [(0, 1), (1, 1)], [(0, 2), (1, 2)]] + + """ + iterator = self.tensordicts if self.stack_dim == 0 else self.unbind(0) + return [td.tolist() for td in iterator] + + @classmethod + def from_nontensordata(cls, non_tensor: NonTensorData): + data = non_tensor.data + prev = NonTensorData(data, batch_size=[], device=non_tensor.device) + for dim in reversed(non_tensor.shape): + prev = cls(*[prev.clone(False) for _ in range(dim)], stack_dim=0) + return prev + + def __repr__(self): + selfrepr = str(self.tolist()) + if len(selfrepr) > 50: + selfrepr = f"{selfrepr[:50]}..." + selfrepr = indent(selfrepr, prefix=4 * " ") + batch_size = indent(f"batch_size={self.batch_size}", prefix=4 * " ") + device = indent(f"device={self.device}", prefix=4 * " ") + return f"NonTensorStack(\n{selfrepr}," f"\n{batch_size}," f"\n{device})" + + def to_dict(self) -> dict[str, Any]: + return self.tolist() diff --git a/tensordict/utils.py b/tensordict/utils.py index 7192c66b3..bb80ca1de 100644 --- a/tensordict/utils.py +++ b/tensordict/utils.py @@ -668,18 +668,9 @@ def _set_item(tensor: Tensor, index: IndexType, value: Tensor, *, validated) -> and tensor.data == value.data ): return tensor - if isinstance(index, tuple): - if len(index) == 1: - index = index[0] - else: - idx = index[0] - tensor_idx = tensor[idx] - tensor_idx = _set_item(tensor_idx, index[1:], value, validated=True) - tensor = _set_item(tensor, idx, tensor_idx, validated=True) - return tensor - if isinstance(tensor, NonTensorData): - tensor = NonTensorStack(*[tensor[0]] * tensor.shape[0], stack_dim=0) - elif tensor.stack_dim != 0: + elif isinstance(tensor, NonTensorData): + tensor = NonTensorStack.from_nontensordata(tensor) + if tensor.stack_dim != 0: tensor = NonTensorStack(*tensor.unbind(0), stack_dim=0) tensor[index] = value return tensor @@ -1610,8 +1601,9 @@ def wrapper(*args, **kwargs): def _broadcast_tensors(index): # tensors and range need to be broadcast + assert isinstance(index, tuple) tensors = { - i: tensor if isinstance(tensor, Tensor) else torch.tensor(tensor) + i: torch.as_tensor(tensor) for i, tensor in enumerate(index) if isinstance(tensor, (range, list, np.ndarray, Tensor)) } diff --git a/test/test_tensordict.py b/test/test_tensordict.py index 79cbd092e..01bdb358d 100644 --- a/test/test_tensordict.py +++ b/test/test_tensordict.py @@ -2774,6 +2774,7 @@ def test_index_tensor_nd_names(self, td_name, device, npy): index = index.numpy() td_idx = td[:, index] assert tensor_example[:, index].shape == td_idx.shape + # TODO: this multiple dims with identical names should not be allowed assert td_idx.names == [names[0], names[1], names[1], *names[2:]] td_idx = td[0, index] assert tensor_example[0, index].shape == td_idx.shape @@ -6100,10 +6101,24 @@ def test_lazy_indexing(self, pos1, pos2, pos3): pos2 = self._idx_list[pos2] pos3 = self._idx_list[pos3] index = (pos1, pos2, pos3) + print( + "index", + tuple( + index if not hasattr(index, "shape") else index.shape for index in index + ), + ) result = outer[index] ref_tensor = torch.zeros(outer.shape) - assert result.batch_size == ref_tensor[index].shape, index - assert result.batch_size == outer_dense[index].shape, index + assert result.batch_size == ref_tensor[index].shape, ( + result.batch_size, + ref_tensor[index].shape, + index, + ) + assert result.batch_size == outer_dense[index].shape, ( + result.batch_size, + outer_dense[index].shape, + index, + ) @pytest.mark.parametrize("stack_dim", [0, 1, 2]) @pytest.mark.parametrize("mask_dim", [0, 1, 2]) From a06aa778a6cf302ff72504ed117a90f18939a112 Mon Sep 17 00:00:00 2001 From: vmoens Date: Sun, 25 Feb 2024 15:28:52 -0800 Subject: [PATCH 22/32] amend --- tensordict/_lazy.py | 1 + tensordict/nn/params.py | 8 -------- tensordict/utils.py | 1 - test/test_tensordict.py | 6 ------ 4 files changed, 1 insertion(+), 15 deletions(-) diff --git a/tensordict/_lazy.py b/tensordict/_lazy.py index df256f265..02c37d491 100644 --- a/tensordict/_lazy.py +++ b/tensordict/_lazy.py @@ -1553,6 +1553,7 @@ def __setitem__(self, index: IndexType, value: T) -> T: return self if is_nd_tensor: unbind_dim = self.stack_dim - num_single + num_none - num_squash + # converted_idx is a nested list with (int, index) items def assign(converted_idx, value=value): value = value.unbind(unbind_dim) diff --git a/tensordict/nn/params.py b/tensordict/nn/params.py index 81a124c81..e5714bf4e 100644 --- a/tensordict/nn/params.py +++ b/tensordict/nn/params.py @@ -1117,8 +1117,6 @@ def compute_should_use_set_data(tensor, tensor_applied): param.data = param_applied out_param = param else: - assert isinstance(param, nn.Parameter) - assert param.is_leaf out_param = nn.Parameter(param_applied, param.requires_grad) self._parameters[key] = out_param @@ -1129,10 +1127,8 @@ def compute_should_use_set_data(tensor, tensor_applied): param.grad, grad_applied ) if should_use_set_data: - assert out_param.grad is not None out_param.grad.data = grad_applied else: - assert param.grad.is_leaf out_param.grad = grad_applied.requires_grad_( param.grad.requires_grad ) @@ -1150,8 +1146,6 @@ def compute_should_use_set_data(tensor, tensor_applied): buffer.data = buffer_applied out_buffer = buffer else: - assert isinstance(buffer, Buffer) - assert buffer.is_leaf out_buffer = Buffer(buffer_applied, buffer.requires_grad) self._buffers[key] = out_buffer @@ -1162,10 +1156,8 @@ def compute_should_use_set_data(tensor, tensor_applied): buffer.grad, grad_applied ) if should_use_set_data: - assert out_buffer.grad is not None out_buffer.grad.data = grad_applied else: - assert buffer.grad.is_leaf out_buffer.grad = grad_applied.requires_grad_( buffer.grad.requires_grad ) diff --git a/tensordict/utils.py b/tensordict/utils.py index bb80ca1de..c23c68ec9 100644 --- a/tensordict/utils.py +++ b/tensordict/utils.py @@ -1601,7 +1601,6 @@ def wrapper(*args, **kwargs): def _broadcast_tensors(index): # tensors and range need to be broadcast - assert isinstance(index, tuple) tensors = { i: torch.as_tensor(tensor) for i, tensor in enumerate(index) diff --git a/test/test_tensordict.py b/test/test_tensordict.py index 01bdb358d..9ffb602a6 100644 --- a/test/test_tensordict.py +++ b/test/test_tensordict.py @@ -6101,12 +6101,6 @@ def test_lazy_indexing(self, pos1, pos2, pos3): pos2 = self._idx_list[pos2] pos3 = self._idx_list[pos3] index = (pos1, pos2, pos3) - print( - "index", - tuple( - index if not hasattr(index, "shape") else index.shape for index in index - ), - ) result = outer[index] ref_tensor = torch.zeros(outer.shape) assert result.batch_size == ref_tensor[index].shape, ( From d50104126b355ff2efef87e0a92dd1a8055421dd Mon Sep 17 00:00:00 2001 From: vmoens Date: Mon, 26 Feb 2024 09:48:51 -0500 Subject: [PATCH 23/32] amend --- tensordict/_td.py | 32 ++++++++++++++++++++------------ test/test_nn.py | 2 ++ 2 files changed, 22 insertions(+), 12 deletions(-) diff --git a/tensordict/_td.py b/tensordict/_td.py index aeca5001a..8ccf1388b 100644 --- a/tensordict/_td.py +++ b/tensordict/_td.py @@ -3043,7 +3043,10 @@ def __contains__(self, key: NestedKey) -> bool: if isinstance(key, str): if key in self._keys(): if self.leaves_only: - return not _is_tensor_collection(self.tensordict.entry_class(key)) + # TODO: make this faster for LazyStacked without compromising regular + return not _is_tensor_collection( + type(self.tensordict._get_str(key)) + ) return True return False else: @@ -3051,25 +3054,30 @@ def __contains__(self, key: NestedKey) -> bool: if len(key) == 1: return key[0] in self._keys() elif self.include_nested: - if key[0] in self._keys(): - entry_type = self.tensordict.entry_class(key[0]) - if entry_type in (Tensor, _MemmapTensor): + item_root = self.tensordict._get_str(key[0], default=None) + if item_root is not None: + entry_type = type(item_root) + if issubclass(entry_type, (Tensor, _MemmapTensor)): return False - if entry_type is KeyedJaggedTensor: + elif entry_type is KeyedJaggedTensor: if len(key) > 2: return False - return key[1] in self.tensordict.get(key[0]).keys() + return key[1] in item_root.keys() + # TODO: make this faster for LazyStacked without compromising regular _is_tensordict = _is_tensor_collection(entry_type) if _is_tensordict: # # this will call _unravel_key_to_tuple many times # return key[1:] in self.tensordict._get_str(key[0], NO_DEFAULT).keys(include_nested=self.include_nested) # this won't call _unravel_key_to_tuple but requires to get the default which can be suboptimal - leaf_td = self.tensordict._get_tuple(key[:-1], None) - if leaf_td is None or ( - not _is_tensor_collection(leaf_td.__class__) - and not isinstance(leaf_td, KeyedJaggedTensor) - ): - return False + if len(key) >= 3: + leaf_td = item_root._get_tuple(key[1:-1], None) + if leaf_td is None or ( + not _is_tensor_collection(leaf_td.__class__) + and not isinstance(leaf_td, KeyedJaggedTensor) + ): + return False + else: + leaf_td = item_root return key[-1] in leaf_td.keys() return False # this is reached whenever there is more than one key but include_nested is False diff --git a/test/test_nn.py b/test/test_nn.py index 37d5975b0..016237077 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -415,6 +415,7 @@ def test_functional_before(self): tensordict_module = TensorDictModule( module=net, in_keys=["in"], out_keys=["out"] ) + make_functional(tensordict_module, return_params=False) td = TensorDict({"in": torch.randn(3, 3)}, [3]) tensordict_module(td, params=TensorDict({"module": params}, [])) @@ -580,6 +581,7 @@ def test_functional_with_buffer(self): tdmodule = TensorDictModule(module=net, in_keys=["in"], out_keys=["out"]) td = TensorDict({"in": torch.randn(3, 32 * param_multiplier)}, [3]) + make_functional(tdmodule, return_params=False) tdmodule(td, params=TensorDict({"module": params}, [])) assert td.shape == torch.Size([3]) assert td.get("out").shape == torch.Size([3, 32]) From c23e2577e482ba8060185841001fbddddbb81d33 Mon Sep 17 00:00:00 2001 From: vmoens Date: Mon, 26 Feb 2024 10:25:28 -0500 Subject: [PATCH 24/32] amend --- tensordict/_lazy.py | 7 ++++--- tensordict/_td.py | 5 +---- tensordict/_torch_func.py | 13 ++++++++----- tensordict/base.py | 30 +++++++++--------------------- tensordict/tensorclass.py | 14 +++++++++++--- tensordict/utils.py | 9 ++++----- 6 files changed, 37 insertions(+), 41 deletions(-) diff --git a/tensordict/_lazy.py b/tensordict/_lazy.py index 02c37d491..96827fe44 100644 --- a/tensordict/_lazy.py +++ b/tensordict/_lazy.py @@ -29,6 +29,7 @@ _register_tensor_class, BEST_ATTEMPT_INPLACE, CompatibleType, + is_non_tensor, is_tensor_collection, NO_DEFAULT, T, @@ -931,15 +932,15 @@ def lazy_stack( if not items: raise RuntimeError("items cannot be empty") - from .tensorclass import NonTensorData - if all(isinstance(item, torch.Tensor) for item in items): return torch.stack(items, dim=dim, out=out) if all( is_tensorclass(item) and type(item) == type(items[0]) # noqa: E721 for item in items ): - if all(isinstance(tensordict, NonTensorData) for tensordict in items): + if all(is_non_tensor(tensordict) for tensordict in items): + from .tensorclass import NonTensorData + return NonTensorData._stack_non_tensor(items, dim=dim) lazy_stack = cls.lazy_stack( [item._tensordict for item in items], dim=dim, out=out diff --git a/tensordict/_td.py b/tensordict/_td.py index 8ccf1388b..d2f5bc327 100644 --- a/tensordict/_td.py +++ b/tensordict/_td.py @@ -2456,11 +2456,8 @@ def get( def _get_non_tensor(self, key: NestedKey, default=NO_DEFAULT): out = super()._get_non_tensor(key, default=default) - from tensordict.tensorclass import NonTensorData, NonTensorStack - if isinstance(out, _SubTensorDict) and isinstance( - out._source, (NonTensorData, NonTensorStack) - ): + if isinstance(out, _SubTensorDict) and is_non_tensor(out._source): return out._source return out diff --git a/tensordict/_torch_func.py b/tensordict/_torch_func.py index d7b9a12e5..6d3280958 100644 --- a/tensordict/_torch_func.py +++ b/tensordict/_torch_func.py @@ -13,7 +13,12 @@ import torch from tensordict._lazy import LazyStackedTensorDict from tensordict._td import TensorDict -from tensordict.base import _is_leaf_nontensor, NO_DEFAULT, TensorDictBase +from tensordict.base import ( + _is_leaf_nontensor, + is_non_tensor, + NO_DEFAULT, + TensorDictBase, +) from tensordict.persistent import PersistentTensorDict from tensordict.utils import ( _check_keys, @@ -381,11 +386,9 @@ def _stack( if not list_of_tensordicts: raise RuntimeError("list_of_tensordicts cannot be empty") - from tensordict.tensorclass import NonTensorData, NonTensorStack + if all(is_non_tensor(td) for td in list_of_tensordicts): + from tensordict.tensorclass import NonTensorData - if all( - isinstance(td, (NonTensorData, NonTensorStack)) for td in list_of_tensordicts - ): return NonTensorData._stack_non_tensor(list_of_tensordicts, dim=dim) batch_size = list_of_tensordicts[0].batch_size diff --git a/tensordict/base.py b/tensordict/base.py index e08ba35f9..5ac9bf5bf 100644 --- a/tensordict/base.py +++ b/tensordict/base.py @@ -242,14 +242,11 @@ def __getitem__(self, index: IndexType) -> T: idx_unravel = _unravel_key_to_tuple(index) if idx_unravel: result = self._get_tuple(idx_unravel, NO_DEFAULT) - from .tensorclass import NonTensorData - - if isinstance(result, NonTensorData): - return result.data - from tensordict.tensorclass import NonTensorStack - - if isinstance(result, NonTensorStack): - return result.tolist() + if is_non_tensor(result): + result_data = getattr(result, "data", NO_DEFAULT) + if result_data is NO_DEFAULT: + return result_data.tolist() + return result_data return result if (istuple and not index) or (not istuple and index is Ellipsis): @@ -2321,20 +2318,15 @@ def _get_non_tensor(self, key: NestedKey, default=NO_DEFAULT): return subtd._get_non_tensor(key[1:], default=default) value = self._get_str(key, default=default) - from .tensorclass import NonTensorData, NonTensorStack - - if isinstance(value, NonTensorData): - return value.data - if isinstance(value, NonTensorStack): + if is_non_tensor(value): return value.tolist() return value def filter_non_tensor_data(self) -> T: """Filters out all non-tensor-data.""" - from tensordict.tensorclass import NonTensorData def _filter(x): - if not isinstance(x, NonTensorData): + if not is_non_tensor(x): if is_tensor_collection(x): return x.filter_non_tensor_data() return x @@ -5388,9 +5380,7 @@ def is_tensor_collection(datatype: type | Any) -> bool: def is_non_tensor(data): """Checks if an item is a non-tensor.""" - from tensordict.tensorclass import NonTensorData, NonTensorStack - - return isinstance(data, (NonTensorData, NonTensorStack)) + return type(data).__dict__.get("_non_tensor", False) def _default_is_leaf(cls: Type) -> bool: @@ -5398,10 +5388,8 @@ def _default_is_leaf(cls: Type) -> bool: def _is_leaf_nontensor(cls: Type) -> bool: - from tensordict.tensorclass import NonTensorData, NonTensorStack - if issubclass(cls, KeyedJaggedTensor): return False if _is_tensor_collection(cls): - return issubclass(cls, (NonTensorData, NonTensorStack)) + return cls.__dict__.get("_non_tensor", False) return issubclass(cls, torch.Tensor) diff --git a/tensordict/tensorclass.py b/tensordict/tensorclass.py index 4421aa318..0c0dfdbf4 100644 --- a/tensordict/tensorclass.py +++ b/tensordict/tensorclass.py @@ -28,7 +28,12 @@ from tensordict._td import is_tensor_collection, NO_DEFAULT, TensorDict, TensorDictBase from tensordict._tensordict import _unravel_key_to_tuple from tensordict._torch_func import TD_HANDLED_FUNCTIONS -from tensordict.base import _ACCEPTED_CLASSES, _register_tensor_class, CompatibleType +from tensordict.base import ( + _ACCEPTED_CLASSES, + _register_tensor_class, + CompatibleType, + is_non_tensor, +) from tensordict.memmap_deprec import MemmapTensor as _MemmapTensor from tensordict.utils import ( @@ -1261,10 +1266,11 @@ class NonTensorData: # to patch tensordict with additional checks that will encur unwanted overhead # and all the overhead falls back on this class. data: Any + _non_tensor: bool = True def __post_init__(self): - if isinstance(self.data, NonTensorData): - self.data = self.data.data + if is_non_tensor(self.data): + self.data = self.data.tolist() old_eq = self.__class__.__eq__ if old_eq is _eq: @@ -1488,6 +1494,8 @@ class NonTensorStack(LazyStackedTensorDict): """ + _non_tensor: bool = True + def tolist(self): """Extracts the content of a :class:`tensordict.tensorclass.NonTensorStack` in a nested list. diff --git a/tensordict/utils.py b/tensordict/utils.py index c23c68ec9..0dd27a153 100644 --- a/tensordict/utils.py +++ b/tensordict/utils.py @@ -49,6 +49,8 @@ unravel_key_list, unravel_keys, ) + +from tensordict.base import is_non_tensor from torch import Tensor from torch._C import _disabled_torch_function_impl from torch.nn.parameter import ( @@ -59,7 +61,6 @@ ) from torch.utils.data._utils.worker import _generate_state - if TYPE_CHECKING: from tensordict.memmap_deprec import MemmapTensor as _MemmapTensor from tensordict.tensordict import TensorDictBase @@ -661,7 +662,7 @@ def _set_item(tensor: Tensor, index: IndexType, value: Tensor, *, validated) -> return tensor from tensordict.tensorclass import NonTensorData, NonTensorStack - if isinstance(tensor, (NonTensorData, NonTensorStack)): + if is_non_tensor(tensor): if ( isinstance(value, NonTensorData) and isinstance(tensor, NonTensorData) @@ -1521,9 +1522,7 @@ def _expand_to_match_shape( def _set_max_batch_size(source: T, batch_dims=None): """Updates a tensordict with its maximium batch size.""" - from tensordict import NonTensorData - - tensor_data = [val for val in source.values() if not isinstance(val, NonTensorData)] + tensor_data = [val for val in source.values() if not is_non_tensor(val)] for val in tensor_data: from tensordict.base import _is_tensor_collection From 4bfe951dddd2bc085012a205e81a645adbe3182c Mon Sep 17 00:00:00 2001 From: vmoens Date: Mon, 26 Feb 2024 10:40:15 -0500 Subject: [PATCH 25/32] amend --- tensordict/_lazy.py | 2 +- tensordict/_td.py | 2 +- tensordict/_torch_func.py | 8 ++------ tensordict/base.py | 6 +----- tensordict/tensorclass.py | 8 ++------ tensordict/utils.py | 6 +++++- 6 files changed, 12 insertions(+), 20 deletions(-) diff --git a/tensordict/_lazy.py b/tensordict/_lazy.py index 96827fe44..3f9765246 100644 --- a/tensordict/_lazy.py +++ b/tensordict/_lazy.py @@ -29,7 +29,6 @@ _register_tensor_class, BEST_ATTEMPT_INPLACE, CompatibleType, - is_non_tensor, is_tensor_collection, NO_DEFAULT, T, @@ -54,6 +53,7 @@ expand_right, IndexType, infer_size_impl, + is_non_tensor, is_tensorclass, KeyedJaggedTensor, lock_blocked, diff --git a/tensordict/_td.py b/tensordict/_td.py index d2f5bc327..021c9151e 100644 --- a/tensordict/_td.py +++ b/tensordict/_td.py @@ -29,7 +29,6 @@ _register_tensor_class, BEST_ATTEMPT_INPLACE, CompatibleType, - is_non_tensor, is_tensor_collection, NO_DEFAULT, T, @@ -72,6 +71,7 @@ DeviceType, expand_as_right, IndexType, + is_non_tensor, is_tensorclass, KeyedJaggedTensor, lock_blocked, diff --git a/tensordict/_torch_func.py b/tensordict/_torch_func.py index 6d3280958..fbabc77e4 100644 --- a/tensordict/_torch_func.py +++ b/tensordict/_torch_func.py @@ -13,17 +13,13 @@ import torch from tensordict._lazy import LazyStackedTensorDict from tensordict._td import TensorDict -from tensordict.base import ( - _is_leaf_nontensor, - is_non_tensor, - NO_DEFAULT, - TensorDictBase, -) +from tensordict.base import _is_leaf_nontensor, NO_DEFAULT, TensorDictBase from tensordict.persistent import PersistentTensorDict from tensordict.utils import ( _check_keys, _ErrorInteceptor, DeviceType, + is_non_tensor, lazy_legacy, set_lazy_legacy, ) diff --git a/tensordict/base.py b/tensordict/base.py index 5ac9bf5bf..0700ec4e4 100644 --- a/tensordict/base.py +++ b/tensordict/base.py @@ -60,6 +60,7 @@ IndexType, infer_size_impl, int_generator, + is_non_tensor, KeyedJaggedTensor, lazy_legacy, lock_blocked, @@ -5378,11 +5379,6 @@ def is_tensor_collection(datatype: type | Any) -> bool: return _is_tensor_collection(datatype) -def is_non_tensor(data): - """Checks if an item is a non-tensor.""" - return type(data).__dict__.get("_non_tensor", False) - - def _default_is_leaf(cls: Type) -> bool: return not _is_tensor_collection(cls) diff --git a/tensordict/tensorclass.py b/tensordict/tensorclass.py index 0c0dfdbf4..f64927c1b 100644 --- a/tensordict/tensorclass.py +++ b/tensordict/tensorclass.py @@ -28,12 +28,7 @@ from tensordict._td import is_tensor_collection, NO_DEFAULT, TensorDict, TensorDictBase from tensordict._tensordict import _unravel_key_to_tuple from tensordict._torch_func import TD_HANDLED_FUNCTIONS -from tensordict.base import ( - _ACCEPTED_CLASSES, - _register_tensor_class, - CompatibleType, - is_non_tensor, -) +from tensordict.base import _ACCEPTED_CLASSES, _register_tensor_class, CompatibleType from tensordict.memmap_deprec import MemmapTensor as _MemmapTensor from tensordict.utils import ( @@ -42,6 +37,7 @@ _LOCK_ERROR, DeviceType, IndexType, + is_non_tensor, is_tensorclass, NestedKey, ) diff --git a/tensordict/utils.py b/tensordict/utils.py index 0dd27a153..c64c30a7a 100644 --- a/tensordict/utils.py +++ b/tensordict/utils.py @@ -50,7 +50,6 @@ unravel_keys, ) -from tensordict.base import is_non_tensor from torch import Tensor from torch._C import _disabled_torch_function_impl from torch.nn.parameter import ( @@ -2170,3 +2169,8 @@ def __call__(self, mod: torch.nn.Module, args, kwargs): return else: raise RuntimeError("did not find pre-hook") + + +def is_non_tensor(data): + """Checks if an item is a non-tensor.""" + return type(data).__dict__.get("_non_tensor", False) From 856079865a830d9041781d3b40a80d412a73646b Mon Sep 17 00:00:00 2001 From: vmoens Date: Mon, 26 Feb 2024 11:49:37 -0500 Subject: [PATCH 26/32] amend --- tensordict/base.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tensordict/base.py b/tensordict/base.py index 0700ec4e4..93216ca12 100644 --- a/tensordict/base.py +++ b/tensordict/base.py @@ -2320,7 +2320,10 @@ def _get_non_tensor(self, key: NestedKey, default=NO_DEFAULT): value = self._get_str(key, default=default) if is_non_tensor(value): - return value.tolist() + data = getattr(value, "data", None) + if data is None: + return value.tolist() + return data return value def filter_non_tensor_data(self) -> T: From 43ddacd2bd2324b5f3d681604c6ef284aa2c3c6f Mon Sep 17 00:00:00 2001 From: vmoens Date: Mon, 26 Feb 2024 13:57:34 -0500 Subject: [PATCH 27/32] amend --- tensordict/base.py | 2 +- tensordict/tensorclass.py | 5 ++++- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/tensordict/base.py b/tensordict/base.py index 93216ca12..bbc3cbf0c 100644 --- a/tensordict/base.py +++ b/tensordict/base.py @@ -246,7 +246,7 @@ def __getitem__(self, index: IndexType) -> T: if is_non_tensor(result): result_data = getattr(result, "data", NO_DEFAULT) if result_data is NO_DEFAULT: - return result_data.tolist() + return result.tolist() return result_data return result diff --git a/tensordict/tensorclass.py b/tensordict/tensorclass.py index f64927c1b..cac696893 100644 --- a/tensordict/tensorclass.py +++ b/tensordict/tensorclass.py @@ -1266,7 +1266,10 @@ class NonTensorData: def __post_init__(self): if is_non_tensor(self.data): - self.data = self.data.tolist() + data = getattr(self.data, "data", None) + if data is None: + data = self.data.tolist() + self.data = data old_eq = self.__class__.__eq__ if old_eq is _eq: From 7c13cfaa435ae2ccac3467bc9eabe649d7161f77 Mon Sep 17 00:00:00 2001 From: vmoens Date: Mon, 26 Feb 2024 14:31:46 -0500 Subject: [PATCH 28/32] amend --- tensordict/_lazy.py | 58 ++++++++++++++++++++++++++++----------------- 1 file changed, 36 insertions(+), 22 deletions(-) diff --git a/tensordict/_lazy.py b/tensordict/_lazy.py index 3f9765246..00410a6dd 100644 --- a/tensordict/_lazy.py +++ b/tensordict/_lazy.py @@ -135,6 +135,8 @@ class LazyStackedTensorDict(TensorDictBase): `td.ndimension()-1` along which the stack should be performed. hook_out (callable, optional): a callable to execute after :meth:`~.get`. hook_in (callable, optional): a callable to execute before :meth:`~.set`. + stack_dim_name (str, optional): the name of the stack dimension. + Defaults to ``None``. Examples: >>> from tensordict import TensorDict @@ -185,6 +187,7 @@ def __init__( hook_out: callable | None = None, hook_in: callable | None = None, batch_size: Sequence[int] | None = None, # TODO: remove + stack_dim_name: str | None = None, ) -> None: self._is_locked = None @@ -229,6 +232,8 @@ def __init__( self.hook_in = hook_in if batch_size is not None and batch_size != self.batch_size: raise RuntimeError("batch_size does not match self.batch_size.") + if stack_dim_name is not None: + self._td_dim_name = stack_dim_name # These attributes should never be set @property @@ -855,7 +860,9 @@ def _get_str( # then we consider this default as non-stackable and return prematurly return default try: - out = self.lazy_stack(tensors, self.stack_dim) + out = self.lazy_stack( + tensors, self.stack_dim, stack_dim_name=self._td_dim_name + ) if _is_tensor_collection(out.__class__): if isinstance(out, LazyStackedTensorDict): # then it's a LazyStackedTD @@ -867,8 +874,6 @@ def _get_str( self._batch_size + out.batch_size[(len(self._batch_size) + incr) :] ) - if self._td_dim_name is not None: - out._td_dim_name = self._td_dim_name elif is_tensorclass(out): # then it's a tensorclass out._tensordict.hook_out = self.hook_out @@ -879,8 +884,6 @@ def _get_str( self._batch_size + out._tensordict.batch_size[(len(self._batch_size) + incr) :] ) - if self._td_dim_name is not None: - out._tensordict._td_dim_name = self._td_dim_name else: raise RuntimeError elif self.hook_out is not None: @@ -925,8 +928,10 @@ def lazy_stack( cls, items: Sequence[TensorDictBase], dim: int = 0, + *, device: DeviceType | None = None, out: T | None = None, + stack_dim_name: str | None = None, ) -> T: """Stacks tensordicts in a LazyStackedTensorDict.""" if not items: @@ -943,7 +948,10 @@ def lazy_stack( return NonTensorData._stack_non_tensor(items, dim=dim) lazy_stack = cls.lazy_stack( - [item._tensordict for item in items], dim=dim, out=out + [item._tensordict for item in items], + dim=dim, + out=out, + stack_dim_name=stack_dim_name, ) # we take the first non_tensordict by convention return type(items[0])._from_tensordict( @@ -968,7 +976,9 @@ def lazy_stack( # The first case is handled within _check_keys which fails if keys # don't match exactly. # The second requires a check over the tensor shapes. - return LazyStackedTensorDict(*items, stack_dim=dim) + return LazyStackedTensorDict( + *items, stack_dim=dim, stack_dim_name=stack_dim_name + ) else: batch_size = list(batch_size) batch_size.insert(dim, len(items)) @@ -1293,14 +1303,14 @@ def _clone(self, recurse: bool = True) -> T: result = type(self)( *[td._clone() for td in self.tensordicts], stack_dim=self.stack_dim, + stack_dim_name=self._td_dim_name, ) else: result = type(self)( *[td._clone(recurse=False) for td in self.tensordicts], stack_dim=self.stack_dim, + stack_dim_name=self._td_dim_name, ) - if self._td_dim_name is not None: - result._td_dim_name = self._td_dim_name return result def pin_memory(self) -> T: @@ -1451,13 +1461,12 @@ def _apply_nest( out = type(self)( *results, stack_dim=self.stack_dim, + stack_dim_name=self._td_dim_name, ) else: out = self if names is not None: out.names = names - else: - out._td_dim_name = self._td_dim_name return out def _select( @@ -1654,9 +1663,10 @@ def recompose(converted_idx, stack_dim=new_stack_dim): else: stack_elt, idx = item stack.append(self.tensordicts[stack_elt][idx]) - result = LazyStackedTensorDict.lazy_stack(stack, stack_dim) # TODO: this produces multiple dims with the same name - result._td_dim_name = self._td_dim_name + result = LazyStackedTensorDict.lazy_stack( + stack, stack_dim, stack_dim=self._td_dim_name + ) return result return recompose(converted_idx) @@ -1680,8 +1690,9 @@ def recompose(converted_idx, stack_dim=new_stack_dim): result.append(self.tensordicts[i]) else: result.append(self.tensordicts[i][_idx]) - result = LazyStackedTensorDict.lazy_stack(result, new_stack_dim) - result._td_dim_name = self._td_dim_name + result = LazyStackedTensorDict.lazy_stack( + result, new_stack_dim, stack_dim_name=self._td_dim_name + ) return result def __eq__(self, other): @@ -2427,8 +2438,8 @@ def _transpose(self, dim0, dim1): result = type(self)( *(td.transpose(dim0, dim1) for td in self.tensordicts), stack_dim=self.stack_dim, + stack_dim_name=self._td_dim_name, ) - result._td_dim_name = self._td_dim_name return result def _permute( @@ -2446,9 +2457,10 @@ def _permute( d if d < self.stack_dim else d - 1 for d in dims_list if d != self.stack_dim ] result = LazyStackedTensorDict.lazy_stack( - [td.permute(dims_list) for td in self.tensordicts], stack_dim + [td.permute(dims_list) for td in self.tensordicts], + stack_dim, + stack_dim_name=self._td_dim_name, ) - result._td_dim_name = self._td_dim_name return result def _squeeze(self, dim=None): @@ -2471,9 +2483,10 @@ def _squeeze(self, dim=None): else: stack_dim = self.stack_dim - 1 result = LazyStackedTensorDict.lazy_stack( - [td.squeeze(dim) for td in self.tensordicts], stack_dim + [td.squeeze(dim) for td in self.tensordicts], + stack_dim, + stack_dim_name=self._td_dim_name, ) - result._td_dim_name = result._td_dim_name else: result = self for dim in range(self.batch_dims - 1, -1, -1): @@ -2496,9 +2509,10 @@ def _unsqueeze(self, dim): else: stack_dim = self.stack_dim + 1 result = LazyStackedTensorDict.lazy_stack( - [td.unsqueeze(dim) for td in self.tensordicts], stack_dim + [td.unsqueeze(dim) for td in self.tensordicts], + stack_dim, + stack_dim_name=self._td_dim_name, ) - result._td_dim_name = result._td_dim_name return result lock_ = TensorDictBase.lock_ From 82eab7c7d84f921912700f4d15a13cbc3adfc0b9 Mon Sep 17 00:00:00 2001 From: vmoens Date: Mon, 26 Feb 2024 14:53:04 -0500 Subject: [PATCH 29/32] fix --- tensordict/_lazy.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensordict/_lazy.py b/tensordict/_lazy.py index 00410a6dd..971ef142d 100644 --- a/tensordict/_lazy.py +++ b/tensordict/_lazy.py @@ -1665,7 +1665,7 @@ def recompose(converted_idx, stack_dim=new_stack_dim): stack.append(self.tensordicts[stack_elt][idx]) # TODO: this produces multiple dims with the same name result = LazyStackedTensorDict.lazy_stack( - stack, stack_dim, stack_dim=self._td_dim_name + stack, stack_dim, stack_dim_name=self._td_dim_name ) return result From d5c299a5f13962614e8c746597baaf1486a4de9a Mon Sep 17 00:00:00 2001 From: vmoens Date: Mon, 26 Feb 2024 16:42:42 -0500 Subject: [PATCH 30/32] fix --- tensordict/tensorclass.py | 29 ++++++++++++++++++++++++++--- 1 file changed, 26 insertions(+), 3 deletions(-) diff --git a/tensordict/tensorclass.py b/tensordict/tensorclass.py index cac696893..c2029f4fa 100644 --- a/tensordict/tensorclass.py +++ b/tensordict/tensorclass.py @@ -221,6 +221,7 @@ def __torch_function__( cls.to_tensordict = _to_tensordict cls.device = property(_device, _device_setter) cls.batch_size = property(_batch_size, _batch_size_setter) + cls.names = property(_names, _names_setter) cls.__doc__ = f"{cls.__name__}{inspect.signature(cls)}" @@ -502,10 +503,12 @@ def wrapper(self, key: str, value: Any) -> None: # noqa: D417 """ __dict__ = self.__dict__ - if "_tensordict" not in __dict__ or "_non_tensordict" not in __dict__: + if ( + "_tensordict" not in __dict__ + or "_non_tensordict" not in __dict__ + or key in SET_ATTRIBUTES + ): return setattr_(self, key, value) - if key in SET_ATTRIBUTES: - return setattr(self._tensordict, key, value) out = self.set(key, value) if out is not self: @@ -844,6 +847,26 @@ def _batch_size_setter(self, new_size: torch.Size) -> None: # noqa: D417 self._tensordict._batch_size_setter(new_size) +def _names(self) -> torch.Size: + """Retrieves the dim names for the tensor class. + + Returns: + names (list of str) + + """ + return self._tensordict.names + + +def _names_setter(self, names: str) -> None: # noqa: D417 + """Set the value of ``tensorclass.names``. + + Args: + names (sequence of str) + + """ + self._tensordict.names = names + + def _state_dict( self, destination=None, prefix="", keep_vars=False, flatten=False ) -> dict[str, Any]: From 26b427dcc5016cabe3f2c1203b25414407928319 Mon Sep 17 00:00:00 2001 From: vmoens Date: Mon, 26 Feb 2024 18:11:55 -0500 Subject: [PATCH 31/32] amend --- test/test_tensordict.py | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/test/test_tensordict.py b/test/test_tensordict.py index 9ffb602a6..66e2875c8 100644 --- a/test/test_tensordict.py +++ b/test/test_tensordict.py @@ -7797,6 +7797,29 @@ def test_map_with_out(self, mmap, chunksize, tmpdir): input.map(self.selectfn, num_workers=2, chunksize=chunksize, out=out) assert (out["a"] == torch.arange(10)).all(), (chunksize, mmap) + @classmethod + def nontensor_check(cls, td): + td["check"] = td["non_tensor"] == ( + "a string!" if (td["tensor"] % 2) == 0 else "another string!" + ) + return td + + def test_non_tensor(self): + # with NonTensorStack + td = TensorDict( + {"tensor": torch.arange(10), "non_tensor": "a string!"}, batch_size=[10] + ) + td[1::2] = TensorDict({"non_tensor": "another string!"}, [5]) + td = td.map(self.nontensor_check, chunksize=0) + assert td["check"].all() + # with NonTensorData + td = TensorDict( + {"tensor": torch.zeros(10, dtype=torch.int), "non_tensor": "a string!"}, + batch_size=[10], + ) + td = td.map(self.nontensor_check, chunksize=0) + assert td["check"].all() + # class TestNonTensorData: class TestNonTensorData: From 137d89da3df83f7b0de573cd4bc2ef00aee687f1 Mon Sep 17 00:00:00 2001 From: vmoens Date: Mon, 26 Feb 2024 18:12:38 -0500 Subject: [PATCH 32/32] amend --- .github/workflows/test-rl-gpu.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/test-rl-gpu.yml b/.github/workflows/test-rl-gpu.yml index ab284d754..dea057e8b 100644 --- a/.github/workflows/test-rl-gpu.yml +++ b/.github/workflows/test-rl-gpu.yml @@ -31,6 +31,7 @@ jobs: repository: pytorch/tensordict gpu-arch-type: cuda gpu-arch-version: ${{ matrix.cuda_arch_version }} + timeout: 120 script: | # Set env vars from matrix export PYTHON_VERSION=${{ matrix.python_version }}