Skip to content

Commit

Permalink
[Refactor] Fix property handling in TC
Browse files Browse the repository at this point in the history
ghstack-source-id: 8be779a7a85fdf45000181a9ea0f830822f19e37
Pull Request resolved: #1178
  • Loading branch information
vmoens committed Jan 10, 2025
1 parent 74cae09 commit 7a558de
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 66 deletions.
119 changes: 53 additions & 66 deletions tensordict/tensorclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,8 @@ def __subclasscheck__(self, subclass):
"_propagate_unlock",
"_reduce_get_metadata",
"_values_list",
"batch_dims",
"batch_size",
"bytes",
"cat_tensors",
"data_ptr",
Expand All @@ -168,6 +170,7 @@ def __subclasscheck__(self, subclass):
"is_cuda",
"is_empty",
"is_floating_point",
"is_locked",
"is_memmap",
"is_meta",
"is_shared",
Expand All @@ -176,20 +179,23 @@ def __subclasscheck__(self, subclass):
"keys",
"make_memmap",
"make_memmap_from_tensor",
"names",
"ndim",
"ndimension",
"numel",
"numpy",
"param_count",
"pop",
"recv",
"reduce",
"requires_grad",
"saved_path",
"send",
"shape",
"size",
"sorted_keys",
"to_struct_array",
"values",
# "ndim",
]

# Methods to be executed from tensordict, any ref to self means 'self._tensordict'
Expand Down Expand Up @@ -910,8 +916,16 @@ def __torch_function__(
for method_name in _FALLBACK_METHOD_FROM_TD_FORCE:
setattr(cls, method_name, _wrap_td_method(method_name))
for method_name in _FALLBACK_METHOD_FROM_TD_NOWRAP:
if not hasattr(cls, method_name):
setattr(cls, method_name, _wrap_td_method(method_name, no_wrap=True))
if not hasattr(cls, method_name) and method_name not in expected_keys:
is_property = isinstance(
getattr(TensorDictBase, method_name, None), property
)
setattr(
cls,
method_name,
_wrap_td_method(method_name, no_wrap=True, is_property=is_property),
)

for method_name in _FALLBACK_METHOD_FROM_TD_COPY:
if not hasattr(cls, method_name):
setattr(
Expand Down Expand Up @@ -949,20 +963,6 @@ def __torch_function__(
cls.to_tensordict = _to_tensordict
if not hasattr(cls, "device") and "device" not in expected_keys:
cls.device = property(_device, _device_setter)
if not hasattr(cls, "batch_size") and "batch_size" not in expected_keys:
cls.batch_size = property(_batch_size, _batch_size_setter)
if not hasattr(cls, "batch_dims") and "batch_dims" not in expected_keys:
cls.batch_dims = property(_batch_dims)
if not hasattr(cls, "requires_grad") and "requires_grad" not in expected_keys:
cls.requires_grad = property(_requires_grad)
if not hasattr(cls, "is_locked") and "is_locked" not in expected_keys:
cls.is_locked = property(_is_locked)
if not hasattr(cls, "ndim") and "ndim" not in expected_keys:
cls.ndim = property(_batch_dims)
if not hasattr(cls, "shape") and "shape" not in expected_keys:
cls.shape = property(_batch_size, _batch_size_setter)
if not hasattr(cls, "names") and "names" not in expected_keys:
cls.names = property(_names, _names_setter)
if not _is_non_tensor and not hasattr(cls, "data") and "data" not in expected_keys:
cls.data = property(_data, _data_setter)
if not hasattr(cls, "grad") and "grad" not in expected_keys:
Expand Down Expand Up @@ -1541,7 +1541,9 @@ def wrapper(self, key: str, value: Any) -> None: # noqa: D417
return wrapper


def _wrap_td_method(funcname, *, copy_non_tensor=False, no_wrap=False):
def _wrap_td_method(
funcname, *, copy_non_tensor=False, no_wrap=False, is_property=False
):
def deliver_result(self, result, kwargs):
if result is None:
return
Expand All @@ -1559,13 +1561,33 @@ def deliver_result(self, result, kwargs):
return self._from_tensordict(result, non_tensordict, safe=False)
return result

def wrapped_func(self, *args, **kwargs):
if not is_property:

def wrapped_func(self, *args, **kwargs):
if not is_compiling():
td = super(type(self), self).__getattribute__("_tensordict")
else:
td = self._tensordict

result = getattr(td, funcname)(*args, **kwargs)
if no_wrap:
return result

if result is td:
return self

if isinstance(result, tuple):
return tuple(deliver_result(self, r, kwargs) for r in result)
return deliver_result(self, result, kwargs)

return wrapped_func

def wrapped_func(self):
if not is_compiling():
td = super(type(self), self).__getattribute__("_tensordict")
else:
td = self._tensordict

result = getattr(td, funcname)(*args, **kwargs)
result = getattr(td, funcname)

if no_wrap:
return result
Expand All @@ -1574,10 +1596,17 @@ def wrapped_func(self, *args, **kwargs):
return self

if isinstance(result, tuple):
return tuple(deliver_result(self, r, kwargs) for r in result)
return deliver_result(self, result, kwargs)
return tuple(deliver_result(self, r, {}) for r in result)
return deliver_result(self, result, {})

return wrapped_func
def wrapped_func_setter(self, value):
if not is_compiling():
td = super(type(self), self).__getattribute__("_tensordict")
else:
td = self._tensordict
return setattr(td, funcname, value)

return property(wrapped_func, wrapped_func_setter)


def _wrap_method(self, attr, func, nowarn=False):
Expand Down Expand Up @@ -2244,48 +2273,6 @@ def _get_at(self, key: NestedKey, idx, default: Any = NO_DEFAULT):
return default


def _batch_size(self) -> torch.Size:
"""Retrieves the batch size for the tensor class.
Returns:
batch size (torch.Size)
"""
return self._tensordict.batch_size


def _batch_dims(self) -> torch.Size:
return self._tensordict.batch_dims


def _requires_grad(self) -> torch.Size:
return self._tensordict.requires_grad


def _is_locked(self) -> torch.Size:
return self._tensordict.is_locked


def _batch_size_setter(self, new_size: torch.Size) -> None: # noqa: D417
"""Set the value of batch_size.
Args:
new_size (torch.Size): new_batch size to be set
"""
self._tensordict._batch_size_setter(new_size)


def _names(self) -> torch.Size:
"""Retrieves the dim names for the tensor class.
Returns:
names (list of str)
"""
return self._tensordict.names


def _data(self):
# We allow data to be a field of the class too
if "data" in self.__dataclass_fields__:
Expand Down
3 changes: 3 additions & 0 deletions test/test_tensorclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,6 +424,9 @@ def test_batch_size(self):
myc.batch_size = torch.Size([2])

assert myc.batch_size == torch.Size([2])
myc.batch_size = [2]
assert isinstance(myc.batch_size, torch.Size)

assert myc.X.shape == torch.Size([2, 3, 4])

def test_cat(self):
Expand Down

0 comments on commit 7a558de

Please sign in to comment.