From 05762ecf816326211e7a6a8977c8c5078846cd33 Mon Sep 17 00:00:00 2001 From: Vladislav Agafonov Date: Fri, 30 Dec 2022 15:19:20 +0000 Subject: [PATCH 1/6] fix [BUG] Auto-nested tensordict bugs #106 --- tensordict/tensordict.py | 25 ++++++++++++++++++++++--- 1 file changed, 22 insertions(+), 3 deletions(-) diff --git a/tensordict/tensordict.py b/tensordict/tensordict.py index 39c3f6018..4bf0d23b5 100644 --- a/tensordict/tensordict.py +++ b/tensordict/tensordict.py @@ -137,6 +137,7 @@ def __init__( self.tensordict = tensordict self.include_nested = include_nested self.leaves_only = leaves_only + self.loop_set = set() def __iter__(self): if not self.include_nested: @@ -153,12 +154,13 @@ def __iter__(self): def _iter_helper(self, tensordict, prefix=None): items_iter = self._items(tensordict) - + self.loop_set.add(id(tensordict)) for key, value in items_iter: full_key = self._combine_keys(prefix, key) if ( isinstance(value, (TensorDictBase, KeyedJaggedTensor)) and self.include_nested + and id(value) not in self.loop_set ): subkeys = tuple( self._iter_helper( @@ -169,6 +171,7 @@ def _iter_helper(self, tensordict, prefix=None): yield from subkeys if not (isinstance(value, TensorDictBase) and self.leaves_only): yield full_key + self.loop_set.remove(id(tensordict)) def _combine_keys(self, prefix, key): if prefix is not None: @@ -265,6 +268,8 @@ def __setstate__(self, state: Dict[str, Any]) -> Dict[str, Any]: def __init__(self): self._dict_meta = KeyDependentDefaultDict(self._make_meta) + self._being_called = False + self._being_flattened = False @abc.abstractmethod def _make_meta(self, key: str) -> MetaTensor: @@ -1717,7 +1722,11 @@ def permute( ) def __repr__(self) -> str: + if self._being_called: + return "Auto-nested" + self._being_called = True fields = _td_fields(self) + self._being_called = False field_str = indent(f"fields={{{fields}}}", 4 * " ") batch_size_str = indent(f"batch_size={self.batch_size}", 4 * " ") device_str = indent(f"device={self.device}", 4 * " ") @@ -1791,6 +1800,10 @@ def __iter__(self) -> Generator: def flatten_keys( self, separator: str = ".", inplace: bool = False ) -> TensorDictBase: + if self._being_flattened: + return self + + self._being_flattened = True to_flatten = [] existing_keys = self.keys(include_nested=True) for key, meta_value in self.items_meta(): @@ -1815,6 +1828,7 @@ def flatten_keys( self.set(separator.join([key, inner_key]), inner_item) for key in to_flatten: del self[key] + self._being_flattened = False return self else: tensordict_out = TensorDict( @@ -1830,10 +1844,14 @@ def flatten_keys( inner_tensordict = self.get(key).flatten_keys( separator=separator, inplace=inplace ) - for inner_key, inner_item in inner_tensordict.items(): - tensordict_out.set(separator.join([key, inner_key]), inner_item) + if inner_tensordict is not self.get(key): + for inner_key, inner_item in inner_tensordict.items(): + tensordict_out.set( + separator.join([key, inner_key]), inner_item + ) else: tensordict_out.set(key, value) + self._being_flattened = False return tensordict_out def unflatten_keys( @@ -4623,6 +4641,7 @@ def __init__( device: Optional[torch.device] = None, batch_size: Optional[Sequence[int]] = None, ): + super().__init__() if not isinstance(source, TensorDictBase): raise TypeError( f"Expected source to be a TensorDictBase instance, but got {type(source)} instead." From 9377e554a15bb5add19ce0a6da79938df37f56ed Mon Sep 17 00:00:00 2001 From: Vladislav Agafonov Date: Wed, 4 Jan 2023 18:30:17 +0000 Subject: [PATCH 2/6] try catch was added, renamed _being_called --- tensordict/tensordict.py | 106 ++++++++++++++++++++------------------- 1 file changed, 55 insertions(+), 51 deletions(-) diff --git a/tensordict/tensordict.py b/tensordict/tensordict.py index 4bf0d23b5..eb1c70f71 100644 --- a/tensordict/tensordict.py +++ b/tensordict/tensordict.py @@ -268,7 +268,7 @@ def __setstate__(self, state: Dict[str, Any]) -> Dict[str, Any]: def __init__(self): self._dict_meta = KeyDependentDefaultDict(self._make_meta) - self._being_called = False + self._being_repr = False self._being_flattened = False @abc.abstractmethod @@ -1722,11 +1722,13 @@ def permute( ) def __repr__(self) -> str: - if self._being_called: + if self._being_repr: return "Auto-nested" - self._being_called = True - fields = _td_fields(self) - self._being_called = False + try: + self._being_repr = True + fields = _td_fields(self) + finally: + self._being_repr = False field_str = indent(f"fields={{{fields}}}", 4 * " ") batch_size_str = indent(f"batch_size={self.batch_size}", 4 * " ") device_str = indent(f"device={self.device}", 4 * " ") @@ -1802,57 +1804,59 @@ def flatten_keys( ) -> TensorDictBase: if self._being_flattened: return self + try: + self._being_flattened = True + to_flatten = [] + existing_keys = self.keys(include_nested=True) + for key, meta_value in self.items_meta(): + key_split = tuple(key.split(separator)) + if meta_value.is_tensordict(): + to_flatten.append(key) + elif ( + separator in key + and key_split in existing_keys + and not self._get_meta(key_split).is_tensordict() + ): + raise KeyError( + f"Flattening keys in tensordict collides with existing key '{key}'" + ) - self._being_flattened = True - to_flatten = [] - existing_keys = self.keys(include_nested=True) - for key, meta_value in self.items_meta(): - key_split = tuple(key.split(separator)) - if meta_value.is_tensordict(): - to_flatten.append(key) - elif ( - separator in key - and key_split in existing_keys - and not self._get_meta(key_split).is_tensordict() - ): - raise KeyError( - f"Flattening keys in tensordict collides with existing key '{key}'" - ) - - if inplace: - for key in to_flatten: - inner_tensordict = self.get(key).flatten_keys( - separator=separator, inplace=inplace - ) - for inner_key, inner_item in inner_tensordict.items(): - self.set(separator.join([key, inner_key]), inner_item) - for key in to_flatten: - del self[key] - self._being_flattened = False - return self - else: - tensordict_out = TensorDict( - {}, - batch_size=self.batch_size, - device=self.device, - _run_checks=False, - _is_shared=self.is_shared(), - _is_memmap=self.is_memmap(), - ) - for key, value in self.items(): - if key in to_flatten: + if inplace: + for key in to_flatten: inner_tensordict = self.get(key).flatten_keys( separator=separator, inplace=inplace ) - if inner_tensordict is not self.get(key): - for inner_key, inner_item in inner_tensordict.items(): - tensordict_out.set( - separator.join([key, inner_key]), inner_item - ) - else: - tensordict_out.set(key, value) + for inner_key, inner_item in inner_tensordict.items(): + self.set(separator.join([key, inner_key]), inner_item) + for key in to_flatten: + del self[key] + self._being_flattened = False + return self + else: + tensordict_out = TensorDict( + {}, + batch_size=self.batch_size, + device=self.device, + _run_checks=False, + _is_shared=self.is_shared(), + _is_memmap=self.is_memmap(), + ) + for key, value in self.items(): + if key in to_flatten: + inner_tensordict = self.get(key).flatten_keys( + separator=separator, inplace=inplace + ) + if inner_tensordict is not self.get(key): + for inner_key, inner_item in inner_tensordict.items(): + tensordict_out.set( + separator.join([key, inner_key]), inner_item + ) + else: + tensordict_out.set(key, value) + self._being_flattened = False + return tensordict_out + finally: self._being_flattened = False - return tensordict_out def unflatten_keys( self, separator: str = ".", inplace: bool = False From acdbaf390172aee088cda288b74637793e8eb990 Mon Sep 17 00:00:00 2001 From: Vladislav Agafonov Date: Wed, 4 Jan 2023 18:33:17 +0000 Subject: [PATCH 3/6] lint fix --- tensordict/tensordict.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensordict/tensordict.py b/tensordict/tensordict.py index eb1c70f71..2ab0911a6 100644 --- a/tensordict/tensordict.py +++ b/tensordict/tensordict.py @@ -1804,7 +1804,7 @@ def flatten_keys( ) -> TensorDictBase: if self._being_flattened: return self - try: + try: self._being_flattened = True to_flatten = [] existing_keys = self.keys(include_nested=True) From e40ad463a482fb688ca52349a389b77b5bf4df00 Mon Sep 17 00:00:00 2001 From: Vladislav Agafonov Date: Thu, 5 Jan 2023 20:20:05 +0000 Subject: [PATCH 4/6] added test for repr and keys methods --- test/_utils_internal.py | 5 ++++ test/test_tensordict.py | 54 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 59 insertions(+) diff --git a/test/_utils_internal.py b/test/_utils_internal.py index cd7e5e631..5adb7ee04 100644 --- a/test/_utils_internal.py +++ b/test/_utils_internal.py @@ -52,6 +52,11 @@ def nested_td(self, device): device=device, ) + def auto_nested_td(self, device): + td = self.td(device) + td["self"] = td + return td + def stacked_td(self, device): td1 = TensorDict( source={ diff --git a/test/test_tensordict.py b/test/test_tensordict.py index 7ef724f39..e98642fc8 100644 --- a/test/test_tensordict.py +++ b/test/test_tensordict.py @@ -1968,6 +1968,11 @@ def stacked_td(self, device, dtype): return stack_td([td1, td2], 2) + def auto_nested_td(self, device, dtype): + td = self.td(device, dtype) + td["self"] = td + return td + def memmap_td(self, device, dtype): return self.td(device, dtype).memmap_() @@ -2073,6 +2078,22 @@ def test_repr_stacked(self, device, dtype): is_shared={is_shared})""" assert repr(stacked_td) == expected + def test_repr_auto_nested(self, device, dtype): + auto_nested_td = self.auto_nested_td(device, dtype) + if device is not None and device.type == "cuda": + is_shared = True + else: + is_shared = False + tensor_class = "Tensor" + expected = f"""TensorDict( + fields={{ + a: {tensor_class}(torch.Size([4, 3, 2, 1, 5]), dtype={dtype}), + self: Auto-nested}}, + batch_size=torch.Size([4, 3, 2, 1]), + device={str(device)}, + is_shared={is_shared})""" + assert repr(auto_nested_td) == expected + @pytest.mark.parametrize("index", [None, (slice(None), 0)]) def test_repr_indexed_tensordict(self, device, dtype, index): tensordict = self.td(device, dtype)[index] @@ -3436,6 +3457,39 @@ def test_shared_inheritance(): assert td0.is_shared() +@pytest.mark.parametrize("include_nested", [True, False]) +def test_keys_auto_nested(include_nested): + td1 = TensorDict( + { + "a": torch.ones(3, 4, 5), + "b": torch.zeros(3, 4, 5, dtype=torch.bool), + }, + batch_size=[3, 4], + ) + td2 = TensorDict( + { + "a": torch.ones(3, 4, 5), + "b": torch.zeros(3, 4, 5, dtype=torch.bool), + }, + batch_size=[3, 4], + ) + td1["nested"] = td2 + td1["self"] = td1 + + keys_set = set(td1.keys(include_nested=include_nested)) # Ok + if include_nested is True: + assert keys_set == { + "a", + "b", + "self", + "nested", + ("nested", "a"), + ("nested", "b"), + } + else: + assert keys_set == {"a", "b", "self", "nested"} + + if __name__ == "__main__": args, unknown = argparse.ArgumentParser().parse_known_args() pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown) From 1d4571cb079c325b7985c2f1b5ff784d359df621 Mon Sep 17 00:00:00 2001 From: Vladislav Agafonov Date: Fri, 6 Jan 2023 11:28:18 +0000 Subject: [PATCH 5/6] added auto_nested_td into test, many tests are broken --- test/test_tensordict.py | 1 + 1 file changed, 1 insertion(+) diff --git a/test/test_tensordict.py b/test/test_tensordict.py index e98642fc8..3bd276bd5 100644 --- a/test/test_tensordict.py +++ b/test/test_tensordict.py @@ -624,6 +624,7 @@ def test_convert_ellipsis_to_idx_invalid(ellipsis_index, expectation): "td_reset_bs", "nested_td", "permute_td", + "auto_nested_td", ], ) @pytest.mark.parametrize("device", get_available_devices()) From 397b99fa91b7f89d158d13422c04f26e3655bc7e Mon Sep 17 00:00:00 2001 From: vmoens Date: Wed, 18 Jan 2023 15:46:44 +0000 Subject: [PATCH 6/6] init --- tensordict/tensordict.py | 211 +++++++++++++++++++++++++++++++-------- 1 file changed, 168 insertions(+), 43 deletions(-) diff --git a/tensordict/tensordict.py b/tensordict/tensordict.py index 2ab0911a6..d07727e1f 100644 --- a/tensordict/tensordict.py +++ b/tensordict/tensordict.py @@ -1004,24 +1004,61 @@ def __eq__(self, other: object) -> TensorDictBase: tensors of the same shape as the original tensors. """ - if not isinstance(other, (TensorDictBase, dict, float, int)): - return False - if not isinstance(other, TensorDictBase) and isinstance(other, dict): - other = make_tensordict(**other, batch_size=self.batch_size) - if not isinstance(other, TensorDictBase): - return TensorDict( - {key: value == other for key, value in self.items()}, - self.batch_size, + + def __eq__( + tensordict, other, current_key: Tuple = None, being_computed: Dict = None + ): + """A version of __eq__ that supports auto-nesting.""" + if not isinstance(other, (TensorDictBase, dict, float, int)): + return False + if not isinstance(other, TensorDictBase) and isinstance(other, dict): + other = make_tensordict(**other, batch_size=tensordict.batch_size) + if not isinstance(other, TensorDictBase): + return TensorDict( + {key: value == other for key, value in tensordict.items()}, + tensordict.batch_size, + device=tensordict.device, + ) + keys1 = set(tensordict.keys()) + keys2 = set(other.keys()) + if len(keys1.difference(keys2)) or len(keys1) != len(keys2): + raise KeyError(f"keys in tensordicts mismatch, got {keys1} and {keys2}") + out_dict = {} + + out_dict = {} + if current_key is None: + current_key = () + if being_computed is None: + being_computed = {} + being_computed[current_key] = id(tensordict) + for key, value in tensordict.items(): + if isinstance(value, TensorDictBase): + nested_key = current_key + (key,) + if id(value) in being_computed.values(): + being_computed[nested_key] = id(value) + continue + new_value = __eq__( + value, + other[key], + current_key=nested_key, + being_computed=being_computed, + ) + else: + new_value = value == other[key] + out_dict[key] = new_value + + out = TensorDict( + out_dict, device=self.device, + batch_size=self.batch_size, ) - keys1 = set(self.keys()) - keys2 = set(other.keys()) - if len(keys1.difference(keys2)) or len(keys1) != len(keys2): - raise KeyError(f"keys in tensordicts mismatch, got {keys1} and {keys2}") - d = {} - for (key, item1) in self.items(): - d[key] = item1 == other.get(key) - return TensorDict(batch_size=self.batch_size, source=d, device=self.device) + for other_nested_key, other_value in being_computed.items(): + if other_nested_key != current_key: + if other_value == id(tensordict): + out[other_nested_key] = out + return out + + return __eq__(tensordict=self, other=other) @abc.abstractmethod def del_(self, key: str) -> TensorDictBase: @@ -1186,16 +1223,42 @@ def to_tensordict(self): a new TensorDict object containing the same values. """ - return TensorDict( - { - key: value.clone() - if not isinstance(value, TensorDictBase) - else value.to_tensordict() - for key, value in self.items() - }, - device=self.device, - batch_size=self.batch_size, - ) + + def to_tensordict( + tensordict, current_key: Tuple = None, being_computed: Dict = None + ): + """A version of to_tensordict that supports auto-nesting.""" + out_dict = {} + if current_key is None: + current_key = () + if being_computed is None: + being_computed = {} + being_computed[current_key] = id(tensordict) + for key, value in tensordict.items(): + if isinstance(value, TensorDictBase): + nested_key = current_key + (key,) + if id(value) in being_computed.values(): + being_computed[nested_key] = id(value) + continue + new_value = to_tensordict( + value, current_key=nested_key, being_computed=being_computed + ) + else: + new_value = value.clone() + out_dict[key] = new_value + out = TensorDict( + out_dict, + device=self.device, + batch_size=self.batch_size, + _run_checks=False, + ) + for other_nested_key, other_value in being_computed.items(): + if other_nested_key != current_key: + if other_value == id(tensordict): + out[other_nested_key] = out + return out + + return to_tensordict(self) def zero_(self) -> TensorDictBase: """Zeros all tensors in the tensordict in-place.""" @@ -1267,14 +1330,44 @@ def clone(self, recurse: bool = True) -> TensorDictBase: """ - return TensorDict( - source={key: _clone_value(value, recurse) for key, value in self.items()}, - batch_size=self.batch_size, - device=self.device, - _run_checks=False, - _is_shared=self.is_shared() if not recurse else False, - _is_memmap=self.is_memmap() if not recurse else False, - ) + def clone(tensordict, current_key: Tuple = None, being_computed: Dict = None): + """A version of to_tensordict that supports auto-nesting.""" + out_dict = {} + if current_key is None: + current_key = () + if being_computed is None: + being_computed = {} + being_computed[current_key] = id(tensordict) + for key, value in tensordict.items(): + if isinstance(value, TensorDictBase): + nested_key = current_key + (key,) + if id(value) in being_computed.values(): + being_computed[nested_key] = id(value) + continue + new_value = clone( + value, current_key=nested_key, being_computed=being_computed + ) + else: + if not recurse: + new_value = value + else: + new_value = value.clone() + out_dict[key] = new_value + out = TensorDict( + out_dict, + device=tensordict.device, + batch_size=tensordict.batch_size, + _run_checks=False, + _is_shared=tensordict.is_shared() if not recurse else False, + _is_memmap=tensordict.is_memmap() if not recurse else False, + ) + for other_nested_key, other_value in being_computed.items(): + if other_nested_key != current_key: + if other_value == id(tensordict): + out[other_nested_key] = out + return out + + return clone(self) @classmethod def __torch_function__( @@ -1752,15 +1845,47 @@ def all(self, dim: int = None) -> Union[bool, TensorDictBase]: "dim must be greater than -tensordict.batch_dims and smaller " "than tensordict.batchdims" ) - if dim is not None: - if dim < 0: - dim = self.batch_dims + dim - return TensorDict( - source={key: value.all(dim=dim) for key, value in self.items()}, - batch_size=[b for i, b in enumerate(self.batch_size) if i != dim], - device=self.device, - ) - return all(value.all() for value in self.values()) + elif dim is not None and dim < 0: + dim = self.batch_dims + dim + + def _all(tensordict, current_key: Tuple = None, being_computed: Dict = None): + """A version of to_tensordict that supports auto-nesting.""" + out_dict = {} + if current_key is None: + current_key = () + if being_computed is None: + being_computed = {} + being_computed[current_key] = id(tensordict) + for key, value in tensordict.items(): + if isinstance(value, TensorDictBase): + nested_key = current_key + (key,) + if id(value) in being_computed.values(): + being_computed[nested_key] = id(value) + continue + new_value = _all( + value, current_key=nested_key, being_computed=being_computed + ) + else: + new_value = value.all(dim=dim) if dim is not None else value.all() + out_dict[key] = new_value + if dim is None: + # no need to do anything with being_computed, as a True in it will be kept + out = all(value for value in out_dict.values()) + else: + out = TensorDict( + out_dict, + batch_size=[ + b for i, b in enumerate(tensordict.batch_size) if i != dim + ], + device=tensordict.device, + ) + for other_nested_key, other_value in being_computed.items(): + if other_nested_key != current_key: + if other_value == id(tensordict): + out[other_nested_key] = out + return out + + return _all(self) def any(self, dim: int = None) -> Union[bool, TensorDictBase]: """Checks if any value is True/non-null in the tensordict.