diff --git a/tensordict/tensordict.py b/tensordict/tensordict.py index 91bc04a54..08c35a361 100644 --- a/tensordict/tensordict.py +++ b/tensordict/tensordict.py @@ -227,12 +227,13 @@ def __len__(self): i += 1 return i - # TODO fix method for SubTensorDict case def _items(self, tensordict=None): if tensordict is None: tensordict = self.tensordict if isinstance(tensordict, TensorDict): return tensordict._tensordict.items() + elif isinstance(tensordict, SubTensorDict): + return tensordict._source._tensordict.items() elif isinstance(tensordict, LazyStackedTensorDict): return _iter_items_lazystack(tensordict) elif isinstance(tensordict, KeyedJaggedTensor): @@ -242,6 +243,9 @@ def _items(self, tensordict=None): # or _CustomOpTensorDict, so as we iterate through the contents we need to # be careful to not rely on tensordict._tensordict existing. return ((key, tensordict.get(key)) for key in tensordict._source.keys()) + raise NotImplementedError( + f"_TensorDictKeysView doesn't support {tensordict.__class__}" + ) def _keys(self): return self.tensordict._tensordict.keys()