diff --git a/tensordict/_lazy.py b/tensordict/_lazy.py index bc12dc015..a4969ce03 100644 --- a/tensordict/_lazy.py +++ b/tensordict/_lazy.py @@ -1550,22 +1550,35 @@ def keys( include_nested: bool = False, leaves_only: bool = False, is_leaf: Callable[[Type], bool] | None = None, + *, + sort: bool = False, ) -> _LazyStackedTensorDictKeysView: keys = _LazyStackedTensorDictKeysView( self, include_nested=include_nested, leaves_only=leaves_only, is_leaf=is_leaf, + sort=sort, ) return keys - def values(self, include_nested=False, leaves_only=False, is_leaf=None): + def values( + self, + include_nested=False, + leaves_only=False, + is_leaf=None, + *, + sort: bool = False, + ): if is_leaf not in ( _NESTED_TENSORS_AS_LISTS, _NESTED_TENSORS_AS_LISTS_NONTENSOR, ): yield from super().values( - include_nested=include_nested, leaves_only=leaves_only, is_leaf=is_leaf + include_nested=include_nested, + leaves_only=leaves_only, + is_leaf=is_leaf, + sort=sort, ) else: for td in self.tensordicts: @@ -1573,15 +1586,26 @@ def values(self, include_nested=False, leaves_only=False, is_leaf=None): include_nested=include_nested, leaves_only=leaves_only, is_leaf=is_leaf, + sort=sort, ) - def items(self, include_nested=False, leaves_only=False, is_leaf=None): + def items( + self, + include_nested=False, + leaves_only=False, + is_leaf=None, + *, + sort: bool = False, + ): if is_leaf not in ( _NESTED_TENSORS_AS_LISTS, _NESTED_TENSORS_AS_LISTS_NONTENSOR, ): yield from super().items( - include_nested=include_nested, leaves_only=leaves_only, is_leaf=is_leaf + include_nested=include_nested, + leaves_only=leaves_only, + is_leaf=is_leaf, + sort=sort, ) else: for i, td in enumerate(self.tensordicts): @@ -1589,6 +1613,7 @@ def items(self, include_nested=False, leaves_only=False, is_leaf=None): include_nested=include_nested, leaves_only=leaves_only, is_leaf=is_leaf, + sort=sort, ): if isinstance(key, str): key = (str(i), key) @@ -3383,9 +3408,14 @@ def keys( include_nested: bool = False, leaves_only: bool = False, is_leaf: Callable[[Type], bool] | None = None, + *, + sort: bool = False, ) -> _TensorDictKeysView: return self._source.keys( - include_nested=include_nested, leaves_only=leaves_only, is_leaf=is_leaf + include_nested=include_nested, + leaves_only=leaves_only, + is_leaf=is_leaf, + sort=sort, ) def _select( diff --git a/tensordict/_td.py b/tensordict/_td.py index 72e3efbe3..376c93b97 100644 --- a/tensordict/_td.py +++ b/tensordict/_td.py @@ -3154,12 +3154,23 @@ def keys( include_nested: bool = False, leaves_only: bool = False, is_leaf: Callable[[Type], bool] | None = None, + *, + sort: bool = False, ) -> _TensorDictKeysView: if not include_nested and not leaves_only and is_leaf is None: - return _StringKeys(self._tensordict.keys()) + if not sort: + return _StringKeys(self._tensordict.keys()) + else: + return sorted( + _StringKeys(self._tensordict.keys()), + key=lambda x: ".".join(x) if isinstance(x, tuple) else x, + ) else: return self._nested_keys( - include_nested=include_nested, leaves_only=leaves_only, is_leaf=is_leaf + include_nested=include_nested, + leaves_only=leaves_only, + is_leaf=is_leaf, + sort=sort, ) @cache # noqa: B019 @@ -3168,12 +3179,15 @@ def _nested_keys( include_nested: bool = False, leaves_only: bool = False, is_leaf: Callable[[Type], bool] | None = None, + *, + sort: bool = False, ) -> _TensorDictKeysView: return _TensorDictKeysView( self, include_nested=include_nested, leaves_only=leaves_only, is_leaf=is_leaf, + sort=sort, ) # some custom methods for efficiency @@ -3182,65 +3196,44 @@ def items( include_nested: bool = False, leaves_only: bool = False, is_leaf: Callable[[Type], bool] | None = None, + *, + sort: bool = False, ) -> Iterator[tuple[str, CompatibleType]]: if not include_nested and not leaves_only: - return self._tensordict.items() - elif include_nested and leaves_only: + if not sort: + return self._tensordict.items() + return sorted(self._tensordict.items(), key=lambda x: x[0]) + elif include_nested and leaves_only and not sort: is_leaf = _default_is_leaf if is_leaf is None else is_leaf result = [] - if is_dynamo_compiling(): - - def fast_iter(): - for key, val in self._tensordict.items(): - if not is_leaf(type(val)): - for _key, _val in val.items( - include_nested=include_nested, - leaves_only=leaves_only, - is_leaf=is_leaf, - ): - result.append( - ( - ( - key, - *( - (_key,) - if isinstance(_key, str) - else _key - ), - ), - _val, - ) - ) - else: - result.append((key, val)) - return result - else: - # dynamo doesn't like generators - def fast_iter(): - for key, val in self._tensordict.items(): - if not is_leaf(type(val)): - yield from ( - ( - ( - key, - *((_key,) if isinstance(_key, str) else _key), - ), - _val, - ) - for _key, _val in val.items( - include_nested=include_nested, - leaves_only=leaves_only, - is_leaf=is_leaf, - ) - ) - else: - yield (key, val) + def fast_iter(): + for key, val in self._tensordict.items(): + # We could easily make this faster, here we're iterating twice over the keys, + # but we could iterate just once. + # Ideally we should make a "dirty" list of items then call unravel_key on all of them. + if not is_leaf(type(val)): + for _key, _val in val.items( + include_nested=include_nested, + leaves_only=leaves_only, + is_leaf=is_leaf, + ): + if isinstance(_key, str): + _key = (key, _key) + else: + _key = (key, *_key) + result.append((_key, _val)) + else: + result.append((key, val)) + return result return fast_iter() else: return super().items( - include_nested=include_nested, leaves_only=leaves_only, is_leaf=is_leaf + include_nested=include_nested, + leaves_only=leaves_only, + is_leaf=is_leaf, + sort=sort, ) def values( @@ -3248,15 +3241,23 @@ def values( include_nested: bool = False, leaves_only: bool = False, is_leaf: Callable[[Type], bool] | None = None, + *, + sort: bool = False, ) -> Iterator[tuple[str, CompatibleType]]: if not include_nested and not leaves_only: - return self._tensordict.values() + if not sort: + return self._tensordict.values() + else: + return list(zip(*sorted(self._tensordict.items(), key=lambda x: x[0])))[ + 1 + ] else: return TensorDictBase.values( self, include_nested=include_nested, leaves_only=leaves_only, is_leaf=is_leaf, + sort=sort, ) @@ -3545,9 +3546,14 @@ def keys( include_nested: bool = False, leaves_only: bool = False, is_leaf: Callable[[Type], bool] | None = None, + *, + sort: bool = False, ) -> _TensorDictKeysView: return self._source.keys( - include_nested=include_nested, leaves_only=leaves_only, is_leaf=is_leaf + include_nested=include_nested, + leaves_only=leaves_only, + is_leaf=is_leaf, + sort=sort, ) def entry_class(self, key: NestedKey) -> type: @@ -4184,6 +4190,7 @@ def __init__( include_nested: bool, leaves_only: bool, is_leaf: Callable[[Type], bool] = None, + sort: bool = False, ) -> None: self.tensordict = tensordict self.include_nested = include_nested @@ -4191,22 +4198,32 @@ def __init__( if is_leaf is None: is_leaf = _default_is_leaf self.is_leaf = is_leaf + self.sort = sort def __iter__(self) -> Iterable[str] | Iterable[tuple[str, ...]]: - if not self.include_nested: - if self.leaves_only: - for key in self._keys(): - target_class = self.tensordict.entry_class(key) - if _is_tensor_collection(target_class): - continue - yield key + def _iter(): + if not self.include_nested: + if self.leaves_only: + for key in self._keys(): + target_class = self.tensordict.entry_class(key) + if _is_tensor_collection(target_class): + continue + yield key + else: + yield from self._keys() else: - yield from self._keys() - else: - yield from ( - key if len(key) > 1 else key[0] - for key in self._iter_helper(self.tensordict) + yield from ( + key if len(key) > 1 else key[0] + for key in self._iter_helper(self.tensordict) + ) + + if self.sort: + yield from sorted( + _iter(), + key=lambda key: ".".join(key) if isinstance(key, tuple) else key, ) + else: + yield from _iter() def _iter_helper( self, tensordict: T, prefix: str | None = None diff --git a/tensordict/base.py b/tensordict/base.py index 3b326caaa..918878227 100644 --- a/tensordict/base.py +++ b/tensordict/base.py @@ -5151,8 +5151,13 @@ def setdefault( return self.get(key) def items( - self, include_nested: bool = False, leaves_only: bool = False, is_leaf=None - ) -> Iterator[tuple[str, CompatibleType]]: + self, + include_nested: bool = False, + leaves_only: bool = False, + is_leaf=None, + *, + sort: bool = False, + ) -> Iterator[tuple[str, CompatibleType]]: # noqa: D417 """Returns a generator of key-value pairs for the tensordict. Args: @@ -5163,46 +5168,64 @@ def items( is_leaf: an optional callable that indicates if a class is to be considered a leaf or not. + Keyword Args: + sort (bool, optional): whether the keys should be sorted. For nested keys, + the keys are sorted according to their joined name (ie, ``("a", "key")`` will + be counted as ``"a.key"`` for sorting). Be mindful that sorting may incur + significant overhead when dealing with large tensordicts. + Defaults to ``False``. + """ if is_leaf is None: is_leaf = _default_is_leaf - # check the conditions once only - if include_nested and leaves_only: - for k in self.keys(): - val = self._get_str(k, NO_DEFAULT) - if not is_leaf(type(val)): - yield from ( - (_unravel_key_to_tuple((k, _key)), _val) - for _key, _val in val.items( - include_nested=include_nested, - leaves_only=leaves_only, - is_leaf=is_leaf, + def _items(): + if include_nested and leaves_only: + # check the conditions once only + for k in self.keys(): + val = self._get_str(k, NO_DEFAULT) + if not is_leaf(type(val)): + yield from ( + (_unravel_key_to_tuple((k, _key)), _val) + for _key, _val in val.items( + include_nested=include_nested, + leaves_only=leaves_only, + is_leaf=is_leaf, + ) ) - ) - else: + else: + yield k, val + elif include_nested: + for k in self.keys(): + val = self._get_str(k, NO_DEFAULT) yield k, val - elif include_nested: - for k in self.keys(): - val = self._get_str(k, NO_DEFAULT) - yield k, val - if not is_leaf(type(val)): - yield from ( - (_unravel_key_to_tuple((k, _key)), _val) - for _key, _val in val.items( - include_nested=include_nested, - leaves_only=leaves_only, - is_leaf=is_leaf, + if not is_leaf(type(val)): + yield from ( + (_unravel_key_to_tuple((k, _key)), _val) + for _key, _val in val.items( + include_nested=include_nested, + leaves_only=leaves_only, + is_leaf=is_leaf, + ) ) - ) - elif leaves_only: - for k in self.keys(): - val = self._get_str(k, NO_DEFAULT) - if is_leaf(type(val)): - yield k, val + elif leaves_only: + for k in self.keys(): + val = self._get_str(k, NO_DEFAULT) + if is_leaf(type(val)): + yield k, val + else: + for k in self.keys(): + yield k, self._get_str(k, NO_DEFAULT) + + if sort: + yield from sorted( + _items(), + key=lambda item: ( + item[0] if isinstance(item[0], str) else ".".join(item[0]) + ), + ) else: - for k in self.keys(): - yield k, self._get_str(k, NO_DEFAULT) + yield from _items() def non_tensor_items(self, include_nested: bool = False): """Returns all non-tensor leaves, maybe recursively.""" @@ -5219,7 +5242,9 @@ def values( include_nested: bool = False, leaves_only: bool = False, is_leaf=None, - ) -> Iterator[CompatibleType]: + *, + sort: bool = False, + ) -> Iterator[CompatibleType]: # noqa: D417 """Returns a generator representing the values for the tensordict. Args: @@ -5230,39 +5255,54 @@ def values( is_leaf: an optional callable that indicates if a class is to be considered a leaf or not. + Keyword Args: + sort (bool, optional): whether the keys should be sorted. For nested keys, + the keys are sorted according to their joined name (ie, ``("a", "key")`` will + be counted as ``"a.key"`` for sorting). Be mindful that sorting may incur + significant overhead when dealing with large tensordicts. + Defaults to ``False``. + """ if is_leaf is None: is_leaf = _default_is_leaf - # check the conditions once only - if include_nested and leaves_only: - for k in self.keys(): - val = self._get_str(k, NO_DEFAULT) - if not is_leaf(type(val)): - yield from val.values( - include_nested=include_nested, - leaves_only=leaves_only, - is_leaf=is_leaf, - ) - else: - yield val - elif include_nested: - for k in self.keys(): - val = self._get_str(k, NO_DEFAULT) - yield val - if not is_leaf(type(val)): - yield from val.values( - include_nested=include_nested, - leaves_only=leaves_only, - is_leaf=is_leaf, - ) - elif leaves_only: - for k in self.keys(): - val = self._get_str(k, NO_DEFAULT) - if is_leaf(type(val)): + + def _values(): + # check the conditions once only + if include_nested and leaves_only: + for k in self.keys(): + val = self._get_str(k, NO_DEFAULT) + if not is_leaf(type(val)): + yield from val.values( + include_nested=include_nested, + leaves_only=leaves_only, + is_leaf=is_leaf, + ) + else: + yield val + elif include_nested: + for k in self.keys(): + val = self._get_str(k, NO_DEFAULT) yield val + if not is_leaf(type(val)): + yield from val.values( + include_nested=include_nested, + leaves_only=leaves_only, + is_leaf=is_leaf, + ) + elif leaves_only: + for k in self.keys(sort=sort): + val = self._get_str(k, NO_DEFAULT) + if is_leaf(type(val)): + yield val + else: + for k in self.keys(sort=sort): + yield self._get_str(k, NO_DEFAULT) + + if not sort or not include_nested: + yield from _values() else: - for k in self.keys(): - yield self._get_str(k, NO_DEFAULT) + for _, value in self.items(include_nested, leaves_only, is_leaf, sort=sort): + yield value @cache # noqa: B019 def _values_list( @@ -5360,7 +5400,9 @@ def keys( self, include_nested: bool = False, leaves_only: bool = False, - is_leaf: Callable[[Type], bool] = None, + is_leaf: Callable[[Type], bool] | None = None, + *, + sort: bool = False, ): """Returns a generator of tensordict keys. @@ -5372,6 +5414,13 @@ def keys( is_leaf: an optional callable that indicates if a class is to be considered a leaf or not. + Keyword Args: + sort (bool, optional): whether the keys shoulbe sorted. For nested keys, + the keys are sorted according to their joined name (ie, ``("a", "key")`` will + be counted as ``"a.key"`` for sorting). Be mindful that sorting may incur + significant overhead when dealing with large tensordicts. + Defaults to ``False``. + Examples: >>> from tensordict import TensorDict >>> data = TensorDict({"0": 0, "1": {"2": 2}}, batch_size=[]) diff --git a/tensordict/nn/params.py b/tensordict/nn/params.py index 66e7153a9..2b778815f 100644 --- a/tensordict/nn/params.py +++ b/tensordict/nn/params.py @@ -1014,10 +1014,12 @@ def values( include_nested: bool = False, leaves_only: bool = False, is_leaf: Callable[[Type], bool] | None = None, + *, + sort: bool = False, ) -> Iterator[CompatibleType]: if is_leaf is None: is_leaf = _default_is_leaf - for v in self._param_td.values(include_nested, leaves_only): + for v in self._param_td.values(include_nested, leaves_only, sort=sort): if not is_leaf(type(v)): yield v continue @@ -1082,10 +1084,12 @@ def items( include_nested: bool = False, leaves_only: bool = False, is_leaf: Callable[[Type], bool] | None = None, + *, + sort: bool = False, ) -> Iterator[CompatibleType]: if is_leaf is None: is_leaf = _default_is_leaf - for k, v in self._param_td.items(include_nested, leaves_only): + for k, v in self._param_td.items(include_nested, leaves_only, sort=sort): if not is_leaf(type(v)): yield k, v continue diff --git a/tensordict/persistent.py b/tensordict/persistent.py index 4da8c118d..54b0375a5 100644 --- a/tensordict/persistent.py +++ b/tensordict/persistent.py @@ -471,6 +471,8 @@ def keys( include_nested: bool = False, leaves_only: bool = False, is_leaf: Callable[[Type], bool] | None = None, + *, + sort: bool = False, ) -> _PersistentTDKeysView: if is_leaf not in (None, _default_is_leaf, _is_leaf_nontensor): raise ValueError( @@ -481,6 +483,7 @@ def keys( include_nested=include_nested, leaves_only=leaves_only, is_leaf=is_leaf, + sort=sort, ) def _items_metadata(self, include_nested=False, leaves_only=False): diff --git a/test/test_tensordict.py b/test/test_tensordict.py index 583c592cf..5ff085150 100644 --- a/test/test_tensordict.py +++ b/test/test_tensordict.py @@ -2239,6 +2239,37 @@ def assert_shared(td0): td0 = td0.squeeze(0) assert_shared(td0) + def test_sorted_keys(self): + td = TensorDict( + { + "a": {"b": 0, "c": 1}, + "d": 2, + "e": {"f": 3, "g": {"h": 4, "i": 5}, "j": 6}, + } + ) + tdflat = td.flatten_keys() + tdflat["d"] = tdflat.pop("d") + tdflat["a.b"] = tdflat.pop("a.b") + for key1, key2 in zip( + td.keys(True, True, sort=True), tdflat.keys(True, True, sort=True) + ): + if isinstance(key1, str): + assert key1 == key2 + else: + assert ".".join(key1) == key2 + for v1, v2 in zip( + td.values(True, True, sort=True), tdflat.values(True, True, sort=True) + ): + assert v1 == v2 + for (k1, v1), (k2, v2) in zip( + td.items(True, True, sort=True), tdflat.items(True, True, sort=True) + ): + if isinstance(k1, str): + assert k1 == k2 + else: + assert ".".join(k1) == k2 + assert v1 == v2 + def test_split_keys(self): td = TensorDict( {