diff --git a/tensordict/utils.py b/tensordict/utils.py index 6a3184d83..5813e5b7f 100644 --- a/tensordict/utils.py +++ b/tensordict/utils.py @@ -1373,8 +1373,6 @@ def _parse_to(*args, **kwargs): *args, **kwargs ) else: - device = None - dtype = None non_blocking = kwargs.get("non_blocking", False) convert_to_format = kwargs.get("convert_to_format") if len(args) > 0: @@ -1390,7 +1388,7 @@ def _parse_to(*args, **kwargs): device = torch.device(device) if device and device.type == "cuda" and device.index is None: - device = torch.device("cuda:0") + device = torch.device(f"cuda:{torch.cuda.current_device()}") if other is not None: if device is not None and device != other.device: