diff --git a/.github/workflows/test-rl-gpu.yml b/.github/workflows/test-rl-gpu.yml index ab284d754..dea057e8b 100644 --- a/.github/workflows/test-rl-gpu.yml +++ b/.github/workflows/test-rl-gpu.yml @@ -31,6 +31,7 @@ jobs: repository: pytorch/tensordict gpu-arch-type: cuda gpu-arch-version: ${{ matrix.cuda_arch_version }} + timeout: 120 script: | # Set env vars from matrix export PYTHON_VERSION=${{ matrix.python_version }} diff --git a/docs/source/reference/tensorclass.rst b/docs/source/reference/tensorclass.rst index 17518df1d..8e6e4e907 100644 --- a/docs/source/reference/tensorclass.rst +++ b/docs/source/reference/tensorclass.rst @@ -273,3 +273,4 @@ Here is an example: tensorclass NonTensorData + NonTensorStack diff --git a/tensordict/__init__.py b/tensordict/__init__.py index c71661ceb..5e6dc8761 100644 --- a/tensordict/__init__.py +++ b/tensordict/__init__.py @@ -16,7 +16,7 @@ from tensordict.memmap import MemoryMappedTensor from tensordict.memmap_deprec import is_memmap, MemmapTensor, set_transfer_ownership from tensordict.persistent import PersistentTensorDict -from tensordict.tensorclass import NonTensorData, tensorclass +from tensordict.tensorclass import NonTensorData, NonTensorStack, tensorclass from tensordict.utils import ( assert_allclose_td, is_batchedtensor, @@ -43,6 +43,7 @@ "TensorDict", "TensorDictBase", "merge_tensordicts", + "NonTensorStack", "set_transfer_ownership", "pad_sequence", "is_memmap", diff --git a/tensordict/_lazy.py b/tensordict/_lazy.py index b0c6ee6cc..971ef142d 100644 --- a/tensordict/_lazy.py +++ b/tensordict/_lazy.py @@ -53,6 +53,7 @@ expand_right, IndexType, infer_size_impl, + is_non_tensor, is_tensorclass, KeyedJaggedTensor, lock_blocked, @@ -134,6 +135,8 @@ class LazyStackedTensorDict(TensorDictBase): `td.ndimension()-1` along which the stack should be performed. hook_out (callable, optional): a callable to execute after :meth:`~.get`. hook_in (callable, optional): a callable to execute before :meth:`~.set`. + stack_dim_name (str, optional): the name of the stack dimension. + Defaults to ``None``. Examples: >>> from tensordict import TensorDict @@ -184,6 +187,7 @@ def __init__( hook_out: callable | None = None, hook_in: callable | None = None, batch_size: Sequence[int] | None = None, # TODO: remove + stack_dim_name: str | None = None, ) -> None: self._is_locked = None @@ -200,6 +204,10 @@ def __init__( ) _batch_size = tensordicts[0].batch_size device = tensordicts[0].device + if stack_dim > len(_batch_size): + raise RuntimeError( + f"Stack dim {stack_dim} is too big for batch size {_batch_size}." + ) for td in tensordicts[1:]: if not is_tensor_collection(td): @@ -224,6 +232,8 @@ def __init__( self.hook_in = hook_in if batch_size is not None and batch_size != self.batch_size: raise RuntimeError("batch_size does not match self.batch_size.") + if stack_dim_name is not None: + self._td_dim_name = stack_dim_name # These attributes should never be set @property @@ -487,9 +497,10 @@ def _split_index(self, index): isinteger = False is_nd_tensor = False cursor = 0 # the dimension cursor - selected_td_idx = range(len(self.tensordicts)) + selected_td_idx = torch.arange(len(self.tensordicts)) has_bool = False num_squash = 0 + encountered_tensor = False for i, idx in enumerate(index): # noqa: B007 cursor_incr = 1 if idx is None: @@ -509,10 +520,8 @@ def _split_index(self, index): if not isinstance(selected_td_idx, range): isinteger = True selected_td_idx = [selected_td_idx] - elif isinstance(idx, (list, range)): - selected_td_idx = idx - elif isinstance(idx, (torch.Tensor, np.ndarray)): - if idx.dtype in (np.dtype("bool"), torch.bool): + elif isinstance(idx, torch.Tensor): + if idx.dtype == torch.bool: # we mark that we need to dispatch the indices across stack idx has_bool = True # split mask along dim @@ -522,11 +531,14 @@ def _split_index(self, index): split_dim = self.stack_dim - num_single mask_loc = i else: - if isinstance(idx, np.ndarray): - idx = torch.tensor(idx) is_nd_tensor = True - selected_td_idx = range(len(idx)) - out.append(idx.unbind(0)) + if not encountered_tensor: + # num_single -= idx.ndim - 1 + encountered_tensor = True + else: + num_single += 1 + selected_td_idx = idx + # out.append(idx.unbind(0)) else: raise TypeError(f"Invalid index type: {type(idx)}.") else: @@ -537,13 +549,11 @@ def _split_index(self, index): ( ftdim.Dim, slice, - list, - range, ), ): out.append(idx) - elif isinstance(idx, (np.ndarray, torch.Tensor)): - if idx.dtype in (np.dtype("bool"), torch.bool): + elif isinstance(idx, torch.Tensor): + if idx.dtype == torch.bool: cursor_incr = idx.ndim if cursor < self.stack_dim: num_squash += cursor_incr - 1 @@ -568,7 +578,11 @@ def _split_index(self, index): # smth[torch.tensor(1)].ndim = smth.ndim-1 # smth[torch.tensor([1])].ndim = smth.ndim # smth[torch.tensor([[1]])].ndim = smth.ndim+1 - num_single -= idx.ndim - 1 + if not encountered_tensor: + num_single -= idx.ndim - 1 + encountered_tensor = True + else: + num_single += 1 out.append(idx) else: raise TypeError(f"Invalid index type: {type(idx)}.") @@ -593,20 +607,45 @@ def _split_index(self, index): elif is_nd_tensor: def isindexable(idx): - if isinstance(idx, (torch.Tensor, np.ndarray)): - if idx.dtype in (torch.bool, np.dtype("bool")): + if isinstance(idx, torch.Tensor): + if idx.dtype == torch.bool: return False return True if isinstance(idx, (tuple, list, range)): return True return False - out = tuple( - tuple(idx if not isindexable(idx) else idx[i] for idx in out) - for i in selected_td_idx - ) + def outer_list(tensor_index, tuple_index): + """Converts a tensor and a tuple to a nested list where each leaf is a (int, index) tuple where the index only points to one element.""" + if isinstance(tensor_index, torch.Tensor): + list_index = tensor_index.tolist() + else: + list_index = tensor_index + list_result = [] + + def index_tuple_index(i, convert=False): + for idx in tuple_index: + if isindexable(idx): + if convert: + yield int(idx[i]) + else: + yield idx[i] + else: + yield idx + + for i, idx in enumerate(list_index): + if isinstance(idx, int): + list_result.append( + (idx, tuple(index_tuple_index(i, convert=True))) + ) + elif isinstance(idx, list): + list_result.append(outer_list(idx, tuple(index_tuple_index(i)))) + else: + raise NotImplementedError + return list_result + return { - "index_dict": dict(enumerate(out)), + "index_dict": outer_list(selected_td_idx, out), "num_single": num_single, "isinteger": isinteger, "has_bool": has_bool, @@ -646,8 +685,19 @@ def _set_at_str(self, key, value, index, *, validated): if is_nd_tensor: unbind_dim = self.stack_dim - num_single + num_none - num_squash value_unbind = value.unbind(unbind_dim) - for idx, _value in zip(converted_idx.values(), value_unbind): - self._set_at_str(key, _value, idx, validated=validated) + + def set_at_str(converted_idx): + for i, item in enumerate(converted_idx): + if isinstance(item, list): + set_at_str(item) + else: + _value = value_unbind[i] + stack_idx, idx = item + self.tensordicts[stack_idx]._set_at_str( + key, _value, idx, validated=validated + ) + + set_at_str(converted_idx) return self elif not has_bool: unbind_dim = self.stack_dim - num_single + num_none - num_squash @@ -718,7 +768,7 @@ def _legacy_unsqueeze(self, dim: int) -> T: else: dim = dim - 1 stack_dim = self.stack_dim - return LazyStackedTensorDict( + return type(self)( *(tensordict.unsqueeze(dim) for tensordict in self.tensordicts), stack_dim=stack_dim, ) @@ -756,7 +806,7 @@ def _legacy_squeeze(self, dim: int | None = None) -> T: else: dim = dim - 1 stack_dim = self.stack_dim - return LazyStackedTensorDict( + return type(self)( *(tensordict.squeeze(dim) for tensordict in self.tensordicts), stack_dim=stack_dim, ) @@ -810,7 +860,9 @@ def _get_str( # then we consider this default as non-stackable and return prematurly return default try: - out = self.lazy_stack(tensors, self.stack_dim) + out = self.lazy_stack( + tensors, self.stack_dim, stack_dim_name=self._td_dim_name + ) if _is_tensor_collection(out.__class__): if isinstance(out, LazyStackedTensorDict): # then it's a LazyStackedTD @@ -822,8 +874,6 @@ def _get_str( self._batch_size + out.batch_size[(len(self._batch_size) + incr) :] ) - if self._td_dim_name is not None: - out._td_dim_name = self._td_dim_name elif is_tensorclass(out): # then it's a tensorclass out._tensordict.hook_out = self.hook_out @@ -834,8 +884,6 @@ def _get_str( self._batch_size + out._tensordict.batch_size[(len(self._batch_size) + incr) :] ) - if self._td_dim_name is not None: - out._tensordict._td_dim_name = self._td_dim_name else: raise RuntimeError elif self.hook_out is not None: @@ -880,25 +928,30 @@ def lazy_stack( cls, items: Sequence[TensorDictBase], dim: int = 0, + *, device: DeviceType | None = None, out: T | None = None, + stack_dim_name: str | None = None, ) -> T: """Stacks tensordicts in a LazyStackedTensorDict.""" if not items: raise RuntimeError("items cannot be empty") - from .tensorclass import NonTensorData - if all(isinstance(item, torch.Tensor) for item in items): return torch.stack(items, dim=dim, out=out) if all( is_tensorclass(item) and type(item) == type(items[0]) # noqa: E721 for item in items ): - if all(isinstance(tensordict, NonTensorData) for tensordict in items): + if all(is_non_tensor(tensordict) for tensordict in items): + from .tensorclass import NonTensorData + return NonTensorData._stack_non_tensor(items, dim=dim) lazy_stack = cls.lazy_stack( - [item._tensordict for item in items], dim=dim, out=out + [item._tensordict for item in items], + dim=dim, + out=out, + stack_dim_name=stack_dim_name, ) # we take the first non_tensordict by convention return type(items[0])._from_tensordict( @@ -923,7 +976,9 @@ def lazy_stack( # The first case is handled within _check_keys which fails if keys # don't match exactly. # The second requires a check over the tensor shapes. - return LazyStackedTensorDict(*items, stack_dim=dim) + return LazyStackedTensorDict( + *items, stack_dim=dim, stack_dim_name=stack_dim_name + ) else: batch_size = list(batch_size) batch_size.insert(dim, len(items)) @@ -1236,7 +1291,7 @@ def contiguous(self) -> T: return out def empty(self, recurse=False) -> T: - return LazyStackedTensorDict( + return type(self)( *[td.empty(recurse=recurse) for td in self.tensordicts], stack_dim=self.stack_dim, ) @@ -1245,17 +1300,17 @@ def _clone(self, recurse: bool = True) -> T: if recurse: # This could be optimized using copy but we must be careful with # metadata (_is_shared etc) - result = LazyStackedTensorDict( + result = type(self)( *[td._clone() for td in self.tensordicts], stack_dim=self.stack_dim, + stack_dim_name=self._td_dim_name, ) else: - result = LazyStackedTensorDict( + result = type(self)( *[td._clone(recurse=False) for td in self.tensordicts], stack_dim=self.stack_dim, + stack_dim_name=self._td_dim_name, ) - if self._td_dim_name is not None: - result._td_dim_name = self._td_dim_name return result def pin_memory(self) -> T: @@ -1274,7 +1329,7 @@ def to(self, *args, **kwargs) -> T: if device is not None and dtype is None and device == self.device: return result - return LazyStackedTensorDict( + return type(self)( *[td.to(*args, **kwargs) for td in self.tensordicts], stack_dim=self.stack_dim, hook_out=self.hook_out, @@ -1403,16 +1458,15 @@ def _apply_nest( if filter_empty and all(r is None for r in results): return if not inplace: - out = LazyStackedTensorDict( + out = type(self)( *results, stack_dim=self.stack_dim, + stack_dim_name=self._td_dim_name, ) else: out = self if names is not None: out.names = names - else: - out._td_dim_name = self._td_dim_name return out def _select( @@ -1429,7 +1483,7 @@ def _select( ] if inplace: return self - result = LazyStackedTensorDict(*tensordicts, stack_dim=self.stack_dim) + result = type(self)(*tensordicts, stack_dim=self.stack_dim) return result def _exclude( @@ -1442,7 +1496,7 @@ def _exclude( if inplace: self.tensordicts = tensordicts return self - result = LazyStackedTensorDict(*tensordicts, stack_dim=self.stack_dim) + result = type(self)(*tensordicts, stack_dim=self.stack_dim) return result def __setitem__(self, index: IndexType, value: T) -> T: @@ -1460,10 +1514,12 @@ def __setitem__(self, index: IndexType, value: T) -> T: ) return - if any(isinstance(sub_index, (list, range)) for sub_index in index): + if any( + isinstance(sub_index, (list, range, np.ndarray)) for sub_index in index + ): index = tuple( - torch.tensor(sub_index, device=self.device) - if isinstance(sub_index, (list, range)) + torch.as_tensor(sub_index, device=self.device) + if isinstance(sub_index, (list, range, np.ndarray)) else sub_index for sub_index in index ) @@ -1471,9 +1527,9 @@ def __setitem__(self, index: IndexType, value: T) -> T: if index is Ellipsis or (isinstance(index, tuple) and Ellipsis in index): index = convert_ellipsis_to_idx(index, self.batch_size) elif isinstance(index, (list, range)): - index = torch.tensor(index, device=self.device) + index = torch.as_tensor(index, device=self.device) - if isinstance(value, (TensorDictBase, dict)): + if is_tensor_collection(value) or isinstance(value, dict): indexed_bs = _getitem_batch_size(self.batch_size, index) if isinstance(value, dict): value = TensorDict( @@ -1500,13 +1556,26 @@ def __setitem__(self, index: IndexType, value: T) -> T: if isinteger: # this will break if the index along the stack dim is [0] or :1 or smth for i, _idx in converted_idx.items(): - self.tensordicts[i][_idx] = value + if _idx == (): + self.tensordicts[i].update(value, inplace=True) + else: + self.tensordicts[i][_idx] = value return self if is_nd_tensor: - raise RuntimeError( - "Indexing along stack dim with a non-boolean tensor is not supported yet. " - "Use SubTensorDict instead." - ) + unbind_dim = self.stack_dim - num_single + num_none - num_squash + + # converted_idx is a nested list with (int, index) items + def assign(converted_idx, value=value): + value = value.unbind(unbind_dim) + for i, item in enumerate(converted_idx): + if isinstance(item, list): + assign(item) + else: + stack_item, idx = item + self.tensordicts[stack_item][idx] = value[i] + + assign(converted_idx) + return self if not has_bool: unbind_dim = self.stack_dim - num_single + num_none - num_squash value_unbind = value.unbind(unbind_dim) @@ -1514,7 +1583,10 @@ def __setitem__(self, index: IndexType, value: T) -> T: converted_idx.items(), value_unbind, ): - self.tensordicts[i][_idx] = _value + if _idx == (): + self.tensordicts[i].update(_value, inplace=True) + else: + self.tensordicts[i][_idx] = _value else: # we must split, not unbind mask_unbind = split_index["individual_masks"] @@ -1582,11 +1654,22 @@ def __getitem__(self, index: IndexType) -> T: return torch.cat(result, cat_dim) elif is_nd_tensor: new_stack_dim = self.stack_dim - num_single + num_none - out = LazyStackedTensorDict.lazy_stack( - [self[idx] for idx in converted_idx.values()], new_stack_dim - ) - out._td_dim_name = self._td_dim_name - return out + + def recompose(converted_idx, stack_dim=new_stack_dim): + stack = [] + for item in converted_idx: + if isinstance(item, list): + stack.append(recompose(item, stack_dim=stack_dim)) + else: + stack_elt, idx = item + stack.append(self.tensordicts[stack_elt][idx]) + # TODO: this produces multiple dims with the same name + result = LazyStackedTensorDict.lazy_stack( + stack, stack_dim, stack_dim_name=self._td_dim_name + ) + return result + + return recompose(converted_idx) else: if isinteger: for ( @@ -1603,9 +1686,13 @@ def __getitem__(self, index: IndexType) -> T: result = [] new_stack_dim = self.stack_dim - num_single + num_none - num_squash for i, _idx in converted_idx.items(): - result.append(self.tensordicts[i][_idx]) - result = LazyStackedTensorDict.lazy_stack(result, new_stack_dim) - result._td_dim_name = self._td_dim_name + if _idx == (): + result.append(self.tensordicts[i]) + else: + result.append(self.tensordicts[i][_idx]) + result = LazyStackedTensorDict.lazy_stack( + result, new_stack_dim, stack_dim_name=self._td_dim_name + ) return result def __eq__(self, other): @@ -2329,9 +2416,9 @@ def _transpose(self, dim0, dim1): # example: shape = [5, 4, 3, 2, 1], stack_dim=1, dim0=1, dim1=4 # resulting shape: [5, 1, 3, 2, 4] if dim1 == dim0 + 1: - result = LazyStackedTensorDict(*self.tensordicts, stack_dim=dim1) + result = type(self)(*self.tensordicts, stack_dim=dim1) else: - result = LazyStackedTensorDict( + result = type(self)( *(td.transpose(dim0, dim1 - 1) for td in self.tensordicts), stack_dim=dim1, ) @@ -2339,20 +2426,20 @@ def _transpose(self, dim0, dim1): # example: shape = [5, 4, 3, 2, 1], stack_dim=3, dim0=1, dim1=3 # resulting shape: [5, 2, 3, 4, 1] if dim0 + 1 == dim1: - result = LazyStackedTensorDict(*self.tensordicts, stack_dim=dim0) + result = type(self)(*self.tensordicts, stack_dim=dim0) else: - result = LazyStackedTensorDict( + result = type(self)( *(td.transpose(dim0 + 1, dim1) for td in self.tensordicts), stack_dim=dim0, ) else: dim0 = dim0 if dim0 < self.stack_dim else dim0 - 1 dim1 = dim1 if dim1 < self.stack_dim else dim1 - 1 - result = LazyStackedTensorDict( + result = type(self)( *(td.transpose(dim0, dim1) for td in self.tensordicts), stack_dim=self.stack_dim, + stack_dim_name=self._td_dim_name, ) - result._td_dim_name = self._td_dim_name return result def _permute( @@ -2370,9 +2457,10 @@ def _permute( d if d < self.stack_dim else d - 1 for d in dims_list if d != self.stack_dim ] result = LazyStackedTensorDict.lazy_stack( - [td.permute(dims_list) for td in self.tensordicts], stack_dim + [td.permute(dims_list) for td in self.tensordicts], + stack_dim, + stack_dim_name=self._td_dim_name, ) - result._td_dim_name = self._td_dim_name return result def _squeeze(self, dim=None): @@ -2395,9 +2483,10 @@ def _squeeze(self, dim=None): else: stack_dim = self.stack_dim - 1 result = LazyStackedTensorDict.lazy_stack( - [td.squeeze(dim) for td in self.tensordicts], stack_dim + [td.squeeze(dim) for td in self.tensordicts], + stack_dim, + stack_dim_name=self._td_dim_name, ) - result._td_dim_name = result._td_dim_name else: result = self for dim in range(self.batch_dims - 1, -1, -1): @@ -2420,9 +2509,10 @@ def _unsqueeze(self, dim): else: stack_dim = self.stack_dim + 1 result = LazyStackedTensorDict.lazy_stack( - [td.unsqueeze(dim) for td in self.tensordicts], stack_dim + [td.unsqueeze(dim) for td in self.tensordicts], + stack_dim, + stack_dim_name=self._td_dim_name, ) - result._td_dim_name = result._td_dim_name return result lock_ = TensorDictBase.lock_ diff --git a/tensordict/_td.py b/tensordict/_td.py index df0660191..021c9151e 100644 --- a/tensordict/_td.py +++ b/tensordict/_td.py @@ -38,6 +38,9 @@ from tensordict.memmap import MemoryMappedTensor from tensordict.memmap_deprec import MemmapTensor as _MemmapTensor from tensordict.utils import ( + _add_batch_dim_pre_hook, + _BatchedUninitializedBuffer, + _BatchedUninitializedParameter, _clone_value, _expand_to_match_shape, _get_item, @@ -68,6 +71,7 @@ DeviceType, expand_as_right, IndexType, + is_non_tensor, is_tensorclass, KeyedJaggedTensor, lock_blocked, @@ -305,9 +309,8 @@ def is_empty(self): if _is_tensor_collection(type(item)): if not item.is_empty(): return False - from tensordict.tensorclass import NonTensorData - if isinstance(item, NonTensorData): + if is_non_tensor(item): return False else: return False @@ -327,7 +330,7 @@ def _to_module( if not use_state_dict and isinstance(module, TensorDictBase): if return_swap: swap = module.copy() - module.update(self) + module._param_td = getattr(self, "_param_td", self) return swap else: module.update(self) @@ -403,7 +406,17 @@ def convert_type(x, y): continue child = __dict__["_modules"][key] local_out = memo.get(id(child), NO_DEFAULT) + if local_out is NO_DEFAULT: + # if isinstance(child, TensorDictBase): + # # then child is a TensorDictParams + # from tensordict.nn import TensorDictParams + # + # local_out = child + # if not isinstance(value, TensorDictParams): + # value = TensorDictParams(value, no_convert=True) + # __dict__["_modules"][key] = value + # else: local_out = value._to_module( child, inplace=inplace, @@ -676,7 +689,11 @@ def make_result(): any_set = False for key, item in self.items(): - if not call_on_nested and _is_tensor_collection(item.__class__): + if ( + not call_on_nested + and _is_tensor_collection(item.__class__) + # and not is_non_tensor(item) + ): if default is not NO_DEFAULT: _others = [_other._get_str(key, default=None) for _other in others] _others = [ @@ -750,15 +767,26 @@ def make_result(): @cache # noqa: B019 def _add_batch_dim(self, *, in_dim, vmap_level): td = self + + def _add_batch_dim_wrapper(key, value): + if is_tensor_collection(value): + return value._add_batch_dim(in_dim=in_dim, vmap_level=vmap_level) + + if isinstance( + value, (_BatchedUninitializedParameter, _BatchedUninitializedBuffer) + ): + value.in_dim = in_dim + value.vmap_level = vmap_level + return value + return _add_batch_dim(value, in_dim, vmap_level) + out = TensorDict( - { - key: value._add_batch_dim(in_dim=in_dim, vmap_level=vmap_level) - if is_tensor_collection(value) - else _add_batch_dim(value, in_dim, vmap_level) - for key, value in td.items() - }, - batch_size=[b for i, b in enumerate(td.batch_size) if i != in_dim], + {key: _add_batch_dim_wrapper(key, value) for key, value in td.items()}, + batch_size=torch.Size( + [b for i, b in enumerate(td.batch_size) if i != in_dim] + ), names=[name for i, name in enumerate(td.names) if i != in_dim], + _run_checks=False, ) return out @@ -805,12 +833,14 @@ def _index_tensordict( raise RuntimeError( f"indexing a tensordict with td.batch_dims==0 is not permitted. Got index {index}." ) - if names is None: - names = self._get_names_idx(index) if new_batch_size is not None: batch_size = new_batch_size else: batch_size = _getitem_batch_size(batch_size, index) + + if names is None: + names = self._get_names_idx(index) + source = {} for key, item in self.items(): if isinstance(item, TensorDict): @@ -1347,14 +1377,20 @@ def is_boolean(idx): # this will convert a [None, :, :, 0, None, 0] in [None, 0, 1, None, 3] count = 0 idx_to_take = [] + no_more_tensors = False for _idx in idx_names: if _idx is None: idx_to_take.append(None) elif _is_number(_idx): count += 1 elif isinstance(_idx, (torch.Tensor, np.ndarray)): - idx_to_take.extend([count] * _idx.ndim) - count += 1 + if not no_more_tensors: + idx_to_take.extend([count] * _idx.ndim) + count += 1 + no_more_tensors = True + else: + # skip this one + count += 1 else: idx_to_take.append(count) count += 1 @@ -1368,6 +1404,7 @@ def names(self, value): self._rename_subtds(value) self._erase_names() return + value = list(value) num_none = sum(v is None for v in value) if num_none: num_none -= 1 @@ -1524,7 +1561,9 @@ def _set_at_str(self, key, value, idx, *, validated): tensor_in = _sub_index(tensor_in, idx) tensor_in.copy_(value) else: - _set_item(tensor_in, idx, value, validated=validated) + tensor_out = _set_item(tensor_in, idx, value, validated=validated) + if tensor_in is not tensor_out: + self._set_str(key, tensor_out, validated=True, inplace=False) return self @@ -2352,10 +2391,11 @@ def _set_at_str(self, key, value, idx, *, validated): ) tensor_in = _sub_index(tensor_in, idx) tensor_in.copy_(value) + tensor_out = tensor_in else: - _set_item(tensor_in, idx, value, validated=validated) + tensor_out = _set_item(tensor_in, idx, value, validated=validated) # make sure that the value is updated - self._source._set_at_str(key, tensor_in, self.idx, validated=validated) + self._source._set_at_str(key, tensor_out, self.idx, validated=validated) return self def _set_at_tuple(self, key, value, idx, *, validated): @@ -2416,15 +2456,17 @@ def get( def _get_non_tensor(self, key: NestedKey, default=NO_DEFAULT): out = super()._get_non_tensor(key, default=default) - from tensordict.tensorclass import NonTensorData - if isinstance(out, _SubTensorDict) and isinstance(out._source, NonTensorData): - return out._source.data + if isinstance(out, _SubTensorDict) and is_non_tensor(out._source): + return out._source return out def _get_str(self, key, default): if key in self.keys() and _is_tensor_collection(self.entry_class(key)): - return _SubTensorDict(self._source._get_str(key, NO_DEFAULT), self.idx) + data = self._source._get_str(key, NO_DEFAULT) + if is_non_tensor(data): + return data[self.idx] + return _SubTensorDict(data, self.idx) return self._source._get_at_str(key, self.idx, default=default) def _get_tuple(self, key, default): @@ -2998,7 +3040,10 @@ def __contains__(self, key: NestedKey) -> bool: if isinstance(key, str): if key in self._keys(): if self.leaves_only: - return not _is_tensor_collection(self.tensordict.entry_class(key)) + # TODO: make this faster for LazyStacked without compromising regular + return not _is_tensor_collection( + type(self.tensordict._get_str(key)) + ) return True return False else: @@ -3006,25 +3051,30 @@ def __contains__(self, key: NestedKey) -> bool: if len(key) == 1: return key[0] in self._keys() elif self.include_nested: - if key[0] in self._keys(): - entry_type = self.tensordict.entry_class(key[0]) - if entry_type in (Tensor, _MemmapTensor): + item_root = self.tensordict._get_str(key[0], default=None) + if item_root is not None: + entry_type = type(item_root) + if issubclass(entry_type, (Tensor, _MemmapTensor)): return False - if entry_type is KeyedJaggedTensor: + elif entry_type is KeyedJaggedTensor: if len(key) > 2: return False - return key[1] in self.tensordict.get(key[0]).keys() + return key[1] in item_root.keys() + # TODO: make this faster for LazyStacked without compromising regular _is_tensordict = _is_tensor_collection(entry_type) if _is_tensordict: # # this will call _unravel_key_to_tuple many times # return key[1:] in self.tensordict._get_str(key[0], NO_DEFAULT).keys(include_nested=self.include_nested) # this won't call _unravel_key_to_tuple but requires to get the default which can be suboptimal - leaf_td = self.tensordict._get_tuple(key[:-1], None) - if leaf_td is None or ( - not _is_tensor_collection(leaf_td.__class__) - and not isinstance(leaf_td, KeyedJaggedTensor) - ): - return False + if len(key) >= 3: + leaf_td = item_root._get_tuple(key[1:-1], None) + if leaf_td is None or ( + not _is_tensor_collection(leaf_td.__class__) + and not isinstance(leaf_td, KeyedJaggedTensor) + ): + return False + else: + leaf_td = item_root return key[-1] in leaf_td.keys() return False # this is reached whenever there is more than one key but include_nested is False @@ -3040,7 +3090,7 @@ def __repr__(self): def _set_tensor_dict( # noqa: F811 module_dict, hooks, - module, + module: torch.nn.Module, name: str, tensor: torch.Tensor, inplace: bool, @@ -3066,6 +3116,14 @@ def _set_tensor_dict( # noqa: F811 if output is not None: tensor = output module_dict["_parameters"][name] = tensor + + if isinstance( + tensor, (_BatchedUninitializedParameter, _BatchedUninitializedBuffer) + ): + module.register_forward_pre_hook( + _add_batch_dim_pre_hook(), with_kwargs=True + ) + elif was_buffer and isinstance(tensor, torch.Tensor): module_dict["_buffers"][name] = tensor else: diff --git a/tensordict/_torch_func.py b/tensordict/_torch_func.py index 087b49534..fbabc77e4 100644 --- a/tensordict/_torch_func.py +++ b/tensordict/_torch_func.py @@ -11,7 +11,6 @@ from typing import Any, Callable, Sequence, TypeVar import torch - from tensordict._lazy import LazyStackedTensorDict from tensordict._td import TensorDict from tensordict.base import _is_leaf_nontensor, NO_DEFAULT, TensorDictBase @@ -20,10 +19,16 @@ _check_keys, _ErrorInteceptor, DeviceType, + is_non_tensor, lazy_legacy, set_lazy_legacy, ) from torch import Tensor +from torch.nn.parameter import ( + UninitializedBuffer, + UninitializedParameter, + UninitializedTensorMixin, +) TD_HANDLED_FUNCTIONS: dict[Callable, Callable] = {} LAZY_TD_HANDLED_FUNCTIONS: dict[Callable, Callable] = {} @@ -92,8 +97,10 @@ def _gather_tensor(tensor, dest=None): return out if out is None: - names = input.names if input._has_names() else None - + if len(index.shape) == input.ndim and input._has_names(): + names = input.names + else: + names = None return TensorDict( { key: _gather_tensor(value) @@ -157,6 +164,32 @@ def _ones_like(td: T, **kwargs: Any) -> T: return td_clone +@implements_for_td(torch.rand_like) +def _rand_like(td: T, **kwargs: Any) -> T: + td_clone = td._fast_apply(lambda x: torch.rand_like(x)) + if "device" in kwargs: + td_clone = td_clone.to(kwargs.pop("device")) + if len(kwargs): + raise RuntimeError( + f"keyword arguments {list(kwargs.keys())} are not " + f"supported with full_like with TensorDict" + ) + return td_clone + + +@implements_for_td(torch.randn_like) +def _randn_like(td: T, **kwargs: Any) -> T: + td_clone = td._fast_apply(lambda x: torch.randn_like(x)) + if "device" in kwargs: + td_clone = td_clone.to(kwargs.pop("device")) + if len(kwargs): + raise RuntimeError( + f"keyword arguments {list(kwargs.keys())} are not " + f"supported with full_like with TensorDict" + ) + return td_clone + + @implements_for_td(torch.empty_like) def _empty_like(td: T, *args, **kwargs) -> T: try: @@ -349,9 +382,9 @@ def _stack( if not list_of_tensordicts: raise RuntimeError("list_of_tensordicts cannot be empty") - from tensordict.tensorclass import NonTensorData + if all(is_non_tensor(td) for td in list_of_tensordicts): + from tensordict.tensorclass import NonTensorData - if all(isinstance(tensordict, NonTensorData) for tensordict in list_of_tensordicts): return NonTensorData._stack_non_tensor(list_of_tensordicts, dim=dim) batch_size = list_of_tensordicts[0].batch_size @@ -376,7 +409,7 @@ def _stack( From v0.4 onward, a dense stack will be returned and to build a lazy stack, an explicit call to LazyStackedTensorDict.lazy_stack will be required. To silence this warning, choose one of the following options: -- set the LAZY_LEGACY_OP environment variable to 'False' (recommended) or 'True' depending on +- set the LAZY_LEGACY_OP environment variable to 'False' (recommended, default) or 'True' depending on the behaviour you want to use. Another way to achieve this is to call `tensordict.set_lazy_legacy(False).set()` at the beginning of your script. - set the decorator/context manager `tensordict.set_lazy_legacy(False)` (recommended) around @@ -406,24 +439,32 @@ def _stack( out = {} for key in keys: out[key] = [] + is_lazy = False tensor_shape = None for _tensordict in list_of_tensordicts: tensor = _tensordict._get_str(key, default=NO_DEFAULT) - if tensor_shape is None: + if isinstance(tensor, UninitializedTensorMixin): + is_lazy = True + elif tensor_shape is None: tensor_shape = tensor.shape elif tensor.shape != tensor_shape: with set_lazy_legacy(True): return _stack(list_of_tensordicts, dim=dim) out[key].append(tensor) + out[key] = (out[key], is_lazy) - def stack_fn(key_values): - key, values = key_values + def stack_fn(key, values, is_lazy): + if is_lazy: + return _stack_uninit_params(values, dim) with _ErrorInteceptor( key, "Attempted to stack tensors on different devices at key" ): return torch.stack(values, dim) - out = {key: stack_fn((key, value)) for key, value in out.items()} + out = { + key: stack_fn(key, values, is_lazy) + for key, (values, is_lazy) in out.items() + } return TensorDict( out, @@ -498,3 +539,29 @@ def where(condition, input, other, *, out=None): "Cannot use a persistent tensordict as output of torch.where." ) return input.where(condition, other, out=out) + + +def _stack_uninit_params(list_of_params, dim=0, out=None): + if out is not None: + raise NotImplementedError + if dim > 0: + raise NotImplementedError + from tensordict.utils import ( + _BatchedUninitializedBuffer, + _BatchedUninitializedParameter, + ) + + if isinstance(list_of_params[0], UninitializedParameter): + out = _BatchedUninitializedParameter( + requires_grad=list_of_params[0].requires_grad, + device=list_of_params[0].device, + dtype=list_of_params[0].dtype, + ) + elif isinstance(list_of_params[0], UninitializedBuffer): + out = _BatchedUninitializedBuffer( + requires_grad=list_of_params[0].requires_grad, + device=list_of_params[0].device, + dtype=list_of_params[0].dtype, + ) + out.batch_size = torch.Size([len(list_of_params)]) + return out diff --git a/tensordict/base.py b/tensordict/base.py index d26c2c900..682e80afd 100644 --- a/tensordict/base.py +++ b/tensordict/base.py @@ -52,6 +52,7 @@ _td_fields, _unravel_key_to_tuple, as_decorator, + Buffer, cache, convert_ellipsis_to_idx, DeviceType, @@ -59,16 +60,19 @@ IndexType, infer_size_impl, int_generator, + is_non_tensor, KeyedJaggedTensor, lazy_legacy, lock_blocked, NestedKey, prod, + set_lazy_legacy, TensorDictFuture, unravel_key, unravel_key_list, ) from torch import distributed as dist, multiprocessing as mp, nn, Tensor +from torch.nn.parameter import UninitializedTensorMixin from torch.utils._pytree import tree_map # NO_DEFAULT is used as a placeholder whenever the default is not provided. @@ -239,10 +243,11 @@ def __getitem__(self, index: IndexType) -> T: idx_unravel = _unravel_key_to_tuple(index) if idx_unravel: result = self._get_tuple(idx_unravel, NO_DEFAULT) - from .tensorclass import NonTensorData - - if isinstance(result, NonTensorData): - return result.data + if is_non_tensor(result): + result_data = getattr(result, "data", NO_DEFAULT) + if result_data is NO_DEFAULT: + return result.tolist() + return result_data return result if (istuple and not index) or (not istuple and index is Ellipsis): @@ -376,6 +381,7 @@ def from_module( """Copies the params and buffers of a module in a tensordict. Args: + module (nn.Module): the module to get the parameters from. as_module (bool, optional): if ``True``, a :class:`~tensordict.nn.TensorDictParams` instance will be returned which can be used to store parameters within a :class:`torch.nn.Module`. Defaults to ``False``. @@ -385,7 +391,7 @@ def from_module( module will be used and unflattened into a TensorDict with the tree structure of the model. Defaults to ``False``. .. note:: - This is particularily useful when state-dict hooks have to be + This is particularly useful when state-dict hooks have to be used. Examples: @@ -405,6 +411,146 @@ def from_module( """ ... + @classmethod + def from_modules( + cls, + *modules, + as_module: bool = False, + lock: bool = True, + use_state_dict: bool = False, + lazy_stack: bool = False, + ): + """Retrieves the parameters of several modules for ensebmle learning/feature of expects applications through vmap. + + Args: + modules (sequence of nn.Module): the modules to get the parameters from. + If the modules differ in their structure, a lazy stack is needed + (see the ``lazy_stack`` argument below). + + Keyword Args: + as_module (bool, optional): if ``True``, a :class:`~tensordict.nn.TensorDictParams` + instance will be returned which can be used to store parameters + within a :class:`torch.nn.Module`. Defaults to ``False``. + lock (bool, optional): if ``True``, the resulting tensordict will be locked. + Defaults to ``True``. + use_state_dict (bool, optional): if ``True``, the state-dict from the + module will be used and unflattened into a TensorDict with + the tree structure of the model. Defaults to ``False``. + .. note:: + This is particularly useful when state-dict hooks have to be + used. + lazy_stack (bool, optional): whether parameters should be densly or + lazily stacked. Defaults to ``False`` (dense stack). + + .. note:: ``lazy_stack`` and ``as_module`` are exclusive features. + + .. warning:: + There is a crucial difference between lazy and non-lazy outputs + in that non-lazy output will reinstantiate parameters with the + desired batch-size, while ``lazy_stack`` will just represent + the parameters as lazily stacked. This means that whilst the + original parameters can safely be passed to an optimizer + when ``lazy_stack=True``, the new parameters need to be passed + when it is set to ``True``. + + .. warning:: + Whilst it can be tempting to use a lazy stack to keep the + orignal parameter references, remember that lazy stack + perform a stack each time :meth:`~.get` is called. This will + require memory (N times the size of the parameters, more if a + graph is built) and time to be computed. + It also means that the optimizer(s) will contain more + parameters, and operations like :meth:`~torch.optim.Optimizer.step` + or :meth:`~torch.optim.Optimizer.zero_grad` will take longer + to be executed. In general, ``lazy_stack`` should be reserved + to very few use cases. + + Examples: + >>> from torch import nn + >>> from tensordict import TensorDict + >>> torch.manual_seed(0) + >>> empty_module = nn.Linear(3, 4, device="meta") + >>> n_models = 2 + >>> modules = [nn.Linear(3, 4) for _ in range(n_models)] + >>> params = TensorDict.from_modules(*modules) + >>> print(params) + TensorDict( + fields={ + bias: Parameter(shape=torch.Size([2, 4]), device=cpu, dtype=torch.float32, is_shared=False), + weight: Parameter(shape=torch.Size([2, 4, 3]), device=cpu, dtype=torch.float32, is_shared=False)}, + batch_size=torch.Size([2]), + device=None, + is_shared=False) + >>> # example of batch execution + >>> def exec_module(params, x): + ... with params.to_module(empty_module): + ... return empty_module(x) + >>> x = torch.randn(3) + >>> y = torch.vmap(exec_module, (0, None))(params, x) + >>> assert y.shape == (n_models, 4) + >>> # since lazy_stack = False, backprop leaves the original params untouched + >>> y.sum().backward() + >>> assert params["weight"].grad.norm() > 0 + >>> assert modules[0].weight.grad is None + + With ``lazy_stack=True``, things are slightly different: + + >>> params = TensorDict.from_modules(*modules, lazy_stack=True) + >>> print(params) + LazyStackedTensorDict( + fields={ + bias: Tensor(shape=torch.Size([2, 4]), device=cpu, dtype=torch.float32, is_shared=False), + weight: Tensor(shape=torch.Size([2, 4, 3]), device=cpu, dtype=torch.float32, is_shared=False)}, + exclusive_fields={ + }, + batch_size=torch.Size([2]), + device=None, + is_shared=False, + stack_dim=0) + >>> # example of batch execution + >>> y = torch.vmap(exec_module, (0, None))(params, x) + >>> assert y.shape == (n_models, 4) + >>> y.sum().backward() + >>> assert modules[0].weight.grad is not None + + + """ + param_list = [ + cls.from_module(module, use_state_dict=use_state_dict) for module in modules + ] + if lazy_stack: + from tensordict._lazy import LazyStackedTensorDict + + for param in param_list: + if any( + isinstance(tensor, UninitializedTensorMixin) + for tensor in param.values(True, True) + ): + raise RuntimeError( + "lasy_stack=True is not compatible with lazy modules." + ) + params = LazyStackedTensorDict.lazy_stack(param_list) + else: + with set_lazy_legacy(False), torch.no_grad(): + params = torch.stack(param_list) + + # Make sure params are params, buffers are buffers + def make_param(param, orig_param): + if isinstance(param, UninitializedTensorMixin): + return param + if isinstance(orig_param, nn.Parameter): + return nn.Parameter(param.detach(), orig_param.requires_grad) + return Buffer(param) + + params = params._fast_apply(make_param, param_list[0]) + if as_module: + from tensordict.nn import TensorDictParams + + params = TensorDictParams(params, no_convert=True) + if lock: + params.lock_() + return params + @as_decorator() def to_module( self, @@ -706,7 +852,7 @@ def unsqueeze(self, *args, **kwargs): version of the unsqueezed tensordict. Up until v0.3 included, a lazy unsqueezed tensordict was returned. From v0.4 onward, a dense unsqueeze will be returned. To silence this warning, choose one of the following options: -- set the LAZY_LEGACY_OP environment variable to 'False' (recommended) or 'True' depending on +- set the LAZY_LEGACY_OP environment variable to 'False' (recommended, default) or 'True' depending on the behaviour you want to use. Another way to achieve this is to call `tensordict.set_lazy_legacy(False).set()` at the beginning of your script. - set the decorator/context manager `tensordict.set_lazy_legacy(False)` (recommended) around @@ -795,7 +941,7 @@ def squeeze(self, *args, **kwargs): version of the squeezed tensordict. Up until v0.3 included, a lazy squeezed tensordict was returned. From v0.4 onward, a dense squeeze will be returned. To silence this warning, choose one of the following options: -- set the LAZY_LEGACY_OP environment variable to 'False' (recommended) or 'True' depending on +- set the LAZY_LEGACY_OP environment variable to 'False' (recommended, default) or 'True' depending on the behaviour you want to use. Another way to achieve this is to call `tensordict.set_lazy_legacy(False).set()` at the beginning of your script. - set the decorator/context manager `tensordict.set_lazy_legacy(False)` (recommended) around @@ -1007,7 +1153,7 @@ def view( version of the viewed tensordict. Up until v0.3 included, a lazy view of the tensordict was returned. From v0.4 onward, a proper view will be returned. To silence this warning, choose one of the following options: -- set the LAZY_LEGACY_OP environment variable to 'False' (recommended) or 'True' depending on +- set the LAZY_LEGACY_OP environment variable to 'False' (recommended, default) or 'True' depending on the behaviour you want to use. Another way to achieve this is to call `tensordict.set_lazy_legacy(False).set()` at the beginning of your script. - set the decorator/context manager `tensordict.set_lazy_legacy(False)` (recommended) around @@ -1076,7 +1222,7 @@ def transpose(self, dim0, dim1): version of the transposed tensordict. Up until v0.3 included, a lazy transpose of the tensordict was returned. From v0.4 onward, a proper transpose will be returned. To silence this warning, choose one of the following options: -- set the LAZY_LEGACY_OP environment variable to 'False' (recommended) or 'True' depending on +- set the LAZY_LEGACY_OP environment variable to 'False' (recommended, default) or 'True' depending on the behaviour you want to use. Another way to achieve this is to call `tensordict.set_lazy_legacy(False).set()` at the beginning of your script. - set the decorator/context manager `tensordict.set_lazy_legacy(False)` (recommended) around @@ -1191,7 +1337,7 @@ def permute(self, *args, **kwargs): version of the permuted tensordict. Up until v0.3 included, a lazy permute of the tensordict was returned. From v0.4 onward, a proper permute will be returned. To silence this warning, choose one of the following options: -- set the LAZY_LEGACY_OP environment variable to 'False' (recommended) or 'True' depending on +- set the LAZY_LEGACY_OP environment variable to 'False' (recommended, default) or 'True' depending on the behaviour you want to use. Another way to achieve this is to call `tensordict.set_lazy_legacy(False).set()` at the beginning of your script. - set the decorator/context manager `tensordict.set_lazy_legacy(False)` (recommended) around @@ -1511,6 +1657,14 @@ def cuda(self, device: int = None) -> T: return self.to(torch.device("cuda")) return self.to(f"cuda:{device}") + @property + def is_cuda(self): + return self.device is not None and self.device.type == "cuda" + + @property + def is_cpu(self): + return self.device is not None and self.device.type == "cpu" + # Serialization functionality def state_dict( self, @@ -2162,18 +2316,19 @@ def _get_non_tensor(self, key: NestedKey, default=NO_DEFAULT): return subtd return subtd._get_non_tensor(key[1:], default=default) value = self._get_str(key, default=default) - from tensordict.tensorclass import NonTensorData - if isinstance(value, NonTensorData): - return value.data + if is_non_tensor(value): + data = getattr(value, "data", None) + if data is None: + return value.tolist() + return data return value def filter_non_tensor_data(self) -> T: """Filters out all non-tensor-data.""" - from tensordict.tensorclass import NonTensorData def _filter(x): - if not isinstance(x, NonTensorData): + if not is_non_tensor(x): if is_tensor_collection(x): return x.filter_non_tensor_data() return x @@ -5230,10 +5385,8 @@ def _default_is_leaf(cls: Type) -> bool: def _is_leaf_nontensor(cls: Type) -> bool: - from tensordict.tensorclass import NonTensorData - if issubclass(cls, KeyedJaggedTensor): return False if _is_tensor_collection(cls): - return issubclass(cls, NonTensorData) + return cls.__dict__.get("_non_tensor", False) return issubclass(cls, torch.Tensor) diff --git a/tensordict/functional.py b/tensordict/functional.py index 1cbf168b7..cfb7bff4e 100644 --- a/tensordict/functional.py +++ b/tensordict/functional.py @@ -1,5 +1,7 @@ from __future__ import annotations +import warnings + from typing import Sequence import torch @@ -77,7 +79,8 @@ def pad(tensordict: T, pad_size: Sequence[int], value: float = 0.0) -> T: def pad_sequence( list_of_tensordicts: Sequence[T], - batch_first: bool = True, + batch_first: bool | None = None, + pad_dim: int = 0, padding_value: float = 0.0, out: T | None = None, device: DeviceType | None = None, @@ -87,76 +90,117 @@ def pad_sequence( Args: list_of_tensordicts (List[TensorDictBase]): the list of instances to pad and stack. - batch_first (bool, optional): the ``batch_first`` correspondant of :func:`torch.nn.utils.rnn.pad_sequence`. - Defaults to ``True``. + pad_dim (int, optional): the ``pad_dim`` indicates the dimension to pad all the keys in the tensordict. + Defaults to ``0``. padding_value (number, optional): the padding value. Defaults to ``0.0``. out (TensorDictBase, optional): if provided, the destination where the data will be written. device (device compatible type, optional): if provded, the device where the TensorDict output will be created. - return_mask (bool, optional): if ``True``, a "mask" entry will be returned. - It contains the mask of valid values in the stacked tensordict. + return_mask (bool, optional): if ``True``, a "masks" entry will be returned. + It contains a tensordict with the same structure as the stacked tensordict where every entry contains the mask of valid values with size ``torch.Size([stack_len, *new_shape])``, + where `new_shape[pad_dim] = max_seq_length` and the rest of the `new_shape` matches the previous shape of the contained tensors. Examples: >>> list_td = [ - ... TensorDict({"a": torch.zeros((3,))}, []), - ... TensorDict({"a": torch.zeros((4,))}, []), + ... TensorDict({"a": torch.zeros((3, 8)), "b": torch.zeros((6, 8))}, batch_size=[]), + ... TensorDict({"a": torch.zeros((5, 8)), "b": torch.zeros((6, 8))}, batch_size=[]), ... ] - >>> padded_td = pad_sequence(list_td) + >>> padded_td = pad_sequence(list_td, return_mask=True) >>> print(padded_td) TensorDict( fields={ - a: Tensor(shape=torch.Size([2, 4]), device=cpu, dtype=torch.float32, is_shared=False)}, - batch_size=torch.Size([]), + a: Tensor(shape=torch.Size([2, 4, 8]), device=cpu, dtype=torch.float32, is_shared=False), + b: Tensor(shape=torch.Size([2, 5, 8]), device=cpu, dtype=torch.float32, is_shared=False), + masks: TensorDict( + fields={ + a: Tensor(shape=torch.Size([2, 4]), device=cpu, dtype=torch.bool, is_shared=False), + b: Tensor(shape=torch.Size([2, 6]), device=cpu, dtype=torch.bool, is_shared=False)}, + batch_size=torch.Size([2]), + device=None, + is_shared=False)}, + batch_size=torch.Size([2]), device=None, is_shared=False) """ + if batch_first is not None: + warnings.warn( + "The batch_first argument is deprecated and will be removed in a future release. The output will always be batch_first.", + category=DeprecationWarning, + ) + if not list_of_tensordicts: raise RuntimeError("list_of_tensordicts cannot be empty") + # check that all tensordict match - if return_mask: - list_of_tensordicts = [ - td.clone(False).set("mask", torch.ones(td.shape, dtype=torch.bool)) - for td in list_of_tensordicts - ] + update_batch_size = True + max_seq_length = float("-inf") keys = _check_keys(list_of_tensordicts, leaves_only=True, include_nested=True) - shape = max(len(td) for td in list_of_tensordicts) - if shape == 0: - shape = [ - len(list_of_tensordicts), - ] - elif batch_first: - shape = [len(list_of_tensordicts), shape] - else: - shape = [shape, len(list_of_tensordicts)] - if out is None: - out = TensorDict( - {}, batch_size=torch.Size(shape), device=device, _run_checks=False - ) + tmp_list_of_tensordicts = [] + for td in list_of_tensordicts: + + if return_mask: + tmp_list_of_tensordicts.append(td.clone(False)) + for key in keys: - try: - out.set( - key, - torch.nn.utils.rnn.pad_sequence( - [td.get(key) for td in list_of_tensordicts], - batch_first=batch_first, - padding_value=padding_value, - ), + tensor_shape = td.get(key).shape + pos_pad_dim = pad_dim if pad_dim >= 0 else len(tensor_shape) + pad_dim + + # track the maximum sequence length to update batch_size accordingly + if tensor_shape[pos_pad_dim] > max_seq_length: + max_seq_length = tensor_shape[pos_pad_dim] + + # The mask should always contain the batch_size of the TensorDict + mask_shape = td.shape + + # if the pad_dim is past the batch_size of the TensorDict, we need to add the new dimension to the mask + if pos_pad_dim >= td.ndim: + mask_shape += torch.Size([tensor_shape[pos_pad_dim]]) + update_batch_size = False + + if return_mask: + tmp_list_of_tensordicts[-1].set( + ("masks", key), + torch.ones(mask_shape, dtype=torch.bool), ) - except Exception as err: - raise RuntimeError(f"pad_sequence failed for key {key}") from err - return out - else: - for key in keys: - out.set_( + if return_mask: + list_of_tensordicts = tmp_list_of_tensordicts + + keys = _check_keys(list_of_tensordicts, leaves_only=True, include_nested=True) + + old_batch_size = list(list_of_tensordicts[0].batch_size) + if update_batch_size and len(old_batch_size) > 0: + old_batch_size[pad_dim] = max_seq_length + shape = [ + len(list_of_tensordicts), + ] + old_batch_size + + if out is None: + out = list_of_tensordicts[0].empty(recurse=True).reshape(torch.Size(shape)) + + for key in keys: + try: + tensor_shape = list_of_tensordicts[0].get(key).shape + pos_pad_dim = ( + (pad_dim if pad_dim >= 0 else len(tensor_shape) + pad_dim) + if len(tensor_shape) > 1 + else 0 # handles the case when the masks are 1-dimensional + ) + out.set( key, torch.nn.utils.rnn.pad_sequence( - [td.get(key) for td in list_of_tensordicts], - batch_first=batch_first, + [ + td.get(key).transpose(0, pos_pad_dim) + for td in list_of_tensordicts + ], + batch_first=True, padding_value=padding_value, - ), + ).transpose(1, pos_pad_dim + 1), + inplace=True, ) - return out + except Exception as err: + raise RuntimeError(f"pad_sequence failed for key {key}") from err + return out def merge_tensordicts(*tensordicts: T) -> T: diff --git a/tensordict/memmap.py b/tensordict/memmap.py index ff948e1c7..c5b2489a9 100644 --- a/tensordict/memmap.py +++ b/tensordict/memmap.py @@ -22,20 +22,8 @@ from tensordict.utils import implement_for -from torch import distributed as dist - from torch.multiprocessing.reductions import ForkingPickler -try: - if dist.is_available(): - from torch.distributed._tensor.api import DTensor - else: - raise ImportError -except ImportError: - - class DTensor(torch.Tensor): # noqa: D101 - ... - class MemoryMappedTensor(torch.Tensor): """A Memory-mapped Tensor. @@ -204,7 +192,7 @@ def from_tensor( out.index = None out.parent_shape = input.shape if copy_data: - if isinstance(input, DTensor): + if hasattr(input, "full_tensor"): input = input.full_tensor() out.copy_(input) return out diff --git a/tensordict/nn/params.py b/tensordict/nn/params.py index f51f1a1ca..e5714bf4e 100644 --- a/tensordict/nn/params.py +++ b/tensordict/nn/params.py @@ -990,18 +990,17 @@ def _load_from_state_dict( unexpected_keys, error_msgs, ): - data = ( - TensorDict( - { - key: val - for key, val in state_dict.items() - if key.startswith(prefix) and val is not None - }, - [], - ) - .unflatten_keys(".") - .get(prefix[:-1]) - ) + data = TensorDict( + { + key: val + for key, val in state_dict.items() + if key.startswith(prefix) and val is not None + }, + [], + ).unflatten_keys(".") + prefix = tuple(key for key in prefix.split(".") if key) + if prefix: + data = data.get(prefix) self.data.load_state_dict(data) def items( @@ -1118,8 +1117,6 @@ def compute_should_use_set_data(tensor, tensor_applied): param.data = param_applied out_param = param else: - assert isinstance(param, nn.Parameter) - assert param.is_leaf out_param = nn.Parameter(param_applied, param.requires_grad) self._parameters[key] = out_param @@ -1130,10 +1127,8 @@ def compute_should_use_set_data(tensor, tensor_applied): param.grad, grad_applied ) if should_use_set_data: - assert out_param.grad is not None out_param.grad.data = grad_applied else: - assert param.grad.is_leaf out_param.grad = grad_applied.requires_grad_( param.grad.requires_grad ) @@ -1151,8 +1146,6 @@ def compute_should_use_set_data(tensor, tensor_applied): buffer.data = buffer_applied out_buffer = buffer else: - assert isinstance(buffer, Buffer) - assert buffer.is_leaf out_buffer = Buffer(buffer_applied, buffer.requires_grad) self._buffers[key] = out_buffer @@ -1163,10 +1156,8 @@ def compute_should_use_set_data(tensor, tensor_applied): buffer.grad, grad_applied ) if should_use_set_data: - assert out_buffer.grad is not None out_buffer.grad.data = grad_applied else: - assert buffer.grad.is_leaf out_buffer.grad = grad_applied.requires_grad_( buffer.grad.requires_grad ) diff --git a/tensordict/nn/probabilistic.py b/tensordict/nn/probabilistic.py index bb422cc02..510610d9c 100644 --- a/tensordict/nn/probabilistic.py +++ b/tensordict/nn/probabilistic.py @@ -546,8 +546,22 @@ def log_prob( return self.module[-1].log_prob(tensordict_out) def build_dist_from_params(self, tensordict: TensorDictBase) -> D.Distribution: - """Construct a distribution from the input parameters. Other modules in the sequence are not evaluated.""" - return self.module[-1].get_dist(tensordict) + """Construct a distribution from the input parameters. Other modules in the sequence are not evaluated. + + This method will look for the last ProbabilisticTensorDictModule contained in the + sequence and use it to build the distribution. + + """ + dest_module = None + for module in reversed(list(self.modules())): + if isinstance(module, ProbabilisticTensorDictModule): + dest_module = module + break + if dest_module is None: + raise RuntimeError( + "Could not find any ProbabilisticTensorDictModule in the sequence." + ) + return dest_module.get_dist(tensordict) @dispatch(auto_batch_size=False) @set_skip_existing(None) diff --git a/tensordict/nn/utils.py b/tensordict/nn/utils.py index 3e1d698bd..105be6538 100644 --- a/tensordict/nn/utils.py +++ b/tensordict/nn/utils.py @@ -14,7 +14,7 @@ import torch from torch import nn -AUTO_MAKE_FUNCTIONAL = strtobool(os.environ.get("AUTO_MAKE_FUNCTIONAL", "True")) +AUTO_MAKE_FUNCTIONAL = strtobool(os.environ.get("AUTO_MAKE_FUNCTIONAL", "False")) DISPATCH_TDNN_MODULES = strtobool(os.environ.get("DISPATCH_TDNN_MODULES", "True")) diff --git a/tensordict/persistent.py b/tensordict/persistent.py index 0c605074a..09c0fb176 100644 --- a/tensordict/persistent.py +++ b/tensordict/persistent.py @@ -46,14 +46,7 @@ ) from torch import multiprocessing as mp -H5_ERR = None -try: - import h5py - - _has_h5 = True -except ModuleNotFoundError as err: - H5_ERR = err - _has_h5 = False +_has_h5 = importlib.util.find_spec("h5py", None) is not None class _Visitor: @@ -151,7 +144,9 @@ def __init__( self._locked_tensordicts = [] self._lock_id = set() if not _has_h5: - raise ModuleNotFoundError("Could not load h5py.") from H5_ERR + raise ModuleNotFoundError("Could not load h5py.") + import h5py + super().__init__() self.filename = filename self.mode = mode @@ -213,6 +208,8 @@ def from_dict(cls, input_dict, filename, batch_size=None, device=None, **kwargs) A :class:`PersitentTensorDict` instance linked to the newly created file. """ + import h5py + file = h5py.File(filename, "w", locking=cls.LOCKING) _has_batch_size = True if batch_size is None: @@ -260,6 +257,8 @@ def _get_array(self, key, default=NO_DEFAULT): raise KeyError(f"key {key} not found in PersistentTensorDict {self}") def _process_array(self, key, array): + import h5py + if isinstance(array, (h5py.Dataset,)): if self.device is not None: device = self.device @@ -294,6 +293,8 @@ def get(self, key, default=NO_DEFAULT): def get_at( self, key: NestedKey, idx: IndexType, default: CompatibleType = NO_DEFAULT ) -> CompatibleType: + import h5py + array = self._get_array(key, default) if isinstance(array, (h5py.Dataset,)): if self.device is not None: @@ -339,6 +340,8 @@ def _get_metadata(self, key): This method avoids creating a tensor from scratch, and just reads the metadata of the array. """ + import h5py + array = self._get_array(key) if ( isinstance(array, (h5py.Dataset,)) @@ -1048,6 +1051,8 @@ def _set_metadata(self, orig_metadata_container: PersistentTensorDict): self._nested_tensordicts[key]._set_metadata(td) def _clone(self, recurse: bool = True, newfile=None) -> PersistentTensorDict: + import h5py + if recurse: # this should clone the h5 to a new location indicated by newfile if newfile is None: @@ -1106,6 +1111,8 @@ def __getstate__(self): return state def __setstate__(self, state): + import h5py + state["file"] = h5py.File( state["filename"], mode=state["mode"], locking=self.LOCKING ) diff --git a/tensordict/tensorclass.py b/tensordict/tensorclass.py index 509ccfb2a..0af45cea4 100644 --- a/tensordict/tensorclass.py +++ b/tensordict/tensorclass.py @@ -15,7 +15,7 @@ import re import sys import warnings -from copy import copy +from copy import copy, deepcopy from dataclasses import dataclass from pathlib import Path from textwrap import indent @@ -24,10 +24,11 @@ import tensordict as tensordict_lib import torch +from tensordict import LazyStackedTensorDict from tensordict._td import is_tensor_collection, NO_DEFAULT, TensorDict, TensorDictBase from tensordict._tensordict import _unravel_key_to_tuple from tensordict._torch_func import TD_HANDLED_FUNCTIONS -from tensordict.base import _ACCEPTED_CLASSES, _register_tensor_class +from tensordict.base import _ACCEPTED_CLASSES, _register_tensor_class, CompatibleType from tensordict.memmap_deprec import MemmapTensor as _MemmapTensor from tensordict.utils import ( @@ -36,6 +37,7 @@ _LOCK_ERROR, DeviceType, IndexType, + is_non_tensor, is_tensorclass, NestedKey, ) @@ -56,9 +58,9 @@ torch.full_like, torch.zeros_like, torch.ones_like, + torch.rand_like, torch.empty_like, torch.randn_like, - torch.rand_like, torch.clone, torch.squeeze, torch.unsqueeze, @@ -199,6 +201,9 @@ def __torch_function__( cls.load_state_dict = _load_state_dict cls._memmap_ = _memmap_ + cls.__enter__ = __enter__ + cls.__exit__ = __exit__ + # Memmap cls.memmap_like = TensorDictBase.memmap_like cls.memmap_ = TensorDictBase.memmap_ @@ -216,6 +221,7 @@ def __torch_function__( cls.to_tensordict = _to_tensordict cls.device = property(_device, _device_setter) cls.batch_size = property(_batch_size, _batch_size_setter) + cls.names = property(_names, _names_setter) cls.__doc__ = f"{cls.__name__}{inspect.signature(cls)}" @@ -424,6 +430,14 @@ def _load_memmap(cls, prefix: Path, metadata: dict): return cls._from_tensordict(td, non_tensordict) +def __enter__(self, *args, **kwargs): + return self._tensordict.__enter__(*args, **kwargs) + + +def __exit__(self, *args, **kwargs): + return self._tensordict.__exit__(*args, **kwargs) + + def _getstate(self) -> dict[str, Any]: """Returns a state dict which consists of tensor and non_tensor dicts for serialization. @@ -475,7 +489,7 @@ def wrapper(self, item: str) -> Any: return wrapper -SET_ATTRIBUTES = ("batch_size", "device", "_locked_tensordicts") +SET_ATTRIBUTES = ("batch_size", "device", "_locked_tensordicts", "names") def _setattr_wrapper(setattr_: Callable, expected_keys: set[str]) -> Callable: @@ -714,6 +728,9 @@ def _set(self, key: NestedKey, value: Any, inplace: bool = False): __dict__ = self.__dict__ if __dict__["_tensordict"].is_locked: raise RuntimeError(_LOCK_ERROR) + if key in ("batch_size", "names", "device"): + # handled by setattr + return expected_keys = self.__dataclass_fields__ if key not in expected_keys: raise AttributeError( @@ -830,6 +847,26 @@ def _batch_size_setter(self, new_size: torch.Size) -> None: # noqa: D417 self._tensordict._batch_size_setter(new_size) +def _names(self) -> torch.Size: + """Retrieves the dim names for the tensor class. + + Returns: + names (list of str) + + """ + return self._tensordict.names + + +def _names_setter(self, names: str) -> None: # noqa: D417 + """Set the value of ``tensorclass.names``. + + Args: + names (sequence of str) + + """ + self._tensordict.names = names + + def _state_dict( self, destination=None, prefix="", keep_vars=False, flatten=False ) -> dict[str, Any]: @@ -1248,6 +1285,7 @@ class NonTensorData: # to patch tensordict with additional checks that will encur unwanted overhead # and all the overhead falls back on this class. data: Any + _non_tensor: bool = True @classmethod def from_tensor(cls, value: torch.Tensor, batch_size, device=None, names=None): @@ -1257,8 +1295,11 @@ def from_tensor(cls, value: torch.Tensor, batch_size, device=None, names=None): return out def __post_init__(self): - if isinstance(self.data, NonTensorData): - self.data = self.data.data + if is_non_tensor(self.data): + data = getattr(self.data, "data", None) + if data is None: + data = self.data.tolist() + self.data = data old_eq = self.__class__.__eq__ if old_eq is _eq: @@ -1313,6 +1354,23 @@ def __or__(self, other): self.__class__.__or__ = __or__ + def update( + self, + input_dict_or_td: dict[str, CompatibleType] | T, + clone: bool = False, + inplace: bool = False, + *, + keys_to_update: Sequence[NestedKey] | None = None, + ) -> T: + if isinstance(input_dict_or_td, NonTensorData): + data = input_dict_or_td.data + if clone: + data = deepcopy(data) + self.data = data + elif not input_dict_or_td.is_empty(): + raise RuntimeError(f"Unexpected type {type(input_dict_or_td)}") + return self + def empty(self, recurse=False): return type(self)( data=self.data, @@ -1339,7 +1397,9 @@ def _check_equal(a, b): iseq = False return iseq - if all(_check_equal(data.data, first.data) for data in list_of_non_tensor[1:]): + if all(isinstance(data, NonTensorData) for data in list_of_non_tensor) and all( + _check_equal(data.data, first.data) for data in list_of_non_tensor[1:] + ): batch_size = list(first.batch_size) batch_size.insert(dim, len(list_of_non_tensor)) return type(cls)( @@ -1349,9 +1409,7 @@ def _check_equal(a, b): device=first.device, ) - from tensordict._lazy import LazyStackedTensorDict - - return LazyStackedTensorDict(*list_of_non_tensor, stack_dim=dim) + return NonTensorStack(*list_of_non_tensor, stack_dim=dim) @classmethod def __torch_function__( @@ -1412,3 +1470,100 @@ def _fast_apply(self, *args, **kwargs): return _wrap_method(self, "_fast_apply", self._tensordict._fast_apply)( *args, **kwargs ) + + def tolist(self): + """Converts the data in a list if the batch-size is non-empty. + + If the batch-size is empty, returns the data. + + """ + if not self.batch_size: + return self.data + return [ntd.tolist() for ntd in self.unbind(0)] + + def copy_(self, src: NonTensorData | NonTensorStack, non_blocking: bool = False): + if isinstance(src, NonTensorStack): + raise RuntimeError( + "Cannot update a NonTensorData with a NonTensorStack object." + ) + if not isinstance(src, NonTensorData): + raise RuntimeError( + "NonTensorData.copy_ requires the source to be a NonTensorData object." + ) + self._non_tensordict["data"] = src.data + + def clone(self, recurse: bool = True): + if recurse: + return type(self)( + data=deepcopy(self.data), + batch_size=self.batch_size, + device=self.device, + names=self.names, + ) + return type(self)( + data=self.data, + batch_size=self.batch_size, + device=self.device, + names=self.names, + ) + + +class NonTensorStack(LazyStackedTensorDict): + """A thin wrapper around LazyStackedTensorDict to make stack on non-tensor data easily recognizable. + + A ``NonTensorStack`` is returned whenever :func:`~torch.stack` is called on + a list of :class:`~tensordict.NonTensorData` or ``NonTensorStack``. + + Examples: + >>> from tensordict import NonTensorData + >>> import torch + >>> data = torch.stack([ + ... torch.stack([NonTensorData(data=(i, j), batch_size=[]) for i in range(2)]) + ... for j in range(3)]) + >>> print(data) + NonTensorStack( + [[(0, 0), (1, 0)], [(0, 1), (1, 1)], [(0, 2), (1, ..., + batch_size=torch.Size([3, 2]), + device=None) + + To obtain the values stored in a ``NonTensorStack``, call :class:`~.tolist`. + + """ + + _non_tensor: bool = True + + def tolist(self): + """Extracts the content of a :class:`tensordict.tensorclass.NonTensorStack` in a nested list. + + Examples: + >>> from tensordict import NonTensorData + >>> import torch + >>> data = torch.stack([ + ... torch.stack([NonTensorData(data=(i, j), batch_size=[]) for i in range(2)]) + ... for j in range(3)]) + >>> data.tolist() + [[(0, 0), (1, 0)], [(0, 1), (1, 1)], [(0, 2), (1, 2)]] + + """ + iterator = self.tensordicts if self.stack_dim == 0 else self.unbind(0) + return [td.tolist() for td in iterator] + + @classmethod + def from_nontensordata(cls, non_tensor: NonTensorData): + data = non_tensor.data + prev = NonTensorData(data, batch_size=[], device=non_tensor.device) + for dim in reversed(non_tensor.shape): + prev = cls(*[prev.clone(False) for _ in range(dim)], stack_dim=0) + return prev + + def __repr__(self): + selfrepr = str(self.tolist()) + if len(selfrepr) > 50: + selfrepr = f"{selfrepr[:50]}..." + selfrepr = indent(selfrepr, prefix=4 * " ") + batch_size = indent(f"batch_size={self.batch_size}", prefix=4 * " ") + device = indent(f"device={self.device}", prefix=4 * " ") + return f"NonTensorStack(\n{selfrepr}," f"\n{batch_size}," f"\n{device})" + + def to_dict(self) -> dict[str, Any]: + return self.tolist() diff --git a/tensordict/utils.py b/tensordict/utils.py index 1384b62e4..c64c30a7a 100644 --- a/tensordict/utils.py +++ b/tensordict/utils.py @@ -49,9 +49,15 @@ unravel_key_list, unravel_keys, ) + from torch import Tensor from torch._C import _disabled_torch_function_impl -from torch.nn.parameter import _ParameterMeta +from torch.nn.parameter import ( + _ParameterMeta, + UninitializedBuffer, + UninitializedParameter, + UninitializedTensorMixin, +) from torch.utils.data._utils.worker import _generate_state if TYPE_CHECKING: @@ -299,7 +305,10 @@ def expand_as_right( f" tensor.ndimension()={tensor.ndimension()} and " f"dest.ndimension()={dest.ndimension()}" ) - if not (tensor.shape == dest.shape[: tensor.ndimension()]): + if any( + tensor.shape[i] != dest.shape[i] and tensor.shape[i] != 1 + for i in range(tensor.ndimension()) + ): raise RuntimeError( f"tensor shape is incompatible with dest shape, " f"got: tensor.shape={tensor.shape}, dest={dest.shape}" @@ -559,7 +568,9 @@ def _ndimension(tensor: Tensor) -> int: def _shape(tensor: Tensor) -> torch.Size: - if not isinstance(tensor, Tensor): + if isinstance(tensor, UninitializedTensorMixin): + return torch.Size([*getattr(tensor, "batch_size", ()), -1]) + elif not isinstance(tensor, Tensor): if type(tensor) is KeyedJaggedTensor: return torch.Size([len(tensor.lengths()) // len(tensor.keys())]) return tensor.shape @@ -648,6 +659,21 @@ def _set_item(tensor: Tensor, index: IndexType, value: Tensor, *, validated) -> elif isinstance(tensor, KeyedJaggedTensor): tensor = setitem_keyedjaggedtensor(tensor, index, value) return tensor + from tensordict.tensorclass import NonTensorData, NonTensorStack + + if is_non_tensor(tensor): + if ( + isinstance(value, NonTensorData) + and isinstance(tensor, NonTensorData) + and tensor.data == value.data + ): + return tensor + elif isinstance(tensor, NonTensorData): + tensor = NonTensorStack.from_nontensordata(tensor) + if tensor.stack_dim != 0: + tensor = NonTensorStack(*tensor.unbind(0), stack_dim=0) + tensor[index] = value + return tensor else: tensor[index] = value return tensor @@ -1424,12 +1450,17 @@ def _td_fields(td: T, keys=None) -> str: if isinstance(tensor, TensorDictBase): substr = _td_fields(tensor) else: + is_shared = ( + tensor.is_shared() + if not isinstance(tensor, UninitializedTensorMixin) + else None + ) substr = _get_repr_custom( tensor.__class__, shape=shape, device=tensor.device, dtype=tensor.dtype, - is_shared=tensor.is_shared(), + is_shared=is_shared, ) strs.append(f"{key}: {substr}") @@ -1490,9 +1521,7 @@ def _expand_to_match_shape( def _set_max_batch_size(source: T, batch_dims=None): """Updates a tensordict with its maximium batch size.""" - from tensordict import NonTensorData - - tensor_data = [val for val in source.values() if not isinstance(val, NonTensorData)] + tensor_data = [val for val in source.values() if not is_non_tensor(val)] for val in tensor_data: from tensordict.base import _is_tensor_collection @@ -1571,7 +1600,7 @@ def wrapper(*args, **kwargs): def _broadcast_tensors(index): # tensors and range need to be broadcast tensors = { - i: tensor if isinstance(tensor, Tensor) else torch.tensor(tensor) + i: torch.as_tensor(tensor) for i, tensor in enumerate(index) if isinstance(tensor, (range, list, np.ndarray, Tensor)) } @@ -1743,7 +1772,7 @@ def _getitem_batch_size(batch_size, index): # Lazy classes control (legacy feature) -_DEFAULT_LAZY_OP = True +_DEFAULT_LAZY_OP = False _LAZY_OP = os.environ.get("LAZY_LEGACY_OP", None) @@ -2101,3 +2130,47 @@ def __setstate__(self, ob: bytes): def __call__(self, *args, **kwargs): return self.fn(*args, **kwargs) + + +class _BatchedUninitializedParameter(UninitializedParameter): + batch_size: torch.Size + in_dim: int | None = None + vmap_level: int | None = None + + def materialize(self, shape, device=None, dtype=None): + UninitializedParameter.materialize( + self, (*self.batch_size, *shape), device=device, dtype=dtype + ) + + +class _BatchedUninitializedBuffer(UninitializedBuffer): + batch_size: torch.Size + in_dim: int | None = None + vmap_level: int | None = None + + def materialize(self, shape, device=None, dtype=None): + UninitializedBuffer.materialize( + self, (*self.batch_size, *shape), device=device, dtype=dtype + ) + + +class _add_batch_dim_pre_hook: + def __call__(self, mod: torch.nn.Module, args, kwargs): + for name, param in list(mod.named_parameters(recurse=False)): + if hasattr(param, "in_dim") and hasattr(param, "vmap_level"): + from torch._C._functorch import _add_batch_dim + + param = _add_batch_dim(param, param.in_dim, param.vmap_level) + delattr(mod, name) + setattr(mod, name, param) + for key, val in list(mod._forward_pre_hooks.items()): + if val is self: + del mod._forward_pre_hooks[key] + return + else: + raise RuntimeError("did not find pre-hook") + + +def is_non_tensor(data): + """Checks if an item is a non-tensor.""" + return type(data).__dict__.get("_non_tensor", False) diff --git a/test/test_nn.py b/test/test_nn.py index fd5aefeff..016237077 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -8,6 +8,7 @@ import pickle import unittest import warnings +import weakref import pytest import torch @@ -414,6 +415,7 @@ def test_functional_before(self): tensordict_module = TensorDictModule( module=net, in_keys=["in"], out_keys=["out"] ) + make_functional(tensordict_module, return_params=False) td = TensorDict({"in": torch.randn(3, 3)}, [3]) tensordict_module(td, params=TensorDict({"module": params}, [])) @@ -579,6 +581,7 @@ def test_functional_with_buffer(self): tdmodule = TensorDictModule(module=net, in_keys=["in"], out_keys=["out"]) td = TensorDict({"in": torch.randn(3, 32 * param_multiplier)}, [3]) + make_functional(tdmodule, return_params=False) tdmodule(td, params=TensorDict({"module": params}, [])) assert td.shape == torch.Size([3]) assert td.get("out").shape == torch.Size([3, 32]) @@ -2972,7 +2975,8 @@ def test_tdparams_clone_tying(self): td_clone = td.clone() assert td_clone["c"] is td_clone["a", "b", "c"] - def test_func_on_tdparams(self): + @pytest.mark.parametrize("with_batch", [False, True]) + def test_func_on_tdparams(self, with_batch): # tdparams isn't represented in a nested way, so we must check that calling to_module on it works ok net = nn.Sequential( nn.Linear(2, 2), @@ -2987,9 +2991,13 @@ def test_func_on_tdparams(self): ), ) - params = TensorDict.from_module(net, as_module=True) + if with_batch: + params = TensorDict.from_modules(net, net, as_module=True) + params0 = params[0].expand(3).clone().apply(lambda x: x.data * 0) + else: + params = TensorDict.from_module(net, as_module=True) + params0 = params.apply(lambda x: x.data * 0) - params0 = params.apply(lambda x: x.data * 0) assert (params0 == 0).all() with params0.to_module(params): assert (params == 0).all() @@ -3002,12 +3010,52 @@ class MyModule(nn.Module): m = MyModule() m.params = params params_m = TensorDict.from_module(m, as_module=True) - params_m0 = params_m.apply(lambda x: x.data * 0) + if with_batch: + params_m0 = params_m.clone() + params_m0["params"] = params_m0["params"][0].expand(3).clone() + else: + params_m0 = params_m + params_m0 = params_m0.apply(lambda x: x.data * 0) assert (params_m0 == 0).all() with params_m0.to_module(m): assert (params_m == 0).all() assert not (params_m == 0).all() + def test_load_state_dict(self): + net = nn.Sequential( + nn.Linear(2, 2), + nn.Sequential( + nn.Linear(2, 2), + nn.Dropout(), + nn.BatchNorm1d(2), + nn.Sequential( + nn.Tanh(), + nn.Linear(2, 2), + ), + ), + ) + + params = TensorDict.from_module(net, as_module=True) + assert any(isinstance(p, nn.Parameter) for p in params.values(True, True)) + weakrefs = {weakref.ref(t) for t in params.values(True, True)} + + # Now with a module around it + class MyModule(nn.Module): + pass + + module = MyModule() + module.model = MyModule() + module.model.params = params + sd = module.state_dict() + sd = { + key: val * 0 if isinstance(val, torch.Tensor) else val + for key, val in sd.items() + } + module.load_state_dict(sd) + assert (params == 0).all() + assert any(isinstance(p, nn.Parameter) for p in params.values(True, True)) + assert weakrefs == {weakref.ref(t) for t in params.values(True, True)} + def test_inplace_ops(self): td = TensorDict( { diff --git a/test/test_tensorclass.py b/test/test_tensorclass.py index b430674b8..d1195c40e 100644 --- a/test/test_tensorclass.py +++ b/test/test_tensorclass.py @@ -1785,6 +1785,22 @@ def test_to(self): assert isinstance(td.get("c")[0], self.TensorClass) +def test_decorator(): + @tensorclass + class MyClass: + X: torch.Tensor + y: Any + + obj = MyClass(X=torch.zeros(2), y="a string!", batch_size=[]) + assert not obj.is_locked + with obj.lock_(): + assert obj.is_locked + with obj.unlock_(): + assert not obj.is_locked + assert obj.is_locked + assert not obj.is_locked + + if __name__ == "__main__": args, unknown = argparse.ArgumentParser().parse_known_args() pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown) diff --git a/test/test_tensordict.py b/test/test_tensordict.py index ee404b2df..66e2875c8 100644 --- a/test/test_tensordict.py +++ b/test/test_tensordict.py @@ -50,6 +50,7 @@ set_lazy_legacy, ) from torch import multiprocessing as mp, nn +from torch.nn.parameter import UninitializedTensorMixin try: import torchsnapshot @@ -714,6 +715,90 @@ def remover(module, *args, **kwargs): sd = net.state_dict() assert_allclose_td(params_sd.flatten_keys("."), TensorDict(sd, [])) + @pytest.mark.parametrize("as_module", [False, True]) + @pytest.mark.parametrize("lazy_stack", [False, True]) + def test_from_modules(self, as_module, lazy_stack): + empty_module = nn.Linear(3, 4, device="meta") + modules = [nn.Linear(3, 4) for _ in range(3)] + if as_module and lazy_stack: + with pytest.raises(RuntimeError, match="within a TensorDictParams"): + params = TensorDict.from_modules( + *modules, as_module=as_module, lazy_stack=lazy_stack + ) + return + + params = TensorDict.from_modules( + *modules, as_module=as_module, lazy_stack=lazy_stack + ) + + def exec_module(params, x): + with params.to_module(empty_module): + return empty_module(x) + + x = torch.zeros(3) + y = torch.vmap(exec_module, (0, None))(params, x) + y.sum().backward() + if lazy_stack: + leaves = [] + + def get_leaf(leaf): + leaves.append(leaf) + + params.apply(get_leaf) + assert all(param.grad is not None for param in leaves) + assert all(param.grad is None for param in params.values(True, True)) + else: + for p in modules[0].parameters(): + assert p.grad is None + assert all(param.grad is not None for param in params.values(True, True)) + + @pytest.mark.parametrize("as_module", [False, True]) + @pytest.mark.parametrize("lazy_stack", [False, True]) + @pytest.mark.parametrize("device", get_available_devices()) + def test_from_modules_lazy(self, as_module, lazy_stack, device): + empty_module = nn.LazyLinear(4, device="meta") + modules = [nn.LazyLinear(4, device=device) for _ in range(3)] + if lazy_stack: + with pytest.raises( + RuntimeError, + match="lasy_stack=True is not compatible with lazy modules.", + ): + params = TensorDict.from_modules( + *modules, as_module=as_module, lazy_stack=lazy_stack + ) + return + + params = TensorDict.from_modules( + *modules, as_module=as_module, lazy_stack=lazy_stack + ) + optim = torch.optim.Adam(list(params.values(True, True))) + + def exec_module(params, x): + with params.to_module(empty_module): + return empty_module(x) + + x = torch.zeros(3, device=device) + assert isinstance(params["weight"], UninitializedTensorMixin) + y = torch.vmap(exec_module, (0, None), randomness="same")(params, x) + assert params["weight"].shape == torch.Size((3, 4, 3)) + + y.sum().backward() + optim.step() + + if lazy_stack: + leaves = [] + + def get_leaf(leaf): + leaves.append(leaf) + + params.apply(get_leaf) + assert all(param.grad is not None for param in leaves) + assert all(param.grad is None for param in params.values(True, True)) + else: + for p in modules[0].parameters(): + assert p.grad is None + assert all(param.grad is not None for param in params.values(True, True)) + @pytest.mark.parametrize( "idx", [ @@ -911,33 +996,103 @@ def test_pad(self): assert torch.equal(padded_td["a"], expected_a) padded_td._check_batch_size() - @pytest.mark.parametrize("batch_first", [True, False]) @pytest.mark.parametrize("make_mask", [True, False]) - def test_pad_sequence(self, batch_first, make_mask): + def test_pad_sequence_pad_dim0(self, make_mask): + pad_dim = 0 list_td = [ - TensorDict({"a": torch.ones((2,)), ("b", "c"): torch.ones((2, 3))}, [2]), - TensorDict({"a": torch.ones((4,)), ("b", "c"): torch.ones((4, 3))}, [4]), + TensorDict( + {"a": torch.ones((2, 8, 8)), ("b", "c"): torch.ones((2, 3))}, [2] + ), + TensorDict( + {"a": torch.full((4, 8, 8), 2), ("b", "c"): torch.full((4, 3), 2)}, + [4], + ), ] - padded_td = pad_sequence( - list_td, batch_first=batch_first, return_mask=make_mask - ) - if batch_first: - assert padded_td.shape == torch.Size([2, 4]) - assert padded_td["a"].shape == torch.Size([2, 4]) - assert padded_td["a"][0, -1] == 0 - assert padded_td["b", "c"].shape == torch.Size([2, 4, 3]) - assert padded_td["b", "c"][0, -1, 0] == 0 + padded_td = pad_sequence(list_td, pad_dim=pad_dim, return_mask=make_mask) + assert padded_td.shape == torch.Size( + [2, 4] + ) # check the shape of the padded tensordict + assert torch.all( + padded_td["a"][0, :2, :, :] == 1 + ) # check the values of the first tensor + assert torch.all( + padded_td["a"][1, :, :, :] == 2 + ) # check the values of the second tensor + assert padded_td["a"].shape == torch.Size( + [2, 4, 8, 8] + ) # check the shape of the padded tensor + assert torch.all(padded_td["a"][0, 2:, :, :] == 0) # check the padding + assert padded_td["b", "c"].shape == torch.Size( + [2, 4, 3] + ) # check the shape of the padded tensor + assert torch.all(padded_td["b", "c"][0, 2:, :] == 0) # check the padding + if make_mask: + padded_td_without_masks = pad_sequence( + list_td, pad_dim=pad_dim, return_mask=False + ) + assert "masks" in padded_td.keys() + assert set( + padded_td_without_masks.keys(include_nested=True, leaves_only=True) + ) == set(padded_td["masks"].keys(include_nested=True, leaves_only=True)) + assert not padded_td["masks", "a"].all() + assert padded_td["masks", "a"].ndim == pad_dim + 2 + assert (padded_td["a"][padded_td["masks", "a"]] != 0).all() + assert (padded_td["a"][~padded_td["masks", "a"]] == 0).all() + assert not padded_td["masks", "b", "c"].all() + assert padded_td["masks", "b", "c"].ndim == pad_dim + 2 + assert (padded_td["b", "c"][padded_td["masks", "b", "c"]] != 0).all() + assert (padded_td["b", "c"][~padded_td["masks", "b", "c"]] == 0).all() else: - assert padded_td.shape == torch.Size([4, 2]) - assert padded_td["a"].shape == torch.Size([4, 2]) - assert padded_td["a"][-1, 0] == 0 - assert padded_td["b", "c"].shape == torch.Size([4, 2, 3]) - assert padded_td["b", "c"][-1, 0, 0] == 0 + assert "masks" not in padded_td.keys() + + @pytest.mark.parametrize("make_mask", [True, False]) + def test_pad_sequence_pad_dim1(self, make_mask): + pad_dim = 1 + list_td = [ + TensorDict( + {"a": torch.ones((6, 3, 8)), ("b", "c"): torch.ones((6, 3))}, [6] + ), + TensorDict( + {"a": torch.full((6, 5, 8), 2), ("b", "c"): torch.full((6, 7), 2)}, + [6], + ), + ] + padded_td = pad_sequence(list_td, pad_dim=pad_dim, return_mask=make_mask) + assert padded_td.shape == torch.Size( + [2, 6] + ) # check the shape of the padded tensordict + assert padded_td["a"].shape == torch.Size( + [2, 6, 5, 8] + ) # check the shape of the padded tensor + assert torch.all( + padded_td["a"][0, :, :3, :] == 1 + ) # check the values of the first tensor + assert torch.all(padded_td["a"][0, :, 3:, :] == 0) # check the padding + assert torch.all( + padded_td["a"][1, :, :, :] == 2 + ) # check the values of the second tensor + assert padded_td["b", "c"].shape == torch.Size( + [2, 6, 7] + ) # check the shape of the padded tensor + assert torch.all(padded_td["b", "c"][0, :, 3:] == 0) # check the padding if make_mask: - assert "mask" in padded_td.keys() - assert not padded_td["mask"].all() + padded_td_without_masks = pad_sequence( + list_td, pad_dim=pad_dim, return_mask=False + ) + assert "masks" in padded_td.keys() + assert set( + padded_td_without_masks.keys(include_nested=True, leaves_only=True) + ) == set(padded_td["masks"].keys(include_nested=True, leaves_only=True)) + assert not padded_td["masks", "a"].all() + assert padded_td["masks", "a"].ndim == pad_dim + 2 + assert (padded_td["a"][padded_td["masks", "a"]] != 0).all() + assert (padded_td["a"][~padded_td["masks", "a"]] == 0).all() + assert not padded_td["masks", "b", "c"].all() + assert padded_td["masks", "b", "c"].ndim == pad_dim + 2 + assert (padded_td["b", "c"][padded_td["masks", "b", "c"]] != 0).all() + assert (padded_td["b", "c"][~padded_td["masks", "b", "c"]] == 0).all() else: - assert "mask" not in padded_td.keys() + assert "masks" not in padded_td.keys() @pytest.mark.parametrize("device", get_available_devices()) def test_permute(self, device): @@ -2619,6 +2774,7 @@ def test_index_tensor_nd_names(self, td_name, device, npy): index = index.numpy() td_idx = td[:, index] assert tensor_example[:, index].shape == td_idx.shape + # TODO: this multiple dims with identical names should not be allowed assert td_idx.names == [names[0], names[1], names[1], *names[2:]] td_idx = td[0, index] assert tensor_example[0, index].shape == td_idx.shape @@ -5947,8 +6103,16 @@ def test_lazy_indexing(self, pos1, pos2, pos3): index = (pos1, pos2, pos3) result = outer[index] ref_tensor = torch.zeros(outer.shape) - assert result.batch_size == ref_tensor[index].shape, index - assert result.batch_size == outer_dense[index].shape, index + assert result.batch_size == ref_tensor[index].shape, ( + result.batch_size, + ref_tensor[index].shape, + index, + ) + assert result.batch_size == outer_dense[index].shape, ( + result.batch_size, + outer_dense[index].shape, + index, + ) @pytest.mark.parametrize("stack_dim", [0, 1, 2]) @pytest.mark.parametrize("mask_dim", [0, 1, 2]) @@ -7633,6 +7797,29 @@ def test_map_with_out(self, mmap, chunksize, tmpdir): input.map(self.selectfn, num_workers=2, chunksize=chunksize, out=out) assert (out["a"] == torch.arange(10)).all(), (chunksize, mmap) + @classmethod + def nontensor_check(cls, td): + td["check"] = td["non_tensor"] == ( + "a string!" if (td["tensor"] % 2) == 0 else "another string!" + ) + return td + + def test_non_tensor(self): + # with NonTensorStack + td = TensorDict( + {"tensor": torch.arange(10), "non_tensor": "a string!"}, batch_size=[10] + ) + td[1::2] = TensorDict({"non_tensor": "another string!"}, [5]) + td = td.map(self.nontensor_check, chunksize=0) + assert td["check"].all() + # with NonTensorData + td = TensorDict( + {"tensor": torch.zeros(10, dtype=torch.int), "non_tensor": "a string!"}, + batch_size=[10], + ) + td = td.map(self.nontensor_check, chunksize=0) + assert td["check"].all() + # class TestNonTensorData: class TestNonTensorData: @@ -7700,6 +7887,43 @@ def test_stack(self, non_tensor_data): LazyStackedTensorDict, ) + def test_assign_non_tensor(self): + data = TensorDict({}, [1, 10]) + + data[0, 0] = TensorDict({"a": 0, "b": "a string!"}, []) + + assert data["b"] == "a string!" + assert data.get("b").tolist() == [["a string!"] * 10] + data[0, 1] = TensorDict({"a": 0, "b": "another string!"}, []) + assert data.get("b").tolist() == [ + ["a string!"] + ["another string!"] + ["a string!"] * 8 + ] + + data = TensorDict({}, [1, 10]) + + data[0, 0] = TensorDict({"a": 0, "b": "a string!"}, []) + + data[0, 5:] = TensorDict({"a": torch.zeros(5), "b": "another string!"}, [5]) + assert data.get("b").tolist() == [["a string!"] * 5 + ["another string!"] * 5] + + data = TensorDict({}, [1, 10]) + + data[0, 0] = TensorDict({"a": 0, "b": "a string!"}, []) + + data[0, 0::2] = TensorDict( + {"a": torch.zeros(5, dtype=torch.long), "b": "another string!"}, [5] + ) + assert data.get("b").tolist() == [["another string!", "a string!"] * 5] + + data = TensorDict({}, [1, 10]) + + data[0, 0] = TensorDict({"a": 0, "b": "a string!"}, []) + + data[0] = TensorDict( + {"a": torch.zeros(10, dtype=torch.long), "b": "another string!"}, [10] + ) + assert data.get("b").tolist() == [["another string!"] * 10] + if __name__ == "__main__": args, unknown = argparse.ArgumentParser().parse_known_args()