From 5c34868816806f465872bac2c742a7676c60559b Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Wed, 3 Jul 2024 15:53:41 +0100 Subject: [PATCH] [Refactor] Make all leaves in tensorclass part of `_tensordict`, except for NonTensorData (#841) --- tensordict/_td.py | 54 ++++++----- tensordict/tensorclass.py | 196 +++++++++++++++++++++----------------- tensordict/utils.py | 3 +- test/test_tensorclass.py | 68 ++++++++----- test/test_tensordict.py | 18 ++-- 5 files changed, 199 insertions(+), 140 deletions(-) diff --git a/tensordict/_td.py b/tensordict/_td.py index d1f5aacff..e1550198a 100644 --- a/tensordict/_td.py +++ b/tensordict/_td.py @@ -794,15 +794,16 @@ def __setitem__( ) from err keys = set(self.keys()) - if any(key not in keys for key in value.keys()): - subtd = self._get_sub_tensordict(index) - for key, item in value.items(): - if key in keys: + subtd = None + for value_key, item in value.items(): + if value_key in keys: self._set_at_str( - key, item, index, validated=False, non_blocking=False + value_key, item, index, validated=False, non_blocking=False ) else: - subtd.set(key, item, inplace=True, non_blocking=False) + if subtd is None: + subtd = self._get_sub_tensordict(index) + subtd.set(value_key, item, inplace=True, non_blocking=False) else: for key in self.keys(): self.set_at_(key, value, index) @@ -2083,12 +2084,35 @@ def _set_tuple( ) return self + _SHARED_INPLACE_ERROR = ( + "You're attempting to update a leaf in-place with a shared " + "tensordict, but the new value does not match the previous. " + "If you're using NonTensorData, see the class documentation " + "to see how to properly pre-allocate memory in shared contexts." + ) + def _set_at_str(self, key, value, idx, *, validated, non_blocking: bool): if not validated: value = self._validate_value(value, check_shape=False) validated = True tensor_in = self._get_str(key, NO_DEFAULT) + if is_non_tensor(value) and not (self._is_shared or self._is_memmap): + dest = self._get_str(key, NO_DEFAULT) + is_diff = dest[idx].tolist() != value.tolist() + if is_diff: + dest_val = dest.maybe_to_stack() + dest_val[idx] = value + if dest_val is not dest: + self._set_str( + key, + dest_val, + validated=True, + inplace=False, + ignore_lock=True, + ) + return + if isinstance(idx, tuple) and len(idx) and isinstance(idx[0], tuple): warn( "Multiple indexing can lead to unexpected behaviours when " @@ -2103,12 +2127,7 @@ def _set_at_str(self, key, value, idx, *, validated, non_blocking: bool): ) if tensor_in is not tensor_out: if self._is_shared or self._is_memmap: - raise RuntimeError( - "You're attempting to update a leaf in-place with a shared " - "tensordict, but the new value does not match the previous. " - "If you're using NonTensorData, see the class documentation " - "to see how to properly pre-allocate memory in shared contexts." - ) + raise RuntimeError(self._SHARED_INPLACE_ERROR) # this happens only when a NonTensorData becomes a NonTensorStack # so it is legitimate (there is no in-place modification of a tensor # that was expected to happen but didn't). @@ -2301,16 +2320,7 @@ def _memmap_( raise RuntimeError( "memmap and shared memory are mutually exclusive features." ) - dest = ( - self - if inplace - else TensorDict( - {}, - batch_size=self.batch_size, - names=self.names if self._has_names() else None, - device=torch.device("cpu"), - ) - ) + dest = self if inplace else self.empty(device=torch.device("cpu")) # We must set these attributes before memmapping because we need the metadata # to match the tensordict content. diff --git a/tensordict/tensorclass.py b/tensordict/tensorclass.py index a303b8232..a656b201e 100644 --- a/tensordict/tensorclass.py +++ b/tensordict/tensorclass.py @@ -97,14 +97,20 @@ def __subclasscheck__(self, subclass): "_get_str", "_get_sub_tensordict", "_get_tuple", + "_has_names", + "_multithread_rebuild", # rebuild checks if self is a non tensor "gather", "is_memmap", "is_shared", + "load_", + "memmap", + "memmap_", + "memmap_like", + "memmap_refresh_", "ndimension", "numel", "replace", - "_has_names", - "_multithread_rebuild", # rebuild checks if self is a non tensor + "save", ] # Methods to be executed from tensordict, any ref to self means 'self._tensordict' _FALLBACK_METHOD_FROM_TD = [ @@ -186,15 +192,15 @@ def __subclasscheck__(self, subclass): "expand_as", "expm1", "expm1_", + "filter_non_tensor_data", "flatten", "floor", "floor_", "frac", "frac_", - "is_empty", - "is_memmap", - "is_shared", - "is_shared", + "is_empty", # no wrap output + "is_memmap", # no wrap output + "is_shared", # no wrap output "isfinite", "isnan", "isreal", @@ -222,7 +228,7 @@ def __subclasscheck__(self, subclass): "mul", "mul_", "named_apply", - "ndimension", + "ndimension", # no wrap output "neg", "neg_", "norm", @@ -260,8 +266,9 @@ def __subclasscheck__(self, subclass): "unflatten", "unlock_", "unsqueeze", - "values", + "values", # no wrap output "view", + "_get_names_idx", # no wrap output "where", "zero_", ] @@ -470,20 +477,14 @@ def __torch_function__( _wrap_td_method(method_name, copy_non_tensor=True), ) - if not hasattr(cls, "filter_non_tensor_data"): - cls.filter_non_tensor_data = _filter_non_tensor_data cls.__enter__ = __enter__ cls.__exit__ = __exit__ # Memmap - if not hasattr(cls, "memmap_like"): - cls.memmap_like = TensorDictBase.memmap_like - if not hasattr(cls, "memmap_"): - cls.memmap_ = TensorDictBase.memmap_ - if not hasattr(cls, "memmap"): - cls.memmap = TensorDictBase.memmap if not hasattr(cls, "load_memmap"): cls.load_memmap = TensorDictBase.load_memmap + if not hasattr(cls, "load"): + cls.load = TensorDictBase.load if not hasattr(cls, "_load_memmap"): cls._load_memmap = classmethod(_load_memmap) if not hasattr(cls, "from_dict"): @@ -524,10 +525,6 @@ def __torch_function__( return cls -def _filter_non_tensor_data(self): - return super(type(self), self).__getattribute__("_tensordict") - - def _arg_to_tensordict(arg): # if arg is a tensorclass or sequence of tensorclasses, extract the underlying # tensordicts and return those instead @@ -821,6 +818,7 @@ def save_metadata(cls=cls, _non_tensordict=_non_tensordict, prefix=prefix): td = self._tensordict.empty() td._is_memmap = True td._is_locked = True + td._memmap_prefix = prefix if inplace: self.__dict__["_tensordict"] = td if not inplace: @@ -898,13 +896,17 @@ def _getattr(self, item: str) -> Any: if _tensordict is not None: out = _tensordict._get_str(item, default=None) if out is not None: + if is_non_tensor(out): + return out.data if hasattr(out, "data") else out.tolist() return out out = getattr(_tensordict, item, NO_DEFAULT) if out is not NO_DEFAULT: if not callable(out): + if is_non_tensor(out): + return out.data if hasattr(out, "data") else out.tolist() return out return _wrap_method(self, item, out) - raise AttributeError + raise AttributeError(item) SET_ATTRIBUTES = ("batch_size", "device", "_locked_tensordicts", "names") @@ -1181,15 +1183,6 @@ def _setitem(self, item: NestedKey, value: Any) -> None: # noqa: D417 del self._non_tensordict[key] self._tensordict[item] = value._tensordict - elif isinstance(value, TensorDictBase): # it is one of accepted "broadcast" types - # attempt broadcast on all tensordata and nested tensorclasses - self._tensordict[item] = value.filter_non_tensor_data() - self._non_tensordict.update( - { - key: val.data - for key, val in value.items(is_leaf=is_non_tensor, leaves_only=True) - } - ) else: # int, float etc. self._tensordict[item] = value @@ -1224,7 +1217,8 @@ def _len(self) -> int: def _to_dict(self) -> dict: td_dict = self._tensordict.to_dict() - td_dict.update(self._non_tensordict) + if self._non_tensordict: + td_dict.update(self._non_tensordict) return td_dict @@ -1238,12 +1232,13 @@ def _from_dict(cls, input_dict, batch_size=None, device=None, batch_dims=None): ) non_tensor = {} - # Note: this won't deal with sub-tensordicts which may or may not be tensorclasses. - # We don't want to enforce them to be tensorclasses so we can't do much about it... - for key, value in list(td.items()): - if is_non_tensor(value): - non_tensor[key] = value.data - del td[key] + if issubclass(cls, NonTensorData): + # Note: this won't deal with sub-tensordicts which may or may not be tensorclasses. + # We don't want to enforce them to be tensorclasses so we can't do much about it... + for key, value in list(td.items()): + if is_non_tensor(value): + non_tensor[key] = value.data + del td[key] return cls.from_tensordict(tensordict=td, non_tensordict=non_tensor) @@ -1303,6 +1298,8 @@ def _to_tensordict(self) -> TensorDict: """ td = self._tensordict.to_tensordict() for key, val in self._non_tensordict.items(): + # if val is None: + # continue td.set_non_tensor(key, val) return td @@ -1354,8 +1351,14 @@ def _set( ) def set_tensor( - key=key, value=value, inplace=inplace, non_blocking=non_blocking + key=key, + value=value, + inplace=inplace, + non_blocking=non_blocking, + non_tensor=False, ): + if non_tensor: + value = NonTensorData(value) # Avoiding key clash, honoring the user input to assign tensor type data to the key if key in self._non_tensordict.keys(): if inplace: @@ -1403,22 +1406,29 @@ def _is_castable(datatype): ) return set_tensor(value=cast_val) elif value is not None and target_cls is not _AnyType: - value = _cast_funcs[target_cls](value) + cast_val = _cast_funcs[target_cls](value) + return set_tensor(value=cast_val, non_tensor=True) elif target_cls is _AnyType and _is_castable(type(value)): return set_tensor() - else: - if isinstance(value, tuple(tensordict_lib.base._ACCEPTED_CLASSES)): - return set_tensor() + elif isinstance(value, tuple(tensordict_lib.base._ACCEPTED_CLASSES)): + return set_tensor() - # Avoiding key clash, honoring the user input to assign non-tensor data to the key - if key in self._tensordict.keys(): - if inplace: - raise RuntimeError( - f"Cannot update an existing entry of type {type(self._tensordict.get(key))} with a value of type {type(value)}." - ) - self._tensordict.del_(key) - # Saving all non-tensor attributes - self._non_tensordict[key] = value + if self._is_non_tensor or value is None: + # Avoiding key clash, honoring the user input to assign non-tensor data to the key + if key in self._tensordict.keys(): + if inplace: + raise RuntimeError( + f"Cannot update an existing entry of type {type(self._tensordict.get(key))} with a value of type {type(value)}." + ) + self._tensordict.del_(key) + self._non_tensordict[key] = value + else: + if key in self._tensordict.keys(): + if inplace: + raise RuntimeError( + f"Cannot update an existing entry of type {type(self._tensordict.get(key))} with a value of type {type(value)}." + ) + set_tensor(value=value, non_tensor=True) return self if isinstance(key, tuple) and len(key): @@ -1692,12 +1702,12 @@ def _eq(self, other: object) -> bool: return False if is_tensorclass(other): tensor = self._tensordict == other._tensordict - elif _is_tensor_collection(type(other)): - # other can be a tensordict reconstruction of self, in which case we discard - # the non-tensor data - tensor = self.filter_non_tensor_data() == other.filter_non_tensor_data() else: - tensor = self._tensordict == other + tensor = self._tensordict == ( + other.exclude(*self._non_tensordict.keys()) + if _is_tensor_collection(type(other)) + else other + ) return _from_tensordict_with_none(self, tensor) @@ -1753,12 +1763,12 @@ def _ne(self, other: object) -> bool: return True if is_tensorclass(other): tensor = self._tensordict != other._tensordict - elif _is_tensor_collection(type(other)): - # other can be a tensordict reconstruction of self, in which case we discard - # the non-tensor data - tensor = self._tensordict != other.exclude(*self._non_tensordict.keys()) else: - tensor = self._tensordict != other + tensor = self._tensordict != ( + other.exclude(*self._non_tensordict.keys()) + if _is_tensor_collection(type(other)) + else other + ) return _from_tensordict_with_none(self, tensor) @@ -1781,12 +1791,12 @@ def _or(self, other: object) -> bool: return False if is_tensorclass(other): tensor = self._tensordict | other._tensordict - elif _is_tensor_collection(type(other)): - # other can be a tensordict reconstruction of self, in which case we discard - # the non-tensor data - tensor = self._tensordict | other.exclude(*self._non_tensordict.keys()) else: - tensor = self._tensordict | other + tensor = self._tensordict | ( + other.exclude(*self._non_tensordict.keys()) + if _is_tensor_collection(type(other)) + else other + ) return _from_tensordict_with_none(self, tensor) @@ -1809,12 +1819,12 @@ def _xor(self, other: object) -> bool: return False if is_tensorclass(other): tensor = self._tensordict ^ other._tensordict - elif _is_tensor_collection(type(other)): - # other can be a tensordict reconstruction of self, in which case we discard - # the non-tensor data - tensor = self._tensordict ^ other.exclude(*self._non_tensordict.keys()) else: - tensor = self._tensordict ^ other + tensor = self._tensordict ^ ( + other.exclude(*self._non_tensordict.keys()) + if _is_tensor_collection(type(other)) + else other + ) return _from_tensordict_with_none(self, tensor) @@ -2150,13 +2160,18 @@ class NonTensorData: _is_non_tensor: bool = True def __post_init__(self): - if is_non_tensor(self.data): - data = getattr(self.data, "data", None) - if data is None: - data = self.data.tolist() - del self._tensordict["data"] - self._non_tensordict["data"] = data - assert self._tensordict.is_empty(), self._tensordict + _tensordict = self.__dict__["_tensordict"] + _non_tensordict = self.__dict__["_non_tensordict"] + data = _non_tensordict.get("data", NO_DEFAULT) + if data is NO_DEFAULT: + data = _tensordict._get_str("data", default=NO_DEFAULT) + data_inner = getattr(data, "data", None) + if data_inner is None: + # Support for stacks + data_inner = data.tolist() + del _tensordict["data"] + _non_tensordict["data"] = data_inner + assert _tensordict.is_empty(), self._tensordict def __repr__(self): data_str = str(self.data) @@ -2309,7 +2324,7 @@ def maybe_to_stack(self): return self for i in reversed(self.batch_size): datalist = [datalist] * i - return NonTensorStack._from_list(datalist, device=self.device) + return NonTensorStack._from_list(datalist, device=self.device, ndim=self.ndim) def update_( self, @@ -2522,7 +2537,11 @@ def _memmap_( memmaped: bool = False, share_non_tensor: bool = False, ): - if self._tensordict._is_memmap: + # For efficiency, we can avoid doing this saving + # if the data is already there. + if self._tensordict._is_memmap and str( + getattr(self._tensordict, "_memmap_prefix", None) + ) == str(prefix): return self _metadata = {} @@ -2694,6 +2713,7 @@ def save_metadata(prefix=prefix, self=self): # The only thing remaining to do is share the data between processes results = [] for i, td in enumerate(self.tensordicts): + td: NonTensorData results.append( td._memmap_( prefix=(prefix / str(i)) if prefix is not None else None, @@ -2732,12 +2752,16 @@ def _load_memmap( return super()._load_memmap(prefix=prefix, metadata=metadata, **kwargs) @classmethod - def _from_list(cls, datalist: List, device: torch.device): - if all(isinstance(item, list) for item in datalist) and all( - len(item) == len(datalist[0]) for item in datalist + def _from_list(cls, datalist: List, device: torch.device, ndim: int | None = None): + if ( + all(isinstance(item, list) for item in datalist) + and all(len(item) == len(datalist[0]) for item in datalist) + and (ndim is None or ndim > 1) ): + ndim = ndim - 1 if ndim is not None else None return NonTensorStack( - *(cls._from_list(item, device=device) for item in datalist), stack_dim=0 + *(cls._from_list(item, device=device, ndim=ndim) for item in datalist), + stack_dim=0, ) return NonTensorStack( *( @@ -2798,7 +2822,9 @@ def _update( datalist = input_dict_or_td.data for d in reversed(self.batch_size): datalist = [datalist] * d - reconstructed = self._from_list(datalist, device=self.device) + reconstructed = self._from_list( + datalist, device=self.device, ndim=self.ndim + ) return self.update( reconstructed, clone=clone, diff --git a/tensordict/utils.py b/tensordict/utils.py index 554dbecd3..8b7a4f2b0 100644 --- a/tensordict/utils.py +++ b/tensordict/utils.py @@ -1485,7 +1485,8 @@ def assert_allclose_td( continue assert_allclose_td(input1, input2, rtol=rtol, atol=atol) continue - + elif not isinstance(input1, torch.Tensor): + continue mse = (input1.to(torch.float) - input2.to(torch.float)).pow(2).sum() mse = mse.div(input1.numel()).sqrt().item() diff --git a/test/test_tensorclass.py b/test/test_tensorclass.py index b010b2970..a230a1256 100644 --- a/test/test_tensorclass.py +++ b/test/test_tensorclass.py @@ -19,6 +19,7 @@ import numpy as np import pytest +import tensordict.utils import torch try: @@ -167,6 +168,7 @@ def test_attributes(self): { "X": X, "y": y, + "z": z, }, batch_size=[3, 4], ) @@ -462,6 +464,7 @@ class MyClass2: ), batch_size=[3], ) + assert not a._non_tensordict b = MyClass2( torch.zeros(3), "z0", @@ -473,7 +476,11 @@ class MyClass2: ), batch_size=[3], ) - c = TensorDict({"x": torch.zeros(3), "y": {"x": torch.ones(3)}}, batch_size=[3]) + assert not b._non_tensordict + c = TensorDict( + {"x": torch.zeros(3), "z": "z0", "y": {"x": torch.ones(3), "z": "z1"}}, + batch_size=[3], + ) assert (a == a.clone()).all() assert (a != 1.0).any() @@ -810,9 +817,9 @@ class MyDataNested: data2 = MyDataNested(X=X, y=data_nest2, z=z, batch_size=batch_size) assert (data == data2).all() assert (data == data2).X.all() - assert (data == data2).z is None + assert (data == data2).z.all() assert (data == data2).y.X.all() - assert (data == data2).y.z is None + assert (data == data2).y.z.all() @pytest.mark.parametrize("any_to_td", [True, False]) def test_nested_heterogeneous(self, any_to_td): @@ -874,9 +881,9 @@ class MyDataNested: data2 = MyDataNested(X=X + 1, y=data_nest2, z=z, batch_size=batch_size) assert (data != data2).any() assert (data != data2).X.all() - assert (data != data2).z is None + assert (data != data2).z.all() assert not (data != data2).y.X.any() - assert (data != data2).y.z is None + assert not (data != data2).y.z.any() def test_permute( self, @@ -1065,10 +1072,9 @@ class MyDataParent: data.set("v", "test") assert data.v == "test" - assert "v" not in data._tensordict.keys() - assert "v" in data._non_tensordict.keys() + assert "v" in data._tensordict.keys() - with pytest.raises(RuntimeError, match="Cannot update an existing"): + with pytest.raises(ValueError, match="Failed to update 'v'"): data.set("v", vorig, inplace=True) # ensure optional fields are writable @@ -1146,8 +1152,8 @@ class MyDataParent: data.v = "test" assert data.v == "test" - assert "v" not in data._tensordict.keys() - assert "v" in data._non_tensordict.keys() + assert "v" in data._tensordict.keys() + assert "v" not in data._non_tensordict.keys() # ensure optional fields are writable data.k = torch.zeros(3, 4, 5) @@ -1199,11 +1205,9 @@ def test_setitem( # Negative testcase for non-tensor data z = "test_bluff" data2 = MyData(X=x, y=y, z=z, batch_size=batch_size) - with pytest.warns( - UserWarning, - match="Meta data at 'z' may or may not be equal, this may result in undefined behaviours", - ): - data[1] = data2[1] + data[1] = data2[1] + assert (data[1] == data2[1]).all() + assert data[1].z == ["test_bluff"] * data[1].numel() # Validating nested test cases @tensorclass @@ -1227,11 +1231,8 @@ class MyDataNested: # Negative Scenario data3 = MyDataNested(X=X2, y=data_nest2, z=["e", "f"], batch_size=batch_size) - with pytest.warns( - UserWarning, - match="Meta data at 'z' may or may not be equal, this may result in undefined behaviours", - ): - data[:2] = data3[:2] + data[:2] = data3[:2] + assert data[:2].z == data3[:2]._get_str("z", None).tolist() @pytest.mark.parametrize( "broadcast_type", @@ -1513,7 +1514,12 @@ class MyDataNested: if lazy_legacy() or lazy: assert isinstance(stacked_tc._tensordict, LazyStackedTensorDict) assert isinstance(stacked_tc.y._tensordict, LazyStackedTensorDict) - assert stacked_tc.z == stacked_tc.y.z == z + zlist = z + if lazy: + for d in range(stacked_tc.ndim - 1, -1, -1): + zlist = [zlist] * stacked_tc.batch_size[d] + assert stacked_tc.z == stacked_tc.y.z + assert stacked_tc.z == zlist # Testing negative scenarios y = torch.zeros(3, 4, 5, dtype=torch.bool) @@ -2039,7 +2045,7 @@ def test_autocast_simple(self): ) assert isinstance(obj.tensor, torch.Tensor) assert isinstance(obj.integer, int) - assert isinstance(obj.string, str) + assert isinstance(obj.string, str), type(obj.string) assert isinstance(obj.floating, float) assert isinstance(obj.numpy_array, np.ndarray) assert isinstance(obj.anything, torch.Tensor) @@ -2256,11 +2262,27 @@ def assert_is_data(x): data_bis = torch.vmap(assert_is_data)(data) assert isinstance(data_bis, VmappableClass) - assert "non_tensor" not in data_bis._tensordict + assert "non_tensor" in data_bis._tensordict + assert "non_tensor" not in data_bis._non_tensordict assert (data_bis == data).all() assert data_bis.non_tensor == "a string!" +class TestSerialization: + def test_save_load(self, tmpdir): + myc = MyData( + X=torch.rand(2, 3, 4), + y=torch.rand(2, 3, 4, 5), + z="test_tensorclass", + batch_size=[2, 3], + ) + myc.save(tmpdir) + tensordict.utils.print_directory_tree(tmpdir) + myc_load = TensorDict.load(tmpdir) + assert myc_load.z == "test_tensorclass" + assert (myc == myc_load).all() + + class TestPointWise: def test_pointwise(self): @tensorclass diff --git a/test/test_tensordict.py b/test/test_tensordict.py index ad3e0cac9..303583d6c 100644 --- a/test/test_tensordict.py +++ b/test/test_tensordict.py @@ -39,7 +39,7 @@ from tensordict._lazy import _CustomOpTensorDict from tensordict._td import _SubTensorDict, is_tensor_collection from tensordict._torch_func import _stack as stack_td -from tensordict.base import TensorDictBase +from tensordict.base import _is_leaf_nontensor, TensorDictBase from tensordict.functional import dense_stack_tds, pad, pad_sequence from tensordict.memmap import MemoryMappedTensor @@ -2745,10 +2745,10 @@ def get_old_val(newval, oldval): def test_apply_filter(self, td_name, device, inplace): td = getattr(self, td_name)(device) assert td.apply(lambda x: None, filter_empty=False) is not None - if td_name != "td_with_non_tensor": - assert td.apply(lambda x: None, filter_empty=True) is None - else: - assert td.apply(lambda x: None, filter_empty=True) is not None + assert ( + td.apply(lambda x: None, filter_empty=True, is_leaf=_is_leaf_nontensor) + is None + ) @pytest.mark.parametrize("inplace", [False, True]) @pytest.mark.parametrize( @@ -6128,7 +6128,7 @@ def memmap_td(self, device, dtype): def nested_td(self, device, dtype): if device is not None: device_not_none = device - elif torch.has_cuda and torch.cuda.device_count(): + elif torch.cuda.is_available() and torch.cuda.device_count(): device_not_none = torch.device("cuda:0") else: device_not_none = torch.device("cpu") @@ -6152,7 +6152,7 @@ class MyClass: if device is not None: device_not_none = device - elif torch.has_cuda and torch.cuda.device_count(): + elif torch.cuda.is_available() and torch.cuda.device_count(): device_not_none = torch.device("cuda:0") else: device_not_none = torch.device("cpu") @@ -6182,7 +6182,7 @@ def share_memory_td(self, device, dtype): def stacked_td(self, device, dtype): if device is not None: device_not_none = device - elif torch.has_cuda and torch.cuda.device_count(): + elif torch.cuda.is_available() and torch.cuda.device_count(): device_not_none = torch.device("cuda:0") else: device_not_none = torch.device("cpu") @@ -6208,7 +6208,7 @@ def stacked_td(self, device, dtype): def td(self, device, dtype): if device is not None: device_not_none = device - elif torch.has_cuda and torch.cuda.device_count(): + elif torch.cuda.is_available() and torch.cuda.device_count(): device_not_none = torch.device("cuda:0") else: device_not_none = torch.device("cpu")