From cd6fa15813d98655696b9467374775e04bced8df Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Fri, 10 Jan 2025 12:56:09 +0000 Subject: [PATCH] [Refactor] Fix property handling in TC ghstack-source-id: ad564f84401bb619d0a6163c352d3de43ea24437 Pull Request resolved: https://github.com/pytorch/tensordict/pull/1178 --- tensordict/tensorclass.py | 110 +++++++++++++++----------------------- test/test_tensorclass.py | 3 ++ 2 files changed, 47 insertions(+), 66 deletions(-) diff --git a/tensordict/tensorclass.py b/tensordict/tensorclass.py index 04ada3198..a9ce2ccab 100644 --- a/tensordict/tensorclass.py +++ b/tensordict/tensorclass.py @@ -152,6 +152,8 @@ def __subclasscheck__(self, subclass): "_propagate_unlock", "_reduce_get_metadata", "_values_list", + "batch_dims", + "batch_size", "bytes", "cat_tensors", "data_ptr", @@ -168,6 +170,7 @@ def __subclasscheck__(self, subclass): "is_cuda", "is_empty", "is_floating_point", + "is_locked", "is_memmap", "is_meta", "is_shared", @@ -176,6 +179,8 @@ def __subclasscheck__(self, subclass): "keys", "make_memmap", "make_memmap_from_tensor", + "names", + "ndim", "ndimension", "numel", "numpy", @@ -183,13 +188,14 @@ def __subclasscheck__(self, subclass): "pop", "recv", "reduce", + "requires_grad", "saved_path", "send", + "shape", "size", "sorted_keys", "to_struct_array", "values", - # "ndim", ] # Methods to be executed from tensordict, any ref to self means 'self._tensordict' @@ -910,8 +916,10 @@ def __torch_function__( for method_name in _FALLBACK_METHOD_FROM_TD_FORCE: setattr(cls, method_name, _wrap_td_method(method_name)) for method_name in _FALLBACK_METHOD_FROM_TD_NOWRAP: - if not hasattr(cls, method_name): - setattr(cls, method_name, _wrap_td_method(method_name, no_wrap=True)) + if not hasattr(cls, method_name) and method_name not in expected_keys: + is_property = isinstance(getattr(TensorDictBase, method_name, None), property) + setattr(cls, method_name, _wrap_td_method(method_name, no_wrap=True, is_property=is_property)) + for method_name in _FALLBACK_METHOD_FROM_TD_COPY: if not hasattr(cls, method_name): setattr( @@ -949,20 +957,6 @@ def __torch_function__( cls.to_tensordict = _to_tensordict if not hasattr(cls, "device") and "device" not in expected_keys: cls.device = property(_device, _device_setter) - if not hasattr(cls, "batch_size") and "batch_size" not in expected_keys: - cls.batch_size = property(_batch_size, _batch_size_setter) - if not hasattr(cls, "batch_dims") and "batch_dims" not in expected_keys: - cls.batch_dims = property(_batch_dims) - if not hasattr(cls, "requires_grad") and "requires_grad" not in expected_keys: - cls.requires_grad = property(_requires_grad) - if not hasattr(cls, "is_locked") and "is_locked" not in expected_keys: - cls.is_locked = property(_is_locked) - if not hasattr(cls, "ndim") and "ndim" not in expected_keys: - cls.ndim = property(_batch_dims) - if not hasattr(cls, "shape") and "shape" not in expected_keys: - cls.shape = property(_batch_size, _batch_size_setter) - if not hasattr(cls, "names") and "names" not in expected_keys: - cls.names = property(_names, _names_setter) if not _is_non_tensor and not hasattr(cls, "data") and "data" not in expected_keys: cls.data = property(_data, _data_setter) if not hasattr(cls, "grad") and "grad" not in expected_keys: @@ -1541,7 +1535,7 @@ def wrapper(self, key: str, value: Any) -> None: # noqa: D417 return wrapper -def _wrap_td_method(funcname, *, copy_non_tensor=False, no_wrap=False): +def _wrap_td_method(funcname, *, copy_non_tensor=False, no_wrap=False, is_property=False): def deliver_result(self, result, kwargs): if result is None: return @@ -1559,13 +1553,32 @@ def deliver_result(self, result, kwargs): return self._from_tensordict(result, non_tensordict, safe=False) return result - def wrapped_func(self, *args, **kwargs): + if not is_property: + def wrapped_func(self, *args, **kwargs): + if not is_compiling(): + td = super(type(self), self).__getattribute__("_tensordict") + else: + td = self._tensordict + + result = getattr(td, funcname)(*args, **kwargs) + if no_wrap: + return result + + if result is td: + return self + + if isinstance(result, tuple): + return tuple(deliver_result(self, r, kwargs) for r in result) + return deliver_result(self, result, kwargs) + + return wrapped_func + + def wrapped_func(self): if not is_compiling(): td = super(type(self), self).__getattribute__("_tensordict") else: td = self._tensordict - - result = getattr(td, funcname)(*args, **kwargs) + result = getattr(td, funcname) if no_wrap: return result @@ -1574,10 +1587,17 @@ def wrapped_func(self, *args, **kwargs): return self if isinstance(result, tuple): - return tuple(deliver_result(self, r, kwargs) for r in result) - return deliver_result(self, result, kwargs) + return tuple(deliver_result(self, r, {}) for r in result) + return deliver_result(self, result, {}) - return wrapped_func + def wrapped_func_setter(self, value): + if not is_compiling(): + td = super(type(self), self).__getattribute__("_tensordict") + else: + td = self._tensordict + return setattr(td, funcname, value) + + return property(wrapped_func, wrapped_func_setter) def _wrap_method(self, attr, func, nowarn=False): @@ -2244,48 +2264,6 @@ def _get_at(self, key: NestedKey, idx, default: Any = NO_DEFAULT): return default -def _batch_size(self) -> torch.Size: - """Retrieves the batch size for the tensor class. - - Returns: - batch size (torch.Size) - - """ - return self._tensordict.batch_size - - -def _batch_dims(self) -> torch.Size: - return self._tensordict.batch_dims - - -def _requires_grad(self) -> torch.Size: - return self._tensordict.requires_grad - - -def _is_locked(self) -> torch.Size: - return self._tensordict.is_locked - - -def _batch_size_setter(self, new_size: torch.Size) -> None: # noqa: D417 - """Set the value of batch_size. - - Args: - new_size (torch.Size): new_batch size to be set - - """ - self._tensordict._batch_size_setter(new_size) - - -def _names(self) -> torch.Size: - """Retrieves the dim names for the tensor class. - - Returns: - names (list of str) - - """ - return self._tensordict.names - - def _data(self): # We allow data to be a field of the class too if "data" in self.__dataclass_fields__: diff --git a/test/test_tensorclass.py b/test/test_tensorclass.py index 01a5aec37..413bd00b2 100644 --- a/test/test_tensorclass.py +++ b/test/test_tensorclass.py @@ -424,6 +424,9 @@ def test_batch_size(self): myc.batch_size = torch.Size([2]) assert myc.batch_size == torch.Size([2]) + myc.batch_size = [2] + assert isinstance(myc.batch_size, torch.Size) + assert myc.X.shape == torch.Size([2, 3, 4]) def test_cat(self):