Skip to content

Commit

Permalink
Support instantiation of _TensorDictKeysView from SubTensorDict
Browse files Browse the repository at this point in the history
  • Loading branch information
tcbegley committed Feb 9, 2023
1 parent 7a05d2c commit 34a0275
Showing 1 changed file with 5 additions and 1 deletion.
6 changes: 5 additions & 1 deletion tensordict/tensordict.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,12 +227,13 @@ def __len__(self):
i += 1
return i

# TODO fix method for SubTensorDict case
def _items(self, tensordict=None):
if tensordict is None:
tensordict = self.tensordict
if isinstance(tensordict, TensorDict):
return tensordict._tensordict.items()
elif isinstance(tensordict, SubTensorDict):
return tensordict._source._tensordict.items()
elif isinstance(tensordict, LazyStackedTensorDict):
return _iter_items_lazystack(tensordict)
elif isinstance(tensordict, KeyedJaggedTensor):
Expand All @@ -242,6 +243,9 @@ def _items(self, tensordict=None):
# or _CustomOpTensorDict, so as we iterate through the contents we need to
# be careful to not rely on tensordict._tensordict existing.
return ((key, tensordict.get(key)) for key in tensordict._source.keys())
raise NotImplementedError(
f"_TensorDictKeysView doesn't support {tensordict.__class__}"
)

def _keys(self):
return self.tensordict._tensordict.keys()
Expand Down

0 comments on commit 34a0275

Please sign in to comment.