Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] sorted keys, values and items #965

Merged
merged 4 commits into from
Sep 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 35 additions & 5 deletions tensordict/_lazy.py
Original file line number Diff line number Diff line change
Expand Up @@ -1550,45 +1550,70 @@ def keys(
include_nested: bool = False,
leaves_only: bool = False,
is_leaf: Callable[[Type], bool] | None = None,
*,
sort: bool = False,
) -> _LazyStackedTensorDictKeysView:
keys = _LazyStackedTensorDictKeysView(
self,
include_nested=include_nested,
leaves_only=leaves_only,
is_leaf=is_leaf,
sort=sort,
)
return keys

def values(self, include_nested=False, leaves_only=False, is_leaf=None):
def values(
self,
include_nested=False,
leaves_only=False,
is_leaf=None,
*,
sort: bool = False,
):
if is_leaf not in (
_NESTED_TENSORS_AS_LISTS,
_NESTED_TENSORS_AS_LISTS_NONTENSOR,
):
yield from super().values(
include_nested=include_nested, leaves_only=leaves_only, is_leaf=is_leaf
include_nested=include_nested,
leaves_only=leaves_only,
is_leaf=is_leaf,
sort=sort,
)
else:
for td in self.tensordicts:
yield from td.values(
include_nested=include_nested,
leaves_only=leaves_only,
is_leaf=is_leaf,
sort=sort,
)

def items(self, include_nested=False, leaves_only=False, is_leaf=None):
def items(
self,
include_nested=False,
leaves_only=False,
is_leaf=None,
*,
sort: bool = False,
):
if is_leaf not in (
_NESTED_TENSORS_AS_LISTS,
_NESTED_TENSORS_AS_LISTS_NONTENSOR,
):
yield from super().items(
include_nested=include_nested, leaves_only=leaves_only, is_leaf=is_leaf
include_nested=include_nested,
leaves_only=leaves_only,
is_leaf=is_leaf,
sort=sort,
)
else:
for i, td in enumerate(self.tensordicts):
for key, val in td.items(
include_nested=include_nested,
leaves_only=leaves_only,
is_leaf=is_leaf,
sort=sort,
):
if isinstance(key, str):
key = (str(i), key)
Expand Down Expand Up @@ -3381,9 +3406,14 @@ def keys(
include_nested: bool = False,
leaves_only: bool = False,
is_leaf: Callable[[Type], bool] | None = None,
*,
sort: bool = False,
) -> _TensorDictKeysView:
return self._source.keys(
include_nested=include_nested, leaves_only=leaves_only, is_leaf=is_leaf
include_nested=include_nested,
leaves_only=leaves_only,
is_leaf=is_leaf,
sort=sort,
)

