Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add recursion guard and tidy tests #220

Merged
merged 1 commit into from
Feb 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 39 additions & 2 deletions tensordict/tensordict.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,22 @@ def is_memmap(datatype: type) -> bool:
_NestedKey = namedtuple("_NestedKey", ["root_key", "nested_key"])


def _recursion_guard(fn):
# catches RecursionError and warns of auto-nesting
@functools.wraps(fn)
def wrapper(*args, **kwargs):
try:
return fn(*args, **kwargs)
except RecursionError as e:
raise RecursionError(
f"{fn.__name__.lstrip('_')} failed due to a recursion error. It's possible the "
"TensorDict has auto-nested values, which are not supported by this "
f"function."
) from e

return wrapper


class _TensorDictKeysView:
"""
_TensorDictKeysView is returned when accessing tensordict.keys() and holds a
Expand Down Expand Up @@ -635,6 +651,7 @@ def apply_(self, fn: Callable) -> TensorDictBase:
"""
return _apply_safe(lambda _, value: fn(value), self, inplace=True)

@_recursion_guard
def apply(
self,
fn: Callable,
Expand Down Expand Up @@ -1249,6 +1266,7 @@ def zero_(self) -> TensorDictBase:
self.get(key).zero_()
return self

@_recursion_guard
def unbind(self, dim: int) -> Tuple[TensorDictBase, ...]:
"""Returns a tuple of indexed tensordicts unbound along the indicated dimension.

Expand Down Expand Up @@ -1668,6 +1686,7 @@ def split(
for i in range(len(dictionaries))
]

@_recursion_guard
def gather(self, dim: int, index: torch.Tensor, out=None):
"""Gathers values along an axis specified by `dim`.

Expand Down Expand Up @@ -1925,6 +1944,7 @@ def __iter__(self) -> Generator:
for i in range(length):
yield self[i]

@_recursion_guard
def flatten_keys(
self, separator: str = ".", inplace: bool = False
) -> TensorDictBase:
Expand Down Expand Up @@ -3086,7 +3106,12 @@ def masked_fill(self, mask: Tensor, value: Union[float, bool]) -> TensorDictBase
return td_copy.masked_fill_(mask, value)

def is_contiguous(self) -> bool:
return all([value.is_contiguous() for _, value in self.items()])
return all(
self.get(key).is_contiguous()
for key in _TensorDictKeysView(
self, include_nested=True, leaves_only=True, error_on_loop=False
)
)

def contiguous(self) -> TensorDictBase:
if not self.is_contiguous():
Expand Down Expand Up @@ -3122,8 +3147,17 @@ def select(
d[key] = value
except KeyError:
if strict:
# TODO: in the case of auto-nesting, this error will not list all of
# the (infinitely many) keys, and so there would be valid keys for
# selection that do not appear in the error message.
keys_view = _TensorDictKeysView(
self,
include_nested=True,
leaves_only=False,
error_on_loop=False,
)
raise KeyError(
f"Key '{key}' was not found among keys {set(self.keys(True))}."
f"Key '{key}' was not found among keys {set(keys_view)}."
)
else:
continue
Expand Down Expand Up @@ -3295,11 +3329,13 @@ def assert_allclose_td(


@implements_for_td(torch.unbind)
@_recursion_guard
def _unbind(td: TensorDictBase, *args, **kwargs) -> Tuple[TensorDictBase, ...]:
return td.unbind(*args, **kwargs)


@implements_for_td(torch.gather)
@_recursion_guard
def _gather(
input: TensorDictBase,
dim: int,
Expand Down Expand Up @@ -3627,6 +3663,7 @@ def recurse(list_of_tds, out, dim, prefix=()):
return out


@_recursion_guard
def pad(tensordict: TensorDictBase, pad_size: Sequence[int], value: float = 0.0):
"""Pads all tensors in a tensordict along the batch dimensions with a constant value, returning a new tensordict.

Expand Down
Loading