Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Refactor] Fix property handling in TC #1178

Merged
merged 2 commits into from
Jan 10, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
119 changes: 53 additions & 66 deletions tensordict/tensorclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,8 @@ def __subclasscheck__(self, subclass):
"_propagate_unlock",
"_reduce_get_metadata",
"_values_list",
"batch_dims",
"batch_size",
"bytes",
"cat_tensors",
"data_ptr",
Expand All @@ -168,6 +170,7 @@ def __subclasscheck__(self, subclass):
"is_cuda",
"is_empty",
"is_floating_point",
"is_locked",
"is_memmap",
"is_meta",
"is_shared",
Expand All @@ -176,20 +179,23 @@ def __subclasscheck__(self, subclass):
"keys",
"make_memmap",
"make_memmap_from_tensor",
"names",
"ndim",
"ndimension",
"numel",
"numpy",
"param_count",
"pop",
"recv",
"reduce",
"requires_grad",
"saved_path",
"send",
"shape",
"size",
"sorted_keys",
"to_struct_array",
"values",
# "ndim",
]

# Methods to be executed from tensordict, any ref to self means 'self._tensordict'
Expand Down Expand Up @@ -910,8 +916,16 @@ def __torch_function__(
for method_name in _FALLBACK_METHOD_FROM_TD_FORCE:
setattr(cls, method_name, _wrap_td_method(method_name))
for method_name in _FALLBACK_METHOD_FROM_TD_NOWRAP:
if not hasattr(cls, method_name):
setattr(cls, method_name, _wrap_td_method(method_name, no_wrap=True))
if not hasattr(cls, method_name) and method_name not in expected_keys:
is_property = isinstance(
getattr(TensorDictBase, method_name, None), property
)
setattr(
cls,
method_name,
_wrap_td_method(method_name, no_wrap=True, is_property=is_property),
)

for method_name in _FALLBACK_METHOD_FROM_TD_COPY:
if not hasattr(cls, method_name):
setattr(
Expand Down Expand Up @@ -949,20 +963,6 @@ def __torch_function__(
cls.to_tensordict = _to_tensordict
if not hasattr(cls, "device") and "device" not in expected_keys:
cls.device = property(_device, _device_setter)
if not hasattr(cls, "batch_size") and "batch_size" not in expected_keys:
cls.batch_size = property(_batch_size, _batch_size_setter)
if not hasattr(cls, "batch_dims") and "batch_dims" not in expected_keys:
cls.batch_dims = property(_batch_dims)
if not hasattr(cls, "requires_grad") and "requires_grad" not in expected_keys:
cls.requires_grad = property(_requires_grad)
if not hasattr(cls, "is_locked") and "is_locked" not in expected_keys:
cls.is_locked = property(_is_locked)
if not hasattr(cls, "ndim") and "ndim" not in expected_keys:
cls.ndim = property(_batch_dims)
if not hasattr(cls, "shape") and "shape" not in expected_keys:
cls.shape = property(_batch_size, _batch_size_setter)
if not hasattr(cls, "names") and "names" not in expected_keys:
cls.names = property(_names, _names_setter)
if not _is_non_tensor and not hasattr(cls, "data") and "data" not in expected_keys:
cls.data = property(_data, _data_setter)
if not hasattr(cls, "grad") and "grad" not in expected_keys:
Expand Down Expand Up @@ -1541,7 +1541,9 @@ def wrapper(self, key: str, value: Any) -> None: # noqa: D417
return wrapper


def _wrap_td_method(funcname, *, copy_non_tensor=False, no_wrap=False):
def _wrap_td_method(
funcname, *, copy_non_tensor=False, no_wrap=False, is_property=False
):
def deliver_result(self, result, kwargs):
if result is None:
return
Expand All @@ -1559,13 +1561,33 @@ def deliver_result(self, result, kwargs):
return self._from_tensordict(result, non_tensordict, safe=False)
return result

def wrapped_func(self, *args, **kwargs):
if not is_property:

def wrapped_func(self, *args, **kwargs):
if not is_compiling():
td = super(type(self), self).__getattribute__("_tensordict")
else:
td = self._tensordict

result = getattr(td, funcname)(*args, **kwargs)
if no_wrap:
return result

if result is td:
return self

if isinstance(result, tuple):
return tuple(deliver_result(self, r, kwargs) for r in result)
return deliver_result(self, result, kwargs)

return wrapped_func

def wrapped_func(self):
if not is_compiling():
td = super(type(self), self).__getattribute__("_tensordict")
else:
td = self._tensordict

result = getattr(td, funcname)(*args, **kwargs)
result = getattr(td, funcname)

if no_wrap:
return result
Expand All @@ -1574,10 +1596,17 @@ def wrapped_func(self, *args, **kwargs):
return self

if isinstance(result, tuple):
return tuple(deliver_result(self, r, kwargs) for r in result)
return deliver_result(self, result, kwargs)
return tuple(deliver_result(self, r, {}) for r in result)
return deliver_result(self, result, {})

return wrapped_func
def wrapped_func_setter(self, value):
if not is_compiling():
td = super(type(self), self).__getattribute__("_tensordict")
else:
td = self._tensordict
return setattr(td, funcname, value)

return property(wrapped_func, wrapped_func_setter)


def _wrap_method(self, attr, func, nowarn=False):
Expand Down Expand Up @@ -2244,48 +2273,6 @@ def _get_at(self, key: NestedKey, idx, default: Any = NO_DEFAULT):
return default


def _batch_size(self) -> torch.Size:
"""Retrieves the batch size for the tensor class.

Returns:
batch size (torch.Size)

"""
return self._tensordict.batch_size


def _batch_dims(self) -> torch.Size:
return self._tensordict.batch_dims


def _requires_grad(self) -> torch.Size:
return self._tensordict.requires_grad


def _is_locked(self) -> torch.Size:
return self._tensordict.is_locked


def _batch_size_setter(self, new_size: torch.Size) -> None: # noqa: D417
"""Set the value of batch_size.

Args:
new_size (torch.Size): new_batch size to be set

"""
self._tensordict._batch_size_setter(new_size)


def _names(self) -> torch.Size:
"""Retrieves the dim names for the tensor class.

Returns:
names (list of str)

"""
return self._tensordict.names


def _data(self):
# We allow data to be a field of the class too
if "data" in self.__dataclass_fields__:
Expand Down
3 changes: 3 additions & 0 deletions test/test_tensorclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,6 +424,9 @@ def test_batch_size(self):
myc.batch_size = torch.Size([2])

assert myc.batch_size == torch.Size([2])
myc.batch_size = [2]
assert isinstance(myc.batch_size, torch.Size)

assert myc.X.shape == torch.Size([2, 3, 4])

def test_cat(self):
Expand Down
Loading