Skip to content

Commit

Permalink
amend
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens committed Jun 18, 2024
1 parent cda0ec4 commit bcbf05e
Show file tree
Hide file tree
Showing 5 changed files with 4 additions and 8 deletions.
5 changes: 2 additions & 3 deletions tensordict/_lazy.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@
_get_shape_from_args,
_getitem_batch_size,
_is_number,
_parse_to,
_renamed_inplace_method,
_shape,
_td_fields,
Expand Down Expand Up @@ -1332,7 +1331,7 @@ def _to(
result = self

if device is not None and dtype is None and device == self.device:
return result
return result, False

if non_blocking in (None, True):
sub_non_blocking = True
Expand Down Expand Up @@ -3008,7 +3007,7 @@ def _to(
result = self

if device is not None and dtype is None and device == self.device:
return result
return result, False

td, must_sync = self._source._to(
device=device,
Expand Down
3 changes: 1 addition & 2 deletions tensordict/_td.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,6 @@
_LOCK_ERROR,
_NON_STR_KEY_ERR,
_NON_STR_KEY_TUPLE_ERR,
_parse_to,
_prune_selected_keys,
_set_item,
_set_max_batch_size,
Expand Down Expand Up @@ -2513,7 +2512,7 @@ def _to(
result = self

if device is not None and dtype is None and device == self.device:
return result
return result, False

must_sync = False
# this may seem awkward but non_blocking=None acts as True and this makes only one check
Expand Down
1 change: 0 additions & 1 deletion tensordict/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@
_is_tensorclass,
_KEY_ERROR,
_parse_to,
_prefix_last_key,
_proc_init,
_prune_selected_keys,
_set_max_batch_size,
Expand Down
2 changes: 1 addition & 1 deletion tensordict/persistent.py
Original file line number Diff line number Diff line change
Expand Up @@ -970,7 +970,7 @@ def _to(
) -> Tuple[T, bool]:
result = self
if device is not None and dtype is None and device == self.device:
return result
return result, False
if dtype is not None:
return self.to_tensordict()._to(
device=device,
Expand Down
1 change: 0 additions & 1 deletion tensordict/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2284,7 +2284,6 @@ def _prefix_last_key(key, prefix):
return key[:-1] + (_prefix_last_key(key[-1], prefix),)



NESTED_TENSOR_ERR = (
"The PyTorch version isn't compatible with "
"nested tensors. Please upgrade to a more recent "
Expand Down

0 comments on commit bcbf05e

Please sign in to comment.