Skip to content

Commit

Permalink
[Feature] sorted keys, values and items
Browse files Browse the repository at this point in the history
ghstack-source-id: af5b112bc5a3dbabf5c490026f883bac89a6a052
Pull Request resolved: #965
  • Loading branch information
vmoens committed Sep 16, 2024
1 parent 48d52d2 commit 257c859
Show file tree
Hide file tree
Showing 6 changed files with 271 additions and 137 deletions.
40 changes: 35 additions & 5 deletions tensordict/_lazy.py
Original file line number Diff line number Diff line change
Expand Up @@ -1550,45 +1550,70 @@ def keys(
include_nested: bool = False,
leaves_only: bool = False,
is_leaf: Callable[[Type], bool] | None = None,
*,
sort: bool = False,
) -> _LazyStackedTensorDictKeysView:
keys = _LazyStackedTensorDictKeysView(
self,
include_nested=include_nested,
leaves_only=leaves_only,
is_leaf=is_leaf,
sort=sort,
)
return keys

def values(self, include_nested=False, leaves_only=False, is_leaf=None):
def values(
self,
include_nested=False,
leaves_only=False,
is_leaf=None,
*,
sort: bool = False,
):
if is_leaf not in (
_NESTED_TENSORS_AS_LISTS,
_NESTED_TENSORS_AS_LISTS_NONTENSOR,
):
yield from super().values(
include_nested=include_nested, leaves_only=leaves_only, is_leaf=is_leaf
include_nested=include_nested,
leaves_only=leaves_only,
is_leaf=is_leaf,
sort=sort,
)
else:
for td in self.tensordicts:
yield from td.values(
include_nested=include_nested,
leaves_only=leaves_only,
is_leaf=is_leaf,
sort=sort,
)

def items(self, include_nested=False, leaves_only=False, is_leaf=None):
def items(
self,
include_nested=False,
leaves_only=False,
is_leaf=None,
*,
sort: bool = False,
):
if is_leaf not in (
_NESTED_TENSORS_AS_LISTS,
_NESTED_TENSORS_AS_LISTS_NONTENSOR,
):
yield from super().items(
include_nested=include_nested, leaves_only=leaves_only, is_leaf=is_leaf
include_nested=include_nested,
leaves_only=leaves_only,
is_leaf=is_leaf,
sort=sort,
)
else:
for i, td in enumerate(self.tensordicts):
for key, val in td.items(
include_nested=include_nested,
leaves_only=leaves_only,
is_leaf=is_leaf,
sort=sort,
):
if isinstance(key, str):
key = (str(i), key)
Expand Down Expand Up @@ -3383,9 +3408,14 @@ def keys(
include_nested: bool = False,
leaves_only: bool = False,
is_leaf: Callable[[Type], bool] | None = None,
*,
sort: bool = False,
) -> _TensorDictKeysView:
return self._source.keys(
include_nested=include_nested, leaves_only=leaves_only, is_leaf=is_leaf
include_nested=include_nested,
leaves_only=leaves_only,
is_leaf=is_leaf,
sort=sort,
)