def _select(
Expand Down
149 changes: 83 additions & 66 deletions tensordict/_td.py
Original file line number Diff line number Diff line change
Expand Up @@ -3144,12 +3144,23 @@ def keys(
include_nested: bool = False,
leaves_only: bool = False,
is_leaf: Callable[[Type], bool] | None = None,
*,
sort: bool = False,
) -> _TensorDictKeysView:
if not include_nested and not leaves_only and is_leaf is None:
return _StringKeys(self._tensordict.keys())
if not sort:
return _StringKeys(self._tensordict.keys())
else:
return sorted(
_StringKeys(self._tensordict.keys()),
key=lambda x: ".".join(x) if isinstance(x, tuple) else x,
)
else:
return self._nested_keys(
include_nested=include_nested, leaves_only=leaves_only, is_leaf=is_leaf
include_nested=include_nested,
leaves_only=leaves_only,
is_leaf=is_leaf,
sort=sort,
)

@cache # noqa: B019
Expand All @@ -3158,12 +3169,15 @@ def _nested_keys(
include_nested: bool = False,
leaves_only: bool = False,
is_leaf: Callable[[Type], bool] | None = None,
*,
sort: bool = False,
) -> _TensorDictKeysView:
return _TensorDictKeysView(
self,
include_nested=include_nested,
leaves_only=leaves_only,
is_leaf=is_leaf,
sort=sort,
)

# some custom methods for efficiency
Expand All @@ -3172,81 +3186,68 @@ def items(
include_nested: bool = False,
leaves_only: bool = False,
is_leaf: Callable[[Type], bool] | None = None,
*,
sort: bool = False,
) -> Iterator[tuple[str, CompatibleType]]:
if not include_nested and not leaves_only:
return self._tensordict.items()
elif include_nested and leaves_only:
if not sort:
return self._tensordict.items()
return sorted(self._tensordict.items(), key=lambda x: x[0])
elif include_nested and leaves_only and not sort:
is_leaf = _default_is_leaf if is_leaf is None else is_leaf
result = []
if is_dynamo_compiling():

def fast_iter():
for key, val in self._tensordict.items():
if not is_leaf(type(val)):
for _key, _val in val.items(
include_nested=include_nested,
leaves_only=leaves_only,
is_leaf=is_leaf,
):
result.append(
(
(
key,
*(
(_key,)
if isinstance(_key, str)
else _key
),
),
_val,
)
)
else:
result.append((key, val))
return result

else:
# dynamo doesn't like generators
def fast_iter():
for key, val in self._tensordict.items():
if not is_leaf(type(val)):
yield from (
(
(
key,
*((_key,) if isinstance(_key, str) else _key),
),
_val,
)
for _key, _val in val.items(
include_nested=include_nested,
leaves_only=leaves_only,
is_leaf=is_leaf,
)
)
else:
yield (key, val)
def fast_iter():
for key, val in self._tensordict.items():
# We could easily make this faster, here we're iterating twice over the keys,
# but we could iterate just once.
# Ideally we should make a "dirty" list of items then call unravel_key on all of them.
if not is_leaf(type(val)):
for _key, _val in val.items(
include_nested=include_nested,
leaves_only=leaves_only,
is_leaf=is_leaf,
):
if isinstance(_key, str):
_key = (key, _key)
else:
_key = (key, *_key)
result.append((_key, _val))
else:
result.append((key, val))
return result

return fast_iter()
else:
return super().items(
include_nested=include_nested, leaves_only=leaves_only, is_leaf=is_leaf
include_nested=include_nested,
leaves_only=leaves_only,
is_leaf=is_leaf,
sort=sort,
)

def values(
self,
include_nested: bool = False,
leaves_only: bool = False,
is_leaf: Callable[[Type], bool] | None = None,
*,
sort: bool = False,
) -> Iterator[tuple[str, CompatibleType]]:
if not include_nested and not leaves_only:
return self._tensordict.values()
if not sort:
return self._tensordict.values()
else:
return list(zip(*sorted(self._tensordict.items(), key=lambda x: x[0])))[
1
]
else:
return TensorDictBase.values(
self,
include_nested=include_nested,
leaves_only=leaves_only,
is_leaf=is_leaf,
sort=sort,
)


Expand Down Expand Up @@ -3535,9 +3536,14 @@ def keys(
include_nested: bool = False,
leaves_only: bool = False,
is_leaf: Callable[[Type], bool] | None = None,
*,
sort: bool = False,
) -> _TensorDictKeysView:
return self._source.keys(
include_nested=include_nested, leaves_only=leaves_only, is_leaf=is_leaf
include_nested=include_nested,
leaves_only=leaves_only,
is_leaf=is_leaf,
sort=sort,
)

def entry_class(self, key: NestedKey) -> type:
Expand Down Expand Up @@ -4172,29 +4178,40 @@ def __init__(
include_nested: bool,
leaves_only: bool,
is_leaf: Callable[[Type], bool] = None,
sort: bool = False,
) -> None:
self.tensordict = tensordict
self.include_nested = include_nested
self.leaves_only = leaves_only
if is_leaf is None:
is_leaf = _default_is_leaf
self.is_leaf = is_leaf
self.sort = sort

def __iter__(self) -> Iterable[str] | Iterable[tuple[str, ...]]:
if not self.include_nested:
if self.leaves_only:
for key in self._keys():
target_class = self.tensordict.entry_class(key)
if _is_tensor_collection(target_class):
continue
yield key
def _iter():
if not self.include_nested:
if self.leaves_only:
for key in self._keys():
target_class = self.tensordict.entry_class(key)
if _is_tensor_collection(target_class):
continue
yield key
else:
yield from self._keys()
else:
yield from self._keys()
else:
yield from (
key if len(key) > 1 else key[0]
for key in self._iter_helper(self.tensordict)
yield from (
key if len(key) > 1 else key[0]
for key in self._iter_helper(self.tensordict)
)

if self.sort:
yield from sorted(
_iter(),
key=lambda key: ".".join(key) if isinstance(key, tuple) else key,
)
else:
yield from _iter()

def _iter_helper(
self, tensordict: T, prefix: str | None = None
Expand Down
Loading
Loading