diff --git a/tensordict/_lazy.py b/tensordict/_lazy.py index 570799f4d..c9a1dcae8 100644 --- a/tensordict/_lazy.py +++ b/tensordict/_lazy.py @@ -50,7 +50,6 @@ _get_shape_from_args, _getitem_batch_size, _is_number, - _parse_to, _renamed_inplace_method, _shape, _td_fields, @@ -1332,7 +1331,7 @@ def _to( result = self if device is not None and dtype is None and device == self.device: - return result + return result, False if non_blocking in (None, True): sub_non_blocking = True @@ -3008,7 +3007,7 @@ def _to( result = self if device is not None and dtype is None and device == self.device: - return result + return result, False td, must_sync = self._source._to( device=device, diff --git a/tensordict/_td.py b/tensordict/_td.py index 17fc2d9ae..6fa08c981 100644 --- a/tensordict/_td.py +++ b/tensordict/_td.py @@ -61,7 +61,6 @@ _LOCK_ERROR, _NON_STR_KEY_ERR, _NON_STR_KEY_TUPLE_ERR, - _parse_to, _prune_selected_keys, _set_item, _set_max_batch_size, @@ -2513,7 +2512,7 @@ def _to( result = self if device is not None and dtype is None and device == self.device: - return result + return result, False must_sync = False # this may seem awkward but non_blocking=None acts as True and this makes only one check diff --git a/tensordict/base.py b/tensordict/base.py index 6c9772b69..9d3b10741 100644 --- a/tensordict/base.py +++ b/tensordict/base.py @@ -49,7 +49,6 @@ _is_tensorclass, _KEY_ERROR, _parse_to, - _prefix_last_key, _proc_init, _prune_selected_keys, _set_max_batch_size, diff --git a/tensordict/persistent.py b/tensordict/persistent.py index b37f11771..070332254 100644 --- a/tensordict/persistent.py +++ b/tensordict/persistent.py @@ -970,7 +970,7 @@ def _to( ) -> Tuple[T, bool]: result = self if device is not None and dtype is None and device == self.device: - return result + return result, False if dtype is not None: return self.to_tensordict()._to( device=device, diff --git a/tensordict/utils.py b/tensordict/utils.py index fb31aeb11..074454309 100644 --- a/tensordict/utils.py +++ b/tensordict/utils.py @@ -2284,7 +2284,6 @@ def _prefix_last_key(key, prefix): return key[:-1] + (_prefix_last_key(key[-1], prefix),) - NESTED_TENSOR_ERR = ( "The PyTorch version isn't compatible with " "nested tensors. Please upgrade to a more recent "