def _select(
Expand Down
149 changes: 83 additions & 66 deletions tensordict/_td.py
Original file line number Diff line number Diff line change
Expand Up @@ -3154,12 +3154,23 @@ def keys(
include_nested: bool = False,
leaves_only: bool = False,
is_leaf: Callable[[Type], bool] | None = None,
*,
sort: bool = False,
) -> _TensorDictKeysView:
if not include_nested and not leaves_only and is_leaf is None:
return _StringKeys(self._tensordict.keys())
if not sort:
return _StringKeys(self._tensordict.keys())
else:
return sorted(
_StringKeys(self._tensordict.keys()),
key=lambda x: ".".join(x) if isinstance(x, tuple) else x,
)
else:
return self._nested_keys(
include_nested=include_nested, leaves_only=leaves_only, is_leaf=is_leaf
include_nested=include_nested,
leaves_only=leaves_only,
is_leaf=is_leaf,
sort=sort,
)

@cache # noqa: B019
Expand All @@ -3168,12 +3179,15 @@ def _nested_keys(
include_nested: bool = False,
leaves_only: bool = False,
is_leaf: Callable[[Type], bool] | None = None,
*,
sort: bool = False,
) -> _TensorDictKeysView:
return _TensorDictKeysView(
self,
include_nested=include_nested,
leaves_only=leaves_only,
is_leaf=is_leaf,
sort=sort,
)

# some custom methods for efficiency
Expand All @@ -3182,81 +3196,68 @@ def items(
include_nested: bool = False,
leaves_only: bool = False,
is_leaf: Callable[[Type], bool] | None = None,
*,
sort: bool = False,
) -> Iterator[tuple[str, CompatibleType]]:
if not include_nested and not leaves_only:
return self._tensordict.items()
elif include_nested and leaves_only:
if not sort:
return self._tensordict.items()
return sorted(self._tensordict.items(), key=lambda x: x[0])
elif include_nested and leaves_only and not sort:
is_leaf = _default_is_leaf if is_leaf is None else is_leaf
result = []
if is_dynamo_compiling():

def fast_iter():
for key, val in self._tensordict.items():
if not is_leaf(type(val)):
for _key, _val in val.items(
include_nested=include_nested,
leaves_only=leaves_only,
is_leaf=is_leaf,
):
result.append(
(
(
key,
*(
(_key,)
if isinstance(_key, str)
else _key
),
),
_val,
)
)
else:
result.append((key, val))
return result

else:
# dynamo doesn't like generators
def fast_iter():
for key, val in self._tensordict.items():
if not is_leaf(type(val)):
yield from (
(
(
key,
*((_key,) if isinstance(_key, str) else _key),
),
_val,
)
for _key, _val in val.items(
include_nested=include_nested,
leaves_only=leaves_only,
is_leaf=is_leaf,
)
)
else:
yield (key, val)
def fast_iter():
for key, val in self._tensordict.items():
# We could easily make this faster, here we're iterating twice over the keys,
# but we could iterate just once.
# Ideally we should make a "dirty" list of items then call unravel_key on all of them.
if not is_leaf(type(val)):
for _key, _val in val.items(
include_nested=include_nested,
leaves_only=leaves_only,
is_leaf=is_leaf,
):
if isinstance(_key, str):
_key = (key, _key)
else:
_key = (key, *_key)
result.append((_key, _val))
else:
result.append((key, val))
return result

return fast_iter()
else:
return super().items(
include_nested=include_nested, leaves_only=leaves_only, is_leaf=is_leaf
include_nested=include_nested,
leaves_only=leaves_only,
is_leaf=is_leaf,
sort=sort,
)

def values(
self,
include_nested: bool = False,
leaves_only: bool = False,
is_leaf: Callable[[Type], bool] | None = None,
*,
sort: bool = False,
) -> Iterator[tuple[str, CompatibleType]]:
if not include_nested and not leaves_only:
return self._tensordict.values()
if not sort:
return self._tensordict.values()
else:
return list(zip(*sorted(self._tensordict.items(), key=lambda x: x[0])))[
1
]
else:
return TensorDictBase.values(
self,
include_nested=include_nested,
leaves_only=leaves_only,
is_leaf=is_leaf,
sort=sort,
)


Expand Down Expand Up @@ -3545,9 +3546,14 @@ def keys(
include_nested: bool = False,
leaves_only: bool = False,
is_leaf: Callable[[Type], bool] | None = None,
*,
sort: bool = False,
) -> _TensorDictKeysView:
return self._source.keys(
include_nested=include_nested, leaves_only=leaves_only, is_leaf=is_leaf
include_nested=include_nested,
leaves_only=leaves_only,
is_leaf=is_leaf,
sort=sort,
)

def entry_class(self, key: NestedKey) -> type:
Expand Down Expand Up @@ -4184,29 +4190,40 @@ def __init__(
include_nested: bool,
leaves_only: bool,
is_leaf: Callable[[Type], bool] = None,
sort: bool = False,
) -> None:
self.tensordict = tensordict
self.include_nested = include_nested
self.leaves_only = leaves_only
if is_leaf is None:
is_leaf = _default_is_leaf
self.is_leaf = is_leaf
self.sort = sort

def __iter__(self) -> Iterable[str] | Iterable[tuple[str, ...]]:
if not self.include_nested:
if self.leaves_only:
for key in self._keys():
target_class = self.tensordict.entry_class(key)
if _is_tensor_collection(target_class):
continue
yield key
def _iter():
if not self.include_nested:
if self.leaves_only:
for key in self._keys():
target_class = self.tensordict.entry_class(key)
if _is_tensor_collection(target_class):
continue
yield key
else:
yield from self._keys()
else:
yield from self._keys()
else:
yield from (
key if len(key) > 1 else key[0]
for key in self._iter_helper(self.tensordict)
yield from (
key if len(key) > 1 else key[0]
for key in self._iter_helper(self.tensordict)
)

if self.sort:
yield from sorted(
_iter(),
key=lambda key: ".".join(key) if isinstance(key, tuple) else key,
)
else:
yield from _iter()

def _iter_helper(
self, tensordict: T, prefix: str | None = None
Expand Down
Loading

0 comments on commit 257c859

Please sign in to comment.