Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BugFix] Allowing for auto-nested tensordict #119

Closed
wants to merge 7 commits into from
Closed
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 22 additions & 3 deletions tensordict/tensordict.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@ def __init__(
self.tensordict = tensordict
self.include_nested = include_nested
self.leaves_only = leaves_only
self.loop_set = set()

def __iter__(self):
if not self.include_nested:
Expand All @@ -153,12 +154,13 @@ def __iter__(self):

def _iter_helper(self, tensordict, prefix=None):
items_iter = self._items(tensordict)

self.loop_set.add(id(tensordict))
vmoens marked this conversation as resolved.
Show resolved Hide resolved
for key, value in items_iter:
full_key = self._combine_keys(prefix, key)
if (
isinstance(value, (TensorDictBase, KeyedJaggedTensor))
and self.include_nested
and id(value) not in self.loop_set
tcbegley marked this conversation as resolved.
Show resolved Hide resolved
):
subkeys = tuple(
self._iter_helper(
Expand All @@ -169,6 +171,7 @@ def _iter_helper(self, tensordict, prefix=None):
yield from subkeys
if not (isinstance(value, TensorDictBase) and self.leaves_only):
yield full_key
self.loop_set.remove(id(tensordict))

def _combine_keys(self, prefix, key):
if prefix is not None:
Expand Down Expand Up @@ -265,6 +268,8 @@ def __setstate__(self, state: Dict[str, Any]) -> Dict[str, Any]:

def __init__(self):
self._dict_meta = KeyDependentDefaultDict(self._make_meta)
self._being_called = False
self._being_flattened = False

@abc.abstractmethod
def _make_meta(self, key: str) -> MetaTensor:
Expand Down Expand Up @@ -1717,7 +1722,11 @@ def permute(
)

def __repr__(self) -> str:
if self._being_called:
Zooll marked this conversation as resolved.
Show resolved Hide resolved
Zooll marked this conversation as resolved.
Show resolved Hide resolved
return "Auto-nested"
self._being_called = True
fields = _td_fields(self)
self._being_called = False
field_str = indent(f"fields={{{fields}}}", 4 * " ")
batch_size_str = indent(f"batch_size={self.batch_size}", 4 * " ")
device_str = indent(f"device={self.device}", 4 * " ")
Expand Down Expand Up @@ -1791,6 +1800,10 @@ def __iter__(self) -> Generator:
def flatten_keys(
self, separator: str = ".", inplace: bool = False
) -> TensorDictBase:
if self._being_flattened:
Zooll marked this conversation as resolved.
Show resolved Hide resolved
return self

self._being_flattened = True
to_flatten = []
existing_keys = self.keys(include_nested=True)
for key, meta_value in self.items_meta():
Expand All @@ -1815,6 +1828,7 @@ def flatten_keys(
self.set(separator.join([key, inner_key]), inner_item)
for key in to_flatten:
del self[key]
self._being_flattened = False
return self
else:
tensordict_out = TensorDict(
Expand All @@ -1830,10 +1844,14 @@ def flatten_keys(
inner_tensordict = self.get(key).flatten_keys(
separator=separator, inplace=inplace
)
for inner_key, inner_item in inner_tensordict.items():
tensordict_out.set(separator.join([key, inner_key]), inner_item)
if inner_tensordict is not self.get(key):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if this is compatible with lazy tensordicts (eg LazyStackedTensorDict) where two calls to get(key) return items which ids are different (eg because they are the results of a call to stack)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's interesting point. I didn't research 'LazyStackedTensorDict' 's behaviour. But if 'get(key)' return new ids every time how can we detect self loop?

for inner_key, inner_item in inner_tensordict.items():
tensordict_out.set(
separator.join([key, inner_key]), inner_item
)
else:
tensordict_out.set(key, value)
self._being_flattened = False
return tensordict_out

def unflatten_keys(
Expand Down Expand Up @@ -4623,6 +4641,7 @@ def __init__(
device: Optional[torch.device] = None,
batch_size: Optional[Sequence[int]] = None,
):
super().__init__()
if not isinstance(source, TensorDictBase):
raise TypeError(
f"Expected source to be a TensorDictBase instance, but got {type(source)} instead."
Expand Down