Skip to content

Commit

Permalink
Support instantiation of _TensorDictKeysView from SubTensorDict
Browse files Browse the repository at this point in the history
  • Loading branch information
tcbegley committed Feb 9, 2023
1 parent 7a05d2c commit 441b521
Showing 1 changed file with 3 additions and 0 deletions.
3 changes: 3 additions & 0 deletions tensordict/tensordict.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,8 @@ def _items(self, tensordict=None):
tensordict = self.tensordict
if isinstance(tensordict, TensorDict):
return tensordict._tensordict.items()
elif isinstance(tensordict, SubTensorDict):
return tensordict._source._tensordict.items()
elif isinstance(tensordict, LazyStackedTensorDict):
return _iter_items_lazystack(tensordict)
elif isinstance(tensordict, KeyedJaggedTensor):
Expand All @@ -242,6 +244,7 @@ def _items(self, tensordict=None):
# or _CustomOpTensorDict, so as we iterate through the contents we need to
# be careful to not rely on tensordict._tensordict existing.
return ((key, tensordict.get(key)) for key in tensordict._source.keys())
raise NotImplementedError(f"_TensorDictKeysView doesn't support {tensordict.__class__}")

def _keys(self):
return self.tensordict._tensordict.keys()
Expand Down

0 comments on commit 441b521

Please sign in to comment.